diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000000000000000000000000000000000000..f6cb8ad931f5442a5e2276ed66ba7b5a733b82c6 --- /dev/null +++ b/.clang-format @@ -0,0 +1 @@ +BasedOnStyle: Google diff --git a/.compatibility b/.compatibility new file mode 100644 index 0000000000000000000000000000000000000000..c8ac4083d2a2b2985a02b6ec281a0e33ab4f23b2 --- /dev/null +++ b/.compatibility @@ -0,0 +1,3 @@ +1.12.0-11.3.0 +1.11.0-11.3.0 +1.10.1-11.3.0 diff --git a/.cuda_ext.json b/.cuda_ext.json new file mode 100644 index 0000000000000000000000000000000000000000..eba19cf05e313e9256c80c1c82a70faa40ef7b32 --- /dev/null +++ b/.cuda_ext.json @@ -0,0 +1,16 @@ +{ + "build": [ + { + "torch_command": "pip install torch==1.12.1+cu102 torchvision==0.13.1+cu102 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu102", + "cuda_image": "hpcaitech/cuda-conda:10.2" + }, + { + "torch_command": "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113", + "cuda_image": "hpcaitech/cuda-conda:11.3" + }, + { + "torch_command": "pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116", + "cuda_image": "hpcaitech/cuda-conda:11.6" + } + ] +} diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000000000000000000000000000000000000..229856aa4366c092e94eb573dbbd67756974a3ae --- /dev/null +++ b/.flake8 @@ -0,0 +1,22 @@ +[flake8] +ignore = + ;W503 line break before binary operator + W503, + ;E203 whitespace before ':' + E203, + +; exclude file +exclude = + .tox, + .git, + __pycache__, + build, + dist, + *.pyc, + *.egg-info, + .cache, + .eggs + +max-line-length = 120 + +per-file-ignores = __init__.py:F401 diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml new file mode 100644 index 0000000000000000000000000000000000000000..c09c1030843ba057b64a093b31ce6f70df323f68 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -0,0 +1,38 @@ +name: 🐛 Bug Report +description: Create a report to help us reproduce and fix the bug +title: "[BUG]: " +labels: [bug] + +body: +- type: markdown + attributes: + value: > + #### Not suitable for your needs? [Open a blank issue](https://github.com/hpcaitech/ColossalAI/issues/new). +- type: textarea + attributes: + label: 🐛 Describe the bug + description: | + **Describe the bug** + A clear and concise description of what the bug is. + **To Reproduce** + Steps or code snippet to reproduce the behavior. + **Expected behavior** + A clear and concise description of what you expected to happen. + **Screenshots** + If applicable, add screenshots to help explain your problem. + **Optional: Affiliation** + Institution/email information helps better analyze and evaluate users to improve the project. Welcome to establish in-depth cooperation. + placeholder: | + A clear and concise description of what the bug is. + validations: + required: true +- type: textarea + attributes: + label: Environment + description: | + Please provide the environment information, eg. CUDA/cuDNN/NCCL/Python/PyTorch version. + +- type: markdown + attributes: + value: > + Thanks for contributing 🎉! diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000000000000000000000000000000000000..b310fcfefc154e3c8b3e941a363181743c95b5ab --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,11 @@ +blank_issues_enabled: true +contact_links: + - name: ❓ Simple question - Slack Chat + url: https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w + about: This issue tracker is not for technical support. Please use our Slack chat, and ask the community for help. + - name: ❓ Simple question - WeChat + url: https://github.com/hpcaitech/ColossalAI/blob/main/docs/images/WeChat.png + about: This issue tracker is not for technical support. Please use WeChat, and ask the community for help. + - name: 😊 Advanced question - GitHub Discussions + url: https://github.com/hpcaitech/ColossalAI/discussions + about: Use GitHub Discussions for advanced and unanswered technical questions, requiring a maintainer's answer. diff --git a/.github/ISSUE_TEMPLATE/documentation.yml b/.github/ISSUE_TEMPLATE/documentation.yml new file mode 100644 index 0000000000000000000000000000000000000000..511997e2ee813d2bf84bf6fd7424ad580fc01bc3 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/documentation.yml @@ -0,0 +1,29 @@ +name: 📚 Documentation +description: Report an issue related to https://www.colossalai.org/ +title: "[DOC]: " +labels: [documentation] + +body: +- type: markdown + attributes: + value: > + #### Not suitable for your needs? [Open a blank issue](https://github.com/hpcaitech/ColossalAI/issues/new). +- type: textarea + attributes: + label: 📚 The doc issue + description: | + **Description** What content in [Documentation](https://www.colossalai.org/) is an issue? + **Location** Where is the issue location? + **Expectation** What is your expected content about it? + **Screenshots** If applicable, add screenshots to help explain your problem. + **Suggestions** Tell us how we could improve the documentation. + **Optional: Affiliation** Institution/email information helps better analyze and evaluate users to improve the project. Welcome to establish in-depth cooperation. + placeholder: | + A clear and concise description of the issue. + validations: + required: true + +- type: markdown + attributes: + value: > + Thanks for contributing 🎉! diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml new file mode 100644 index 0000000000000000000000000000000000000000..f12c41b52e6fe8f5b128959da3c807b6ccc7f4e9 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -0,0 +1,35 @@ +name: 🚀 Feature request +description: Suggest an idea for this project +title: "[FEATURE]: " +labels: [enhancement] + +body: +- type: markdown + attributes: + value: > + #### Not suitable for your needs? [Open a blank issue](https://github.com/hpcaitech/ColossalAI/issues/new). +- type: textarea + attributes: + label: Describe the feature + description: | + **Is your feature request related to a problem? Please describe.** + A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + **Describe the solution you'd like** + A clear and concise description of what you want to happen. + **Describe alternatives you've considered** + A clear and concise description of any alternative solutions or features you've considered. + **Screenshots** + If applicable, add screenshots to help explain your problem. + **Suggest a potential alternative/fix** + Tell us how we could improve this project. + **Optional: Affiliation** + Institution/email information helps better analyze and evaluate users to improve the project. Welcome to establish in-depth cooperation. + placeholder: | + A clear and concise description of your idea. + validations: + required: true + +- type: markdown + attributes: + value: > + Thanks for contributing 🎉! diff --git a/.github/ISSUE_TEMPLATE/proposal.yml b/.github/ISSUE_TEMPLATE/proposal.yml new file mode 100644 index 0000000000000000000000000000000000000000..614ef77751ad88f6dd55856b89fe91ffd1d294b8 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/proposal.yml @@ -0,0 +1,47 @@ +name: 💥 Proposal +description: Propose a non-trivial change to Colossal-AI +title: "[PROPOSAL]: " +labels: [enhancement] + +body: +- type: markdown + attributes: + value: | + Common reasons for proposals include: + + - Altering the infrastructure; + - Bumping a critical dependency's major version; + - A significant improvement in user-friendliness; + - Significant refactor; + - Optional: Affiliation/email information helps better analyze and evaluate users to improve the project. Welcome to establish in-depth cooperation. + - ... + + Please note this is not for feature request or bug template; such action could make us identify the issue wrongly and close it without doing anything. + + We give you maximum freedom to write an elaborated proposal illustrating why you think the change is beneficial for us, and what steps we should take to turn this into reality. + + +- type: textarea + attributes: + label: Proposal + description: A clear and concise description of what the proposal is. + validations: + required: true + +- type: checkboxes + attributes: + label: Self-service + description: | + If you feel like you could contribute to this issue, please check the box below. This would tell us and other people looking for contributions that someone's working on it. + If you do check this box, please send a pull request within 7 days after a maintainer's approval so we can still delegate this to someone else. + + Proposals usually involve significant code changes, so please reach consensus with the maintainers before rushing to implement it, and make sure you follow the [Contributing Guidelines](https://github.com/hpcaitech/ColossalAI/blob/main/CONTRIBUTING.md). + This ensures that you don't waste your time and we don't waste ours reading the large diffs. + options: + - label: I'd be willing to do some initial work on this proposal myself. + + +- type: markdown + attributes: + value: > + Thanks for contributing 🎉! diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 0000000000000000000000000000000000000000..0dc1262b008c70c4e8a76f5edfd95dc4624e9067 --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,36 @@ +## 📌 Checklist before creating the PR + +- [ ] I have created an issue for this PR for traceability +- [ ] The title follows the standard format: `[doc/gemini/tensor/...]: A concise description` +- [ ] I have added relevant tags if possible for us to better distinguish different PRs + + +## 🚨 Issue number + +> Link this PR to your issue with words like fixed to automatically close the linked issue upon merge +> +> e.g. `fixed #1234`, `closed #1234`, `resolved #1234` + + + +## 📝 What does this PR do? + +> Summarize your work here. +> if you have any plots/diagrams/screenshots/tables, please attach them here. + + + +## 💥 Checklist before requesting a review + +- [ ] I have linked my PR to an issue ([instruction](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue)) +- [ ] My issue clearly describes the problem/feature/proposal, with diagrams/charts/table/code if possible +- [ ] I have performed a self-review of my code +- [ ] I have added thorough tests. +- [ ] I have added docstrings for all the functions/methods I implemented + +## ⭐️ Do you enjoy contributing to Colossal-AI? + +- [ ] 🌝 Yes, I do. +- [ ] 🌚 No, I don't. + +Tell us more if you don't enjoy contributing to Colossal-AI. diff --git a/.github/workflows/README.md b/.github/workflows/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a46d8b1c24d05804320acf09eff25fd2fcab3fa9 --- /dev/null +++ b/.github/workflows/README.md @@ -0,0 +1,157 @@ +# CI/CD + +## Table of Contents + +- [CI/CD](#cicd) + - [Table of Contents](#table-of-contents) + - [Overview](#overview) + - [Workflows](#workflows) + - [Code Style Check](#code-style-check) + - [Unit Test](#unit-test) + - [Example Test](#example-test) + - [Example Test on Dispatch](#example-test-on-dispatch) + - [Compatibility Test](#compatibility-test) + - [Compatibility Test on Dispatch](#compatibility-test-on-dispatch) + - [Release](#release) + - [User Friendliness](#user-friendliness) + - [Commmunity](#commmunity) + - [Configuration](#configuration) + - [Progress Log](#progress-log) + +## Overview + +Automation makes our development more efficient as the machine automatically run the pre-defined tasks for the contributors. +This saves a lot of manual work and allow the developer to fully focus on the features and bug fixes. +In Colossal-AI, we use [GitHub Actions](https://github.com/features/actions) to automate a wide range of workflows to ensure the robustness of the software. +In the section below, we will dive into the details of different workflows available. + +## Workflows + +Refer to this [documentation](https://docs.github.com/en/actions/managing-workflow-runs/manually-running-a-workflow) on how to manually trigger a workflow. +I will provide the details of each workflow below. + +**A PR which changes the `version.txt` is considered as a release PR in the following context.** + + +### Code Style Check + +| Workflow Name | File name | Description | +| ------------- | ----------------- | -------------------------------------------------------------------------------------------------------------- | +| `post-commit` | `post_commit.yml` | This workflow runs pre-commit checks for changed files to achieve code style consistency after a PR is merged. | + +### Unit Test + +| Workflow Name | File name | Description | +| ---------------------- | -------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------- | +| `Build on PR` | `build_on_pr.yml` | This workflow is triggered when the label `Run build and Test` is assigned to a PR. It will run all the unit tests in the repository with 4 GPUs. | +| `Build on Schedule` | `build_on_schedule.yml` | This workflow will run the unit tests everyday with 8 GPUs. The result is sent to Lark. | +| `Report test coverage` | `report_test_coverage.yml` | This PR will put up a comment to report the test coverage results when `Build` is done. | + +### Example Test + +| Workflow Name | File name | Description | +| -------------------------- | ------------------------------- | ------------------------------------------------------------------------------ | +| `Test example on PR` | `example_check_on_pr.yml` | The example will be automatically tested if its files are changed in the PR | +| `Test example on Schedule` | `example_check_on_schedule.yml` | This workflow will test all examples every Sunday. The result is sent to Lark. | +| `Example Test on Dispatch` | `example_check_on_dispatch.yml` | Manually test a specified example. | + +#### Example Test on Dispatch + +This workflow is triggered by manually dispatching the workflow. It has the following input parameters: +- `example_directory`: the example directory to test. Multiple directories are supported and must be separated by comma. For example, language/gpt, images/vit. Simply input language or simply gpt does not work. + +### Compatibility Test + +| Workflow Name | File name | Description | +| -------------------------------- | ------------------------------------ | -------------------------------------------------------------------------------------------------------------------- | +| `Compatibility Test on PR` | `compatibility_test_on_pr.yml` | Check Colossal-AI's compatibility when `version.txt` is changed in a PR. | +| `Compatibility Test on Schedule` | `compatibility_test_on_schedule.yml` | This workflow will check the compatibility of Colossal-AI against PyTorch specified in `.compatibility` every Sunday. | +| `Compatibility Test on Dispatch` | `compatibility_test_on_dispatch.yml` | Test PyTorch Compatibility manually. | + + +#### Compatibility Test on Dispatch +This workflow is triggered by manually dispatching the workflow. It has the following input parameters: +- `torch version`:torch version to test against, multiple versions are supported but must be separated by comma. The default is value is all, which will test all available torch versions listed in this [repository](https://github.com/hpcaitech/public_assets/tree/main/colossalai/torch_build/torch_wheels). +- `cuda version`: cuda versions to test against, multiple versions are supported but must be separated by comma. The CUDA versions must be present in our [DockerHub repository](https://hub.docker.com/r/hpcaitech/cuda-conda). + +> It only test the compatibility of the main branch + + +### Release + +| Workflow Name | File name | Description | +| ----------------------------------------------- | ------------------------------------------- | ------------------------------------------------------------------------------------------------------------- | +| `Draft GitHub Release Post` | `draft_github_release_post_after_merge.yml` | Compose a GitHub release post draft based on the commit history when a release PR is merged. | +| `Publish to PyPI` | `release_pypi_after_merge.yml` | Build and release the wheel to PyPI when a release PR is merged. The result is sent to Lark. | +| `Publish Nightly Version to PyPI` | `release_nightly_on_schedule.yml` | Build and release the nightly wheel to PyPI as `colossalai-nightly` every Sunday. The result is sent to Lark. | +| `Publish Docker Image to DockerHub after Merge` | `release_docker_after_merge.yml` | Build and release the Docker image to DockerHub when a release PR is merged. The result is sent to Lark. | +| `Check CUDA Extension Build Before Merge` | `cuda_ext_check_before_merge.yml` | Build CUDA extensions with different CUDA versions when a release PR is created. | +| `Publish to Test-PyPI Before Merge` | `release_test_pypi_before_merge.yml` | Release to test-pypi to simulate user installation when a release PR is created. | + + +### User Friendliness + +| Workflow Name | File name | Description | +| ----------------------- | ----------------------- | -------------------------------------------------------------------------------------------------------------------------------------- | +| `issue-translate` | `translate_comment.yml` | This workflow is triggered when a new issue comment is created. The comment will be translated into English if not written in English. | +| `Synchronize submodule` | `submodule.yml` | This workflow will check if any git submodule is updated. If so, it will create a PR to update the submodule pointers. | +| `Close inactive issues` | `close_inactive.yml` | This workflow will close issues which are stale for 14 days. | + +### Commmunity + +| Workflow Name | File name | Description | +| -------------------------------------------- | -------------------------------- | -------------------------------------------------------------------------------- | +| `Generate Community Report and Send to Lark` | `report_leaderboard_to_lark.yml` | Collect contribution and user engagement stats and share with Lark every Friday. | + +## Configuration + +This section lists the files used to configure the workflow. + +1. `.compatibility` + +This `.compatibility` file is to tell GitHub Actions which PyTorch and CUDA versions to test against. Each line in the file is in the format `${torch-version}-${cuda-version}`, which is a tag for Docker image. Thus, this tag must be present in the [docker registry](https://hub.docker.com/r/pytorch/conda-cuda) so as to perform the test. + +2. `.cuda_ext.json` + +This file controls which CUDA versions will be checked against CUDA extension built. You can add a new entry according to the json schema below to check the AOT build of PyTorch extensions before release. + +```json +{ + "build": [ + { + "torch_command": "", + "cuda_image": "" + }, + ] +} +``` + +## Progress Log + +- [x] Code style check + - [x] post-commit check +- [x] unit testing + - [x] test on PR + - [x] report test coverage + - [x] regular test +- [x] release + - [x] pypi release + - [x] test-pypi simulation + - [x] nightly build + - [x] docker build + - [x] draft release post +- [x] example check + - [x] check on PR + - [x] regular check + - [x] manual dispatch +- [x] compatibility check + - [x] check on PR + - [x] manual dispatch + - [x] auto test when release +- [x] community + - [x] contribution report + - [x] user engagement report +- [x] helpers + - [x] comment translation + - [x] submodule update + - [x] close inactive issue diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml new file mode 100644 index 0000000000000000000000000000000000000000..e6febeeb4d87256efc3219b18b9b53875e54aa21 --- /dev/null +++ b/.github/workflows/build_on_pr.yml @@ -0,0 +1,167 @@ +name: Build on PR + +on: + pull_request: + types: [synchronize, labeled] + +jobs: + detect: + name: Detect file change + if: | + github.event.pull_request.draft == false && + github.base_ref == 'main' && + github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && + contains( github.event.pull_request.labels.*.name, 'Run Build and Test') + outputs: + changedExtenisonFiles: ${{ steps.find-extension-change.outputs.all_changed_files }} + anyExtensionFileChanged: ${{ steps.find-extension-change.outputs.any_changed }} + changedLibraryFiles: ${{ steps.find-lib-change.outputs.all_changed_files }} + anyLibraryFileChanged: ${{ steps.find-lib-change.outputs.any_changed }} + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + ref: ${{ github.event.pull_request.head.sha }} + + - name: Locate base commit + id: locate-base-sha + run: | + curBranch=$(git rev-parse --abbrev-ref HEAD) + commonCommit=$(git merge-base origin/main $curBranch) + echo $commonCommit + echo "baseSHA=$commonCommit" >> $GITHUB_OUTPUT + + - name: Find the changed extension-related files + id: find-extension-change + uses: tj-actions/changed-files@v35 + with: + base_sha: ${{ steps.locate-base-sha.outputs.baseSHA }} + files: | + op_builder/** + colossalai/kernel/** + setup.py + + - name: Find the changed library-related files + id: find-lib-change + uses: tj-actions/changed-files@v35 + with: + base_sha: ${{ steps.locate-base-sha.outputs.baseSHA }} + files: | + **/*.py + **/*.h + **/*.cpp + **/*.cu + **/*.txt + + - name: List changed files + run: | + for file in ${{ steps.find-extension-change.outputs.all_changed_files }}; do + echo "$file was changed" + done + for file in ${{ steps.find-lib-change.outputs.all_changed_files }}; do + echo "$file was changed" + done + + build: + name: Build and Test Colossal-AI + needs: detect + runs-on: [self-hosted, gpu] + container: + image: hpcaitech/pytorch-cuda:1.11.0-11.3.0 + options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 + timeout-minutes: 40 + defaults: + run: + shell: bash + steps: + - name: Checkout TensorNVMe + uses: actions/checkout@v2 + with: + repository: hpcaitech/TensorNVMe + ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} + path: TensorNVMe + + - name: Restore TensorNVMe Cache + run: | + [ ! -z "$(ls -A /github/home/tensornvme_cache/)" ] && cp -p -r /github/home/tensornvme_cache/* /__w/ColossalAI/ColossalAI/TensorNVMe + + - name: Install TensorNVMe + run: | + cd TensorNVMe + conda install cmake + pip install -r requirements.txt + pip install -v . + + - name: Store TensorNVMe Cache + run: | + cd TensorNVMe + cp -p -r ./build /github/home/tensornvme_cache/ + + - name: Checkout Colossal-AI + uses: actions/checkout@v2 + with: + ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} + + - name: Restore Colossal-AI Cache + if: needs.detect.outputs.anyExtensionFileChanged != 'true' + run: | + # -p flag is required to preserve the file timestamp to avoid ninja rebuild + [ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -p -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/ + + - name: Install Colossal-AI + if: needs.detect.outputs.anyLibraryFileChanged == 'true' + run: | + CUDA_EXT=1 pip install -v -e . + pip install -r requirements/requirements-test.txt + + - name: Store Colossal-AI Cache + run: | + # -p flag is required to preserve the file timestamp to avoid ninja rebuild + cp -p -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/ + + - name: Execute Unit Testing + if: needs.detect.outputs.anyLibraryFileChanged == 'true' + run: | + CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest --cov=. --cov-report xml tests/ + env: + DATA: /data/scratch/cifar-10 + NCCL_SHM_DISABLE: 1 + LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 + + - name: Collate artifact + env: + PR_NUMBER: ${{ github.event.number }} + changedLibraryFiles: ${{ needs.detect.outputs.changedLibraryFiles }} + anyLibraryFileChanged: ${{ needs.detect.outputs.anyLibraryFileChanged }} + changedExtenisonFiles: ${{ needs.detect.outputs.changedExtenisonFiles }} + run: | + mkdir report + echo $PR_NUMBER > ./report/pr_number + + # generate coverage.xml if any + if [ "$anyLibraryFileChanged" == "true" ]; then + allFiles="" + for file in $changedLibraryFiles; do + if [ "$allFiles" == "" ]; then + allFiles=$file + else + allFiles=$allFiles,$file + fi + done + + coverage report --data-file .coverage --include $allFiles > ./coverage.txt + + covPercentage=$(tail -n 1 coverage.txt | grep -o '[1-9]*%$') + covNum=${covPercentage::-1} + mv coverage.txt ./report + echo $covNum > ./report/cov_number + else + echo "No coverage report is generated" + fi + + - name: Upload test coverage artifact + uses: actions/upload-artifact@v3 + with: + name: report + path: report/ diff --git a/.github/workflows/build_on_schedule.yml b/.github/workflows/build_on_schedule.yml new file mode 100644 index 0000000000000000000000000000000000000000..6afdf581e6ca42e078118ef238e3d4199de9aa8f --- /dev/null +++ b/.github/workflows/build_on_schedule.yml @@ -0,0 +1,80 @@ +name: Build on Schedule + +on: + schedule: + # run at 00:00 of every Sunday + - cron: '0 0 * * *' + workflow_dispatch: + +jobs: + build: + name: Build and Test Colossal-AI + if: github.repository == 'hpcaitech/ColossalAI' + runs-on: [self-hosted, 8-gpu] + container: + image: hpcaitech/pytorch-cuda:1.11.0-11.3.0 + options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 + timeout-minutes: 40 + steps: + - name: Check GPU Availability # ensure all GPUs have enough memory + id: check-avai + run: | + avai=true + for i in $(seq 0 7); + do + gpu_used=$(nvidia-smi -i $i --query-gpu=memory.used --format=csv,noheader,nounits) + [ "$gpu_used" -gt "10000" ] && avai=false + done + + echo "GPU is available: $avai" + echo "avai=$avai" >> $GITHUB_OUTPUT + + - uses: actions/checkout@v2 + if: steps.check-avai.outputs.avai == 'true' + with: + repository: hpcaitech/TensorNVMe + ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} + path: TensorNVMe + + - name: Install tensornvme + if: steps.check-avai.outputs.avai == 'true' + run: | + cd TensorNVMe + conda install cmake + pip install -r requirements.txt + pip install -v . + + - uses: actions/checkout@v2 + if: steps.check-avai.outputs.avai == 'true' + with: + ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} + + - name: Install Colossal-AI + if: steps.check-avai.outputs.avai == 'true' + run: | + [ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/ + CUDA_EXT=1 pip install -v -e . + cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/ + pip install -r requirements/requirements-test.txt + + - name: Unit Testing + if: steps.check-avai.outputs.avai == 'true' + run: | + PYTHONPATH=$PWD pytest tests + env: + DATA: /data/scratch/cifar-10 + LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 + + - name: Notify Lark + id: message-preparation + if: ${{ failure() }} + run: | + url=$SERVER_URL/$REPO/actions/runs/$RUN_ID + msg="Scheduled Build and Test failed on 8 GPUs, please visit $url for details" + echo $msg + python .github/workflows/scripts/send_message_to_lark.py -m "$msg" -u $WEBHOOK_URL + env: + SERVER_URL: ${{github.server_url }} + REPO: ${{ github.repository }} + RUN_ID: ${{ github.run_id }} + WEBHOOK_URL: ${{ secrets.LARK_NOTIFICATION_WEBHOOK_URL }} diff --git a/.github/workflows/close_inactive.yml b/.github/workflows/close_inactive.yml new file mode 100644 index 0000000000000000000000000000000000000000..e7dec44309303f4602cab8b3205f25cabb07cf30 --- /dev/null +++ b/.github/workflows/close_inactive.yml @@ -0,0 +1,26 @@ +name: Close inactive issues + +on: + schedule: + - cron: "0 0 * * *" + +jobs: + close-issues: + if: github.event.pull_request.draft == false && github.base_ref == 'main' && github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' + runs-on: ubuntu-latest + permissions: + issues: write + pull-requests: write + steps: + - uses: actions/stale@v3 + with: + days-before-issue-stale: 14 + days-before-issue-close: -1 + stale-issue-label: "stale" + stale-issue-message: "This issue is stale because it has been open for 14 days with no activity." +# close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale." + days-before-pr-stale: 14 + days-before-pr-close: -1 + stale-pr-message: "This PR is stale because it has been open for 14 days with no activity." +# close-pr-message: "This PR was closed because it has been inactive for 14 days since being marked as stale." + repo-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/compatiblity_test_on_dispatch.yml b/.github/workflows/compatiblity_test_on_dispatch.yml new file mode 100644 index 0000000000000000000000000000000000000000..717cf729b3f3fb13d29daa075129c6c653808eb4 --- /dev/null +++ b/.github/workflows/compatiblity_test_on_dispatch.yml @@ -0,0 +1,84 @@ +name: Compatibility Test on Dispatch + +on: + workflow_dispatch: + inputs: + torch_version: + type: string + description: torch version, separated by comma + required: true + cuda_version: + type: string + description: cuda version, separated by comma + required: true + +jobs: + matrix_preparation: + name: Prepare Container List + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - id: set-matrix + env: + TORCH_VERSIONS: ${{ inputs.torch_version }} + CUDA_VERSIONS: ${{ inputs.cuda_version }} + run: | + IFS=',' + DOCKER_IMAGE=() + + for tv in $TORCH_VERSIONS + do + for cv in $CUDA_VERSIONS + do + DOCKER_IMAGE+=("\"hpcaitech/pytorch-cuda:${tv}-${cv}\"") + done + done + + container=$( IFS=',' ; echo "${DOCKER_IMAGE[*]}" ) + container="[${container}]" + echo "$container" + echo "::set-output name=matrix::{\"container\":$(echo "$container")}" + + build: + name: Test for PyTorch Compatibility + needs: matrix_preparation + if: github.repository == 'hpcaitech/ColossalAI' + runs-on: [self-hosted, gpu] + strategy: + fail-fast: false + matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} + container: + image: ${{ matrix.container }} + options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 + timeout-minutes: 120 + steps: + - name: Install dependencies + run: | + pip install -U pip setuptools wheel --user + - uses: actions/checkout@v2 + with: + repository: hpcaitech/TensorNVMe + ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} + path: TensorNVMe + - name: Install tensornvme + run: | + cd TensorNVMe + conda install cmake + pip install -r requirements.txt + pip install -v . + - uses: actions/checkout@v2 + with: + ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} + - name: Install Colossal-AI + run: | + pip install -r requirements/requirements.txt + pip install -v --no-cache-dir . + pip install -r requirements/requirements-test.txt + - name: Unit Testing + run: | + PYTHONPATH=$PWD pytest tests + env: + DATA: /data/scratch/cifar-10 + NCCL_SHM_DISABLE: 1 + LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 diff --git a/.github/workflows/compatiblity_test_on_pr.yml b/.github/workflows/compatiblity_test_on_pr.yml new file mode 100644 index 0000000000000000000000000000000000000000..2fca67b820a1d7cf411f1ee9b854f93a150c639d --- /dev/null +++ b/.github/workflows/compatiblity_test_on_pr.yml @@ -0,0 +1,71 @@ +name: Compatibility Test on PR + +on: + pull_request: + paths: + - 'version.txt' + - '.compatibility' + +jobs: + matrix_preparation: + name: Prepare Container List + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - uses: actions/checkout@v3 + - id: set-matrix + run: | + IFS=',' + DOCKER_IMAGE=() + + while read tag; do + DOCKER_IMAGE+=("\"hpcaitech/pytorch-cuda:${tag}\"") + done <.compatibility + + container=$( IFS=',' ; echo "${DOCKER_IMAGE[*]}" ) + container="[${container}]" + echo "$container" + echo "::set-output name=matrix::{\"container\":$(echo "$container")}" + + build: + name: Test for PyTorch Compatibility + needs: matrix_preparation + if: github.repository == 'hpcaitech/ColossalAI' + runs-on: [self-hosted, gpu] + strategy: + fail-fast: false + matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} + container: + image: ${{ matrix.container }} + options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 + timeout-minutes: 120 + steps: + - name: Install dependencies + run: | + pip install -U pip setuptools wheel --user + - uses: actions/checkout@v2 + with: + repository: hpcaitech/TensorNVMe + ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} + path: TensorNVMe + - name: Install tensornvme + run: | + cd TensorNVMe + conda install cmake + pip install -r requirements.txt + pip install -v . + - uses: actions/checkout@v2 + with: + ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} + - name: Install Colossal-AI + run: | + pip install -v --no-cache-dir . + pip install -r requirements/requirements-test.txt + - name: Unit Testing + run: | + PYTHONPATH=$PWD pytest tests + env: + DATA: /data/scratch/cifar-10 + NCCL_SHM_DISABLE: 1 + LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 diff --git a/.github/workflows/compatiblity_test_on_schedule.yml b/.github/workflows/compatiblity_test_on_schedule.yml new file mode 100644 index 0000000000000000000000000000000000000000..9802795fad246864db5c7a86d9f4273f2ce99ff9 --- /dev/null +++ b/.github/workflows/compatiblity_test_on_schedule.yml @@ -0,0 +1,90 @@ +name: Compatibility Test on Schedule + +on: + # run at 03:00 of every Sunday(singapore time) so here is UTC time Saturday 16:00 + schedule: + - cron: '0 19 * * 6' + workflow_dispatch: + +jobs: + matrix_preparation: + name: Prepare Container List + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - uses: actions/checkout@v3 + - id: set-matrix + run: | + IFS=',' + DOCKER_IMAGE=() + + while read tag; do + DOCKER_IMAGE+=("\"hpcaitech/pytorch-cuda:${tag}\"") + done <.compatibility + + container=$( IFS=',' ; echo "${DOCKER_IMAGE[*]}" ) + container="[${container}]" + echo "$container" + echo "::set-output name=matrix::{\"container\":$(echo "$container")}" + + build: + name: Test for PyTorch Compatibility + needs: matrix_preparation + if: github.repository == 'hpcaitech/ColossalAI' + runs-on: [self-hosted, gpu] + strategy: + fail-fast: false + matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} + container: + image: ${{ matrix.container }} + options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 + timeout-minutes: 120 + steps: + - name: Install dependencies + run: | + pip install -U pip setuptools wheel --user + + - uses: actions/checkout@v2 + with: + repository: hpcaitech/TensorNVMe + ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} + path: TensorNVMe + + - name: Install tensornvme + run: | + cd TensorNVMe + conda install cmake + pip install -r requirements.txt + pip install -v . + - uses: actions/checkout@v2 + with: + ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} + + - name: Install Colossal-AI + run: | + pip install -v --no-cache-dir . + pip install -r requirements/requirements-test.txt + + - name: Unit Testing + run: | + PYTHONPATH=$PWD pytest tests + env: + DATA: /data/scratch/cifar-10 + NCCL_SHM_DISABLE: 1 + LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 + + - name: Notify Lark + id: message-preparation + if: ${{ failure() }} + run: | + url=$SERVER_URL/$REPO/actions/runs/$RUN_ID + msg="Compatibility test failed with $container, please visit $url for details" + echo $msg + python .github/workflows/scripts/send_message_to_lark.py -m "$msg" -u $WEBHOOK_URL + env: + SERVER_URL: ${{github.server_url }} + REPO: ${{ github.repository }} + RUN_ID: ${{ github.run_id }} + WEBHOOK_URL: ${{ secrets.LARK_NOTIFICATION_WEBHOOK_URL }} + container: ${{ matrix.container }} diff --git a/.github/workflows/cuda_ext_check_before_merge.yml b/.github/workflows/cuda_ext_check_before_merge.yml new file mode 100644 index 0000000000000000000000000000000000000000..eba5bb98ec07998314708d1b8a159de0d574fcf2 --- /dev/null +++ b/.github/workflows/cuda_ext_check_before_merge.yml @@ -0,0 +1,42 @@ +name: Check CUDA Extension Build Before Merge + +on: + workflow_dispatch: + pull_request: + paths: + - 'version.txt' + +jobs: + matrix_preparation: + name: Prepare Container List + if: github.repository == 'hpcaitech/ColossalAI' + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - uses: actions/checkout@v3 + + - id: set-matrix + run: | + cuda_ext=$(cat .cuda_ext.json | tr '\n' ' ') + echo "matrix=${cuda_ext}" >> $GITHUB_OUTPUT + + build: + name: Release bdist wheels + needs: matrix_preparation + runs-on: [self-hosted, gpu] + strategy: + fail-fast: false + matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} + container: + image: ${{ matrix.build.cuda_image }} + options: --gpus all --rm + steps: + - uses: actions/checkout@v2 + + - name: Install PyTorch + run: eval ${{ matrix.build.torch_command }} + + - name: Build + run: | + CUDA_EXT=1 pip install -v . diff --git a/.github/workflows/doc_build_after_merge.yml b/.github/workflows/doc_build_after_merge.yml new file mode 100644 index 0000000000000000000000000000000000000000..ede04b336620870e3739458e68bfd31afb73ce8c --- /dev/null +++ b/.github/workflows/doc_build_after_merge.yml @@ -0,0 +1,28 @@ +name: Build Documentation After Merge + +on: + workflow_dispatch: + pull_request: + paths: + - 'version.txt' + - 'docs/**' + types: + - closed + +jobs: + build-doc: + name: Trigger Documentation Build Workflow + if: ( github.event_name == 'workflow_dispatch' || github.event.pull_request.merged == true ) && github.repository == 'hpcaitech/ColossalAI' + runs-on: ubuntu-latest + steps: + - name: trigger workflow in ColossalAI-Documentation + run: | + curl \ + -X POST \ + -H "Accept: application/vnd.github+json" \ + -H "Authorization: Bearer ${GH_TOKEN}"\ + -H "X-GitHub-Api-Version: 2022-11-28" \ + https://api.github.com/repos/hpcaitech/ColossalAI-Documentation/actions/workflows/deploy.yml/dispatches \ + -d '{"ref":"main"}' + env: + GH_TOKEN: ${{secrets.DOC_REPO_TOKEN}} diff --git a/.github/workflows/doc_check_on_pr.yml b/.github/workflows/doc_check_on_pr.yml new file mode 100644 index 0000000000000000000000000000000000000000..2022c957fba837904ba2a2efe678edd1686eece8 --- /dev/null +++ b/.github/workflows/doc_check_on_pr.yml @@ -0,0 +1,65 @@ +name: Check Documentation on PR + +on: + pull_request: + paths: + - 'docs/**' + +jobs: + check-i18n: + name: Check docs in diff languages + if: | + github.event.pull_request.draft == false && + github.base_ref == 'main' && + github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + + - uses: actions/setup-python@v2 + with: + python-version: '3.8.14' + + - run: python .github/workflows/scripts/check_doc_i18n.py -d docs/source + + check-doc-build: + name: Test if the docs can be built + if: | + github.event.pull_request.draft == false && + github.base_ref == 'main' && + github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + with: + path: './ColossalAI' + fetch-depth: 0 + + - uses: actions/checkout@v2 + with: + path: './ColossalAI-Documentation' + repository: 'hpcaitech/ColossalAI-Documentation' + + - uses: actions/setup-python@v2 + with: + python-version: '3.8.14' + + # we use the versions in the main branch as the guide for versions to display + # checkout will give your merged branch + # therefore, we need to make the merged branch as the main branch + - name: Make the merged branch main + run: | + cd ColossalAI + curBranch=$(git rev-parse --abbrev-ref HEAD) + git checkout main + git merge $curBranch # fast-forward master up to the merge + + - name: Build docs + run: | + cache_dir=ColossalAI-Documentation/doc-build/.cache + mkdir $cache_dir + mv ColossalAI $cache_dir + cd ColossalAI-Documentation + pip install -v ./doc-build/third_party/hf-doc-builder + pip install -v ./doc-build + bash ./scripts/build.sh diff --git a/.github/workflows/doc_test_on_pr.yml b/.github/workflows/doc_test_on_pr.yml new file mode 100644 index 0000000000000000000000000000000000000000..fbe669582c2088c17678a763ceea9bfe78738108 --- /dev/null +++ b/.github/workflows/doc_test_on_pr.yml @@ -0,0 +1,94 @@ +name: Test Documentation on PR +on: + pull_request: + # any change in the examples folder will trigger check for the corresponding example. + paths: + - 'docs/source/**.md' + +jobs: + # This is for changed example files detect and output a matrix containing all the corresponding directory name. + detect-changed-doc: + if: | + github.event.pull_request.draft == false && + github.base_ref == 'main' && + github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'pull_request' + runs-on: ubuntu-latest + outputs: + any_changed: ${{ steps.changed-files.outputs.any_changed }} + changed_files: ${{ steps.changed-files.outputs.all_changed_files }} + name: Detect changed example files + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + ref: ${{ github.event.pull_request.head.sha }} + + - name: Locate base commit + id: locate-base-sha + run: | + curBranch=$(git rev-parse --abbrev-ref HEAD) + commonCommit=$(git merge-base origin/main $curBranch) + echo $commonCommit + echo "baseSHA=$commonCommit" >> $GITHUB_OUTPUT + + - name: Get all changed example files + id: changed-files + uses: tj-actions/changed-files@v35 + with: + base_sha: ${{ steps.locate-base-sha.outputs.baseSHA }} + files: | + ./docs/source/**/*.md + + # If no file is changed, it will prompt an error and shows the matrix do not have value. + check-changed-doc: + # Add this condition to avoid executing this job if the trigger event is workflow_dispatch. + if: | + github.event.pull_request.draft == false && + github.base_ref == 'main' && + github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'pull_request' && + needs.detect-changed-doc.outputs.any_changed == 'true' + name: Test the changed Doc + needs: detect-changed-doc + runs-on: [self-hosted, gpu] + container: + image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 + options: --gpus all --rm + timeout-minutes: 20 + defaults: + run: + shell: bash + steps: + - name: Checkout ColossalAI-Documentation + uses: actions/checkout@v2 + with: + path: './ColossalAI-Documentation' + repository: 'hpcaitech/ColossalAI-Documentation' + + - name: Install Docer + run: | + pip install -v ./ColossalAI-Documentation/doc-build/third_party/hf-doc-builder + pip install -v ./ColossalAI-Documentation/doc-build + + - name: Checkout ColossalAI + uses: actions/checkout@v3 + + - name: Install Doc Test Requirements + run: | + source activate pytorch + conda env update --file docs/conda-doc-test-deps.yml --prune + pip install -r docs/requirements-doc-test.txt + + - name: Install ColossalAI + run: | + source activate pytorch + pip install -v . + + - name: Test the Doc + run: | + source activate pytorch + for file in ${{ steps.changed-files.outputs.all_changed_files }}; do + echo "Testing $file now..." + docer test -p $file + done + env: + NCCL_SHM_DISABLE: 1 diff --git a/.github/workflows/doc_test_on_schedule.yml b/.github/workflows/doc_test_on_schedule.yml new file mode 100644 index 0000000000000000000000000000000000000000..6b4f5d1f908c608a8ba266725c521369a9fb0047 --- /dev/null +++ b/.github/workflows/doc_test_on_schedule.yml @@ -0,0 +1,47 @@ +name: Test Documentation on Schedule +on: + # run at 07:00 of every Sunday(singapore time) so here is UTC time Saturday 23:00 + schedule: + - cron: '0 23 * * 6' + workflow_dispatch: + +jobs: + check-changed-doc: + # Add this condition to avoid executing this job if the trigger event is workflow_dispatch. + if: github.repository == 'hpcaitech/ColossalAI' + name: Test the changed Doc + runs-on: [self-hosted, gpu] + container: + image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 + options: --gpus all --rm + timeout-minutes: 60 + steps: + - name: Checkout ColossalAI-Documentation + uses: actions/checkout@v2 + with: + path: './ColossalAI-Documentation' + repository: 'hpcaitech/ColossalAI-Documentation' + + - name: Install Docer + run: | + pip install -v ./ColossalAI-Documentation/doc-build/third_party/hf-doc-builder + pip install -v ./ColossalAI-Documentation/doc-build + + - name: Checkout ColossalAI + uses: actions/checkout@v3 + + - name: Install ColossalAI + run: | + pip install -v . + + - name: Install Doc Test Requirements + run: | + pip install -r docs/requirements-doc-test.txt + + - name: Test the Doc + run: | + for file in $(find ./docs/source -name "*.md"); do + docer test -p $file + done + env: + NCCL_SHM_DISABLE: 1 diff --git a/.github/workflows/draft_github_release_post_after_merge.yml b/.github/workflows/draft_github_release_post_after_merge.yml new file mode 100644 index 0000000000000000000000000000000000000000..53bfa9e8deb66636ad1d04434c4715c6ccfa5dfb --- /dev/null +++ b/.github/workflows/draft_github_release_post_after_merge.yml @@ -0,0 +1,43 @@ +name: Draft GitHub Release Post + +on: + workflow_dispatch: + pull_request: + paths: + - 'version.txt' + types: + - closed + +jobs: + release: + name: Draft Release Post + if: ( github.event_name == 'workflow_dispatch' || github.event.pull_request.merged == true ) && github.repository == 'hpcaitech/ColossalAI' + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + - uses: actions/setup-python@v2 + with: + python-version: '3.8.14' + - name: generate draft + id: generate_draft + run: | + version=v$(cat version.txt) + pip install requests + python ./.github/workflows/scripts/generate_release_draft.py --out $PWD/release_draft.md --version $version + echo "::set-output name=version::$version" + echo "::set-output name=path::$PWD/release_draft.md" + env: + GITHUB_API_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Create Release + id: create_release + uses: actions/create-release@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + tag_name: ${{ steps.generate_draft.outputs.version }} + release_name: Version ${{ steps.generate_draft.outputs.version }} Release Today! + body_path: ${{ steps.generate_draft.outputs.path }} + draft: True + prerelease: false diff --git a/.github/workflows/example_check_on_dispatch.yml b/.github/workflows/example_check_on_dispatch.yml new file mode 100644 index 0000000000000000000000000000000000000000..620d4771af55f53972888013c67347f3f4a392bd --- /dev/null +++ b/.github/workflows/example_check_on_dispatch.yml @@ -0,0 +1,64 @@ +name: Test Example on Dispatch +on: + workflow_dispatch: + inputs: + example_directory: + type: string + description: example directory, separated by space. For example, language/gpt, images/vit. Simply input language or simply gpt does not work. + required: true + +jobs: + matrix_preparation: + if: | + github.event.pull_request.draft == false && + github.base_ref == 'main' && + github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' + name: Check the examples user want + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - name: 📚 Checkout + uses: actions/checkout@v3 + - name: Set up matrix + id: set-matrix + env: + check_dir: ${{ inputs.example_directory }} + run: | + res=`python .github/workflows/scripts/example_checks/check_dispatch_inputs.py --fileNameList $check_dir` + if [ res == "failure" ];then + exit -1 + fi + dirs="[${check_dir}]" + echo "Testing examples in $dirs" + echo "matrix={\"directory\":$(echo "$dirs")}" >> $GITHUB_OUTPUT + + test_example: + if: | + github.event.pull_request.draft == false && + github.base_ref == 'main' && + github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' + name: Manually check example files + needs: manual_check_matrix_preparation + runs-on: [self-hosted, gpu] + strategy: + fail-fast: false + matrix: ${{fromJson(needs.manual_check_matrix_preparation.outputs.matrix)}} + container: + image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 + options: --gpus all --rm -v /data/scratch/examples-data:/data/ + timeout-minutes: 10 + steps: + - name: 📚 Checkout + uses: actions/checkout@v3 + - name: Install Colossal-AI + run: | + pip install -v . + - name: Test the example + run: | + dir=${{ matrix.directory }} + echo "Testing ${dir} now" + cd "${PWD}/examples/${dir}" + bash test_ci.sh + env: + NCCL_SHM_DISABLE: 1 diff --git a/.github/workflows/example_check_on_pr.yml b/.github/workflows/example_check_on_pr.yml new file mode 100644 index 0000000000000000000000000000000000000000..b22664ee47ccbf8ea786b526964694933a51ab5f --- /dev/null +++ b/.github/workflows/example_check_on_pr.yml @@ -0,0 +1,91 @@ +name: Test Example on PR +on: + pull_request: + # any change in the examples folder will trigger check for the corresponding example. + paths: + - 'examples/**' + +jobs: + # This is for changed example files detect and output a matrix containing all the corresponding directory name. + detect-changed-example: + if: | + github.event.pull_request.draft == false && + github.base_ref == 'main' && + github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'pull_request' + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.setup-matrix.outputs.matrix }} + anyChanged: ${{ steps.setup-matrix.outputs.anyChanged }} + name: Detect changed example files + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + ref: ${{ github.event.pull_request.head.sha }} + + - name: Locate base commit + id: locate-base-sha + run: | + curBranch=$(git rev-parse --abbrev-ref HEAD) + commonCommit=$(git merge-base origin/main $curBranch) + echo $commonCommit + echo "baseSHA=$commonCommit" >> $GITHUB_OUTPUT + + - name: Get all changed example files + id: changed-files + uses: tj-actions/changed-files@v35 + with: + base_sha: ${{ steps.locate-base-sha.outputs.baseSHA }} + + - name: setup matrix + id: setup-matrix + run: | + changedFileName="" + for file in ${{ steps.changed-files.outputs.all_changed_files }}; do + changedFileName="${file}:${changedFileName}" + done + echo "$changedFileName was changed" + res=`python .github/workflows/scripts/example_checks/detect_changed_example.py --fileNameList $changedFileName` + echo "All changed examples are $res" + + if [ "$res" == "[]" ]; then + echo "anyChanged=false" >> $GITHUB_OUTPUT + echo "matrix=null" >> $GITHUB_OUTPUT + else + dirs=$( IFS=',' ; echo "${res[*]}" ) + echo "anyChanged=true" >> $GITHUB_OUTPUT + echo "matrix={\"directory\":$(echo "$dirs")}" >> $GITHUB_OUTPUT + fi + + # If no file is changed, it will prompt an error and shows the matrix do not have value. + check-changed-example: + # Add this condition to avoid executing this job if the trigger event is workflow_dispatch. + if: | + github.event.pull_request.draft == false && + github.base_ref == 'main' && + github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'pull_request' && + needs.detect-changed-example.outputs.anyChanged == 'true' + name: Test the changed example + needs: detect-changed-example + runs-on: [self-hosted, gpu] + strategy: + fail-fast: false + matrix: ${{fromJson(needs.detect-changed-example.outputs.matrix)}} + container: + image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 + options: --gpus all --rm -v /data/scratch/examples-data:/data/ + timeout-minutes: 10 + steps: + - uses: actions/checkout@v3 + + - name: Install Colossal-AI + run: | + pip install -v . + + - name: Test the example + run: | + example_dir=${{ matrix.directory }} + cd "${PWD}/examples/${example_dir}" + bash test_ci.sh + env: + NCCL_SHM_DISABLE: 1 diff --git a/.github/workflows/example_check_on_schedule.yml b/.github/workflows/example_check_on_schedule.yml new file mode 100644 index 0000000000000000000000000000000000000000..bd52ca4321a2b9abd3182df53cf15b77e40766c7 --- /dev/null +++ b/.github/workflows/example_check_on_schedule.yml @@ -0,0 +1,69 @@ +name: Test Example on Schedule +on: + # run at 00:00 of every Sunday(singapore time) so here is UTC time Saturday 16:00 + schedule: + - cron: '0 16 * * 6' + workflow_dispatch: + +jobs: + # This is for all files' weekly check. Specifically, this job is to find all the directories. + matrix_preparation: + if: github.repository == 'hpcaitech/ColossalAI' + name: Prepare matrix for weekly check + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.setup-matrix.outputs.matrix }} + steps: + - name: 📚 Checkout + uses: actions/checkout@v3 + + - name: setup matrix + id: setup-matrix + run: | + res=`python .github/workflows/scripts/example_checks/check_example_weekly.py` + all_loc=$( IFS=',' ; echo "${res[*]}" ) + echo "Found the examples: $all_loc" + echo "matrix={\"directory\":$(echo "$all_loc")}" >> $GITHUB_OUTPUT + + weekly_check: + if: github.repository == 'hpcaitech/ColossalAI' + name: Weekly check all examples + needs: matrix_preparation + runs-on: [self-hosted, gpu] + strategy: + fail-fast: false + matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} + container: + image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 + timeout-minutes: 10 + steps: + - name: 📚 Checkout + uses: actions/checkout@v3 + + - name: Install Colossal-AI + run: | + pip install -v . + + - name: Traverse all files + run: | + example_dir=${{ matrix.directory }} + echo "Testing ${example_dir} now" + cd "${PWD}/examples/${example_dir}" + bash test_ci.sh + env: + NCCL_SHM_DISABLE: 1 + + - name: Notify Lark + id: message-preparation + if: ${{ failure() }} + run: | + url=$SERVER_URL/$REPO/actions/runs/$RUN_ID + msg="Example tests failed for $EXAMPLE_DIR, please visit $url for details" + echo $msg + python .github/workflows/scripts/send_message_to_lark.py -m "$msg" -u $WEBHOOK_URL + env: + SERVER_URL: ${{github.server_url }} + REPO: ${{ github.repository }} + RUN_ID: ${{ github.run_id }} + WEBHOOK_URL: ${{ secrets.LARK_NOTIFICATION_WEBHOOK_URL }} + EXAMPLE_DIR: ${{ matrix.directory }} diff --git a/.github/workflows/post_commit.yml b/.github/workflows/post_commit.yml new file mode 100644 index 0000000000000000000000000000000000000000..1bbc0d2f5c349ab925da923c4b8aa84c2566c292 --- /dev/null +++ b/.github/workflows/post_commit.yml @@ -0,0 +1,97 @@ +name: post-commit + +on: + pull_request: + types: + - closed + +jobs: + # this job will run after a PR is merged to run pre-commit on any changed file + # so that the user does not need to learn pre-commit and pre-commit can still + # be auto-executed by the workflow + pre-commit: + runs-on: ubuntu-latest + if: github.event.pull_request.merged == true && github.repository == 'hpcaitech/ColossalAI' + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + ref: ${{ github.event.pull_request.head.sha }} + + # the PR branch and the hpcaitech/colossal-ai main branch + # must share a common commit, we need to locate that commit, + # which is the commit checked-out or forked when the PR branch is created + # such that we can look for files changed since that commit + - name: Locate base commit + id: locate-base-sha + run: | + curBranch=$(git rev-parse --abbrev-ref HEAD) + commonCommit=$(git merge-base origin/main $curBranch) + echo $commonCommit + echo "baseSHA=$commonCommit" >> $GITHUB_OUTPUT + + - name: Find the changed files + id: find-changed-files + uses: tj-actions/changed-files@v35 + with: + base_sha: ${{ steps.locate-base-sha.outputs.baseSHA }} + + - name: List all changed files + run: | + for file in ${{ steps.find-changed-files.outputs.all_changed_files }}; do + echo "$file was changed" + done + + # check out the main branch + - uses: actions/checkout@v2 + with: + ref: 'main' + + - uses: actions/setup-python@v3 + + - name: Cache pre-commit hooks + uses: actions/cache@v3 + with: + path: ~/.cache/pre-commit + key: ${{ runner.os }}-pre-commit-hooks + + - name: Set up pre-commit + run: | + pip install pre-commit + pre-commit install + + # run pre-commit on changed files + - name: Run Pre-commit + run: | + for file in ${{ steps.find-changed-files.outputs.all_changed_files }}; do + pre-commit run --files $file || true + done + + # create commit for pre-commit + # when all files are well formatted, there is no need to create a commit + # therefore, this step will produce an error, which should be allowed + - name: Create commits + id: commit + continue-on-error: true + run: | + git config --global user.name 'github-actions' + git config --global user.email 'github-actions@github.com' + git remote set-url origin https://x-access-token:${{ secrets.GITHUB_TOKEN }}@github.com/${{ github.repository }} + git add -A + git commit -am "[format] applied code formatting on changed files in pull request ${{ github.event.pull_request.number }}" + + # create pull request + - name: Create Pull Request + if: steps.commit.outcome == 'success' + id: cpr + uses: peter-evans/create-pull-request@v4 + with: + branch: pre-commit-${{ github.event.pull_request.number }} + title: "[format] applied code formatting on changed files in PR ${{ github.event.pull_request.number }}" + + - name: Enable Auto-merge for the New PR + if: steps.commit.outcome == 'success' + uses: peter-evans/enable-pull-request-automerge@v2 + with: + pull-request-number: ${{ steps.cpr.outputs.pull-request-number }} + merge-method: squash diff --git a/.github/workflows/release_docker_after_merge.yml b/.github/workflows/release_docker_after_merge.yml new file mode 100644 index 0000000000000000000000000000000000000000..607c19b05472e024d08fdd60dcf95c5516a4420a --- /dev/null +++ b/.github/workflows/release_docker_after_merge.yml @@ -0,0 +1,75 @@ +name: Publish Docker Image to DockerHub after Merge + +on: + workflow_dispatch: + pull_request: + paths: + - 'version.txt' + types: + - closed + +jobs: + release: + name: Publish Docker Image to DockerHub + if: ( github.event_name == 'workflow_dispatch' || github.event.pull_request.merged == true ) && github.repository == 'hpcaitech/ColossalAI' + runs-on: [self-hosted, gpu] + container: + image: "hpcaitech/docker-in-docker:latest" + options: --gpus all --rm -v /var/run/docker.sock:/var/run/docker.sock + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Build Docker + id: build + run: | + version=$(cat version.txt) + tag=hpcaitech/colossalai:$version + docker build --build-arg http_proxy=http://172.17.0.1:7890 --build-arg https_proxy=http://172.17.0.1:7890 -t $tag ./docker + echo "tag=${tag}" >> $GITHUB_OUTPUT + + - name: Log in to Docker Hub + uses: docker/login-action@f054a8b539a109f9f41c372932f1ae047eff08c9 + with: + username: ${{ secrets.DOCKER_USERNAME }} + password: ${{ secrets.DOCKER_PASSWORD }} + + - name: Push Docker image + id: docker-push + run: | + docker push ${{ steps.build.outputs.tag }} + + notify: + name: Notify Lark via webhook + needs: release + runs-on: ubuntu-latest + if: ${{ always() }} + steps: + - uses: actions/checkout@v2 + + - uses: actions/setup-python@v2 + with: + python-version: '3.8.14' + + - name: Install requests + run: pip install requests + + - name: Notify Lark + id: message-preparation + run: | + url=$SERVER_URL/$REPO/actions/runs/$RUN_ID + if [ "$STATUS" == 'success' ] + then + msg="The Docker image for the latest release has been successfully built and pushed to DockerHub." + else + msg="Failed to build and push the Docker image for the latest release, please visit $url for details." + fi + echo $msg + python .github/workflows/scripts/send_message_to_lark.py -m "$msg" -u $WEBHOOK_URL + env: + SERVER_URL: ${{github.server_url }} + REPO: ${{ github.repository }} + RUN_ID: ${{ github.run_id }} + WEBHOOK_URL: ${{ secrets.LARK_NOTIFICATION_WEBHOOK_URL }} + STATUS: ${{ needs.release.result }} diff --git a/.github/workflows/release_nightly_on_schedule.yml b/.github/workflows/release_nightly_on_schedule.yml new file mode 100644 index 0000000000000000000000000000000000000000..4125f333f30191fd72ff7e16892d65bf5a6748eb --- /dev/null +++ b/.github/workflows/release_nightly_on_schedule.yml @@ -0,0 +1,65 @@ +name: Publish Nightly Version to PyPI + +on: + workflow_dispatch: + schedule: + - cron: '0 0 * * 6' # release on every Sunday 00:00 UTC time + +jobs: + build-n-publish: + if: github.repository == 'hpcaitech/ColossalAI' + name: Build and publish Python 🐍 distributions 📦 to PyPI + runs-on: ubuntu-latest + timeout-minutes: 20 + steps: + - uses: actions/checkout@v2 + + - uses: actions/setup-python@v2 + with: + python-version: '3.8.14' + + - run: NIGHTLY=1 python setup.py sdist build + + # publish to PyPI if executed on the main branch + - name: Publish package to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + id: publish + with: + user: __token__ + password: ${{ secrets.PYPI_API_TOKEN }} + verbose: true + + notify: + name: Notify Lark via webhook + needs: build-n-publish + runs-on: ubuntu-latest + if: ${{ always() }} && github.repository == 'hpcaitech/ColossalAI' + steps: + - uses: actions/checkout@v2 + + - uses: actions/setup-python@v2 + with: + python-version: '3.8.14' + + - name: Install requests + run: pip install requests + + - name: Notify Lark + id: message-preparation + run: | + url=$SERVER_URL/$REPO/actions/runs/$RUN_ID + + if [ $STATUS == 'success' ] + then + msg="The Colossal-AI nightly version has been successfully released to PyPI." + else + msg="Failed to release Colossal-AI nightly version to PyPI, please visit $url for details." + fi + echo $msg + python .github/workflows/scripts/send_message_to_lark.py -m "$msg" -u $WEBHOOK_URL + env: + SERVER_URL: ${{github.server_url }} + REPO: ${{ github.repository }} + RUN_ID: ${{ github.run_id }} + WEBHOOK_URL: ${{ secrets.LARK_NOTIFICATION_WEBHOOK_URL }} + STATUS: ${{ steps.publish.outcome }} diff --git a/.github/workflows/release_pypi_after_merge.yml b/.github/workflows/release_pypi_after_merge.yml new file mode 100644 index 0000000000000000000000000000000000000000..b987b4397c173ee667fc32093525d6ca0fdf7f14 --- /dev/null +++ b/.github/workflows/release_pypi_after_merge.yml @@ -0,0 +1,68 @@ +name: Publish to PyPI + +on: + workflow_dispatch: + pull_request: + paths: + - 'version.txt' + types: + - closed + +jobs: + build-n-publish: + if: github.event_name == 'workflow_dispatch' || github.repository == 'hpcaitech/ColossalAI' && github.event.pull_request.merged == true && github.base_ref == 'main' + name: Build and publish Python 🐍 distributions 📦 to PyPI + runs-on: ubuntu-latest + timeout-minutes: 20 + steps: + - uses: actions/checkout@v2 + + - uses: actions/setup-python@v2 + with: + python-version: '3.8.14' + + - run: python setup.py sdist build + + # publish to PyPI if executed on the main branch + - name: Publish package to PyPI + id: publish + uses: pypa/gh-action-pypi-publish@release/v1 + with: + user: __token__ + password: ${{ secrets.PYPI_API_TOKEN }} + verbose: true + + notify: + name: Notify Lark via webhook + needs: build-n-publish + runs-on: ubuntu-latest + if: ${{ always() }} + steps: + - uses: actions/checkout@v2 + + - uses: actions/setup-python@v2 + with: + python-version: '3.8.14' + + - name: Install requests + run: pip install requests + + - name: Notify Lark + id: message-preparation + run: | + url=$SERVER_URL/$REPO/actions/runs/$RUN_ID + + if [ "$STATUS" == 'success' ] + then + msg="The Colossal-AI latest version has been successfully released to PyPI." + else + msg="Failed to release Colossal-AI to PyPI, please visit $url for details." + fi + echo $msg + python .github/workflows/scripts/send_message_to_lark.py -m "$msg" -u $WEBHOOK_URL + env: + SERVER_URL: ${{github.server_url }} + REPO: ${{ github.repository }} + RUN_ID: ${{ github.run_id }} + WEBHOOK_URL: ${{ secrets.LARK_NOTIFICATION_WEBHOOK_URL }} + STATUS: ${{ needs.build-n-publish.result }} diff --git a/.github/workflows/release_test_pypi_before_merge.yml b/.github/workflows/release_test_pypi_before_merge.yml new file mode 100644 index 0000000000000000000000000000000000000000..49c6262651752812cbd31b037896dc97260f81ad --- /dev/null +++ b/.github/workflows/release_test_pypi_before_merge.yml @@ -0,0 +1,52 @@ +name: Publish to Test-PyPI Before Merge + +on: + pull_request: + paths: + - 'version.txt' + +jobs: + build-n-publish: + if: github.event_name == 'workflow_dispatch' || github.repository == 'hpcaitech/ColossalAI' + name: Build and publish Python 🐍 distributions 📦 to Test PyPI + runs-on: ubuntu-latest + timeout-minutes: 20 + steps: + - uses: actions/checkout@v2 + + - uses: actions/setup-python@v2 + with: + python-version: '3.8.14' + + - name: add timestamp to the version + id: prep-version + run: | + version=$(cat version.txt) + timestamp=$(date +%s) + new_version="${version}.post${timestamp}" + echo $new_version > ./version.txt + echo "version=$new_version" >> $GITHUB_OUTPUT + + - run: python setup.py sdist build + + # publish to PyPI if executed on the main branch + - name: Publish package to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + user: __token__ + password: ${{ secrets.TEST_PYPI_API_TOKEN }} + repository_url: https://test.pypi.org/legacy/ + verbose: true + + - name: Wait for Test-PyPI refresh + run: sleep 300s + shell: bash + + - name: Try installation + run: | + # we need to install the requirements.txt first + # as test-pypi may not contain the distributions for libs listed in the txt file + pip install -r requirements/requirements.txt + pip install --index-url https://test.pypi.org/simple/ colossalai==$VERSION + env: + VERSION: ${{ steps.prep-version.outputs.version }} diff --git a/.github/workflows/report_leaderboard_to_lark.yml b/.github/workflows/report_leaderboard_to_lark.yml new file mode 100644 index 0000000000000000000000000000000000000000..00d8e9e1f5fd6683b471326bccc4b58477e057b2 --- /dev/null +++ b/.github/workflows/report_leaderboard_to_lark.yml @@ -0,0 +1,29 @@ +name: Generate Community Report and Send to Lark + +on: + workflow_dispatch: + schedule: + # release on every Friday 09:00 UTC time, 17:00 Beijing/Singapore time + - cron: '0 9 * * 5' + +jobs: + generate-and-publish: + if: github.repository == 'hpcaitech/ColossalAI' + name: Generate leaderboard report and publish to Lark + runs-on: ubuntu-latest + timeout-minutes: 20 + steps: + - uses: actions/checkout@v2 + + - uses: actions/setup-python@v2 + with: + python-version: '3.8.14' + + - run: pip install requests matplotlib seaborn requests_toolbelt pytz + + - run: python .github/workflows/scripts/generate_leaderboard_and_send_to_lark.py + env: + LARK_APP_ID: ${{ secrets.LARK_LEADERBOARD_APP_ID }} + LARK_APP_SECRET: ${{ secrets.LARK_LEADERBOARD_APP_SECRET }} + LARK_WEBHOOK_URL: ${{ secrets.LARK_NOTIFICATION_WEBHOOK_URL }} + GITHUB_TOKEN: ${{ github.token }} diff --git a/.github/workflows/report_test_coverage.yml b/.github/workflows/report_test_coverage.yml new file mode 100644 index 0000000000000000000000000000000000000000..bbada74e685025d15b57f6a4cae855046f99e8ec --- /dev/null +++ b/.github/workflows/report_test_coverage.yml @@ -0,0 +1,80 @@ +name: Report Test Coverage + +on: + workflow_run: + workflows: [Build on PR] + types: + - completed + +jobs: + report-test-coverage: + runs-on: ubuntu-latest + steps: + - name: 'Download artifact' + uses: actions/github-script@v6 + with: + script: | + let allArtifacts = await github.rest.actions.listWorkflowRunArtifacts({ + owner: context.repo.owner, + repo: context.repo.repo, + run_id: context.payload.workflow_run.id, + }); + let matchArtifact = allArtifacts.data.artifacts.filter((artifact) => { + return artifact.name == "report" + })[0]; + let download = await github.rest.actions.downloadArtifact({ + owner: context.repo.owner, + repo: context.repo.repo, + artifact_id: matchArtifact.id, + archive_format: 'zip', + }); + let fs = require('fs'); + fs.writeFileSync(`${process.env.GITHUB_WORKSPACE}/report.zip`, Buffer.from(download.data)); + + - name: 'Unzip artifact' + id: unzip + run: | + unzip report.zip + if [ -f "coverage.txt" ]; then + echo "hasReport=true" >> $GITHUB_OUTPUT + else + echo "hasReport=false" >> $GITHUB_OUTPUT + fi + + - name: Make Coverage Report Collapsable + if: steps.unzip.outputs.hasReport == 'true' + run: | + covNum=$(cat cov_number) + title="The code coverage for the changed files is ${covNum}%." + touch coverage_report.txt + echo $title >> coverage_report.txt + echo " " >> coverage_report.txt + echo "
" >> coverage_report.txt + echo "Click me to view the complete report" >> coverage_report.txt + echo " " >> coverage_report.txt + echo "\`\`\`" >> coverage_report.txt + cat coverage.txt >> coverage_report.txt + echo "\`\`\`" >> coverage_report.txt + echo "
" >> coverage_report.txt + mv coverage_report.txt coverage.txt + + - name: 'Comment on PR' + if: steps.unzip.outputs.hasReport == 'true' + uses: actions/github-script@v6 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + let fs = require('fs'); + let issue_number = Number(fs.readFileSync('./pr_number')); + let owner = context.repo.owner; + let repo = context.repo.repo; + let run_id = context.payload.workflow_run.id; + let run_url = `https://github.com/${owner}/${repo}/actions/runs/${run_id}` + let body = fs.readFileSync('./coverage.txt', {encoding:'utf8', flag:'r'}) + + await github.rest.issues.createComment({ + owner: owner, + repo: repo, + issue_number: issue_number, + body: body + }); diff --git a/.github/workflows/run_chatgpt_examples.yml b/.github/workflows/run_chatgpt_examples.yml new file mode 100644 index 0000000000000000000000000000000000000000..9d9d3a007851276cbf5dbd69e5c8b776178e5de4 --- /dev/null +++ b/.github/workflows/run_chatgpt_examples.yml @@ -0,0 +1,56 @@ +name: Run ChatGPT examples + +on: + pull_request: + types: [synchronize, opened, reopened] + paths: + - 'applications/Chat/coati/**' + - 'applications/Chat/requirements.txt' + - 'applications/Chat/setup.py' + - 'applications/Chat/examples/**' + + +jobs: + tests: + name: Run ChatGPT examples + if: | + github.event.pull_request.draft == false && + github.base_ref == 'main' && + github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' + runs-on: [self-hosted, gpu] + container: + image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 + options: --gpus all --rm -v /data/scratch/github_actions/chat:/data/scratch/github_actions/chat + timeout-minutes: 30 + defaults: + run: + shell: bash + steps: + - name: Checkout ColossalAI + uses: actions/checkout@v2 + + - name: Install ColossalAI and ChatGPT + run: | + pip install -e . + cd applications/Chat + pip install -v . + pip install -r examples/requirements.txt + + - name: Install Transformers + run: | + cd applications/Chat + git clone https://github.com/hpcaitech/transformers + cd transformers + pip install -v . + + - name: Execute Examples + run: | + cd applications/Chat + rm -rf ~/.cache/colossalai + ./examples/test_ci.sh + env: + NCCL_SHM_DISABLE: 1 + MAX_JOBS: 8 + SFT_DATASET: /data/scratch/github_actions/chat/data.json + PROMPT_PATH: /data/scratch/github_actions/chat/prompts_en.jsonl + PRETRAIN_DATASET: /data/scratch/github_actions/chat/alpaca_data.json diff --git a/.github/workflows/run_chatgpt_unit_tests.yml b/.github/workflows/run_chatgpt_unit_tests.yml new file mode 100644 index 0000000000000000000000000000000000000000..47c80fc9a9fecafa332a0cb0e457052b8d929711 --- /dev/null +++ b/.github/workflows/run_chatgpt_unit_tests.yml @@ -0,0 +1,47 @@ +name: Run ChatGPT unit tests + +on: + pull_request: + types: [synchronize, opened, reopened] + paths: + - 'applications/Chat/coati/**' + - 'applications/Chat/requirements.txt' + - 'applications/Chat/setup.py' + - 'applications/Chat/requirements-test.txt' + - 'applications/Chat/tests/**' + - 'applications/Chat/pytest.ini' + +jobs: + tests: + name: Run ChatGPT unit tests + if: | + github.event.pull_request.draft == false && + github.base_ref == 'main' && + github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' + runs-on: [self-hosted, gpu] + container: + image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 + options: --gpus all --rm -v /data/scratch/chatgpt:/data/scratch/chatgpt + timeout-minutes: 30 + defaults: + run: + shell: bash + steps: + - name: Checkout ColossalAI + uses: actions/checkout@v2 + + - name: Install ColossalAI and ChatGPT + run: | + pip install -e . + cd applications/Chat + pip install -v . + pip install -r requirements-test.txt + + - name: Execute Unit Testing + run: | + cd applications/Chat + rm -rf ~/.cache/colossalai + pytest tests/ + env: + NCCL_SHM_DISABLE: 1 + MAX_JOBS: 8 diff --git a/.github/workflows/scripts/check_doc_i18n.py b/.github/workflows/scripts/check_doc_i18n.py new file mode 100644 index 0000000000000000000000000000000000000000..1aa7283e9e52f169d89f337d7942cf55f601257d --- /dev/null +++ b/.github/workflows/scripts/check_doc_i18n.py @@ -0,0 +1,67 @@ +import argparse +import os + + +def compare_dirs(dir1, dir2): + # First, we need to check if the two directories exist + if not os.path.exists(dir1) or not os.path.exists(dir2): + return False + + # Now, we compare the list of items in each directory + items1 = os.listdir(dir1) + items2 = os.listdir(dir2) + + # If the number of items in each directory is different, the directories are different + if len(items1) != len(items2): + return False + + # For each item in the first directory, we check if there is a corresponding item in the second directory + for item in items1: + item_path1 = os.path.join(dir1, item) + item_path2 = os.path.join(dir2, item) + + # If the corresponding item doesn't exist in the second directory, the directories are different + if not os.path.exists(item_path2): + print(f'Found mismatch: {item_path1}, {item_path2}') + return False + + # If the corresponding item is a directory, we compare the two directories recursively + if os.path.isdir(item_path1) and os.path.isdir(item_path2): + if not compare_dirs(item_path1, item_path2): + print(f'Found mismatch: {item_path1}, {item_path2}') + return False + + # both are files + elif os.path.isfile(item_path1) and os.path.isfile(item_path2): + continue + + # If the corresponding item is not a file or a directory, the directories are different + else: + print(f'Found mismatch: {item_path1}, {item_path2}') + return False + + # If all items are the same, the directories are the same + return True + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-d', '--directory', help="The directory where the multi-language source files are kept.") + args = parser.parse_args() + + i18n_folders = os.listdir(args.directory) + i18n_folders = [os.path.join(args.directory, val) for val in i18n_folders] + + if len(i18n_folders) > 1: + for i in range(1, len(i18n_folders)): + dir1 = i18n_folders[0] + dir2 = i18n_folders[i] + print(f'comparing {dir1} vs {dir2}') + match = compare_dirs(i18n_folders[0], i18n_folders[i]) + + if not match: + print( + f"{dir1} and {dir2} don't match, please ensure that your documentation is available in different languages" + ) + else: + print(f"{dir1} and {dir2} match") diff --git a/.github/workflows/scripts/example_checks/check_dispatch_inputs.py b/.github/workflows/scripts/example_checks/check_dispatch_inputs.py new file mode 100644 index 0000000000000000000000000000000000000000..5bec96187e0cc5b0aa5ebd8e6a59f73ac8b6d88d --- /dev/null +++ b/.github/workflows/scripts/example_checks/check_dispatch_inputs.py @@ -0,0 +1,27 @@ +import argparse +import os + + +def check_inputs(input_list): + for path in input_list: + real_path = os.path.join('examples', path) + if not os.path.exists(real_path): + return False + return True + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('-f', '--fileNameList', type=str, help="List of file names") + args = parser.parse_args() + name_list = args.fileNameList.split(",") + is_correct = check_inputs(name_list) + + if is_correct: + print('success') + else: + print('failure') + + +if __name__ == '__main__': + main() diff --git a/.github/workflows/scripts/example_checks/check_example_weekly.py b/.github/workflows/scripts/example_checks/check_example_weekly.py new file mode 100644 index 0000000000000000000000000000000000000000..83eff644e3150dae8fa7ada808dd1e16b571e54a --- /dev/null +++ b/.github/workflows/scripts/example_checks/check_example_weekly.py @@ -0,0 +1,37 @@ +import os + + +def show_files(path, all_files): + # Traverse all the folder/file in current directory + file_list = os.listdir(path) + # Determine the element is folder or file. If file, pass it into list, if folder, recurse. + for file_name in file_list: + # Get the abs directory using os.path.join() and store into cur_path. + cur_path = os.path.join(path, file_name) + # Determine whether folder + if os.path.isdir(cur_path): + show_files(cur_path, all_files) + else: + all_files.append(cur_path) + return all_files + + +def join(input_list, sep=None): + return (sep or ' ').join(input_list) + + +def main(): + contents = show_files('examples/', []) + all_loc = [] + for file_loc in contents: + split_loc = file_loc.split('/') + # must have two sub-folder levels after examples folder, such as examples/images/vit is acceptable, examples/images/README.md is not, examples/requirements.txt is not. + if len(split_loc) >= 4: + re_loc = '/'.join(split_loc[1:3]) + if re_loc not in all_loc: + all_loc.append(re_loc) + print(all_loc) + + +if __name__ == '__main__': + main() diff --git a/.github/workflows/scripts/example_checks/detect_changed_example.py b/.github/workflows/scripts/example_checks/detect_changed_example.py new file mode 100644 index 0000000000000000000000000000000000000000..c69d95a552e96bfe425fc7ad06c0de6a30b1d786 --- /dev/null +++ b/.github/workflows/scripts/example_checks/detect_changed_example.py @@ -0,0 +1,24 @@ +import argparse + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('-f', '--fileNameList', type=str, help="The list of changed files") + args = parser.parse_args() + name_list = args.fileNameList.split(":") + folder_need_check = set() + for loc in name_list: + # Find only the sub-sub-folder of 'example' folder + # the examples folder structure is like + # - examples + # - area + # - application + # - file + if loc.split("/")[0] == "examples" and len(loc.split("/")) >= 4: + folder_need_check.add('/'.join(loc.split("/")[1:3])) + # Output the result using print. Then the shell can get the values. + print(list(folder_need_check)) + + +if __name__ == '__main__': + main() diff --git a/.github/workflows/scripts/generate_leaderboard_and_send_to_lark.py b/.github/workflows/scripts/generate_leaderboard_and_send_to_lark.py new file mode 100644 index 0000000000000000000000000000000000000000..16b8957c1d884aba897fd4734a8ec54efc299c5a --- /dev/null +++ b/.github/workflows/scripts/generate_leaderboard_and_send_to_lark.py @@ -0,0 +1,476 @@ +import os +from dataclasses import dataclass +from datetime import datetime, timedelta +from typing import Any, Dict, List + +import matplotlib.pyplot as plt +import pytz +import requests +import seaborn +from requests_toolbelt import MultipartEncoder + + +@dataclass +class Contributor: + """ + Dataclass for a github contributor. + + Args: + name (str): name of the contributor + num_commits_this_week (int): number of commits made within one week + """ + name: str + num_commits_this_week: int + + +def plot_bar_chart(x: List[Any], y: List[Any], xlabel: str, ylabel: str, title: str, output_path: str) -> None: + """ + This function is a utility to plot the bar charts. + """ + plt.clf() + seaborn.color_palette() + fig = seaborn.barplot(x=x, y=y) + fig.set(xlabel=xlabel, ylabel=ylabel, title=title) + seaborn.despine() + plt.tight_layout() + plt.savefig(output_path, dpi=1200) + + +def get_issue_pull_request_comments(github_token: str, since: str) -> Dict[str, int]: + """ + Retrive the issue/PR comments made by our members in the last 7 days. + + Args: + github_token (str): GitHub access token for API calls + since (str): the path parameter required by GitHub Restful APIs, in the format of YYYY-MM-DDTHH:MM:SSZ + """ + # prepare header + headers = { + 'Authorization': f'Bearer {github_token}', + 'Accept': 'application/vnd.github+json', + 'X-GitHub-Api-Version': '2022-11-28' + } + + user_engagement_count = {} + + # do pagination to the API + page = 1 + while True: + comment_api = f'https://api.github.com/repos/hpcaitech/ColossalAI/issues/comments?since={since}&page={page}' + comment_response = requests.get(comment_api, headers=headers).json() + + if len(comment_response) == 0: + break + else: + for item in comment_response: + comment_author_relationship = item['author_association'] + if comment_author_relationship != 'MEMBER': + # if the comment is not made by our member + # we don't count this comment towards user engagement + continue + + issue_id = item['issue_url'].split('/')[-1] + issue_api = f'https://api.github.com/repos/hpcaitech/ColossalAI/issues/{issue_id}' + issue_response = requests.get(issue_api, headers=headers).json() + issue_author_relationship = issue_response['author_association'] + + if issue_author_relationship != 'MEMBER': + # this means that the issue/PR is not created by our own people + # any comments in this issue/PR by our member will be counted towards the leaderboard + member_name = item['user']['login'] + + if member_name in user_engagement_count: + user_engagement_count[member_name] += 1 + else: + user_engagement_count[member_name] = 1 + page += 1 + return user_engagement_count + + +def get_discussion_comments(github_token, since) -> Dict[str, int]: + """ + Retrive the discussion comments made by our members in the last 7 days. + This is only available via the GitHub GraphQL API. + + Args: + github_token (str): GitHub access token for API calls + since (Datetime): the query parameter to determine whether the comment is made this week + """ + + # use graphql to get the discussions updated in the last 7 days + def _generate_discussion_query(num, cursor: str = None): + if cursor is None: + offset_str = "" + else: + offset_str = f", after: \"{cursor}\"" + query = f""" + {{ + repository(owner: "hpcaitech", name: "ColossalAI"){{ + discussions(first: {num} {offset_str}){{ + edges {{ + cursor + node{{ + title + author{{ + login + }} + number + authorAssociation + updatedAt + }} + }} + }} + }} + }} + """ + return query + + def _generate_comment_reply_count_for_discussion(discussion_number, num, cursor: str = None): + # here we assume that each comment will not have more than 100 replies for simplicity + # otherwise, we have to go through pagination for both comment and reply + if cursor is None: + offset_str = "" + else: + offset_str = f", before: \"{cursor}\"" + query = f""" + {{ + repository(owner: "hpcaitech", name: "ColossalAI"){{ + discussion(number: {discussion_number}){{ + title + comments(last: {num} {offset_str}){{ + edges{{ + cursor + node {{ + author{{ + login + }} + updatedAt + authorAssociation + replies (last: 100) {{ + edges {{ + node {{ + author {{ + login + }} + updatedAt + authorAssociation + }} + }} + }} + }} + }} + }} + }} + }} + }} + """ + return query + + # a utility function to make call to Github GraphQL API + def _call_graphql_api(query): + headers = {"Authorization": f"Bearer {github_token}"} + json_data = {'query': query} + response = requests.post('https://api.github.com/graphql', json=json_data, headers=headers) + data = response.json() + return data + + # get the discussion numbers updated in the last 7 days + discussion_numbers = [] + num_per_request = 10 + cursor = None + while True: + query = _generate_discussion_query(num_per_request, cursor) + data = _call_graphql_api(query) + found_discussion_out_of_time_range = False + + edges = data['data']['repository']['discussions']['edges'] + if len(edges) == 0: + break + else: + # keep the discussion whose author is not a member + for edge in edges: + # print the discussion title + discussion = edge['node'] + + discussion_updated_at = datetime.strptime(discussion['updatedAt'], "%Y-%m-%dT%H:%M:%SZ") + # check if the updatedAt is within the last 7 days + # if yes, add it to dicussion_numbers + if discussion_updated_at > since: + if discussion['authorAssociation'] != 'MEMBER': + discussion_numbers.append(discussion['number']) + else: + found_discussion_out_of_time_range = True + + if found_discussion_out_of_time_range: + break + else: + # update cursor + cursor = edges[-1]['cursor'] + + # get the dicussion comments and replies made by our member + user_engagement_count = {} + for dicussion_number in discussion_numbers: + cursor = None + num_per_request = 10 + + while True: + query = _generate_comment_reply_count_for_discussion(dicussion_number, num_per_request, cursor) + data = _call_graphql_api(query) + + # get the comments + edges = data['data']['repository']['discussion']['comments']['edges'] + + # update the cursor + if len(edges) == 0: + break + else: + # update cursor for pagination + cursor = edges[-1]['cursor'] + + for edge in edges: + comment = edge['node'] + if comment['authorAssociation'] == 'MEMBER': + # check if the updatedAt is within the last 7 days + # if yes, add it to user_engagement_count + comment_updated_at = datetime.strptime(comment['updatedAt'], "%Y-%m-%dT%H:%M:%SZ") + if comment_updated_at > since: + member_name = comment['author']['login'] + if member_name in user_engagement_count: + user_engagement_count[member_name] += 1 + else: + user_engagement_count[member_name] = 1 + + # get the replies + reply_edges = comment['replies']['edges'] + if len(reply_edges) == 0: + continue + else: + for reply_edge in reply_edges: + reply = reply_edge['node'] + if reply['authorAssociation'] == 'MEMBER': + # check if the updatedAt is within the last 7 days + # if yes, add it to dicussion_numbers + reply_updated_at = datetime.strptime(reply['updatedAt'], "%Y-%m-%dT%H:%M:%SZ") + if reply_updated_at > since: + member_name = reply['author']['login'] + if member_name in user_engagement_count: + user_engagement_count[member_name] += 1 + else: + user_engagement_count[member_name] = 1 + return user_engagement_count + + +def generate_user_engagement_leaderboard_image(github_token: str, output_path: str) -> bool: + """ + Generate the user engagement leaderboard image for stats within the last 7 days + + Args: + github_token (str): GitHub access token for API calls + output_path (str): the path to save the image + """ + + # request to the Github API to get the users who have replied the most in the last 7 days + now = datetime.utcnow() + start_datetime = now - timedelta(days=7) + start_datetime_str = start_datetime.strftime("%Y-%m-%dT%H:%M:%SZ") + + # get the issue/PR comments and discussion comment count + issue_pr_engagement_count = get_issue_pull_request_comments(github_token=github_token, since=start_datetime_str) + discussion_engagement_count = get_discussion_comments(github_token=github_token, since=start_datetime) + total_engagement_count = {} + + # update the total engagement count + total_engagement_count.update(issue_pr_engagement_count) + for name, count in discussion_engagement_count.items(): + if name in total_engagement_count: + total_engagement_count[name] += count + else: + total_engagement_count[name] = count + + # prepare the data for plotting + x = [] + y = [] + + if len(total_engagement_count) > 0: + ranking = [] + for name, count in total_engagement_count.items(): + ranking.append((name, count)) + + ranking.sort(key=lambda x: x[1], reverse=True) + + for name, count in ranking: + x.append(count) + y.append(name) + + # use Shanghai time to display on the image + start_datetime_str = datetime.now(pytz.timezone('Asia/Shanghai')).strftime("%Y-%m-%dT%H:%M:%SZ") + + # plot the leaderboard + xlabel = f"Number of Comments made (since {start_datetime_str})" + ylabel = "Member" + title = 'Active User Engagement Leaderboard' + plot_bar_chart(x, y, xlabel=xlabel, ylabel=ylabel, title=title, output_path=output_path) + return True + else: + return False + + +def generate_contributor_leaderboard_image(github_token, output_path) -> bool: + """ + Generate the contributor leaderboard image for stats within the last 7 days + + Args: + github_token (str): GitHub access token for API calls + output_path (str): the path to save the image + """ + # request to the Github API to get the users who have contributed in the last 7 days + URL = 'https://api.github.com/repos/hpcaitech/ColossalAI/stats/contributors' + headers = { + 'Authorization': f'Bearer {github_token}', + 'Accept': 'application/vnd.github+json', + 'X-GitHub-Api-Version': '2022-11-28' + } + + while True: + response = requests.get(URL, headers=headers).json() + + if len(response) != 0: + # sometimes the Github API returns empty response for unknown reason + # request again if the response is empty + break + + contributor_list = [] + + # get number of commits for each contributor + start_timestamp = None + for item in response: + num_commits_this_week = item['weeks'][-1]['c'] + name = item['author']['login'] + contributor = Contributor(name=name, num_commits_this_week=num_commits_this_week) + contributor_list.append(contributor) + + # update start_timestamp + start_timestamp = item['weeks'][-1]['w'] + + # convert unix timestamp to Beijing datetime + start_datetime = datetime.fromtimestamp(start_timestamp, tz=pytz.timezone('Asia/Shanghai')) + start_datetime_str = start_datetime.strftime("%Y-%m-%dT%H:%M:%SZ") + + # sort by number of commits + contributor_list.sort(key=lambda x: x.num_commits_this_week, reverse=True) + + # remove contributors who has zero commits + contributor_list = [x for x in contributor_list if x.num_commits_this_week > 0] + + # prepare the data for plotting + x = [x.num_commits_this_week for x in contributor_list] + y = [x.name for x in contributor_list] + + # plot + if len(x) > 0: + xlabel = f"Number of Commits (since {start_datetime_str})" + ylabel = "Contributor" + title = 'Active Contributor Leaderboard' + plot_bar_chart(x, y, xlabel=xlabel, ylabel=ylabel, title=title, output_path=output_path) + return True + else: + return False + + +def upload_image_to_lark(lark_tenant_token: str, image_path: str) -> str: + """ + Upload image to Lark and return the image key + + Args: + lark_tenant_token (str): Lark tenant access token + image_path (str): the path to the image to be uploaded + """ + url = "https://open.feishu.cn/open-apis/im/v1/images" + form = {'image_type': 'message', 'image': (open(image_path, 'rb'))} # 需要替换具体的path + multi_form = MultipartEncoder(form) + headers = { + 'Authorization': f'Bearer {lark_tenant_token}', ## 获取tenant_access_token, 需要替换为实际的token + } + headers['Content-Type'] = multi_form.content_type + response = requests.request("POST", url, headers=headers, data=multi_form).json() + return response['data']['image_key'] + + +def generate_lark_tenant_access_token(app_id: str, app_secret: str) -> str: + """ + Generate Lark tenant access token. + + Args: + app_id (str): Lark app id + app_secret (str): Lark app secret + """ + url = 'https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal' + data = {'app_id': app_id, 'app_secret': app_secret} + response = requests.post(url, json=data).json() + return response['tenant_access_token'] + + +def send_image_to_lark(image_key: str, webhook_url: str) -> None: + """ + Send image to Lark. + + Args: + image_key (str): the image key returned by Lark + webhook_url (str): the webhook url to send the image + """ + data = {"msg_type": "image", "content": {"image_key": image_key}} + requests.post(webhook_url, json=data) + + +def send_message_to_lark(message: str, webhook_url: str): + """ + Send message to Lark. + + Args: + message (str): the message to be sent + webhook_url (str): the webhook url to send the message + """ + data = {"msg_type": "text", "content": {"text": message}} + requests.post(webhook_url, json=data) + + +if __name__ == '__main__': + GITHUB_TOKEN = os.environ['GITHUB_TOKEN'] + CONTRIBUTOR_IMAGE_PATH = 'contributor_leaderboard.png' + USER_ENGAGEMENT_IMAGE_PATH = 'engagement_leaderboard.png' + + # generate images + contrib_success = generate_contributor_leaderboard_image(GITHUB_TOKEN, CONTRIBUTOR_IMAGE_PATH) + engagement_success = generate_user_engagement_leaderboard_image(GITHUB_TOKEN, USER_ENGAGEMENT_IMAGE_PATH) + + # upload images + APP_ID = os.environ['LARK_APP_ID'] + APP_SECRET = os.environ['LARK_APP_SECRET'] + LARK_TENANT_TOKEN = generate_lark_tenant_access_token(app_id=APP_ID, app_secret=APP_SECRET) + contributor_image_key = upload_image_to_lark(LARK_TENANT_TOKEN, CONTRIBUTOR_IMAGE_PATH) + user_engagement_image_key = upload_image_to_lark(LARK_TENANT_TOKEN, USER_ENGAGEMENT_IMAGE_PATH) + + # send message to lark + LARK_WEBHOOK_URL = os.environ['LARK_WEBHOOK_URL'] + message = """本周的社区榜单出炉啦! +1. 开发贡献者榜单 +2. 用户互动榜单 + +注: +- 开发贡献者测评标准为:本周由公司成员提交的commit次数 +- 用户互动榜单测评标准为:本周由公司成员在非成员创建的issue/PR/discussion中回复的次数 +""" + + send_message_to_lark(message, LARK_WEBHOOK_URL) + + # send contributor image to lark + if contrib_success: + send_image_to_lark(contributor_image_key, LARK_WEBHOOK_URL) + else: + send_message_to_lark("本周没有成员贡献commit,无榜单图片生成。", LARK_WEBHOOK_URL) + + # send user engagement image to lark + if engagement_success: + send_image_to_lark(user_engagement_image_key, LARK_WEBHOOK_URL) + else: + send_message_to_lark("本周没有成员互动,无榜单图片生成。", LARK_WEBHOOK_URL) diff --git a/.github/workflows/scripts/generate_release_draft.py b/.github/workflows/scripts/generate_release_draft.py new file mode 100644 index 0000000000000000000000000000000000000000..dc592e4c977b46b0f0156a9da08832b7624776df --- /dev/null +++ b/.github/workflows/scripts/generate_release_draft.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python +# coding: utf-8 + +import argparse +import os +import re + +import requests + +COMMIT_API = 'https://api.github.com/repos/hpcaitech/ColossalAI/commits' +TAGS_API = 'https://api.github.com/repos/hpcaitech/ColossalAI/tags' + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--out', type=str, help='output path for the release draft', required=True) + parser.add_argument('--version', type=str, help='current version to release', required=True) + return parser.parse_args() + + +def get_latest_tag_commit(headers=None): + res = requests.get(url=TAGS_API, headers=headers) + data = res.json() + commit_hash = data[0]['commit']['sha'] + version = data[0]['name'] + return commit_hash, version + + +def get_commit_info(commit_hash, headers=None): + api = f'{COMMIT_API}/{commit_hash}' + res = requests.get(url=api, headers=headers) + return res.json() + + +def get_all_commit_info(since, headers=None): + page = 1 + results = [] + + while True: + api = f'{COMMIT_API}?since={since}&per_page=100&page={page}' + resp = requests.get(url=api, headers=headers) + data = resp.json() + + # exit when no more data + if len(data) == 0: + break + + results.extend(data) + page += 1 + + return results + + +def collate_release_info(commit_info_list): + results = dict() + pattern = pattern = r'\[.*\]' + + for commit_info in commit_info_list: + author = commit_info['commit']['author']['name'] + + try: + author_url = commit_info['author']['url'] + except: + # author can be None + author_url = None + msg = commit_info['commit']['message'] + match = re.search(pattern, msg) + + if match: + tag = match.group().lstrip('[').rstrip(']').capitalize() + if tag not in results: + results[tag] = [] + results[tag].append((msg, author, author_url)) + + return results + + +def generate_release_post_markdown(current_version, last_version, release_info): + text = [] + + # add highlights + highlights = "## What's Changed \n\n" + text.append(highlights) + + # add items + for k, v in release_info.items(): + topic = f"### {k} \n" + text.append(topic) + + for msg, author, author_url in v: + # only keep the first line + msg = msg.split('\n')[0] + + if author_url: + item = f'{msg} by [{author}]({author_url})\n' + else: + item = f'{msg} by {author}\n' + text.append(f'- {item}') + + text.append('\n') + + # add full change log + text.append( + f'**Full Changelog**: https://github.com/hpcaitech/ColossalAI/compare/{current_version}...{last_version}') + + return text + + +if __name__ == '__main__': + args = parse_args() + token = os.environ['GITHUB_API_TOKEN'] + headers = {'Authorization': token} + + # get previous release tag + last_release_commit, last_version = get_latest_tag_commit(headers) + last_release_commit_info = get_commit_info(last_release_commit, headers=headers) + last_release_date = last_release_commit_info['commit']['author']['date'] + + # get the commits since last release + commit_info = get_all_commit_info(since=last_release_date, headers=headers) + commit_info = commit_info[:-1] # remove the release commit + + # collate into markdown + release_info = collate_release_info(commit_info) + markdown_text = generate_release_post_markdown(args.version, last_version, release_info) + + # write into a file + with open(args.out, 'w') as f: + for line in markdown_text: + f.write(line) diff --git a/.github/workflows/scripts/send_message_to_lark.py b/.github/workflows/scripts/send_message_to_lark.py new file mode 100644 index 0000000000000000000000000000000000000000..a113327a786ed1310b6ef8c0ffc784bd6af2e344 --- /dev/null +++ b/.github/workflows/scripts/send_message_to_lark.py @@ -0,0 +1,20 @@ +import argparse + +import requests + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('-m', '--message', type=str) + parser.add_argument('-u', '--url', type=str) + return parser.parse_args() + + +def send_message_to_lark(message, webhook_url): + data = {"msg_type": "text", "content": {"text": message}} + requests.post(webhook_url, json=data) + + +if __name__ == '__main__': + args = parse_args() + send_message_to_lark(args.message, args.url) diff --git a/.github/workflows/submodule.yml b/.github/workflows/submodule.yml new file mode 100644 index 0000000000000000000000000000000000000000..4ffb261183f121a87eb39dc710cc9838dac11bbf --- /dev/null +++ b/.github/workflows/submodule.yml @@ -0,0 +1,45 @@ +name: Synchronize Submodule + +on: + workflow_dispatch: + schedule: + - cron: "0 0 * * *" + +jobs: + sync-submodule: + runs-on: ubuntu-latest + if: github.repository == 'hpcaitech/ColossalAI' + steps: + - name: Checkout + uses: actions/checkout@v2 + with: + ref: 'main' + submodules: true + + - name: echo + run: | + echo ${{github}} + + - name: Git Sumbodule Update + run: | + git pull --recurse-submodules + git submodule update --remote --recursive + + - name: Commit update + run: | + git config --global user.name 'github-actions' + git config --global user.email 'github-actions@github.com' + git remote set-url origin https://x-access-token:${{ secrets.GITHUB_TOKEN }}@github.com/${{ github.repository }} + git commit -am "Automated submodule synchronization" + + - name: Create Pull Request + uses: peter-evans/create-pull-request@v3 + with: + title: '[Bot] Synchronize Submodule References' + body: | + Automated PR to update submodule commits + committer: GitHub + author: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> + assignees: ${{ github.actor }} + delete-branch: true + branch: create-pull-request/patch-sync-submodule diff --git a/.github/workflows/translate_comment.yml b/.github/workflows/translate_comment.yml new file mode 100644 index 0000000000000000000000000000000000000000..83c127b3caa4728a751ec07329242120f946e81a --- /dev/null +++ b/.github/workflows/translate_comment.yml @@ -0,0 +1,18 @@ +name: 'issue-translator' +on: + issue_comment: + types: [created] + issues: + types: [opened] + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: usthe/issues-translate-action@v2.7 + with: + IS_MODIFY_TITLE: false + # not require, default false, . Decide whether to modify the issue title + # if true, the robot account @Issues-translate-bot must have modification permissions, invite @Issues-translate-bot to your project or use your custom bot. + CUSTOM_BOT_NOTE: Bot detected the issue body's language is not English, translate it automatically. 👯👭🏻🧑‍🤝‍🧑👫🧑🏿‍🤝‍🧑🏻👩🏾‍🤝‍👨🏿👬🏿 + # not require. Customize the translation robot prefix message. diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..bf74a753894fcbbf722812700fdf72a38c9e896c --- /dev/null +++ b/.gitignore @@ -0,0 +1,157 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ +docs/.build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# IDE +.idea/ +.vscode/ + +# macos +*.DS_Store +#data/ + +docs/.build + +# pytorch checkpoint +*.pt + +# ignore version.py generated by setup.py +colossalai/version.py + +# ignore any kernel build files +.o +.so + +# ignore python interface defition file +.pyi + +# ignore coverage test file +coverage.lcov +coverage.xml diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000000000000000000000000000000000000..2f1c34298a50eb7f03b5f2b061de5870e92ccd9d --- /dev/null +++ b/.gitmodules @@ -0,0 +1,7 @@ +[submodule "inference"] + path = inference + url = https://github.com/hpcaitech/EnergonAI.git + branch = main +[submodule "examples/tutorial/fastfold/FastFold"] + path = examples/tutorial/fastfold/FastFold + url = https://github.com/hpcaitech/FastFold diff --git a/.isort.cfg b/.isort.cfg new file mode 100644 index 0000000000000000000000000000000000000000..090aa28e39f32da8c0161d5317c710b6c8781641 --- /dev/null +++ b/.isort.cfg @@ -0,0 +1,5 @@ +[settings] +line_length = 120 +multi_line_output=3 +include_trailing_comma = true +ignore_comments = true diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..725d266375ef42e69f6097f7d61924213d46b8ff --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,31 @@ +repos: + + - repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + name: sort all imports (python) + + - repo: https://github.com/pre-commit/mirrors-yapf + rev: v0.32.0 + hooks: + - id: yapf + name: yapf formatter + args: ['--style=.style.yapf', '--parallel', '--in-place'] + + - repo: https://github.com/pre-commit/mirrors-clang-format + rev: v13.0.1 + hooks: + - id: clang-format + name: clang formatter + + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.3.0 + hooks: + - id: check-yaml + - id: check-merge-conflict + - id: check-case-conflict + - id: trailing-whitespace + - id: end-of-file-fixer + - id: mixed-line-ending + args: ['--fix=lf'] diff --git a/.style.yapf b/.style.yapf new file mode 100644 index 0000000000000000000000000000000000000000..05be0dc6a3a598ebd59ff00cec93ce8b809c78b5 --- /dev/null +++ b/.style.yapf @@ -0,0 +1,5 @@ +[style] +based_on_style = google +spaces_before_comment = 4 +split_before_logical_operator = true +column_limit = 120 diff --git a/CHANGE_LOG.md b/CHANGE_LOG.md new file mode 100644 index 0000000000000000000000000000000000000000..3851376659068ecbe2e9c2e901d0a368a4667440 --- /dev/null +++ b/CHANGE_LOG.md @@ -0,0 +1,38 @@ +# Change Log + +All notable changes to this project will be documented in this file. + +🚩 **We have moved the change log to the GitHub [release page](https://github.com/hpcaitech/ColossalAI/releases)** + +## v0.0.2 | 2022-02 + +### Added + +- Unified distributed layers +- MoE support +- DevOps tools such as github action, code review automation, etc. +- New project official website + +### Changes + +- refactored the APIs for usability, flexibility and modularity +- adapted PyTorch AMP for tensor parallel +- refactored utilities for tensor parallel and pipeline parallel +- Separated benchmarks and examples as independent repositories +- Updated pipeline parallelism to support non-interleaved and interleaved versions +- refactored installation scripts for convenience + +### Fixed + +- zero level 3 runtime error +- incorrect calculation in gradient clipping + + +## v0.0.1 beta | 2021-10 + +The first beta version of Colossal-AI. Thanks to all contributors for the effort to implement the system. + +### Added + +- Initial architecture of the system +- Features such as tensor parallelism, gradient clipping, gradient accumulation diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..00abcf650158c571411f7b78a0ef4c982365b746 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,141 @@ +# Contributing + +Colossal-AI welcomes any constructive contribution from the community and the team is more than willing to work on problems you have encountered to make it a better project. + +## Environment Setup + +To contribute to Colossal-AI, we would like to first guide you to set up a proper development environment so that you can better implement your code. It is good to install this system from source with the `editable` flag (`-e`, for development mode) so that your change to the source code will be reflected in runtime without repeated installation and uninstallation. Here are the steps to set up the development environment. + +1. Uninstall any existing Colossal-AI distribution. + +```shell +pip uninstall colossalai +``` + +2. Clone the repository to local workspace + +```shell +git clone https://github.com/hpcaitech/ColossalAI.git +cd ColossalAI +``` + +3. The *Get Started* section of [official documentation](https://colossalai.org) has provided instructions to build from source. Follow to instruction to build from source, **but replace the last `pip install` statement with the command below by adding the `-e` flag.** + +```shell +pip install -e . +``` + +## Coding Standards + +### Unit Tests +We use [PyTest](https://docs.pytest.org/en/latest/) to execute tests. You can install pytest by `pip install pytest`. As some of the tests require initialization of the distributed backend, GPUs are needed to execute these tests. + +If you only want to run CPU tests, you can run + +```bash +pytest -m cpu tests/ +``` + +If you have 8 GPUs on your machine, you can run the full test + +```bash +pytest tests/ +``` + +If you do not have 8 GPUs on your machine, do not worry. Unit testing will be automatically conducted when you put up a pull request to the main branch. + + +### Code Style + +We have some static checks when you commit your code change, please make sure you can pass all the tests and make sure the coding style meets our requirements. We use pre-commit hook to make sure the code is aligned with the writing standard. To set up the code style checking, you need to follow the steps below. + +```shell +# these commands are executed under the Colossal-AI directory +pip install pre-commit +pre-commit install +``` + +Code format checking will be automatically executed when you commit your changes. + + +## Contribution Guide + +You need to follow these steps below to make contribution to the main repository via pull request. You can learn about the details of pull request [here](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/about-pull-requests). + +### 1. Fork the Official Repository + +Firstly, you need to visit the [Colossal-AI repository](https://github.com/hpcaitech/ColossalAI) and fork into your own account. The `fork` button is at the right top corner of the web page alongside with buttons such as `watch` and `star`. + +Now, you can clone your own forked repository into your local environment. + +```shell +git clone https://github.com//ColossalAI.git +``` + +### 2. Configure Git + +You need to set the official repository as your upstream so that you can synchronize with the latest update in the official repository. You can learn about upstream [here](https://www.atlassian.com/git/tutorials/git-forks-and-upstreams). + +Then add the original repository as upstream + +```shell +cd ColossalAI +git remote add upstream https://github.com/hpcaitech/ColossalAI.git +``` + +you can use the following command to verify that the remote is set. You should see both `origin` and `upstream` in the output. + +```shell +git remote -v +``` + +### 3. Synchronize with Official Repository + +Before you make changes to the codebase, it is always good to fetch the latest updates in the official repository. In order to do so, you can use the commands below. + +```shell +git fetch upstream +git checkout main +git merge upstream/main +git push origin main +``` + +Otherwise, you can click the `fetch upstream` button on the github webpage of the main branch of your forked repository. Then, use these commands to sync. + +``` +git checkout main +git fetch main +``` + +### 4. Choose/Create an Issue for Your Pull Request + +Generally, your code change should be only targeted at one problem. Stacking multiple commits for different problems into one pull request will only make the code review such dire suffering and make the system prone to new bugs as the reviewer may not understand the code logic correctly. Thus, you should choose an existing issue or [create your own issue](https://github.com/hpcaitech/ColossalAI/issues) as your pull request target. If you wish to create a new issue, do use appropriate title and description and add related labels. + + +### 5. Create a New Branch + +You should not make changes to the `main` branch of your forked repository as this might make upstream synchronization difficult. You can create a new branch with the appropriate name. General branch name format should start with `hotfix/` and `feature/`. `hotfix` is for bug fix and `feature` is for addition of a new feature. + + +```shell +git checkout -b +``` + +### 6. Implementation and Code Commit + +Now you can implement your code change in the source code. Remember that you installed the system in development, thus you do not need to uninstall and install to make the code take effect. The code change will be reflected in every new PyThon execution. +You can commit and push the changes to your local repository. The changes should be kept logical, modular and atomic. + +```shell +git add -A +git commit -m "" +git push -u origin +``` + +### 7. Open a Pull Request + +You can now create a pull request on the GitHub webpage of your repository. The source branch is `` of your repository and the target branch should be `main` of `hpcaitech/ColossalAI`. After creating this pull request, you should be able to see it [here](https://github.com/hpcaitech/ColossalAI/pulls). + +Do write clearly the description of your pull request and [link the pull request to your target issue](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue). This will automatically close the issue when the pull request is approved. + +In case of code conflict, you should rebase your branch and resolve the conflicts manually. \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..c7a5bb16880e6a2a6364a092fe94dda29399b9e3 --- /dev/null +++ b/LICENSE @@ -0,0 +1,398 @@ +Copyright 2021- HPC-AI Technology Inc. All rights reserved. + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2021- HPC-AI Technology Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + ## Some of colossal-ai's code is derived from others projects, which is subject to the following copyright notice: + + Copyright 2021 The Alpa team. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://github.com/alpa-projects/alpa/blob/979a45a3e6187df941ef4a4c4c6eea664527d68d/LICENSE + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + ------------------------------------------------- + + Copyright 2018-2020 Philippe Tillet + Copyright 2020-2022 OpenAI + + Permission is hereby granted, free of charge, to any person obtaining + a copy of this software and associated documentation files + (the "Software"), to deal in the Software without restriction, + including without limitation the rights to use, copy, modify, merge, + publish, distribute, sublicense, and/or sell copies of the Software, + and to permit persons to whom the Software is furnished to do so, + subject to the following conditions: + + ---------------- LICENSE FOR Microsoft Deepspeed ---------------- + + MIT License + + Copyright (c) Microsoft Corporation. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE + + ---------------- LICENSE FOR NVIDIA Megatron-LM ---------------- + + Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of NVIDIA CORPORATION nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY + EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR + CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + ---------------- LICENSE FOR NVIDIA Apex ---------------- + + All rights reserved. + + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + + 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + ---------------- LICENSE FOR Facebook Fairscale ---------------- + + Copyright (c) Facebook, Inc. and its affiliates + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America + and IDIAP Research Institute nor the names of its contributors may be + used to endorse or promote products derived from this software without + specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + POSSIBILITY OF SUCH DAMAGE. + + ---------------- LICENSE FOR Flash Attention ---------------- + + BSD 3-Clause License + + Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file. + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + ---------------- LICENSE FOR Facebook xFormers ---------------- + + From xFormers: + + Copyright (c) Facebook, Inc. and its affiliates + + + === + + BSD 3-Clause License + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America + and IDIAP Research Institute nor the names of its contributors may be + used to endorse or promote products derived from this software without + specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + POSSIBILITY OF SUCH DAMAGE. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000000000000000000000000000000000000..ad26b634ac3ee5cbf3339fb5844b59a971ad8c81 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,4 @@ +include *.txt README.md +recursive-include requirements *.txt +recursive-include colossalai *.cpp *.h *.cu *.tr *.cuh *.cc *.pyi +recursive-include op_builder *.py diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..79f733122cb3a14ff27b72889caa710ac5e8a5f6 --- /dev/null +++ b/README.md @@ -0,0 +1,431 @@ +# Colossal-AI +
+ + [![logo](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/colossal-ai_logo_vertical.png)](https://www.colossalai.org/) + + Colossal-AI: Making large AI models cheaper, faster, and more accessible + +

Paper | + Documentation | + Examples | + Forum | + Blog

+ + [![GitHub Repo stars](https://img.shields.io/github/stars/hpcaitech/ColossalAI?style=social)](https://github.com/hpcaitech/ColossalAI/stargazers) + [![Build](https://github.com/hpcaitech/ColossalAI/actions/workflows/build_on_schedule.yml/badge.svg)](https://github.com/hpcaitech/ColossalAI/actions/workflows/build_on_schedule.yml) + [![Documentation](https://readthedocs.org/projects/colossalai/badge/?version=latest)](https://colossalai.readthedocs.io/en/latest/?badge=latest) + [![CodeFactor](https://www.codefactor.io/repository/github/hpcaitech/colossalai/badge)](https://www.codefactor.io/repository/github/hpcaitech/colossalai) + [![HuggingFace badge](https://img.shields.io/badge/%F0%9F%A4%97HuggingFace-Join-yellow)](https://huggingface.co/hpcai-tech) + [![slack badge](https://img.shields.io/badge/Slack-join-blueviolet?logo=slack&)](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w) + [![WeChat badge](https://img.shields.io/badge/微信-加入-green?logo=wechat&)](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png) + + + | [English](README.md) | [中文](docs/README-zh-Hans.md) | + +
+ +## Latest News +* [2023/03] [ColossalChat: An Open-Source Solution for Cloning ChatGPT With a Complete RLHF Pipeline](https://medium.com/@yangyou_berkeley/colossalchat-an-open-source-solution-for-cloning-chatgpt-with-a-complete-rlhf-pipeline-5edf08fb538b) +* [2023/03] [Intel and Colossal-AI Partner to Deliver Cost-Efficient Open-Source Solution for Protein Folding Structure Prediction](https://www.hpc-ai.tech/blog/intel-habana) +* [2023/03] [AWS and Google Fund Colossal-AI with Startup Cloud Programs](https://www.hpc-ai.tech/blog/aws-and-google-fund-colossal-ai-with-startup-cloud-programs) +* [2023/02] [Open Source Solution Replicates ChatGPT Training Process! Ready to go with only 1.6GB GPU Memory](https://www.hpc-ai.tech/blog/colossal-ai-chatgpt) +* [2023/01] [Hardware Savings Up to 46 Times for AIGC and Automatic Parallelism](https://medium.com/pytorch/latest-colossal-ai-boasts-novel-automatic-parallelism-and-offers-savings-up-to-46x-for-stable-1453b48f3f02) +* [2022/11] [Diffusion Pretraining and Hardware Fine-Tuning Can Be Almost 7X Cheaper](https://www.hpc-ai.tech/blog/diffusion-pretraining-and-hardware-fine-tuning-can-be-almost-7x-cheaper) +* [2022/10] [Use a Laptop to Analyze 90% of Proteins, With a Single-GPU Inference Sequence Exceeding 10,000](https://www.hpc-ai.tech/blog/use-a-laptop-to-analyze-90-of-proteins-with-a-single-gpu-inference-sequence-exceeding) +* [2022/09] [HPC-AI Tech Completes $6 Million Seed and Angel Round Fundraising](https://www.hpc-ai.tech/blog/hpc-ai-tech-completes-6-million-seed-and-angel-round-fundraising-led-by-bluerun-ventures-in-the) + +## Table of Contents + + +## Why Colossal-AI +
+ + + + + Prof. James Demmel (UC Berkeley): Colossal-AI makes training AI models efficient, easy, and scalable. +
+ +

(back to top)

+ +## Features + +Colossal-AI provides a collection of parallel components for you. We aim to support you to write your +distributed deep learning models just like how you write your model on your laptop. We provide user-friendly tools to kickstart +distributed training and inference in a few lines. + +- Parallelism strategies + - Data Parallelism + - Pipeline Parallelism + - 1D, [2D](https://arxiv.org/abs/2104.05343), [2.5D](https://arxiv.org/abs/2105.14500), [3D](https://arxiv.org/abs/2105.14450) Tensor Parallelism + - [Sequence Parallelism](https://arxiv.org/abs/2105.13120) + - [Zero Redundancy Optimizer (ZeRO)](https://arxiv.org/abs/1910.02054) + - [Auto-Parallelism](https://arxiv.org/abs/2302.02599) + +- Heterogeneous Memory Management + - [PatrickStar](https://arxiv.org/abs/2108.05818) + +- Friendly Usage + - Parallelism based on the configuration file + +- Inference + - [Energon-AI](https://github.com/hpcaitech/EnergonAI) + +

(back to top)

+ +## Colossal-AI in the Real World + +### ColossalChat + + + +[ColossalChat](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat): An open-source solution for cloning [ChatGPT](https://openai.com/blog/chatgpt/) with a complete RLHF pipeline. [[code]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat) [[blog]](https://medium.com/@yangyou_berkeley/colossalchat-an-open-source-solution-for-cloning-chatgpt-with-a-complete-rlhf-pipeline-5edf08fb538b) [[demo]](https://chat.colossalai.org) + +

+ +

+ +- Up to 7.73 times faster for single server training and 1.42 times faster for single-GPU inference + +

+ +

+ +- Up to 10.3x growth in model capacity on one GPU +- A mini demo training process requires only 1.62GB of GPU memory (any consumer-grade GPU) + +

+ +

+ +- Increase the capacity of the fine-tuning model by up to 3.7 times on a single GPU +- Keep at a sufficiently high running speed + +

(back to top)

+ + +### AIGC +Acceleration of AIGC (AI-Generated Content) models such as [Stable Diffusion v1](https://github.com/CompVis/stable-diffusion) and [Stable Diffusion v2](https://github.com/Stability-AI/stablediffusion). +

+ +

+ +- [Training](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion): Reduce Stable Diffusion memory consumption by up to 5.6x and hardware cost by up to 46x (from A100 to RTX3060). + +

+ +

+ +- [DreamBooth Fine-tuning](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/dreambooth): Personalize your model using just 3-5 images of the desired subject. + +

+ +

+ +- [Inference](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion): Reduce inference GPU memory consumption by 2.5x. + + +

(back to top)

+ +### Biomedicine +Acceleration of [AlphaFold Protein Structure](https://alphafold.ebi.ac.uk/) + +

+ +

+ +- [FastFold](https://github.com/hpcaitech/FastFold): Accelerating training and inference on GPU Clusters, faster data processing, inference sequence containing more than 10000 residues. + +

+ +

+ +- [FastFold with Intel](https://github.com/hpcaitech/FastFold): 3x inference acceleration and 39% cost reduce. + +

+ +

+ +- [xTrimoMultimer](https://github.com/biomap-research/xTrimoMultimer): accelerating structure prediction of protein monomers and multimer by 11x. + + +

(back to top)

+ +## Parallel Training Demo + +### GPT-3 +

+ +

+ +- Save 50% GPU resources and 10.7% acceleration + +### GPT-2 + + +- 11x lower GPU memory consumption, and superlinear scaling efficiency with Tensor Parallelism + + + +- 24x larger model size on the same hardware +- over 3x acceleration +### BERT + + +- 2x faster training, or 50% longer sequence length + +### PaLM +- [PaLM-colossalai](https://github.com/hpcaitech/PaLM-colossalai): Scalable implementation of Google's Pathways Language Model ([PaLM](https://ai.googleblog.com/2022/04/pathways-language-model-palm-scaling-to.html)). + +### OPT + + +- [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model released by Meta, which stimulates AI programmers to perform various downstream tasks and application deployments because of public pre-trained model weights. +- 45% speedup fine-tuning OPT at low cost in lines. [[Example]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/opt) [[Online Serving]](https://colossalai.org/docs/advanced_tutorials/opt_service) + +Please visit our [documentation](https://www.colossalai.org/) and [examples](https://github.com/hpcaitech/ColossalAI/tree/main/examples) for more details. + +### ViT +

+ +

+ +- 14x larger batch size, and 5x faster training for Tensor Parallelism = 64 + +### Recommendation System Models +- [Cached Embedding](https://github.com/hpcaitech/CachedEmbedding), utilize software cache to train larger embedding tables with a smaller GPU memory budget. + +

(back to top)

+ +## Single GPU Training Demo + +### GPT-2 +

+ +

+ +- 20x larger model size on the same hardware + +

+ +

+ +- 120x larger model size on the same hardware (RTX 3080) + +### PaLM +

+ +

+ +- 34x larger model size on the same hardware + +

(back to top)

+ + +## Inference (Energon-AI) Demo + +

+ +

+ +- [Energon-AI](https://github.com/hpcaitech/EnergonAI): 50% inference acceleration on the same hardware + +

+ +

+ +- [OPT Serving](https://colossalai.org/docs/advanced_tutorials/opt_service): Try 175-billion-parameter OPT online services + +

+ +

+ +- [BLOOM](https://github.com/hpcaitech/EnergonAI/tree/main/examples/bloom): Reduce hardware deployment costs of 176-billion-parameter BLOOM by more than 10 times. + +

(back to top)

+ +## Installation + +Requirements: +- PyTorch >= 1.11 (PyTorch 2.x in progress) +- Python >= 3.7 +- CUDA >= 11.0 +- [NVIDIA GPU Compute Capability](https://developer.nvidia.com/cuda-gpus) >= 7.0 (V100/RTX20 and higher) +- Linux OS + +If you encounter any problem with installation, you may want to raise an [issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose) in this repository. + +### Install from PyPI + +You can easily install Colossal-AI with the following command. **By default, we do not build PyTorch extensions during installation.** + +```bash +pip install colossalai +``` + +**Note: only Linux is supported for now.** + +However, if you want to build the PyTorch extensions during installation, you can set `CUDA_EXT=1`. + +```bash +CUDA_EXT=1 pip install colossalai +``` + +**Otherwise, CUDA kernels will be built during runtime when you actually need them.** + +We also keep releasing the nightly version to PyPI every week. This allows you to access the unreleased features and bug fixes in the main branch. +Installation can be made via + +```bash +pip install colossalai-nightly +``` + +### Download From Source + +> The version of Colossal-AI will be in line with the main branch of the repository. Feel free to raise an issue if you encounter any problems. :) + +```shell +git clone https://github.com/hpcaitech/ColossalAI.git +cd ColossalAI + +# install colossalai +pip install . +``` + +By default, we do not compile CUDA/C++ kernels. ColossalAI will build them during runtime. +If you want to install and enable CUDA kernel fusion (compulsory installation when using fused optimizer): + +```shell +CUDA_EXT=1 pip install . +``` + +

(back to top)

+ +## Use Docker + +### Pull from DockerHub + +You can directly pull the docker image from our [DockerHub page](https://hub.docker.com/r/hpcaitech/colossalai). The image is automatically uploaded upon release. + + +### Build On Your Own + +Run the following command to build a docker image from Dockerfile provided. + +> Building Colossal-AI from scratch requires GPU support, you need to use Nvidia Docker Runtime as the default when doing `docker build`. More details can be found [here](https://stackoverflow.com/questions/59691207/docker-build-with-nvidia-runtime). +> We recommend you install Colossal-AI from our [project page](https://www.colossalai.org) directly. + + +```bash +cd ColossalAI +docker build -t colossalai ./docker +``` + +Run the following command to start the docker container in interactive mode. + +```bash +docker run -ti --gpus all --rm --ipc=host colossalai bash +``` + +

(back to top)

+ +## Community + +Join the Colossal-AI community on [Forum](https://github.com/hpcaitech/ColossalAI/discussions), +[Slack](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w), +and [WeChat(微信)](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png "qrcode") to share your suggestions, feedback, and questions with our engineering team. + +## Contributing +Referring to the successful attempts of [BLOOM](https://bigscience.huggingface.co/) and [Stable Diffusion](https://en.wikipedia.org/wiki/Stable_Diffusion), any and all developers and partners with computing powers, datasets, models are welcome to join and build the Colossal-AI community, making efforts towards the era of big AI models! + +You may contact us or participate in the following ways: +1. [Leaving a Star ⭐](https://github.com/hpcaitech/ColossalAI/stargazers) to show your like and support. Thanks! +2. Posting an [issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose), or submitting a PR on GitHub follow the guideline in [Contributing](https://github.com/hpcaitech/ColossalAI/blob/main/CONTRIBUTING.md) +3. Send your official proposal to email contact@hpcaitech.com + +Thanks so much to all of our amazing contributors! + + + + + + +

(back to top)

+ + +## CI/CD + +We leverage the power of [GitHub Actions](https://github.com/features/actions) to automate our development, release and deployment workflows. Please check out this [documentation](.github/workflows/README.md) on how the automated workflows are operated. + + +## Cite Us + +This project is inspired by some related projects (some by our team and some by other organizations). We would like to credit these amazing projects as listed in the [Reference List](./docs/REFERENCE.md). + +To cite this project, you can use the following BibTeX citation. + +``` +@article{bian2021colossal, + title={Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training}, + author={Bian, Zhengda and Liu, Hongxin and Wang, Boxiang and Huang, Haichen and Li, Yongbin and Wang, Chuanrui and Cui, Fan and You, Yang}, + journal={arXiv preprint arXiv:2110.14883}, + year={2021} +} +``` + +Colossal-AI has been accepted as official tutorial by top conferences [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), [PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/), etc. + +

(back to top)

diff --git a/applications/Chat/.gitignore b/applications/Chat/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..2b9b4f345d0fae7bc3872d3c723d2698d201b8b8 --- /dev/null +++ b/applications/Chat/.gitignore @@ -0,0 +1,148 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ +docs/.build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# IDE +.idea/ +.vscode/ + +# macos +*.DS_Store +#data/ + +docs/.build + +# pytorch checkpoint +*.pt + +# wandb log +example/wandb/ + +examples/awesome-chatgpt-prompts/ \ No newline at end of file diff --git a/applications/Chat/LICENSE b/applications/Chat/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..0528c89ea9ecd51713b60db68ae69702d8d164f7 --- /dev/null +++ b/applications/Chat/LICENSE @@ -0,0 +1,202 @@ +Copyright 2021- HPC-AI Technology Inc. All rights reserved. + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2021- HPC-AI Technology Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/applications/Chat/README.md b/applications/Chat/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9ba831973b6c912a2d8435c382e4c741dd24b6e4 --- /dev/null +++ b/applications/Chat/README.md @@ -0,0 +1,450 @@ +

+ +
+ ColossalChat +

+ + +## Table of Contents + +- [Table of Contents](#table-of-contents) +- [What is ColossalChat and Coati ?](#what-is-colossalchat-and-coati-) +- [Online demo](#online-demo) +- [Install](#install) + - [Install the environment](#install-the-environment) + - [Install the Transformers](#install-the-transformers) +- [How to use?](#how-to-use) + - [Supervised datasets collection](#supervised-datasets-collection) + - [RLHF Training Stage1 - Supervised instructs tuning](#RLHF-training-stage1---supervised-instructs-tuning) + - [RLHF Training Stage2 - Training reward model](#RLHF-training-stage2---training-reward-model) + - [RLHF Training Stage3 - Training model with reinforcement learning by human feedback](#RLHF-training-stage3---training-model-with-reinforcement-learning-by-human-feedback) + - [Inference Quantization and Serving - After Training](#inference-quantization-and-serving---after-training) +- [Coati7B examples](#coati7b-examples) + - [Generation](#generation) + - [Open QA](#open-qa) + - [Limitation for LLaMA-finetuned models](#limitation) + - [Limitation of dataset](#limitation) +- [FAQ](#faq) + - [How to save/load checkpoint](#faq) + - [How to train with limited resources](#faq) +- [The Plan](#the-plan) + - [Real-time progress](#real-time-progress) +- [Invitation to open-source contribution](#invitation-to-open-source-contribution) +- [Quick Preview](#quick-preview) +- [Authors](#authors) +- [Citations](#citations) +- [Licenses](#licenses) +--- +## What is ColossalChat and Coati ? + +[ColossalChat](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat) is the project to implement LLM with RLHF, powered by the [Colossal-AI](https://github.com/hpcaitech/ColossalAI) project. + +Coati stands for `ColossalAI Talking Intelligence`. It is the name for the module implemented in this project and is also the name of the large language model developed by the ColossalChat project. + +The Coati package provides a unified large language model framework that has implemented the following functions +- Supports comprehensive large-model training acceleration capabilities for ColossalAI, without requiring knowledge of complex distributed training algorithms +- Supervised datasets collection +- Supervised instructions fine-tuning +- Training reward model +- Reinforcement learning with human feedback +- Quantization inference +- Fast model deploying +- Perfectly integrated with the Hugging Face ecosystem, a high degree of model customization + +
+

+ +

+ + Image source: https://openai.com/blog/chatgpt +
+ +**As Colossal-AI is undergoing some major updates, this project will be actively maintained to stay in line with the Colossal-AI project.** + + +More details can be found in the latest news. +* [2023/03] [ColossalChat: An Open-Source Solution for Cloning ChatGPT With a Complete RLHF Pipeline](https://medium.com/@yangyou_berkeley/colossalchat-an-open-source-solution-for-cloning-chatgpt-with-a-complete-rlhf-pipeline-5edf08fb538b) +* [2023/02] [Open Source Solution Replicates ChatGPT Training Process! Ready to go with only 1.6GB GPU Memory](https://www.hpc-ai.tech/blog/colossal-ai-chatgpt) + +## Online demo +You can experience the performance of Coati7B on this page. + +[chat.colossalai.org](https://chat.colossalai.org/) + +Due to resource constraints, we will only provide this service from 29th Mar 2023 to 5 April 2023. However, we have provided the inference code in the [inference](./inference/) folder. The WebUI will be open-sourced soon as well. + +> Warning: Due to model and dataset size limitations, Coati is just a baby model, Coati7B may output incorrect information and lack the ability for multi-turn dialogue. There is still significant room for improvement. +## Install + +### Install the environment + +```shell +conda create -n coati +conda activate coati +git clone https://github.com/hpcaitech/ColossalAI.git +cd ColossalAI/applications/Chat +pip install . +``` + +### Install the Transformers +Given Hugging Face hasn't officially supported the LLaMA models, We fork a branch of Transformers that can be compatible with our code + +```shell +git clone https://github.com/hpcaitech/transformers +cd transformers +pip install . +``` + +## How to use? + +### Supervised datasets collection + +we collected 104K bilingual datasets of Chinese and English, and you can find the datasets in this repo +[InstructionWild](https://github.com/XueFuzhao/InstructionWild) + +Here is how we collected the data +

+ +

+ +### RLHF Training Stage1 - Supervised instructs tuning + +Stage1 is supervised instructs fine-tuning, which uses the datasets mentioned earlier to fine-tune the model. + +You can run the `examples/train_sft.sh` to start a supervised instructs fine-tuning. + +### RLHF Training Stage2 - Training reward model + +Stage2 trains a reward model, which obtains corresponding scores by manually ranking different outputs for the same prompt and supervises the training of the reward model + +You can run the `examples/train_rm.sh` to start a reward model training. + +### RLHF Training Stage3 - Training model with reinforcement learning by human feedback + +Stage3 uses reinforcement learning algorithm, which is the most complex part of the training process: + +

+ +

+ +You can run the `examples/train_prompts.sh` to start training PPO with human feedback. + +For more details, see [`examples/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples). + +### Inference Quantization and Serving - After Training + +We provide an online inference server and a benchmark. We aim to run inference on single GPU, so quantization is essential when using large models. + +We support 8-bit quantization (RTN), 4-bit quantization (GPTQ), and FP16 inference. You can +Online inference server scripts can help you deploy your own services. + +For more details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/inference). + +## Coati7B examples + +### Generation + +
E-mail + +![phd](https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/Phd.png) +
+ +
coding + +![sort](https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/quick_sort.png) + +
+ +
regex + +![regex](https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/regex.png) + +
+ +
Tex + +![tex](https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/tex.png) + +
+ +
writing + +![writing](https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/writing.png) + +
+ +
Table + +![Table](https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/table.png) + +
+ +### Open QA +
Game + +![Game](https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/game.png) + +
+ +
Travel + +![Travel](https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/travel.png) + +
+ +
Physical + +![Physical](https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/physical.png) + +
+ +
Chemical + +![Chemical](https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/chemical.png) + +
+ +
Economy + +![Economy](https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/economy.png) + +
+ +You can find more examples in this [repo](https://github.com/XueFuzhao/InstructionWild/blob/main/comparison.md). + +### Limitation +
Limitation for LLaMA-finetuned models +- Both Alpaca and ColossalChat are based on LLaMA. It is hard to compensate for the missing knowledge in the pre-training stage. +- Lack of counting ability: Cannot count the number of items in a list. +- Lack of Logics (reasoning and calculation) +- Tend to repeat the last sentence (fail to produce the end token). +- Poor multilingual results: LLaMA is mainly trained on English datasets (Generation performs better than QA). +
+ +
Limitation of dataset +- Lack of summarization ability: No such instructions in finetune datasets. +- Lack of multi-turn chat: No such instructions in finetune datasets +- Lack of self-recognition: No such instructions in finetune datasets +- Lack of Safety: + - When the input contains fake facts, the model makes up false facts and explanations. + - Cannot abide by OpenAI's policy: When generating prompts from OpenAI API, it always abides by its policy. So no violation case is in the datasets. +
+ +## FAQ + +
How to save/load checkpoint + +We have integrated the Transformers save and load pipeline, allowing users to freely call Hugging Face's language models and save them in the HF format. + +``` +from coati.models.llama import LlamaLM +from coati.trainer import SFTTrainer + +model = LlamaLM(pretrained=args.pretrain) +tokenizer = AutoTokenizer.from_pretrained(args.pretrain) + +(model, optim) = strategy.prepare((model, optim)) +trainer = SFTTrainer(model=model, + strategy=strategy, + optim=optim, + train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, + batch_size=args.batch_size, + max_epochs=args.max_epochs, + accumulation_steps = args.accumulation_steps +) + +trainer.fit() +# this saves in pytorch format +strategy.save_model(model, args.save_path, only_rank0=True) + +# this saves in HF format. ColossalAI strategy with stage-3 doesn't support this method +strategy.save_pretrained(model, args.save_path, only_rank0=True, tokenizer=tokenizer) +``` + +
+ +
How to train with limited resources + +Here are some examples that can allow you to train a 7B model on a single or multiple consumer-grade GPUs. + +If you only have a single 24G GPU, you can use the following script. `batch_size`, `lora_rank` and `grad_checkpoint` are the most important parameters to successfully train the model. +``` +torchrun --standalone --nproc_per_node=1 train_sft.py \ + --pretrain "/path/to/LLaMa-7B/" \ + --model 'llama' \ + --strategy naive \ + --log_interval 10 \ + --save_path /path/to/Coati-7B \ + --dataset /path/to/data.json \ + --batch_size 1 \ + --accumulation_steps 8 \ + --lr 2e-5 \ + --max_datasets_size 512 \ + --max_epochs 1 \ + --lora_rank 16 \ + --grad_checkpoint +``` + +`colossalai_gemini` strategy can enable a single 24G GPU to train the whole model without using LoRA if you have sufficient CPU memory. You can use the following script. +``` +torchrun --standalone --nproc_per_node=1 train_sft.py \ + --pretrain "/path/to/LLaMa-7B/" \ + --model 'llama' \ + --strategy colossalai_gemini \ + --log_interval 10 \ + --save_path /path/to/Coati-7B \ + --dataset /path/to/data.json \ + --batch_size 1 \ + --accumulation_steps 8 \ + --lr 2e-5 \ + --max_datasets_size 512 \ + --max_epochs 1 \ + --grad_checkpoint +``` + +If you have 4x32 GB GPUs, you can even train the whole 7B model using our `colossalai_zero2_cpu` strategy! The script is given as follows. +``` +torchrun --standalone --nproc_per_node=4 train_sft.py \ + --pretrain "/path/to/LLaMa-7B/" \ + --model 'llama' \ + --strategy colossalai_zero2_cpu \ + --log_interval 10 \ + --save_path /path/to/Coati-7B \ + --dataset /path/to/data.json \ + --batch_size 1 \ + --accumulation_steps 8 \ + --lr 2e-5 \ + --max_datasets_size 512 \ + --max_epochs 1 \ + --grad_checkpoint +``` +
+ + +## The Plan + +- [x] implement PPO fine-tuning +- [x] implement training reward model +- [x] support LoRA +- [x] support inference +- [x] support llama from [facebook](https://github.com/facebookresearch/llama) +- [x] implement PPO-ptx fine-tuning +- [ ] integrate with Ray +- [ ] support more RL paradigms, like Implicit Language Q-Learning (ILQL), +- [ ] support chain-of-thought by [langchain](https://github.com/hwchase17/langchain) + +### Real-time progress +You will find our progress in github project broad + +[Coati](https://github.com/orgs/hpcaitech/projects/17/views/1) + +## Invitation to open-source contribution +Referring to the successful attempts of [BLOOM](https://bigscience.huggingface.co/) and [Stable Diffusion](https://en.wikipedia.org/wiki/Stable_Diffusion), any and all developers and partners with computing powers, datasets, models are welcome to join and build the Colossal-AI community, making efforts towards the era of big AI models from the starting point of replicating ChatGPT! + +You may contact us or participate in the following ways: +1. [Leaving a Star ⭐](https://github.com/hpcaitech/ColossalAI/stargazers) to show your like and support. Thanks! +2. Posting an [issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose), or submitting a PR on GitHub follow the guideline in [Contributing](https://github.com/hpcaitech/ColossalAI/blob/main/CONTRIBUTING.md). +3. Join the Colossal-AI community on +[Slack](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w), +and [WeChat(微信)](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png "qrcode") to share your ideas. +4. Send your official proposal to email contact@hpcaitech.com + +Thanks so much to all of our amazing contributors! + +## Quick Preview + + +- An open-source low cost solution for cloning [ChatGPT](https://openai.com/blog/chatgpt/) with a complete RLHF pipeline. [[demo]](https://chat.colossalai.org) + +

+ +

+ +- Up to 7.73 times faster for single server training and 1.42 times faster for single-GPU inference + +

+ +

+ +- Up to 10.3x growth in model capacity on one GPU +- A mini demo training process requires only 1.62GB of GPU memory (any consumer-grade GPU) + +

+ +

+ +- Increase the capacity of the fine-tuning model by up to 3.7 times on a single GPU +- Keep in a sufficiently high running speed + +| Model Pair | Alpaca-7B ⚔ Coati-7B | Coati-7B ⚔ Alpaca-7B | +| :-----------: | :------------------: | :------------------: | +| Better Cases | 38 ⚔ **41** | **45** ⚔ 33 | +| Win Rate | 48% ⚔ **52%** | **58%** ⚔ 42% | +| Average Score | 7.06 ⚔ **7.13** | **7.31** ⚔ 6.82 | +- Our Coati-7B model performs better than Alpaca-7B when using GPT-4 to evaluate model performance. The Coati-7B model we evaluate is an old version we trained a few weeks ago and the new version is around the corner. + +## Authors + +Coati is developed by ColossalAI Team: +- [Fazzie](https://fazzie-key.cool/about/index.html) +- [FrankLeeeee](https://github.com/FrankLeeeee) +- [BlueRum](https://github.com/ht-zhou) +- [ver217](https://github.com/ver217) +- [ofey404](https://github.com/ofey404) + +The Phd student from [(HPC-AI) Lab](https://ai.comp.nus.edu.sg/) also contributed a lot to this project. +- [Zangwei Zheng](https://github.com/zhengzangw) +- [Xue Fuzhao](https://github.com/XueFuzhao) + +## Citations + +```bibtex +@article{Hu2021LoRALA, + title = {LoRA: Low-Rank Adaptation of Large Language Models}, + author = {Edward J. Hu and Yelong Shen and Phillip Wallis and Zeyuan Allen-Zhu and Yuanzhi Li and Shean Wang and Weizhu Chen}, + journal = {ArXiv}, + year = {2021}, + volume = {abs/2106.09685} +} + +@article{ouyang2022training, + title={Training language models to follow instructions with human feedback}, + author={Ouyang, Long and Wu, Jeff and Jiang, Xu and Almeida, Diogo and Wainwright, Carroll L and Mishkin, Pamela and Zhang, Chong and Agarwal, Sandhini and Slama, Katarina and Ray, Alex and others}, + journal={arXiv preprint arXiv:2203.02155}, + year={2022} +} + +@article{touvron2023llama, + title={LLaMA: Open and Efficient Foundation Language Models}, + author={Touvron, Hugo and Lavril, Thibaut and Izacard, Gautier and Martinet, Xavier and Lachaux, Marie-Anne and Lacroix, Timoth{\'e}e and Rozi{\`e}re, Baptiste and Goyal, Naman and Hambro, Eric and Azhar, Faisal and Rodriguez, Aurelien and Joulin, Armand and Grave, Edouard and Lample, Guillaume}, + journal={arXiv preprint arXiv:2302.13971}, + year={2023} +} + +@misc{alpaca, + author = {Rohan Taori and Ishaan Gulrajani and Tianyi Zhang and Yann Dubois and Xuechen Li and Carlos Guestrin and Percy Liang and Tatsunori B. Hashimoto }, + title = {Stanford Alpaca: An Instruction-following LLaMA model}, + year = {2023}, + publisher = {GitHub}, + journal = {GitHub repository}, + howpublished = {\url{https://github.com/tatsu-lab/stanford_alpaca}}, +} + +@misc{instructionwild, + author = {Fuzhao Xue and Zangwei Zheng and Yang You }, + title = {Instruction in the Wild: A User-based Instruction Dataset}, + year = {2023}, + publisher = {GitHub}, + journal = {GitHub repository}, + howpublished = {\url{https://github.com/XueFuzhao/InstructionWild}}, +} +``` + +## Licenses + +Coati is licensed under the [Apache 2.0 License](LICENSE). diff --git a/applications/Chat/benchmarks/README.md b/applications/Chat/benchmarks/README.md new file mode 100644 index 0000000000000000000000000000000000000000..bc8ad8ba98165ebfefd842e6deea27120bdb547e --- /dev/null +++ b/applications/Chat/benchmarks/README.md @@ -0,0 +1,35 @@ +# Benchmarks + +## Benchmark OPT with LoRA on dummy prompt data + +We provide various OPT models (string in parentheses is the corresponding model name used in this script): + +- OPT-125M (125m) +- OPT-350M (350m) +- OPT-700M (700m) +- OPT-1.3B (1.3b) +- OPT-2.7B (2.7b) +- OPT-3.5B (3.5b) +- OPT-5.5B (5.5b) +- OPT-6.7B (6.7b) +- OPT-10B (10b) +- OPT-13B (13b) + +We also provide various training strategies: + +- ddp: torch DDP +- colossalai_gemini: ColossalAI GeminiDDP with `placement_policy="cuda"`, like zero3 +- colossalai_gemini_cpu: ColossalAI GeminiDDP with `placement_policy="cpu"`, like zero3-offload +- colossalai_zero2: ColossalAI zero2 +- colossalai_zero2_cpu: ColossalAI zero2-offload +- colossalai_zero1: ColossalAI zero1 +- colossalai_zero1_cpu: ColossalAI zero1-offload + +We only support `torchrun` to launch now. E.g. + +```shell +# run OPT-125M with no lora (lora_rank=0) on single-node single-GPU with min batch size +torchrun --standalone --nproc_per_node 1 benchmark_opt_lora_dummy.py --model 125m --critic_model 125m --strategy ddp --experience_batch_size 1 --train_batch_size 1 --lora_rank 0 +# run Actor (OPT-1.3B) and Critic (OPT-350M) with lora_rank=4 on single-node 4-GPU +torchrun --standalone --nproc_per_node 4 benchmark_opt_lora_dummy.py --model 1.3b --critic_model 350m --strategy colossalai_zero2 --lora_rank 4 +``` diff --git a/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py b/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py new file mode 100644 index 0000000000000000000000000000000000000000..7a47624f74d87f188994cf7ee59e5f2d1f2b0b8b --- /dev/null +++ b/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py @@ -0,0 +1,194 @@ +import argparse +from copy import deepcopy + +import torch +import torch.distributed as dist +import torch.nn as nn +from coati.models.base import RewardModel +from coati.models.opt import OPTActor, OPTCritic +from coati.trainer import PPOTrainer +from coati.trainer.callbacks import PerformanceEvaluator +from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy +from torch.optim import Adam +from torch.utils.data import DataLoader +from transformers import AutoTokenizer +from transformers.models.opt.configuration_opt import OPTConfig + +from colossalai.nn.optimizer import HybridAdam + + +def get_model_numel(model: nn.Module, strategy: Strategy) -> int: + numel = sum(p.numel() for p in model.parameters()) + if isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3 and strategy.shard_init: + numel *= dist.get_world_size() + return numel + + +def preprocess_batch(samples) -> dict: + input_ids = torch.stack(samples) + attention_mask = torch.ones_like(input_ids, dtype=torch.long) + return {'input_ids': input_ids, 'attention_mask': attention_mask} + + +def print_rank_0(*args, **kwargs) -> None: + if dist.get_rank() == 0: + print(*args, **kwargs) + + +def print_model_numel(model_dict: dict) -> None: + B = 1024**3 + M = 1024**2 + K = 1024 + outputs = '' + for name, numel in model_dict.items(): + outputs += f'{name}: ' + if numel >= B: + outputs += f'{numel / B:.2f} B\n' + elif numel >= M: + outputs += f'{numel / M:.2f} M\n' + elif numel >= K: + outputs += f'{numel / K:.2f} K\n' + else: + outputs += f'{numel}\n' + print_rank_0(outputs) + + +def get_gpt_config(model_name: str) -> OPTConfig: + model_map = { + '125m': OPTConfig.from_pretrained('facebook/opt-125m'), + '350m': OPTConfig(hidden_size=1024, ffn_dim=4096, num_hidden_layers=24, num_attention_heads=16), + '700m': OPTConfig(hidden_size=1280, ffn_dim=5120, num_hidden_layers=36, num_attention_heads=20), + '1.3b': OPTConfig.from_pretrained('facebook/opt-1.3b'), + '2.7b': OPTConfig.from_pretrained('facebook/opt-2.7b'), + '3.5b': OPTConfig(hidden_size=3072, ffn_dim=12288, num_hidden_layers=32, num_attention_heads=32), + '5.5b': OPTConfig(hidden_size=3840, ffn_dim=15360, num_hidden_layers=32, num_attention_heads=32), + '6.7b': OPTConfig.from_pretrained('facebook/opt-6.7b'), + '10b': OPTConfig(hidden_size=5120, ffn_dim=20480, num_hidden_layers=32, num_attention_heads=32), + '13b': OPTConfig.from_pretrained('facebook/opt-13b'), + } + try: + return model_map[model_name] + except KeyError: + raise ValueError(f'Unknown model "{model_name}"') + + +def main(args): + if args.strategy == 'ddp': + strategy = DDPStrategy() + elif args.strategy == 'colossalai_gemini': + strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5) + elif args.strategy == 'colossalai_gemini_cpu': + strategy = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5) + elif args.strategy == 'colossalai_zero2': + strategy = ColossalAIStrategy(stage=2, placement_policy='cuda') + elif args.strategy == 'colossalai_zero2_cpu': + strategy = ColossalAIStrategy(stage=2, placement_policy='cpu') + elif args.strategy == 'colossalai_zero1': + strategy = ColossalAIStrategy(stage=1, placement_policy='cuda') + elif args.strategy == 'colossalai_zero1_cpu': + strategy = ColossalAIStrategy(stage=1, placement_policy='cpu') + else: + raise ValueError(f'Unsupported strategy "{args.strategy}"') + + torch.cuda.set_per_process_memory_fraction(args.cuda_mem_frac) + + model_config = get_gpt_config(args.model) + critic_config = get_gpt_config(args.critic_model) + with strategy.model_init_context(): + actor = OPTActor(config=model_config, lora_rank=args.lora_rank).cuda() + critic = OPTCritic(config=critic_config, lora_rank=args.lora_rank).cuda() + + initial_model = deepcopy(actor).cuda().half() + reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda().half() + + if args.use_kernels: + from coati.kernels import convert_to_xformer_model + actor, critic, initial_model, reward_model = map(convert_to_xformer_model, + (actor, critic, initial_model, reward_model)) + + actor_numel = get_model_numel(actor, strategy) + critic_numel = get_model_numel(critic, strategy) + initial_model_numel = get_model_numel(initial_model, strategy) + reward_model_numel = get_model_numel(reward_model, strategy) + print_model_numel({ + 'Actor': actor_numel, + 'Critic': critic_numel, + 'Initial model': initial_model_numel, + 'Reward model': reward_model_numel + }) + performance_evaluator = PerformanceEvaluator(actor_numel, + critic_numel, + initial_model_numel, + reward_model_numel, + enable_grad_checkpoint=False, + ignore_episodes=1) + + if args.strategy.startswith('colossalai'): + actor_optim = HybridAdam(actor.parameters(), lr=5e-6) + critic_optim = HybridAdam(critic.parameters(), lr=5e-6) + else: + actor_optim = Adam(actor.parameters(), lr=5e-6) + critic_optim = Adam(critic.parameters(), lr=5e-6) + + tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m') + tokenizer.pad_token = tokenizer.eos_token + + (actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim)) + + trainer = PPOTrainer(strategy, + actor, + critic, + reward_model, + initial_model, + actor_optim, + critic_optim, + ptx_coef=0, + max_epochs=args.max_epochs, + train_batch_size=args.train_batch_size, + offload_inference_models=args.offload_inference_models, + max_length=512, + do_sample=True, + temperature=1.0, + top_k=50, + use_cache=True, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + callbacks=[performance_evaluator]) + + random_prompts = torch.randint(tokenizer.vocab_size, (1000, 256), device=torch.cuda.current_device()) + dataloader = DataLoader(random_prompts, + batch_size=args.experience_batch_size, + shuffle=True, + collate_fn=preprocess_batch) + + trainer.fit(dataloader, + None, + num_episodes=args.num_episodes, + max_timesteps=args.max_timesteps, + update_timesteps=args.update_timesteps) + + print_rank_0(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--model', default='125m') + parser.add_argument('--critic_model', default='125m') + parser.add_argument('--strategy', + choices=[ + 'ddp', 'colossalai_gemini', 'colossalai_gemini_cpu', 'colossalai_zero2', + 'colossalai_zero2_cpu', 'colossalai_zero1', 'colossalai_zero1_cpu' + ], + default='ddp') + parser.add_argument('--num_episodes', type=int, default=3) + parser.add_argument('--max_timesteps', type=int, default=8) + parser.add_argument('--update_timesteps', type=int, default=8) + parser.add_argument('--max_epochs', type=int, default=1) + parser.add_argument('--train_batch_size', type=int, default=8) + parser.add_argument('--experience_batch_size', type=int, default=8) + parser.add_argument('--lora_rank', type=int, default=0) + parser.add_argument('--cuda_mem_frac', type=float, default=1.0) + parser.add_argument('--offload_inference_models', action='store_true', default=False) + parser.add_argument('--use_kernels', action='store_true', default=False) + args = parser.parse_args() + main(args) diff --git a/applications/Chat/coati/__init__.py b/applications/Chat/coati/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/applications/Chat/coati/dataset/__init__.py b/applications/Chat/coati/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f650668e90b0feddbf6b20f1e0fc084a61ac4fc5 --- /dev/null +++ b/applications/Chat/coati/dataset/__init__.py @@ -0,0 +1,9 @@ +from .prompt_dataset import PromptDataset +from .reward_dataset import HhRlhfDataset, RmStaticDataset +from .sft_dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset +from .utils import is_rank_0 + +__all__ = [ + 'RmStaticDataset', 'HhRlhfDataset', 'is_rank_0', 'SFTDataset', 'SupervisedDataset', + 'DataCollatorForSupervisedDataset', 'PromptDataset' +] diff --git a/applications/Chat/coati/dataset/prompt_dataset.py b/applications/Chat/coati/dataset/prompt_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..f8ab2346c4b79c31036c0f5904a3344d6729149e --- /dev/null +++ b/applications/Chat/coati/dataset/prompt_dataset.py @@ -0,0 +1,51 @@ +import copy +import random +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Callable, Dict, Sequence + +import torch +import torch.distributed as dist +import transformers +from torch.utils.data import Dataset +from tqdm import tqdm + +from colossalai.logging import get_dist_logger + +from .utils import is_rank_0, jload + +logger = get_dist_logger() + + +class PromptDataset(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__(self, + data_path: str, + tokenizer: transformers.PreTrainedTokenizer, + max_datasets_size: int = None, + max_length: int = 96): + super(PromptDataset, self).__init__() + self.keyed_prompt = defaultdict(list) + logger.info("Loading data...") + list_data_dict = jload(data_path) + logger.info(f"Loaded {len(list_data_dict)} examples.") + + if max_datasets_size is not None: + logger.info(f"Limiting dataset to {max_datasets_size} examples.") + list_data_dict = list_data_dict[:max_datasets_size] + + for data_dict in list_data_dict: + token = tokenizer(data_dict["instruction"], + return_tensors='pt', + max_length=max_length, + padding='max_length', + truncation=True) + for k, tensor in token.items(): + self.keyed_prompt[k].extend(tensor.to(torch.cuda.current_device()).unbind()) + + def __len__(self): + return len(self.keyed_prompt) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + return {k: v[i] for k, v in self.keyed_prompt.items()} diff --git a/applications/Chat/coati/dataset/reward_dataset.py b/applications/Chat/coati/dataset/reward_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..faa1c94d27286a7c69a62c7da699a3377fe14cd9 --- /dev/null +++ b/applications/Chat/coati/dataset/reward_dataset.py @@ -0,0 +1,112 @@ +from typing import Callable + +from torch.utils.data import Dataset +from tqdm import tqdm + +from .utils import is_rank_0 + + +# Dahaos/rm-static +class RmStaticDataset(Dataset): + """ + Dataset for reward model + + Args: + dataset: dataset for reward model + tokenizer: tokenizer for reward model + max_length: max length of input + special_token: special token at the end of sentence + """ + + def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None: + super().__init__() + self.chosen = [] + self.reject = [] + if special_token is None: + self.end_token = tokenizer.eos_token + else: + self.end_token = special_token + for data in tqdm(dataset, disable=not is_rank_0()): + prompt = data['prompt'] + + chosen = prompt + data['chosen'] + self.end_token + chosen_token = tokenizer(chosen, + max_length=max_length, + padding="max_length", + truncation=True, + return_tensors="pt") + self.chosen.append({ + "input_ids": chosen_token['input_ids'], + "attention_mask": chosen_token['attention_mask'] + }) + + reject = prompt + data['rejected'] + self.end_token + reject_token = tokenizer(reject, + max_length=max_length, + padding="max_length", + truncation=True, + return_tensors="pt") + self.reject.append({ + "input_ids": reject_token['input_ids'], + "attention_mask": reject_token['attention_mask'] + }) + + def __len__(self): + length = len(self.chosen) + return length + + def __getitem__(self, idx): + return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][ + "input_ids"], self.reject[idx]["attention_mask"] + + +# Anthropic/hh-rlhf +class HhRlhfDataset(Dataset): + """ + Dataset for reward model + + Args: + dataset: dataset for reward model + tokenizer: tokenizer for reward model + max_length: max length of input + special_token: special token at the end of sentence + """ + + def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None: + super().__init__() + self.chosen = [] + self.reject = [] + if special_token is None: + self.end_token = tokenizer.eos_token + else: + self.end_token = special_token + for data in tqdm(dataset, disable=not is_rank_0()): + chosen = data['chosen'] + self.end_token + chosen_token = tokenizer(chosen, + max_length=max_length, + padding="max_length", + truncation=True, + return_tensors="pt") + self.chosen.append({ + "input_ids": chosen_token['input_ids'], + "attention_mask": chosen_token['attention_mask'] + }) + + reject = data['rejected'] + self.end_token + reject_token = tokenizer(reject, + max_length=max_length, + padding="max_length", + truncation=True, + return_tensors="pt") + self.reject.append({ + "input_ids": reject_token['input_ids'], + "attention_mask": reject_token['attention_mask'] + }) + + def __len__(self): + length = len(self.chosen) + return length + + def __getitem__(self, idx): + return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][ + "input_ids"], self.reject[idx]["attention_mask"] diff --git a/applications/Chat/coati/dataset/sft_dataset.py b/applications/Chat/coati/dataset/sft_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..3e2453468bbc4d5be00e9c2964803299bf176004 --- /dev/null +++ b/applications/Chat/coati/dataset/sft_dataset.py @@ -0,0 +1,166 @@ +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import random +from dataclasses import dataclass, field +from typing import Callable, Dict, Sequence + +import torch +import torch.distributed as dist +import transformers +from torch.utils.data import Dataset +from tqdm import tqdm + +from colossalai.logging import get_dist_logger + +from .utils import is_rank_0, jload + +logger = get_dist_logger() + +IGNORE_INDEX = -100 +PROMPT_DICT = { + "prompt_input": + ("Below is an instruction that describes a task, paired with an input that provides further context. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"), + "prompt_no_input": ("Below is an instruction that describes a task. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n### Response:"), +} + + +class SFTDataset(Dataset): + """ + Dataset for sft model + + Args: + dataset: dataset for supervised model + tokenizer: tokenizer for supervised model + max_length: max length of input + """ + + def __init__(self, dataset, tokenizer: Callable, max_length: int = 512) -> None: + super().__init__() + self.input_ids = [] + + for data in tqdm(dataset, disable=not is_rank_0()): + prompt = data['prompt'] + data['completion'] + tokenizer.eos_token + prompt_token = tokenizer(prompt, + max_length=max_length, + padding="max_length", + truncation=True, + return_tensors="pt") + + self.input_ids.append(prompt_token['input_ids'][0]) + self.labels = copy.deepcopy(self.input_ids) + + def __len__(self): + length = len(self.input_ids) + return length + + def __getitem__(self, idx): + return dict(input_ids=self.input_ids[idx], labels=self.labels[idx]) + + +def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, max_length: int) -> Dict: + """Tokenize a list of strings.""" + tokenized_list = [ + tokenizer( + text, + return_tensors="pt", + padding="longest", + max_length=max_length, + truncation=True, + ) for text in strings + ] + input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] + input_ids_lens = labels_lens = [ + tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list + ] + return dict( + input_ids=input_ids, + labels=labels, + input_ids_lens=input_ids_lens, + labels_lens=labels_lens, + ) + + +def preprocess( + sources: Sequence[str], + targets: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer, + max_length: int, +) -> Dict: + """Preprocess the data by tokenizing.""" + examples = [s + t for s, t in zip(sources, targets)] + examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources)] + input_ids = examples_tokenized["input_ids"] + labels = copy.deepcopy(input_ids) + for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]): + label[:source_len] = IGNORE_INDEX + return dict(input_ids=input_ids, labels=labels) + + +class SupervisedDataset(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, max_datasets_size: int = None, max_length: int = 512): + super(SupervisedDataset, self).__init__() + logger.info("Loading data...") + list_data_dict = jload(data_path) + logger.info(f"Loaded {len(list_data_dict)} examples.") + + if max_datasets_size is not None: + logger.info(f"Limiting dataset to {max_datasets_size} examples.") + list_data_dict = list_data_dict[:max_datasets_size] + + logger.info("Formatting inputs...") + prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"] + sources = [ + prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example) + for example in list_data_dict + ] + targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict] + + logger.info("Tokenizing inputs... This may take some time...") + data_dict = preprocess(sources, targets, tokenizer, max_length) + + self.input_ids = data_dict["input_ids"] + self.labels = data_dict["labels"] + + def __len__(self): + return len(self.input_ids) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + return dict(input_ids=self.input_ids[i], labels=self.labels[i]) + + +@dataclass +class DataCollatorForSupervisedDataset(object): + """Collate examples for supervised fine-tuning.""" + + tokenizer: transformers.PreTrainedTokenizer + + def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: + input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) + input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, + batch_first=True, + padding_value=self.tokenizer.pad_token_id) + labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) + return dict( + input_ids=input_ids, + labels=labels, + attention_mask=input_ids.ne(self.tokenizer.pad_token_id), + ) diff --git a/applications/Chat/coati/dataset/utils.py b/applications/Chat/coati/dataset/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f37fce67a7c60083b33945378ca5c392243daa04 --- /dev/null +++ b/applications/Chat/coati/dataset/utils.py @@ -0,0 +1,22 @@ +import io +import json + +import torch.distributed as dist + + +def is_rank_0() -> bool: + return not dist.is_initialized() or dist.get_rank() == 0 + + +def _make_r_io_base(f, mode: str): + if not isinstance(f, io.IOBase): + f = open(f, mode=mode) + return f + + +def jload(f, mode="r"): + """Load a .json file into a dictionary.""" + f = _make_r_io_base(f, mode) + jdict = json.load(f) + f.close() + return jdict diff --git a/applications/Chat/coati/experience_maker/__init__.py b/applications/Chat/coati/experience_maker/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..39ca7576b22761807b5804cc15d82e15fdde42cd --- /dev/null +++ b/applications/Chat/coati/experience_maker/__init__.py @@ -0,0 +1,4 @@ +from .base import Experience, ExperienceMaker +from .naive import NaiveExperienceMaker + +__all__ = ['Experience', 'ExperienceMaker', 'NaiveExperienceMaker'] diff --git a/applications/Chat/coati/experience_maker/base.py b/applications/Chat/coati/experience_maker/base.py new file mode 100644 index 0000000000000000000000000000000000000000..ff75852576c848625d8786298a46f59e6a395598 --- /dev/null +++ b/applications/Chat/coati/experience_maker/base.py @@ -0,0 +1,77 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +from coati.models.base import Actor + + +@dataclass +class Experience: + """Experience is a batch of data. + These data should have the the sequence length and number of actions. + Left padding for sequences is applied. + + Shapes of each tensor: + sequences: (B, S) + action_log_probs: (B, A) + values: (B) + reward: (B) + advantages: (B) + attention_mask: (B, S) + action_mask: (B, A) + + "A" is the number of actions. + """ + sequences: torch.Tensor + action_log_probs: torch.Tensor + values: torch.Tensor + reward: torch.Tensor + advantages: torch.Tensor + attention_mask: Optional[torch.LongTensor] + action_mask: Optional[torch.BoolTensor] + + @torch.no_grad() + def to_device(self, device: torch.device) -> None: + self.sequences = self.sequences.to(device) + self.action_log_probs = self.action_log_probs.to(device) + self.values = self.values.to(device) + self.reward = self.reward.to(device) + self.advantages = self.advantages.to(device) + if self.attention_mask is not None: + self.attention_mask = self.attention_mask.to(device) + if self.action_mask is not None: + self.action_mask = self.action_mask.to(device) + + def pin_memory(self): + self.sequences = self.sequences.pin_memory() + self.action_log_probs = self.action_log_probs.pin_memory() + self.values = self.values.pin_memory() + self.reward = self.reward.pin_memory() + self.advantages = self.advantages.pin_memory() + if self.attention_mask is not None: + self.attention_mask = self.attention_mask.pin_memory() + if self.action_mask is not None: + self.action_mask = self.action_mask.pin_memory() + return self + + +class ExperienceMaker(ABC): + + def __init__(self, + actor: Actor, + critic: nn.Module, + reward_model: nn.Module, + initial_model: Actor, + kl_coef: float = 0.1) -> None: + super().__init__() + self.actor = actor + self.critic = critic + self.reward_model = reward_model + self.initial_model = initial_model + self.kl_coef = kl_coef + + @abstractmethod + def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience: + pass diff --git a/applications/Chat/coati/experience_maker/naive.py b/applications/Chat/coati/experience_maker/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..94546eeb28e787df620d15ec631b1776dc6bf282 --- /dev/null +++ b/applications/Chat/coati/experience_maker/naive.py @@ -0,0 +1,35 @@ +import torch +from coati.models.utils import compute_reward, normalize + +from .base import Experience, ExperienceMaker + + +class NaiveExperienceMaker(ExperienceMaker): + """ + Naive experience maker. + """ + + @torch.no_grad() + def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience: + self.actor.eval() + self.critic.eval() + self.initial_model.eval() + self.reward_model.eval() + + sequences, attention_mask, action_mask = self.actor.generate(input_ids, + return_action_mask=True, + **generate_kwargs) + num_actions = action_mask.size(1) + + action_log_probs = self.actor(sequences, num_actions, attention_mask) + base_action_log_probs = self.initial_model(sequences, num_actions, attention_mask) + value = self.critic(sequences, action_mask, attention_mask) + r = self.reward_model(sequences, attention_mask) + reward = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask) + + advantage = reward - value + # TODO(ver217): maybe normalize adv + if advantage.ndim == 1: + advantage = advantage.unsqueeze(-1) + + return Experience(sequences, action_log_probs, value, reward, advantage, attention_mask, action_mask) diff --git a/applications/Chat/coati/kernels/__init__.py b/applications/Chat/coati/kernels/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..230eedf7ecba9531dc53d9fafde34a15af806cb3 --- /dev/null +++ b/applications/Chat/coati/kernels/__init__.py @@ -0,0 +1,6 @@ +from .wrapper import convert_to_xformer_model, recover_from_xformer_model + +__all__ = [ + 'convert_to_xformer_model', + 'recover_from_xformer_model', +] diff --git a/applications/Chat/coati/kernels/opt_attn.py b/applications/Chat/coati/kernels/opt_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..c10f341e94a3bd32511f6da0c50dbf075136935d --- /dev/null +++ b/applications/Chat/coati/kernels/opt_attn.py @@ -0,0 +1,87 @@ +from typing import Optional, Tuple + +import torch +import xformers.ops as xops +from torch import Tensor +from transformers.models.opt.modeling_opt import OPTAttention + + +# This is modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py +class XOPTAttention(OPTAttention): + # def _shape(self, tensor: Tensor, seq_len: int, bsz: int): + # return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous() + + def forward( + self, + hidden_states: Tensor, + key_value_states: Optional[Tensor] = None, + past_key_value: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + layer_head_mask: Optional[Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[Tensor, Optional[Tensor], Optional[Tuple[Tensor]]]: + if not self.training: + return super().forward(hidden_states, key_value_states, past_key_value, attention_mask, layer_head_mask, + output_attentions) + """Input shape: Batch x Time x Channel""" + assert layer_head_mask is None, 'Xformers attention does not support layer_head_mask' + assert not output_attentions, 'Xformers attention does not support output_attentions' + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + query_states = self._shape(query_states, tgt_len, bsz).transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = xops.memory_efficient_attention(query_states, + key_states, + value_states, + attn_bias=xops.LowerTriangularMask(), + p=self.dropout if self.training else 0.0, + scale=self.scaling) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + attn_weights_reshaped = None + + return attn_output, attn_weights_reshaped, past_key_value diff --git a/applications/Chat/coati/kernels/wrapper.py b/applications/Chat/coati/kernels/wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..c55bda6002303227e7e26a8cfacbbd0f5ee599b0 --- /dev/null +++ b/applications/Chat/coati/kernels/wrapper.py @@ -0,0 +1,18 @@ +import torch.nn as nn +from transformers.models.opt.modeling_opt import OPTAttention + +from .opt_attn import XOPTAttention + + +def convert_to_xformer_model(model: nn.Module) -> nn.Module: + for module in model.modules(): + if isinstance(module, OPTAttention): + module.__class__ = XOPTAttention + return model + + +def recover_from_xformer_model(model: nn.Module) -> nn.Module: + for module in model.modules(): + if isinstance(module, XOPTAttention): + module.__class__ = OPTAttention + return model diff --git a/applications/Chat/coati/models/__init__.py b/applications/Chat/coati/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..709bc5ac0948ec13ca251b77a944b96e6543e97a --- /dev/null +++ b/applications/Chat/coati/models/__init__.py @@ -0,0 +1,8 @@ +from .base import Actor, Critic, RewardModel +from .lora import LoRAModule, convert_to_lora_module +from .loss import LogExpLoss, LogSigLoss, PolicyLoss, PPOPtxActorLoss, ValueLoss + +__all__ = [ + 'Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'LogSigLoss', 'LogExpLoss', + 'LoRAModule', 'convert_to_lora_module' +] diff --git a/applications/Chat/coati/models/base/__init__.py b/applications/Chat/coati/models/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fe4152f2b760bf38924f521ee1d70a89074c8004 --- /dev/null +++ b/applications/Chat/coati/models/base/__init__.py @@ -0,0 +1,24 @@ +import torch.nn as nn + +from .actor import Actor +from .critic import Critic +from .reward_model import RewardModel + + +def get_base_model(model: nn.Module) -> nn.Module: + """Get the base model of our wrapper classes. + For Actor, it's base model is ``actor.model`` and it's usually a ``transformers.PreTrainedModel``. + For Critic and RewardModel, it's base model is itself. + + Args: + model (nn.Module): model to get base model from + + Returns: + nn.Module: the base model + """ + if isinstance(model, Actor): + return model.get_base_model() + return model + + +__all__ = ['Actor', 'Critic', 'RewardModel', 'get_base_model'] diff --git a/applications/Chat/coati/models/base/actor.py b/applications/Chat/coati/models/base/actor.py new file mode 100644 index 0000000000000000000000000000000000000000..71fbf7bbae7d90875a202b5e4e56e431efc1468e --- /dev/null +++ b/applications/Chat/coati/models/base/actor.py @@ -0,0 +1,65 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..generation import generate +from ..lora import LoRAModule +from ..utils import log_probs_from_logits + + +class Actor(LoRAModule): + """ + Actor model base class. + + Args: + model (nn.Module): Actor Model. + lora_rank (int): LoRA rank. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str = 'none') -> None: + super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias) + self.model = model + self.convert_to_lora() + + @torch.no_grad() + def generate( + self, + input_ids: torch.Tensor, + return_action_mask: bool = True, + **kwargs + ) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]: + sequences = generate(self.model, input_ids, **kwargs) + attention_mask = None + pad_token_id = kwargs.get('pad_token_id', None) + if pad_token_id is not None: + attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device) + if not return_action_mask: + return sequences, attention_mask, None + input_len = input_ids.size(1) + eos_token_id = kwargs.get('eos_token_id', None) + if eos_token_id is None: + action_mask = torch.ones_like(sequences, dtype=torch.bool) + else: + # left padding may be applied, only mask action + action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0 + action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input + action_mask[:, :input_len] = False + action_mask = action_mask[:, 1:] + return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):] + + def forward(self, + sequences: torch.LongTensor, + num_actions: int, + attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """Returns action log probs + """ + output = self.model(sequences, attention_mask=attention_mask) + logits = output['logits'] + log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) + return log_probs[:, -num_actions:] + + def get_base_model(self): + return self.model diff --git a/applications/Chat/coati/models/base/critic.py b/applications/Chat/coati/models/base/critic.py new file mode 100644 index 0000000000000000000000000000000000000000..e68a743a7762094c255df002565d61bdc611479f --- /dev/null +++ b/applications/Chat/coati/models/base/critic.py @@ -0,0 +1,54 @@ +from typing import Optional + +import torch +import torch.nn as nn + +from ..lora import LoRAModule +from ..utils import masked_mean + + +class Critic(LoRAModule): + """ + Critic model base class. + + Args: + model (nn.Module): Critic model. + value_head (nn.Module): Value head to get value. + lora_rank (int): LoRA rank. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__( + self, + model: nn.Module, + value_head: nn.Module, + lora_rank: int = 0, + lora_train_bias: str = 'none', + use_action_mask: bool = False, + ) -> None: + + super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias) + self.model = model + self.value_head = value_head + self.use_action_mask = use_action_mask + self.convert_to_lora() + + def forward(self, + sequences: torch.LongTensor, + action_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + outputs = self.model(sequences, attention_mask=attention_mask) + last_hidden_states = outputs['last_hidden_state'] + + values = self.value_head(last_hidden_states).squeeze(-1) + + if action_mask is not None and self.use_action_mask: + num_actions = action_mask.size(1) + prompt_mask = attention_mask[:, :-num_actions] + values = values[:, :-num_actions] + value = masked_mean(values, prompt_mask, dim=1) + return value + + values = values[:, :-1] + value = values.mean(dim=1) + return value diff --git a/applications/Chat/coati/models/base/reward_model.py b/applications/Chat/coati/models/base/reward_model.py new file mode 100644 index 0000000000000000000000000000000000000000..ce8c0a1d35687da2cb61b55c59be6c91c0f9c394 --- /dev/null +++ b/applications/Chat/coati/models/base/reward_model.py @@ -0,0 +1,41 @@ +from typing import Optional + +import torch +import torch.nn as nn + +from ..lora import LoRAModule + + +class RewardModel(LoRAModule): + """ + Reward model base class. + + Args: + model (nn.Module): Reward model. + value_head (nn.Module): Value head to get reward score. + lora_rank (int): LoRA rank. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__(self, + model: nn.Module, + value_head: Optional[nn.Module] = None, + lora_rank: int = 0, + lora_train_bias: str = 'none') -> None: + super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias) + self.model = model + self.convert_to_lora() + + if value_head is not None: + if value_head.out_features != 1: + raise ValueError("The value head of reward model's output dim should be 1!") + self.value_head = value_head + else: + self.value_head = nn.Linear(model.config.n_embd, 1) + + def forward(self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + outputs = self.model(sequences, attention_mask=attention_mask) + last_hidden_states = outputs['last_hidden_state'] + values = self.value_head(last_hidden_states)[:, :-1] + value = values.mean(dim=1).squeeze(1) # ensure shape is (B) + return value diff --git a/applications/Chat/coati/models/bloom/__init__.py b/applications/Chat/coati/models/bloom/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d0e7f7b1ef94e2a8434fa36efe6b5a2e4586f5b8 --- /dev/null +++ b/applications/Chat/coati/models/bloom/__init__.py @@ -0,0 +1,5 @@ +from .bloom_actor import BLOOMActor +from .bloom_critic import BLOOMCritic +from .bloom_rm import BLOOMRM + +__all__ = ['BLOOMActor', 'BLOOMCritic', 'BLOOMRM'] diff --git a/applications/Chat/coati/models/bloom/bloom_actor.py b/applications/Chat/coati/models/bloom/bloom_actor.py new file mode 100644 index 0000000000000000000000000000000000000000..d7577f0964934955992f8d952c4e209a68a58829 --- /dev/null +++ b/applications/Chat/coati/models/bloom/bloom_actor.py @@ -0,0 +1,35 @@ +from typing import Optional + +import torch +from transformers import BloomConfig, BloomForCausalLM, BloomModel + +from ..base import Actor + + +class BLOOMActor(Actor): + """ + BLOOM Actor model. + + Args: + pretrained (str): Pretrained model name or path. + config (BloomConfig): Model config. + checkpoint (bool): Enable gradient checkpointing. + lora_rank (int): LoRA rank. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__(self, + pretrained: str = None, + config: Optional[BloomConfig] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = 'none') -> None: + if pretrained is not None: + model = BloomForCausalLM.from_pretrained(pretrained) + elif config is not None: + model = BloomForCausalLM(config) + else: + model = BloomForCausalLM(BloomConfig()) + if checkpoint: + model.gradient_checkpointing_enable() + super().__init__(model, lora_rank, lora_train_bias) diff --git a/applications/Chat/coati/models/bloom/bloom_critic.py b/applications/Chat/coati/models/bloom/bloom_critic.py new file mode 100644 index 0000000000000000000000000000000000000000..a32fb2e102f9b05bb5c62d97f215343ce466b840 --- /dev/null +++ b/applications/Chat/coati/models/bloom/bloom_critic.py @@ -0,0 +1,38 @@ +from typing import Optional + +import torch +import torch.nn as nn +from transformers import BloomConfig, BloomForCausalLM, BloomModel + +from ..base import Critic + + +class BLOOMCritic(Critic): + """ + BLOOM Critic model. + + Args: + pretrained (str): Pretrained model name or path. + config (BloomConfig): Model config. + checkpoint (bool): Enable gradient checkpointing. + lora_rank (int): LoRA rank. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__(self, + pretrained: str = None, + config: Optional[BloomConfig] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = 'none', + **kwargs) -> None: + if pretrained is not None: + model = BloomModel.from_pretrained(pretrained) + elif config is not None: + model = BloomModel(config) + else: + model = BloomModel(BloomConfig()) + if checkpoint: + model.gradient_checkpointing_enable() + value_head = nn.Linear(model.config.hidden_size, 1) + super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs) diff --git a/applications/Chat/coati/models/bloom/bloom_rm.py b/applications/Chat/coati/models/bloom/bloom_rm.py new file mode 100644 index 0000000000000000000000000000000000000000..22cfab441abb6666812a5875e02da77b28357e94 --- /dev/null +++ b/applications/Chat/coati/models/bloom/bloom_rm.py @@ -0,0 +1,37 @@ +from typing import Optional + +import torch.nn as nn +from transformers import BloomConfig, BloomForCausalLM, BloomModel + +from ..base import RewardModel + + +class BLOOMRM(RewardModel): + """ + BLOOM Reward model. + + Args: + pretrained (str): Pretrained model name or path. + config (BloomConfig): Model config. + checkpoint (bool): Enable gradient checkpointing. + lora_rank (int): LoRA rank. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__(self, + pretrained: str = None, + config: Optional[BloomConfig] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = 'none') -> None: + if pretrained is not None: + model = BloomModel.from_pretrained(pretrained) + elif config is not None: + model = BloomModel(config) + else: + model = BloomModel(BloomConfig()) + if checkpoint: + model.gradient_checkpointing_enable() + value_head = nn.Linear(model.config.hidden_size, 1) + value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1)) + super().__init__(model, value_head, lora_rank, lora_train_bias) diff --git a/applications/Chat/coati/models/deberta/__init__.py b/applications/Chat/coati/models/deberta/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b66888f34fd0b81b726b28dec39b9ed26d99c945 --- /dev/null +++ b/applications/Chat/coati/models/deberta/__init__.py @@ -0,0 +1,4 @@ +from .deberta_critic import DebertaCritic +from .deberta_rm import DebertaRM + +__all__ = ['DebertaCritic', 'DebertaRM'] diff --git a/applications/Chat/coati/models/deberta/deberta_critic.py b/applications/Chat/coati/models/deberta/deberta_critic.py new file mode 100644 index 0000000000000000000000000000000000000000..e84c1dbd8380728a6544c50be2b82146821fb3c3 --- /dev/null +++ b/applications/Chat/coati/models/deberta/deberta_critic.py @@ -0,0 +1,36 @@ +from typing import Optional + +import torch.nn as nn +from transformers import DebertaV2Config, DebertaV2Model + +from ..base import Critic + + +class DebertaCritic(Critic): + """ + Deberta Critic model. + + Args: + pretrained (str): Pretrained model name or path. + config (DebertaV2Config): Model config. + checkpoint (bool): Enable gradient checkpointing. + lora_rank (int): Rank of the LO-RA decomposition. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__(self, + pretrained: Optional[str] = None, + config: Optional[DebertaV2Config] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = 'none') -> None: + if pretrained is not None: + model = DebertaV2Model.from_pretrained(pretrained) + elif config is not None: + model = DebertaV2Model(config) + else: + model = DebertaV2Model(DebertaV2Config()) + if checkpoint: + model.gradient_checkpointing_enable() + value_head = nn.Linear(model.config.hidden_size, 1) + super().__init__(model, value_head, lora_rank, lora_train_bias) diff --git a/applications/Chat/coati/models/deberta/deberta_rm.py b/applications/Chat/coati/models/deberta/deberta_rm.py new file mode 100644 index 0000000000000000000000000000000000000000..2448c879ec859ebfd13afd859508e28b63427c06 --- /dev/null +++ b/applications/Chat/coati/models/deberta/deberta_rm.py @@ -0,0 +1,37 @@ +from typing import Optional + +import torch.nn as nn +from transformers import DebertaV2Config, DebertaV2Model + +from ..base import RewardModel + + +class DebertaRM(RewardModel): + """ + Deberta Reward model. + + Args: + pretrained (str): Pretrained model name or path. + config (DebertaV2Config): Model config. + checkpoint (bool): Enable gradient checkpointing. + lora_rank (int): Rank of the LO-RA decomposition. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__(self, + pretrained: str = None, + config: Optional[DebertaV2Config] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = 'none') -> None: + if pretrained is not None: + model = DebertaV2Model.from_pretrained(pretrained) + elif config is not None: + model = DebertaV2Model(config) + else: + model = DebertaV2Model(DebertaV2Config()) + if checkpoint: + model.gradient_checkpointing_enable() + value_head = nn.Linear(model.config.hidden_size, 1) + value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1)) + super().__init__(model, value_head, lora_rank, lora_train_bias) diff --git a/applications/Chat/coati/models/generation.py b/applications/Chat/coati/models/generation.py new file mode 100644 index 0000000000000000000000000000000000000000..f57c9458a271131a14d6b035cd039d2a5de1281b --- /dev/null +++ b/applications/Chat/coati/models/generation.py @@ -0,0 +1,146 @@ +from typing import Any, Callable, Optional + +import torch +import torch.distributed as dist +import torch.nn as nn + +try: + from transformers.generation_logits_process import ( + LogitsProcessorList, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, + ) +except ImportError: + from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper + + +def prepare_logits_processor(top_k: Optional[int] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None) -> LogitsProcessorList: + processor_list = LogitsProcessorList() + if temperature is not None and temperature != 1.0: + processor_list.append(TemperatureLogitsWarper(temperature)) + if top_k is not None and top_k != 0: + processor_list.append(TopKLogitsWarper(top_k)) + if top_p is not None and top_p < 1.0: + processor_list.append(TopPLogitsWarper(top_p)) + return processor_list + + +def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool: + if dist.is_initialized() and dist.get_world_size() > 1: + # consider DP + unfinished_sequences = unfinished_sequences.clone() + dist.all_reduce(unfinished_sequences) + return unfinished_sequences.max() == 0 + + +def sample(model: nn.Module, + input_ids: torch.Tensor, + max_length: int, + early_stopping: bool = False, + eos_token_id: Optional[int] = None, + pad_token_id: Optional[int] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, + update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None, + **model_kwargs) -> torch.Tensor: + if input_ids.size(1) >= max_length: + return input_ids + + logits_processor = prepare_logits_processor(top_k, top_p, temperature) + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + + for _ in range(input_ids.size(1), max_length): + model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else { + 'input_ids': input_ids + } + outputs = model(**model_inputs) + + next_token_logits = outputs['logits'][:, -1, :] + # pre-process distribution + next_token_logits = logits_processor(input_ids, next_token_logits) + # sample + probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float) + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + + # finished sentences should have their next token be a padding token + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + # update generated ids, model inputs for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + if update_model_kwargs_fn is not None: + model_kwargs = update_model_kwargs_fn(outputs, model_kwargs) + + # if eos_token was found in one sentence, set sentence to finished + if eos_token_id is not None: + unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) + + # stop when each sentence is finished if early_stopping=True + if early_stopping and _is_sequence_finished(unfinished_sequences): + break + + return input_ids + + +def generate(model: nn.Module, + input_ids: torch.Tensor, + max_length: int, + num_beams: int = 1, + do_sample: bool = True, + early_stopping: bool = False, + eos_token_id: Optional[int] = None, + pad_token_id: Optional[int] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, + update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None, + **model_kwargs) -> torch.Tensor: + """Generate token sequence. The returned sequence is input_ids + generated_tokens. + + Args: + model (nn.Module): model + input_ids (torch.Tensor): input sequence + max_length (int): max length of the returned sequence + num_beams (int, optional): number of beams. Defaults to 1. + do_sample (bool, optional): whether to do sample. Defaults to True. + early_stopping (bool, optional): if True, the sequence length may be smaller than max_length due to finding eos. Defaults to False. + eos_token_id (Optional[int], optional): end of sequence token id. Defaults to None. + pad_token_id (Optional[int], optional): pad token id. Defaults to None. + top_k (Optional[int], optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to None. + top_p (Optional[float], optional): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. Defaults to None. + temperature (Optional[float], optional): The value used to module the next token probabilities. Defaults to None. + prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None. + update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None. + """ + is_greedy_gen_mode = ((num_beams == 1) and do_sample is False) + is_sample_gen_mode = ((num_beams == 1) and do_sample is True) + is_beam_gen_mode = ((num_beams > 1) and do_sample is False) + if is_greedy_gen_mode: + # run greedy search + raise NotImplementedError + elif is_sample_gen_mode: + # run sample + return sample(model, + input_ids, + max_length, + early_stopping=early_stopping, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + top_k=top_k, + top_p=top_p, + temperature=temperature, + prepare_inputs_fn=prepare_inputs_fn, + update_model_kwargs_fn=update_model_kwargs_fn, + **model_kwargs) + elif is_beam_gen_mode: + raise NotImplementedError + else: + raise ValueError("Unsupported generation mode") diff --git a/applications/Chat/coati/models/gpt/__init__.py b/applications/Chat/coati/models/gpt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..63dc5ab0f5ead4450e1f9bffae874520614fa05c --- /dev/null +++ b/applications/Chat/coati/models/gpt/__init__.py @@ -0,0 +1,5 @@ +from .gpt_actor import GPTActor +from .gpt_critic import GPTCritic +from .gpt_rm import GPTRM + +__all__ = ['GPTActor', 'GPTCritic', 'GPTRM'] diff --git a/applications/Chat/coati/models/gpt/gpt_actor.py b/applications/Chat/coati/models/gpt/gpt_actor.py new file mode 100644 index 0000000000000000000000000000000000000000..ae9d669f1f5669dc63e829332ccdf7ebf991afe2 --- /dev/null +++ b/applications/Chat/coati/models/gpt/gpt_actor.py @@ -0,0 +1,36 @@ +from typing import Optional + +from transformers.models.gpt2.configuration_gpt2 import GPT2Config +from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel + +from ..base import Actor + + +class GPTActor(Actor): + """ + GPT Actor model. + + Args: + pretrained (str): Pretrained model name or path. + config (GPT2Config): Model config. + checkpoint (bool): Enable gradient checkpointing. + lora_rank (int): Rank of the LoRa layer. + lora_train_bias (str): Bias training strategy for the LoRa layer. + """ + + def __init__(self, + pretrained: Optional[str] = None, + config: Optional[GPT2Config] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = 'none', + **kwargs) -> None: + if pretrained is not None: + model = GPT2LMHeadModel.from_pretrained(pretrained) + elif config is not None: + model = GPT2LMHeadModel(config) + else: + model = GPT2LMHeadModel(GPT2Config()) + if checkpoint: + model.gradient_checkpointing_enable() + super().__init__(model, lora_rank, lora_train_bias, **kwargs) diff --git a/applications/Chat/coati/models/gpt/gpt_critic.py b/applications/Chat/coati/models/gpt/gpt_critic.py new file mode 100644 index 0000000000000000000000000000000000000000..2e70f5f1fc9632edbd2d83ef7b0067f7accd0eda --- /dev/null +++ b/applications/Chat/coati/models/gpt/gpt_critic.py @@ -0,0 +1,38 @@ +from typing import Optional + +import torch.nn as nn +from transformers.models.gpt2.configuration_gpt2 import GPT2Config +from transformers.models.gpt2.modeling_gpt2 import GPT2Model + +from ..base import Critic + + +class GPTCritic(Critic): + """ + GPT Critic model. + + Args: + pretrained (str): Pretrained model name or path. + config (GPT2Config): Model config. + checkpoint (bool): Enable gradient checkpointing. + lora_rank (int): Rank of the LO-RA decomposition. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__(self, + pretrained: Optional[str] = None, + config: Optional[GPT2Config] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = 'none', + **kwargs) -> None: + if pretrained is not None: + model = GPT2Model.from_pretrained(pretrained) + elif config is not None: + model = GPT2Model(config) + else: + model = GPT2Model(GPT2Config()) + if checkpoint: + model.gradient_checkpointing_enable() + value_head = nn.Linear(model.config.n_embd, 1) + super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs) diff --git a/applications/Chat/coati/models/gpt/gpt_rm.py b/applications/Chat/coati/models/gpt/gpt_rm.py new file mode 100644 index 0000000000000000000000000000000000000000..054432e1ce863a36c169ea3b7569198ab1994e6c --- /dev/null +++ b/applications/Chat/coati/models/gpt/gpt_rm.py @@ -0,0 +1,39 @@ +from typing import Optional + +import torch.nn as nn +from transformers.models.gpt2.configuration_gpt2 import GPT2Config +from transformers.models.gpt2.modeling_gpt2 import GPT2Model + +from ..base import RewardModel + + +class GPTRM(RewardModel): + """ + GPT Reward model. + + Args: + pretrained (str): Pretrained model name or path. + config (GPT2Config): Model config. + checkpoint (bool): Enable gradient checkpointing. + lora_rank (int): Rank of the low-rank approximation. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__(self, + pretrained: Optional[str] = None, + config: Optional[GPT2Config] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = 'none') -> None: + if pretrained is not None: + model = GPT2Model.from_pretrained(pretrained) + elif config is not None: + model = GPT2Model(config) + else: + model = GPT2Model(GPT2Config()) + if checkpoint: + model.gradient_checkpointing_enable() + + value_head = nn.Linear(model.config.n_embd, 1) + value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.n_embd + 1)) + super().__init__(model, value_head, lora_rank, lora_train_bias) diff --git a/applications/Chat/coati/models/llama/__init__.py b/applications/Chat/coati/models/llama/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9b2a024afdb28d3336d801895da3f26b01f28c56 --- /dev/null +++ b/applications/Chat/coati/models/llama/__init__.py @@ -0,0 +1,5 @@ +from .llama_actor import LlamaActor +from .llama_critic import LlamaCritic +from .llama_rm import LlamaRM + +__all__ = ['LlamaActor', 'LlamaCritic', 'LlamaRM'] diff --git a/applications/Chat/coati/models/llama/llama_actor.py b/applications/Chat/coati/models/llama/llama_actor.py new file mode 100644 index 0000000000000000000000000000000000000000..2c7adb390d8bea055e9fc84f75dc66658a8fe0e3 --- /dev/null +++ b/applications/Chat/coati/models/llama/llama_actor.py @@ -0,0 +1,38 @@ +from typing import Optional + +import torch +from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM + +from ..base import Actor + + +class LlamaActor(Actor): + """ + Llama Actor model. + + Args: + pretrained (str): Pretrained model name or path. + config (LlamaConfig): Model config. + checkpoint (bool): Enable gradient checkpointing. + lora_rank (int): LoRA rank. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__(self, + pretrained: Optional[str] = None, + config: Optional[LlamaConfig] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = 'none') -> None: + + if pretrained is not None: + model = LlamaForCausalLM.from_pretrained(pretrained) + elif config is not None: + model = LlamaForCausalLM(config) + else: + model = LlamaForCausalLM(LlamaConfig()) + + if checkpoint: + model.gradient_checkpointing_enable() + + super().__init__(model, lora_rank, lora_train_bias) diff --git a/applications/Chat/coati/models/llama/llama_critic.py b/applications/Chat/coati/models/llama/llama_critic.py new file mode 100644 index 0000000000000000000000000000000000000000..dd9e5e7bfa1ae6da1c9c09789ae0eda602a35720 --- /dev/null +++ b/applications/Chat/coati/models/llama/llama_critic.py @@ -0,0 +1,41 @@ +from typing import Optional + +import torch.nn as nn +from transformers import LlamaConfig, LlamaModel + +from ..base import Critic + + +class LlamaCritic(Critic): + """ + Llama Critic model. + + Args: + pretrained (str): Pretrained model name or path. + config (LlamaConfig): Model config. + checkpoint (bool): Enable gradient checkpointing. + lora_rank (int): LoRA rank. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__(self, + pretrained: Optional[str] = None, + config: Optional[LlamaConfig] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = 'none', + **kwargs) -> None: + + if pretrained is not None: + model = LlamaModel.from_pretrained(pretrained) + elif config is not None: + model = LlamaModel(config) + else: + model = LlamaModel(LlamaConfig()) + + if checkpoint: + model.gradient_checkpointing_enable() + + value_head = nn.Linear(model.config.hidden_size, 1) + + super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs) diff --git a/applications/Chat/coati/models/llama/llama_rm.py b/applications/Chat/coati/models/llama/llama_rm.py new file mode 100644 index 0000000000000000000000000000000000000000..f936019d62d28bac0bf41161b3f4aaad26e2bbf0 --- /dev/null +++ b/applications/Chat/coati/models/llama/llama_rm.py @@ -0,0 +1,40 @@ +from typing import Optional + +import torch.nn as nn +from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel + +from ..base import RewardModel + + +class LlamaRM(RewardModel): + """ + Llama Reward model. + + Args: + pretrained (str): Pretrained model name or path. + config (LlamaConfig): Model config. + checkpoint (bool): Enable gradient checkpointing. + lora_rank (int): LoRA rank. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__(self, + pretrained: Optional[str] = None, + config: Optional[LlamaConfig] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = 'none') -> None: + + if pretrained is not None: + model = LlamaModel.from_pretrained(pretrained) + elif config is not None: + model = LlamaModel(config) + else: + model = LlamaModel(LlamaConfig()) + + if checkpoint: + model.gradient_checkpointing_enable() + value_head = nn.Linear(model.config.hidden_size, 1) + value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1)) + + super().__init__(model, value_head, lora_rank, lora_train_bias) diff --git a/applications/Chat/coati/models/lora.py b/applications/Chat/coati/models/lora.py new file mode 100644 index 0000000000000000000000000000000000000000..0533a60dc53266d592fe2b5808eb637ac64db875 --- /dev/null +++ b/applications/Chat/coati/models/lora.py @@ -0,0 +1,143 @@ +import math +from typing import Optional + +import loralib as lora +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class LoraLinear(lora.LoRALayer, nn.Module): + """Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear. + """ + + def __init__( + self, + weight: nn.Parameter, + bias: Optional[nn.Parameter], + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0., + fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) + merge_weights: bool = True, + ): + nn.Module.__init__(self) + lora.LoRALayer.__init__(self, + r=r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + merge_weights=merge_weights) + self.weight = weight + self.bias = bias + + out_features, in_features = weight.shape + self.in_features = in_features + self.out_features = out_features + + self.fan_in_fan_out = fan_in_fan_out + # Actual trainable parameters + if r > 0: + self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features))) + self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r))) + self.scaling = self.lora_alpha / self.r + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + self.reset_parameters() + if fan_in_fan_out: + self.weight.data = self.weight.data.T + + def reset_parameters(self): + if hasattr(self, 'lora_A'): + # initialize A the same way as the default for nn.Linear and B to zero + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + + def train(self, mode: bool = True): + + def T(w): + return w.T if self.fan_in_fan_out else w + + nn.Module.train(self, mode) + if self.merge_weights and self.merged: + # Make sure that the weights are not merged + if self.r > 0: + self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling + self.merged = False + + def eval(self): + + def T(w): + return w.T if self.fan_in_fan_out else w + + nn.Module.eval(self) + if self.merge_weights and not self.merged: + # Merge the weights and mark it + if self.r > 0: + self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling + delattr(self, 'lora_A') + delattr(self, 'lora_B') + self.merged = True + + def forward(self, x: torch.Tensor): + + def T(w): + return w.T if self.fan_in_fan_out else w + + if self.r > 0 and not self.merged: + result = F.linear(x, T(self.weight), bias=self.bias) + if self.r > 0: + result = result + (self.lora_dropout(x) @ self.lora_A.t() @ self.lora_B.t()) * self.scaling + return result + else: + return F.linear(x, T(self.weight), bias=self.bias) + + +def lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear: + assert lora_rank <= linear.in_features, f'LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})' + lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False) + return lora_linear + + +def convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None: + for name, child in module.named_children(): + if isinstance(child, nn.Linear): + setattr(module, name, lora_linear_wrapper(child, lora_rank)) + else: + convert_to_lora_recursively(child, lora_rank) + + +def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = 'none') -> nn.Module: + """Convert a torch.nn.Module to a LoRA module. + + Args: + module (nn.Module): The module to convert. + lora_rank (int): LoRA rank. + + Returns: + nn.Module: The converted module. + """ + if lora_rank <= 0: + return module + convert_to_lora_recursively(module, lora_rank) + lora.mark_only_lora_as_trainable(module, lora_train_bias) + return module + + +class LoRAModule(nn.Module): + """A LoRA module base class. All derived classes should call `convert_to_lora()` at the bottom of `__init__()`. + This class will convert all torch.nn.Linear layer to LoraLinear layer. + + Args: + lora_rank (int, optional): LoRA rank. 0 means LoRA is not applied. Defaults to 0. + lora_train_bias (str, optional): Whether LoRA train biases. + 'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers. + Defaults to 'none'. + """ + + def __init__(self, lora_rank: int = 0, lora_train_bias: str = 'none') -> None: + super().__init__() + self.lora_rank = lora_rank + self.lora_train_bias = lora_train_bias + + def convert_to_lora(self) -> None: + convert_to_lora_module(self, self.lora_rank, self.lora_train_bias) diff --git a/applications/Chat/coati/models/loss.py b/applications/Chat/coati/models/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..926c6e2a4e4131ece0693630bce4894e8cbec5f0 --- /dev/null +++ b/applications/Chat/coati/models/loss.py @@ -0,0 +1,117 @@ +from typing import Optional + +import torch +import torch.nn as nn + +from .utils import masked_mean + + +class GPTLMLoss(nn.Module): + """ + GPT Language Model Loss + """ + + def __init__(self): + super().__init__() + self.loss = nn.CrossEntropyLoss() + + def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + return self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + +class PolicyLoss(nn.Module): + """ + Policy Loss for PPO + """ + + def __init__(self, clip_eps: float = 0.2) -> None: + super().__init__() + self.clip_eps = clip_eps + + def forward(self, + log_probs: torch.Tensor, + old_log_probs: torch.Tensor, + advantages: torch.Tensor, + action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + ratio = (log_probs - old_log_probs).exp() + surr1 = ratio * advantages + surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages + loss = -torch.min(surr1, surr2) + if action_mask is not None: + loss = masked_mean(loss, action_mask) + loss = loss.mean() + return loss + + +class ValueLoss(nn.Module): + """ + Value Loss for PPO + """ + + def __init__(self, clip_eps: float = 0.4) -> None: + super().__init__() + self.clip_eps = clip_eps + + def forward(self, + values: torch.Tensor, + old_values: torch.Tensor, + reward: torch.Tensor, + action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps) + surr1 = (values_clipped - reward)**2 + surr2 = (values - reward)**2 + loss = torch.max(surr1, surr2) + loss = loss.mean() + return 0.5 * loss + + +class PPOPtxActorLoss(nn.Module): + """ + To Do: + + PPO-ptx Actor Loss + """ + + def __init__(self, policy_clip_eps: float = 0.2, pretrain_coef: float = 0.0, pretrain_loss_fn=GPTLMLoss()) -> None: + super().__init__() + self.pretrain_coef = pretrain_coef + self.policy_loss_fn = PolicyLoss(clip_eps=policy_clip_eps) + self.pretrain_loss_fn = pretrain_loss_fn + + def forward(self, + log_probs: torch.Tensor, + old_log_probs: torch.Tensor, + advantages: torch.Tensor, + lm_logits: torch.Tensor, + lm_input_ids: torch.Tensor, + action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + policy_loss = self.policy_loss_fn(log_probs, old_log_probs, advantages, action_mask=action_mask) + lm_loss = self.pretrain_loss_fn(lm_logits, lm_input_ids) + return policy_loss + self.pretrain_coef * lm_loss + + +class LogSigLoss(nn.Module): + """ + Pairwise Loss for Reward Model + Details: https://arxiv.org/abs/2203.02155 + """ + + def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor: + probs = torch.sigmoid(chosen_reward - reject_reward) + log_probs = torch.log(probs) + loss = -log_probs.mean() + return loss + + +class LogExpLoss(nn.Module): + """ + Pairwise Loss for Reward Model + Details: https://arxiv.org/abs/2204.05862 + """ + + def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor: + loss = torch.log(1 + torch.exp(reject_reward - chosen_reward)).mean() + return loss diff --git a/applications/Chat/coati/models/opt/__init__.py b/applications/Chat/coati/models/opt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..334f4df0032a1da8f8a3d23c9c987448c9dbc0d8 --- /dev/null +++ b/applications/Chat/coati/models/opt/__init__.py @@ -0,0 +1,5 @@ +from .opt_actor import OPTActor +from .opt_critic import OPTCritic +from .opt_rm import OPTRM + +__all__ = ['OPTActor', 'OPTCritic', 'OPTRM'] diff --git a/applications/Chat/coati/models/opt/opt_actor.py b/applications/Chat/coati/models/opt/opt_actor.py new file mode 100644 index 0000000000000000000000000000000000000000..c14e4377ffb2b00983f4ea4a1af7c9931e9d1cb9 --- /dev/null +++ b/applications/Chat/coati/models/opt/opt_actor.py @@ -0,0 +1,35 @@ +from typing import Optional + +from transformers.models.opt.configuration_opt import OPTConfig +from transformers.models.opt.modeling_opt import OPTForCausalLM + +from ..base import Actor + + +class OPTActor(Actor): + """ + OPT Actor model. + + Args: + pretrained (str): Pretrained model name or path. + config (OPTConfig): Model config. + checkpoint (bool): Enable gradient checkpointing. + lora_rank (int): Rank of the low-rank approximation. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__(self, + pretrained: Optional[str] = None, + config: Optional[OPTConfig] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = 'none') -> None: + if pretrained is not None: + model = OPTForCausalLM.from_pretrained(pretrained) + elif config is not None: + model = OPTForCausalLM(config) + else: + model = OPTForCausalLM(OPTConfig()) + if checkpoint: + model.gradient_checkpointing_enable() + super().__init__(model, lora_rank, lora_train_bias) diff --git a/applications/Chat/coati/models/opt/opt_critic.py b/applications/Chat/coati/models/opt/opt_critic.py new file mode 100644 index 0000000000000000000000000000000000000000..fcfebd8a8b031785d93d709aead85d5aa6f30d08 --- /dev/null +++ b/applications/Chat/coati/models/opt/opt_critic.py @@ -0,0 +1,38 @@ +from typing import Optional + +import torch.nn as nn +from transformers.models.opt.configuration_opt import OPTConfig +from transformers.models.opt.modeling_opt import OPTModel + +from ..base import Critic + + +class OPTCritic(Critic): + """ + OPT Critic model. + + Args: + pretrained (str): Pretrained model name or path. + config (OPTConfig): Model config. + checkpoint (bool): Enable gradient checkpointing. + lora_rank (int): Rank of the low-rank approximation. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__(self, + pretrained: Optional[str] = None, + config: Optional[OPTConfig] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = 'none', + **kwargs) -> None: + if pretrained is not None: + model = OPTModel.from_pretrained(pretrained) + elif config is not None: + model = OPTModel(config) + else: + model = OPTModel(OPTConfig()) + if checkpoint: + model.gradient_checkpointing_enable() + value_head = nn.Linear(model.config.word_embed_proj_dim, 1) + super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs) diff --git a/applications/Chat/coati/models/opt/opt_rm.py b/applications/Chat/coati/models/opt/opt_rm.py new file mode 100644 index 0000000000000000000000000000000000000000..50fc0dee8568f86d8d63e712104306e8c1e013b8 --- /dev/null +++ b/applications/Chat/coati/models/opt/opt_rm.py @@ -0,0 +1,38 @@ +from typing import Optional + +import torch.nn as nn +from transformers import OPTConfig, OPTModel + +from ..base import RewardModel + + +class OPTRM(RewardModel): + """ + OPT Reward model. + + Args: + pretrained (str): Pretrained model name or path. + config (OPTConfig): Model config. + checkpoint (bool): Enable gradient checkpointing. + lora_rank (int): Rank of the low-rank approximation. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__(self, + pretrained: Optional[str] = None, + config: Optional[OPTConfig] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = 'none') -> None: + if pretrained is not None: + model = OPTModel.from_pretrained(pretrained) + elif config is not None: + model = OPTModel(config) + else: + model = OPTModel(OPTConfig()) + if checkpoint: + model.gradient_checkpointing_enable() + + value_head = nn.Linear(model.config.word_embed_proj_dim, 1) + value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.word_embed_proj_dim + 1)) + super().__init__(model, value_head, lora_rank, lora_train_bias) diff --git a/applications/Chat/coati/models/roberta/__init__.py b/applications/Chat/coati/models/roberta/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0f4a8de067b1695c4abbc880f2e3dc1f6510ad26 --- /dev/null +++ b/applications/Chat/coati/models/roberta/__init__.py @@ -0,0 +1,5 @@ +from .roberta_actor import RoBERTaActor +from .roberta_critic import RoBERTaCritic +from .roberta_rm import RoBERTaRM + +__all__ = ['RoBERTaActor', 'RoBERTaCritic', 'RoBERTaRM'] \ No newline at end of file diff --git a/applications/Chat/coati/models/roberta/roberta_actor.py b/applications/Chat/coati/models/roberta/roberta_actor.py new file mode 100644 index 0000000000000000000000000000000000000000..e35fa6eb19a8053a2ea5cec3f2a07a2dd3c80735 --- /dev/null +++ b/applications/Chat/coati/models/roberta/roberta_actor.py @@ -0,0 +1,35 @@ +from typing import Optional + +from transformers.models.roberta.configuration_roberta import RobertaConfig +from transformers.models.roberta.modeling_roberta import RobertaForCausalLM + +from ..base import Actor + +class RoBERTaActor(Actor): + """ + RoBERTa Actor model. + + Args: + pretrained (str): Pretrained model name or path. + config (RoBERTaConfig): Model config. + checkpoint (bool): Enable gradient checkpointing. + lora_rank (int): Rank of the low-rank approximation. + lora_train_bias (str): LoRA bias training mode. + """ + + + def __init__(self, + pretrained: Optional[str] = None, + config: Optional[RobertaConfig] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = 'none') -> None: + if pretrained is not None: + model = RobertaForCausalLM.from_pretrained(pretrained) + elif config is not None: + model = RobertaForCausalLM(config) + else: + model = RobertaForCausalLM(RobertaConfig()) + if checkpoint: + model.gradient_checkpointing_enable() + super().__init__(model, lora_rank, lora_train_bias) diff --git a/applications/Chat/coati/models/roberta/roberta_critic.py b/applications/Chat/coati/models/roberta/roberta_critic.py new file mode 100644 index 0000000000000000000000000000000000000000..c8dc0d9e14f2813907a6345ffbf93784e9b8528c --- /dev/null +++ b/applications/Chat/coati/models/roberta/roberta_critic.py @@ -0,0 +1,38 @@ +from typing import Optional + +import torch.nn as nn +from transformers.models.roberta.configuration_roberta import RobertaConfig +from transformers.models.roberta.modeling_roberta import RobertaModel + +from ..base import Critic + + +class RoBERTaCritic(Critic): + """ + RoBERTa Critic model. + + Args: + pretrained (str): Pretrained model name or path. + config (RoBERTa Config): Model config. + checkpoint (bool): Enable gradient checkpointing. + lora_rank (int): Rank of the low-rank approximation. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__(self, + pretrained: Optional[str] = None, + config: Optional[RobertaConfig] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = 'none', + **kwargs) -> None: + if pretrained is not None: + model = RobertaModel.from_pretrained(pretrained, add_pooling_layer=False) + elif config is not None: + model = RobertaModel(config) + else: + model = RobertaModel(RobertaConfig()) + if checkpoint: + model.gradient_checkpointing_enable() + value_head = nn.Linear(model.config.hidden_size, 1) + super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs) diff --git a/applications/Chat/coati/models/roberta/roberta_rm.py b/applications/Chat/coati/models/roberta/roberta_rm.py new file mode 100644 index 0000000000000000000000000000000000000000..77075052978b56d9bde336b3ac0473a343f6c332 --- /dev/null +++ b/applications/Chat/coati/models/roberta/roberta_rm.py @@ -0,0 +1,39 @@ +from typing import Optional + +import torch.nn as nn +from transformers import RobertaConfig, RobertaModel + + +from ..base import RewardModel + + +class RoBERTaRM(RewardModel): + """ + RoBERTa Reward model. + + Args: + pretrained (str): Pretrained model name or path. + config (RoBERTaConfig): Model config. + checkpoint (bool): Enable gradient checkpointing. + lora_rank (int): Rank of the low-rank approximation. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__(self, + pretrained: Optional[str] = None, + config: Optional[RobertaConfig] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = 'none') -> None: + if pretrained is not None: + model = RobertaModel.from_pretrained(pretrained, add_pooling_layer=False) + elif config is not None: + model = RobertaModel(config) + else: + model = RobertaModel(RobertaConfig()) + if checkpoint: + model.gradient_checkpointing_enable() + + value_head = nn.Linear(model.config.hidden_size, 1) + value_head.weight.data.normal_(mean=0.0, std=1/(model.config.hidden_size + 1)) + super().__init__(model, value_head, lora_rank, lora_train_bias) \ No newline at end of file diff --git a/applications/Chat/coati/models/utils.py b/applications/Chat/coati/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0ff13181fcd2f9d36aa413c0ff4a3f2b04a1ca0d --- /dev/null +++ b/applications/Chat/coati/models/utils.py @@ -0,0 +1,92 @@ +from typing import Optional, Union + +import loralib as lora +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def compute_approx_kl(log_probs: torch.Tensor, + log_probs_base: torch.Tensor, + action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Compute the approximate KL divergence between two distributions. + Schulman blog: http://joschu.net/blog/kl-approx.html + + Args: + log_probs: Log probabilities of the new distribution. + log_probs_base: Log probabilities of the base distribution. + action_mask: Mask for actions. + """ + + log_ratio = log_probs - log_probs_base + approx_kl = (log_ratio.exp() - 1) - log_ratio + if action_mask is not None: + approx_kl = masked_mean(approx_kl, action_mask, dim=1) + return approx_kl + approx_kl = approx_kl.mean(dim=1) + return approx_kl + + +def compute_reward(r: Union[torch.Tensor, float], + kl_coef: float, + log_probs: torch.Tensor, + log_probs_base: torch.Tensor, + action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + if kl_coef <= 0.0: + return r + kl = compute_approx_kl(log_probs, log_probs_base, action_mask=action_mask) + reward = r - kl_coef * kl + return reward + + +def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + log_probs = F.log_softmax(logits, dim=-1) + log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)) + return log_probs_labels.squeeze(-1) + + +def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor: + tensor = tensor * mask + tensor = tensor.sum(dim=dim) + mask_sum = mask.sum(dim=dim) + mean = tensor / (mask_sum + 1e-8) + return mean + + +def masked_normalize(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1, eps: float = 1e-8) -> torch.Tensor: + tensor = tensor * mask + mean = masked_mean(tensor, mask, dim=dim) + mean_centered = tensor - mean + var = masked_mean(mean_centered**2, mask, dim=dim) + return mean_centered * var.clamp(min=eps).rsqrt() + + +def normalize(tensor: torch.Tensor, dim: int = 0, eps: float = 1e-8) -> torch.Tensor: + mean = tensor.mean(dim) + mean_centered = tensor - mean + var = (mean_centered**2).mean(dim) + norm = mean_centered * var.clamp(min=eps).rsqrt() + return norm + + +def convert_to_lora(model: nn.Module, + input_size: int, + output_size: int, + lora_rank: int = 16, + lora_alpha: int = 1, + lora_dropout: float = 0., + fan_in_fan_out: bool = False, + merge_weights: bool = True): + if lora_rank > min(input_size, output_size): + raise ValueError(f"LoRA rank {lora_rank} must be less or equal than {min(input_size, output_size)}") + + for name, module in model.named_modules(): + if isinstance(module, nn.Linear): + module._modules[name] = lora.Linear(input_size, + output_size, + r=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + fan_in_fan_out=fan_in_fan_out, + merge_weights=merge_weights) diff --git a/applications/Chat/coati/ray/__init__.py b/applications/Chat/coati/ray/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5802c05bc03feb3b755754de7385a2141e547673 --- /dev/null +++ b/applications/Chat/coati/ray/__init__.py @@ -0,0 +1,2 @@ +from .src.detached_replay_buffer import DetachedReplayBuffer +from .src.detached_trainer_ppo import DetachedPPOTrainer diff --git a/applications/Chat/coati/ray/example/1m1t.py b/applications/Chat/coati/ray/example/1m1t.py new file mode 100644 index 0000000000000000000000000000000000000000..a6527370505b9ce87a1985b8b3d45bcbf0c103a3 --- /dev/null +++ b/applications/Chat/coati/ray/example/1m1t.py @@ -0,0 +1,153 @@ +import argparse +from copy import deepcopy + +import pandas as pd +import torch +from coati.trainer import PPOTrainer + + +from coati.ray.src.experience_maker_holder import ExperienceMakerHolder +from coati.ray.src.detached_trainer_ppo import DetachedPPOTrainer + +from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy +from coati.experience_maker import NaiveExperienceMaker +from torch.optim import Adam +from transformers import AutoTokenizer, BloomTokenizerFast +from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer + +from colossalai.nn.optimizer import HybridAdam + +import ray +import os +import socket + +def get_free_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(('', 0)) + return s.getsockname()[1] + + +def get_local_ip(): + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: + s.connect(('8.8.8.8', 80)) + return s.getsockname()[0] + +def main(args): + master_addr = str(get_local_ip()) + # trainer_env_info + trainer_port = str(get_free_port()) + env_info_trainer = {'local_rank' : '0', + 'rank' : '0', + 'world_size' : '1', + 'master_port' : trainer_port, + 'master_addr' : master_addr} + + # maker_env_info + maker_port = str(get_free_port()) + env_info_maker = {'local_rank' : '0', + 'rank' : '0', + 'world_size' : '1', + 'master_port' : maker_port, + 'master_addr' : master_addr} + + # configure tokenizer + if args.model == 'gpt2': + tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + tokenizer.pad_token = tokenizer.eos_token + elif args.model == 'bloom': + tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain) + tokenizer.pad_token = tokenizer.eos_token + elif args.model == 'opt': + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + else: + raise ValueError(f'Unsupported model "{args.model}"') + + # configure Trainer + trainer_ref = DetachedPPOTrainer.options(name="trainer1", num_gpus=1, max_concurrency=2).remote( + experience_maker_holder_name_list=["maker1"], + strategy=args.trainer_strategy, + model=args.model, + env_info = env_info_trainer, + pretrained=args.pretrain, + lora_rank=args.lora_rank, + train_batch_size=args.train_batch_size, + buffer_limit=16, + experience_batch_size=args.experience_batch_size, + max_epochs=args.max_epochs, + #kwargs: + max_length=128, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + debug=args.debug, + ) + + # configure Experience Maker + experience_holder_ref = ExperienceMakerHolder.options(name="maker1", num_gpus=1, max_concurrency=2).remote( + detached_trainer_name_list=["trainer1"], + strategy=args.maker_strategy, + env_info = env_info_maker, + experience_batch_size=args.experience_batch_size, + kl_coef=0.1, + #kwargs: + max_length=128, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + debug=args.debug, + ) + + # trainer send its actor and critic to experience holders. + ray.get(trainer_ref.initialize_remote_makers.remote()) + + # configure sampler + dataset = pd.read_csv(args.prompt_path)['prompt'] + + def tokenize_fn(texts): + # MUST padding to max length to ensure inputs of all ranks have the same length + # Different length may lead to hang when using gemini, as different generation steps + batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True) + return {k: v.cuda() for k, v in batch.items()} + + trainer_done_ref = trainer_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps) + num_exp_per_maker = args.num_episodes * args.max_timesteps // args.update_timesteps * args.max_epochs + 3 # +3 for fault tolerance + maker_done_ref = experience_holder_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker) + + ray.get([trainer_done_ref, maker_done_ref]) + + # save model checkpoint after fitting + trainer_ref.strategy_save_actor.remote(args.save_path, only_rank0=True) + # save optimizer checkpoint on all ranks + if args.need_optim_ckpt: + trainer_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), + only_rank0=False) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('prompt_path') + parser.add_argument('--trainer_strategy', + choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], + default='naive') + parser.add_argument('--maker_strategy', + choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], + default='naive') + parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt']) + parser.add_argument('--pretrain', type=str, default=None) + parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt') + parser.add_argument('--need_optim_ckpt', type=bool, default=False) + parser.add_argument('--num_episodes', type=int, default=10) + parser.add_argument('--max_timesteps', type=int, default=10) + parser.add_argument('--update_timesteps', type=int, default=10) + parser.add_argument('--max_epochs', type=int, default=5) + parser.add_argument('--train_batch_size', type=int, default=8) + parser.add_argument('--experience_batch_size', type=int, default=8) + parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") + + parser.add_argument('--debug', action='store_true') + args = parser.parse_args() + ray.init(namespace=os.environ["RAY_NAMESPACE"]) + main(args) diff --git a/applications/Chat/coati/ray/example/1m1t.sh b/applications/Chat/coati/ray/example/1m1t.sh new file mode 100644 index 0000000000000000000000000000000000000000..f7c5054c800eb376ae973a5103482aec97b0511f --- /dev/null +++ b/applications/Chat/coati/ray/example/1m1t.sh @@ -0,0 +1,23 @@ +set_n_least_used_CUDA_VISIBLE_DEVICES() { + local n=${1:-"9999"} + echo "GPU Memory Usage:" + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \ + | tail -n +2 \ + | nl -v 0 \ + | tee /dev/tty \ + | sort -g -k 2 \ + | awk '{print $1}' \ + | head -n $n) + export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') + echo "Now CUDA_VISIBLE_DEVICES is set to:" + echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +} + +set_n_least_used_CUDA_VISIBLE_DEVICES 2 + +export RAY_NAMESPACE="admin" + +python 1m1t.py "/path/to/prompts.csv" \ + --trainer_strategy colossalai_zero2 --maker_strategy naive --lora_rank 2 --pretrain "facebook/opt-350m" --model 'opt' \ + --num_episodes 10 --max_timesteps 10 --update_timesteps 10 \ + --max_epochs 10 --debug diff --git a/applications/Chat/coati/ray/example/1m2t.py b/applications/Chat/coati/ray/example/1m2t.py new file mode 100644 index 0000000000000000000000000000000000000000..3883c364a8e02fe0adcfa742599870d90f78631d --- /dev/null +++ b/applications/Chat/coati/ray/example/1m2t.py @@ -0,0 +1,186 @@ +import argparse +from copy import deepcopy + +import pandas as pd +import torch +from coati.trainer import PPOTrainer + + +from coati.ray.src.experience_maker_holder import ExperienceMakerHolder +from coati.ray.src.detached_trainer_ppo import DetachedPPOTrainer + +from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy +from coati.experience_maker import NaiveExperienceMaker +from torch.optim import Adam +from transformers import AutoTokenizer, BloomTokenizerFast +from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer + +from colossalai.nn.optimizer import HybridAdam + +import ray +import os +import socket + + +def get_free_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(('', 0)) + return s.getsockname()[1] + + +def get_local_ip(): + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: + s.connect(('8.8.8.8', 80)) + return s.getsockname()[0] + +def main(args): + master_addr = str(get_local_ip()) + # trainer_env_info + trainer_port = str(get_free_port()) + env_info_trainer_1 = {'local_rank' : '0', + 'rank' : '0', + 'world_size' : '2', + 'master_port' : trainer_port, + 'master_addr' : master_addr} + env_info_trainer_2 = {'local_rank' : '0', + 'rank' : '1', + 'world_size' : '2', + 'master_port' : trainer_port, + 'master_addr' : master_addr} + # maker_env_info + maker_port = str(get_free_port()) + env_info_maker_1 = {'local_rank' : '0', + 'rank' : '0', + 'world_size' : '2', + 'master_port' : maker_port, + 'master_addr' : master_addr} + print([env_info_trainer_1, + env_info_trainer_2, + env_info_maker_1]) + ray.init(dashboard_port = 1145) + # configure tokenizer + if args.model == 'gpt2': + tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + tokenizer.pad_token = tokenizer.eos_token + elif args.model == 'bloom': + tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain) + tokenizer.pad_token = tokenizer.eos_token + elif args.model == 'opt': + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + else: + raise ValueError(f'Unsupported model "{args.model}"') + + # configure Trainer + trainer_1_ref = DetachedPPOTrainer.options(name="trainer1", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote( + experience_maker_holder_name_list=["maker1"], + strategy=args.trainer_strategy, + model=args.model, + env_info=env_info_trainer_1, + pretrained=args.pretrain, + lora_rank=args.lora_rank, + train_batch_size=args.train_batch_size, + buffer_limit=16, + experience_batch_size=args.experience_batch_size, + max_epochs=args.max_epochs, + #kwargs: + max_length=128, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + debug=args.debug, + ) + + trainer_2_ref = DetachedPPOTrainer.options(name="trainer2", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote( + experience_maker_holder_name_list=["maker1"], + strategy=args.trainer_strategy, + model=args.model, + env_info=env_info_trainer_2, + pretrained=args.pretrain, + lora_rank=args.lora_rank, + train_batch_size=args.train_batch_size, + buffer_limit=16, + experience_batch_size=args.experience_batch_size, + max_epochs=args.max_epochs, + #kwargs: + max_length=128, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + debug= args.debug, + ) + + # configure Experience Maker + experience_holder_1_ref = ExperienceMakerHolder.options(name="maker1", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote( + detached_trainer_name_list=["trainer1", "trainer2"], + strategy=args.maker_strategy, + env_info=env_info_maker_1, + experience_batch_size=args.experience_batch_size, + kl_coef=0.1, + #kwargs: + max_length=128, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + debug=args.debug, + ) + + # trainer send its actor and critic to experience holders. + # TODO: balance duty + ray.get(trainer_1_ref.initialize_remote_makers.remote()) + + # configure sampler + dataset = pd.read_csv(args.prompt_path)['prompt'] + + def tokenize_fn(texts): + # MUST padding to max length to ensure inputs of all ranks have the same length + # Different length may lead to hang when using gemini, as different generation steps + batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True) + return {k: v.cuda() for k, v in batch.items()} + + trainer_1_done_ref = trainer_1_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps) + trainer_2_done_ref = trainer_2_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps) + num_exp_per_maker = args.num_episodes * args.max_timesteps // args.update_timesteps * args.max_epochs * 2 + 3 # +3 for fault tolerance + maker_1_done_ref = experience_holder_1_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker) + + ray.get([trainer_1_done_ref, trainer_2_done_ref, maker_1_done_ref]) + # save model checkpoint after fitting + trainer_1_ref.strategy_save_actor.remote(args.save_path, only_rank0=True) + trainer_2_ref.strategy_save_actor.remote(args.save_path, only_rank0=True) + # save optimizer checkpoint on all ranks + if args.need_optim_ckpt: + trainer_1_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), + only_rank0=False) + trainer_2_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), + only_rank0=False) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('prompt_path') + parser.add_argument('--trainer_strategy', + choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], + default='naive') + parser.add_argument('--maker_strategy', + choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], + default='naive') + parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt']) + parser.add_argument('--pretrain', type=str, default=None) + parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt') + parser.add_argument('--need_optim_ckpt', type=bool, default=False) + parser.add_argument('--num_episodes', type=int, default=10) + parser.add_argument('--max_timesteps', type=int, default=10) + parser.add_argument('--update_timesteps', type=int, default=10) + parser.add_argument('--max_epochs', type=int, default=5) + parser.add_argument('--train_batch_size', type=int, default=8) + parser.add_argument('--experience_batch_size', type=int, default=8) + parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") + + parser.add_argument('--debug', action='store_true') + args = parser.parse_args() + main(args) diff --git a/applications/Chat/coati/ray/example/1m2t.sh b/applications/Chat/coati/ray/example/1m2t.sh new file mode 100644 index 0000000000000000000000000000000000000000..669f4141026c25ef405d2502a506d1bb019520ea --- /dev/null +++ b/applications/Chat/coati/ray/example/1m2t.sh @@ -0,0 +1,23 @@ +set_n_least_used_CUDA_VISIBLE_DEVICES() { + local n=${1:-"9999"} + echo "GPU Memory Usage:" + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \ + | tail -n +2 \ + | nl -v 0 \ + | tee /dev/tty \ + | sort -g -k 2 \ + | awk '{print $1}' \ + | head -n $n) + export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') + echo "Now CUDA_VISIBLE_DEVICES is set to:" + echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +} + +set_n_least_used_CUDA_VISIBLE_DEVICES 2 + +export RAY_NAMESPACE="admin" + +python 1m2t.py "/path/to/prompts.csv" --model gpt2 \ + --maker_strategy naive --trainer_strategy ddp --lora_rank 2 \ + --num_episodes 10 --max_timesteps 10 --update_timesteps 10 \ + --max_epochs 10 #--debug \ No newline at end of file diff --git a/applications/Chat/coati/ray/example/2m1t.py b/applications/Chat/coati/ray/example/2m1t.py new file mode 100644 index 0000000000000000000000000000000000000000..b655de1ab1fa987987fbd5954b6c94db98b27c50 --- /dev/null +++ b/applications/Chat/coati/ray/example/2m1t.py @@ -0,0 +1,140 @@ +import argparse +from copy import deepcopy + +import pandas as pd +import torch +from coati.trainer import PPOTrainer + + +from coati.ray.src.experience_maker_holder import ExperienceMakerHolder +from coati.ray.src.detached_trainer_ppo import DetachedPPOTrainer + +from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy +from coati.experience_maker import NaiveExperienceMaker +from torch.optim import Adam +from transformers import AutoTokenizer, BloomTokenizerFast +from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer + +from colossalai.nn.optimizer import HybridAdam + +import ray +import os +import socket + + +def main(args): + # configure tokenizer + if args.model == 'gpt2': + tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + tokenizer.pad_token = tokenizer.eos_token + elif args.model == 'bloom': + tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain) + tokenizer.pad_token = tokenizer.eos_token + elif args.model == 'opt': + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + else: + raise ValueError(f'Unsupported model "{args.model}"') + + # configure Trainer + trainer_ref = DetachedPPOTrainer.options(name="trainer1", num_gpus=1, max_concurrency=2).remote( + experience_maker_holder_name_list=["maker1", "maker2"], + strategy=args.trainer_strategy, + model=args.model, + pretrained=args.pretrain, + lora_rank=args.lora_rank, + train_batch_size=args.train_batch_size, + buffer_limit=16, + experience_batch_size=args.experience_batch_size, + max_epochs=args.max_epochs, + #kwargs: + max_length=128, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + debug=args.debug, + ) + + # configure Experience Maker + experience_holder_1_ref = ExperienceMakerHolder.options(name="maker1", num_gpus=1, max_concurrency=2).remote( + detached_trainer_name_list=["trainer1"], + strategy=args.maker_strategy, + experience_batch_size=args.experience_batch_size, + kl_coef=0.1, + #kwargs: + max_length=128, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + debug=args.debug, + ) + + experience_holder_2_ref = ExperienceMakerHolder.options(name="maker2", num_gpus=1, max_concurrency=2).remote( + detached_trainer_name_list=["trainer1"], + strategy=args.maker_strategy, + experience_batch_size=args.experience_batch_size, + kl_coef=0.1, + #kwargs: + max_length=128, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + debug=args.debug, + ) + + # trainer send its actor and critic to experience holders. + ray.get(trainer_ref.initialize_remote_makers.remote()) + + # configure sampler + dataset = pd.read_csv(args.prompt_path)['prompt'] + + def tokenize_fn(texts): + # MUST padding to max length to ensure inputs of all ranks have the same length + # Different length may lead to hang when using gemini, as different generation steps + batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True) + return {k: v.cuda() for k, v in batch.items()} + + trainer_done_ref = trainer_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps) + num_exp_per_maker = args.num_episodes * args.max_timesteps // args.update_timesteps * args.max_epochs // 2 + 3 # +3 for fault tolerance + maker_1_done_ref = experience_holder_1_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker) + maker_2_done_ref = experience_holder_2_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker) + + ray.get([trainer_done_ref, maker_1_done_ref, maker_2_done_ref]) + + # save model checkpoint after fitting + trainer_ref.strategy_save_actor.remote(args.save_path, only_rank0=True) + # save optimizer checkpoint on all ranks + if args.need_optim_ckpt: + trainer_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), + only_rank0=False) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('prompt_path') + parser.add_argument('--trainer_strategy', + choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], + default='naive') + parser.add_argument('--maker_strategy', + choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], + default='naive') + parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt']) + parser.add_argument('--pretrain', type=str, default=None) + parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt') + parser.add_argument('--need_optim_ckpt', type=bool, default=False) + parser.add_argument('--num_episodes', type=int, default=10) + parser.add_argument('--max_timesteps', type=int, default=10) + parser.add_argument('--update_timesteps', type=int, default=10) + parser.add_argument('--max_epochs', type=int, default=5) + parser.add_argument('--train_batch_size', type=int, default=8) + parser.add_argument('--experience_batch_size', type=int, default=8) + parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") + + parser.add_argument('--debug', action='store_true') + args = parser.parse_args() + ray.init(namespace=os.environ["RAY_NAMESPACE"]) + main(args) diff --git a/applications/Chat/coati/ray/example/2m1t.sh b/applications/Chat/coati/ray/example/2m1t.sh new file mode 100644 index 0000000000000000000000000000000000000000..a207d4118d605a8c3a17882bcf3cc6b1a8f32eb3 --- /dev/null +++ b/applications/Chat/coati/ray/example/2m1t.sh @@ -0,0 +1,23 @@ +set_n_least_used_CUDA_VISIBLE_DEVICES() { + local n=${1:-"9999"} + echo "GPU Memory Usage:" + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \ + | tail -n +2 \ + | nl -v 0 \ + | tee /dev/tty \ + | sort -g -k 2 \ + | awk '{print $1}' \ + | head -n $n) + export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') + echo "Now CUDA_VISIBLE_DEVICES is set to:" + echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +} + +set_n_least_used_CUDA_VISIBLE_DEVICES 3 + +export RAY_NAMESPACE="admin" + +python 2m1t.py "/path/to/prompts.csv" \ + --trainer_strategy naive --maker_strategy naive --lora_rank 2 --pretrain "facebook/opt-350m" --model 'opt' \ + --num_episodes 10 --max_timesteps 10 --update_timesteps 10 \ + --max_epochs 10 # --debug diff --git a/applications/Chat/coati/ray/example/2m2t.py b/applications/Chat/coati/ray/example/2m2t.py new file mode 100644 index 0000000000000000000000000000000000000000..435c71915fc2820b9b12bbd2204823f12ec2438f --- /dev/null +++ b/applications/Chat/coati/ray/example/2m2t.py @@ -0,0 +1,209 @@ +import argparse +from copy import deepcopy + +import pandas as pd +import torch +from coati.trainer import PPOTrainer + + +from coati.ray.src.experience_maker_holder import ExperienceMakerHolder +from coati.ray.src.detached_trainer_ppo import DetachedPPOTrainer + +from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy +from coati.experience_maker import NaiveExperienceMaker +from torch.optim import Adam +from transformers import AutoTokenizer, BloomTokenizerFast +from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer + +from colossalai.nn.optimizer import HybridAdam + +import ray +import os +import socket + + +def get_free_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(('', 0)) + return s.getsockname()[1] + + +def get_local_ip(): + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: + s.connect(('8.8.8.8', 80)) + return s.getsockname()[0] + +def main(args): + master_addr = str(get_local_ip()) + # trainer_env_info + trainer_port = str(get_free_port()) + env_info_trainer_1 = {'local_rank' : '0', + 'rank' : '0', + 'world_size' : '2', + 'master_port' : trainer_port, + 'master_addr' : master_addr} + env_info_trainer_2 = {'local_rank' : '0', + 'rank' : '1', + 'world_size' : '2', + 'master_port' : trainer_port, + 'master_addr' : master_addr} + # maker_env_info + maker_port = str(get_free_port()) + env_info_maker_1 = {'local_rank' : '0', + 'rank' : '0', + 'world_size' : '2', + 'master_port' : maker_port, + 'master_addr' : master_addr} + env_info_maker_2 = {'local_rank' : '0', + 'rank' : '1', + 'world_size' : '2', + 'master_port': maker_port, + 'master_addr' : master_addr} + print([env_info_trainer_1, + env_info_trainer_2, + env_info_maker_1, + env_info_maker_2]) + ray.init() + # configure tokenizer + if args.model == 'gpt2': + tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + tokenizer.pad_token = tokenizer.eos_token + elif args.model == 'bloom': + tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain) + tokenizer.pad_token = tokenizer.eos_token + elif args.model == 'opt': + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + else: + raise ValueError(f'Unsupported model "{args.model}"') + + # configure Trainer + trainer_1_ref = DetachedPPOTrainer.options(name="trainer1", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote( + experience_maker_holder_name_list=["maker1", "maker2"], + strategy=args.trainer_strategy, + model=args.model, + env_info=env_info_trainer_1, + pretrained=args.pretrain, + lora_rank=args.lora_rank, + train_batch_size=args.train_batch_size, + buffer_limit=16, + experience_batch_size=args.experience_batch_size, + max_epochs=args.max_epochs, + #kwargs: + max_length=128, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + debug=args.debug, + ) + + trainer_2_ref = DetachedPPOTrainer.options(name="trainer2", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote( + experience_maker_holder_name_list=["maker1", "maker2"], + strategy=args.trainer_strategy, + model=args.model, + env_info=env_info_trainer_2, + pretrained=args.pretrain, + lora_rank=args.lora_rank, + train_batch_size=args.train_batch_size, + buffer_limit=16, + experience_batch_size=args.experience_batch_size, + max_epochs=args.max_epochs, + #kwargs: + max_length=128, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + debug=args.debug, + ) + + # configure Experience Maker + experience_holder_1_ref = ExperienceMakerHolder.options(name="maker1", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote( + detached_trainer_name_list=["trainer1", "trainer2"], + strategy=args.maker_strategy, + env_info=env_info_maker_1, + experience_batch_size=args.experience_batch_size, + kl_coef=0.1, + #kwargs: + max_length=128, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + debug=args.debug, + ) + + experience_holder_2_ref = ExperienceMakerHolder.options(name="maker2", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote( + detached_trainer_name_list=["trainer1", "trainer2"], + strategy=args.maker_strategy, + env_info=env_info_maker_2, + experience_batch_size=args.experience_batch_size, + kl_coef=0.1, + #kwargs: + max_length=128, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + debug=args.debug, + ) + + # trainer send its actor and critic to experience holders. + # TODO: balance duty + ray.get(trainer_1_ref.initialize_remote_makers.remote()) + + # configure sampler + dataset = pd.read_csv(args.prompt_path)['prompt'] + + def tokenize_fn(texts): + # MUST padding to max length to ensure inputs of all ranks have the same length + # Different length may lead to hang when using gemini, as different generation steps + batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True) + return {k: v.cuda() for k, v in batch.items()} + + trainer_1_done_ref = trainer_1_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps) + trainer_2_done_ref = trainer_2_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps) + num_exp_per_maker = args.num_episodes * args.max_timesteps // args.update_timesteps * args.max_epochs + 3 # +3 for fault tolerance + maker_1_done_ref = experience_holder_1_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker) + maker_2_done_ref = experience_holder_2_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker) + + ray.get([trainer_1_done_ref, trainer_2_done_ref, maker_1_done_ref, maker_2_done_ref]) + # save model checkpoint after fitting + trainer_1_ref.strategy_save_actor.remote(args.save_path, only_rank0=True) + trainer_2_ref.strategy_save_actor.remote(args.save_path, only_rank0=True) + # save optimizer checkpoint on all ranks + if args.need_optim_ckpt: + trainer_1_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), + only_rank0=False) + trainer_2_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), + only_rank0=False) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('prompt_path') + parser.add_argument('--trainer_strategy', + choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], + default='naive') + parser.add_argument('--maker_strategy', + choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], + default='naive') + parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt']) + parser.add_argument('--pretrain', type=str, default=None) + parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt') + parser.add_argument('--need_optim_ckpt', type=bool, default=False) + parser.add_argument('--num_episodes', type=int, default=10) + parser.add_argument('--max_timesteps', type=int, default=10) + parser.add_argument('--update_timesteps', type=int, default=10) + parser.add_argument('--max_epochs', type=int, default=5) + parser.add_argument('--train_batch_size', type=int, default=8) + parser.add_argument('--experience_batch_size', type=int, default=8) + parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") + + parser.add_argument('--debug', action='store_true') + args = parser.parse_args() + main(args) diff --git a/applications/Chat/coati/ray/example/2m2t.sh b/applications/Chat/coati/ray/example/2m2t.sh new file mode 100644 index 0000000000000000000000000000000000000000..fb4024766c54182efcd752f991bfc11c1db7588e --- /dev/null +++ b/applications/Chat/coati/ray/example/2m2t.sh @@ -0,0 +1,23 @@ +set_n_least_used_CUDA_VISIBLE_DEVICES() { + local n=${1:-"9999"} + echo "GPU Memory Usage:" + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \ + | tail -n +2 \ + | nl -v 0 \ + | tee /dev/tty \ + | sort -g -k 2 \ + | awk '{print $1}' \ + | head -n $n) + export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') + echo "Now CUDA_VISIBLE_DEVICES is set to:" + echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +} + +set_n_least_used_CUDA_VISIBLE_DEVICES 2 + +export RAY_NAMESPACE="admin" + +python 2m2t.py "path/to/prompts.csv" \ + --maker_strategy naive --trainer_strategy colossalai_zero2 --lora_rank 2 \ + --num_episodes 10 --max_timesteps 10 --update_timesteps 10 \ + --max_epochs 10 --debug \ No newline at end of file diff --git a/applications/Chat/coati/ray/src/__init__.py b/applications/Chat/coati/ray/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/applications/Chat/coati/ray/src/detached_replay_buffer.py b/applications/Chat/coati/ray/src/detached_replay_buffer.py new file mode 100644 index 0000000000000000000000000000000000000000..855eee48c5a5ccb252ccd3e697073d728afcd3df --- /dev/null +++ b/applications/Chat/coati/ray/src/detached_replay_buffer.py @@ -0,0 +1,88 @@ +import torch +import random +from typing import List, Any +# from torch.multiprocessing import Queue +from ray.util.queue import Queue +import ray +import asyncio +from coati.experience_maker.base import Experience +from coati.replay_buffer.utils import BufferItem, make_experience_batch, split_experience_batch +from coati.replay_buffer import ReplayBuffer +from threading import Lock +import copy + +class DetachedReplayBuffer: + ''' + Detached replay buffer. Share Experience across workers on the same node. + Therefore a trainer node is expected to have only one instance. + It is ExperienceMakerHolder's duty to call append(exp) method, remotely. + + Args: + sample_batch_size: Batch size when sampling. Exp won't enqueue until they formed a batch. + tp_world_size: Number of workers in the same tp group + limit: Limit of number of experience sample BATCHs. A number <= 0 means unlimited. Defaults to 0. + cpu_offload: Whether to offload experience to cpu when sampling. Defaults to True. + ''' + + def __init__(self, sample_batch_size: int, tp_world_size: int = 1, limit : int = 0, cpu_offload: bool = True) -> None: + self.cpu_offload = cpu_offload + self.sample_batch_size = sample_batch_size + self.limit = limit + self.items = Queue(self.limit, actor_options={"num_cpus":1}) + self.batch_collector : List[BufferItem] = [] + + ''' + Workers in the same tp group share this buffer and need same sample for one step. + Therefore a held_sample should be returned tp_world_size times before it could be dropped. + worker_state records wheter a worker got the held_sample + ''' + self.tp_world_size = tp_world_size + self.worker_state = [False] * self.tp_world_size + self.held_sample = None + self._worker_state_lock = Lock() + + @torch.no_grad() + def append(self, experience: Experience) -> None: + ''' + Expected to be called remotely. + ''' + if self.cpu_offload: + experience.to_device(torch.device('cpu')) + items = split_experience_batch(experience) + self.batch_collector.extend(items) + while len(self.batch_collector) >= self.sample_batch_size: + items = self.batch_collector[:self.sample_batch_size] + experience = make_experience_batch(items) + self.items.put(experience, block=True) + self.batch_collector = self.batch_collector[self.sample_batch_size:] + + def clear(self) -> None: + # self.items.close() + self.items.shutdown() + self.items = Queue(self.limit) + self.worker_state = [False] * self.tp_world_size + self.batch_collector = [] + + @torch.no_grad() + def sample(self, worker_rank = 0, to_device = "cpu") -> Experience: + self._worker_state_lock.acquire() + if not any(self.worker_state): + self.held_sample = self._sample_and_erase() + self.worker_state[worker_rank] = True + if all(self.worker_state): + self.worker_state = [False] * self.tp_world_size + ret = self.held_sample + else: + ret = copy.deepcopy(self.held_sample) + self._worker_state_lock.release() + ret.to_device(to_device) + return ret + + @torch.no_grad() + def _sample_and_erase(self) -> Experience: + ret = self.items.get(block=True) + return ret + + def get_length(self) -> int: + ret = self.items.qsize() + return ret \ No newline at end of file diff --git a/applications/Chat/coati/ray/src/detached_trainer_base.py b/applications/Chat/coati/ray/src/detached_trainer_base.py new file mode 100644 index 0000000000000000000000000000000000000000..f1ed1ec71499a9e7f84c80a52e357ef4c2a92201 --- /dev/null +++ b/applications/Chat/coati/ray/src/detached_trainer_base.py @@ -0,0 +1,121 @@ +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, List, Optional, Union +from tqdm import tqdm +from coati.trainer.callbacks import Callback +from coati.experience_maker import Experience +import ray +import os + +from .detached_replay_buffer import DetachedReplayBuffer +from .utils import is_rank_0 + +class DetachedTrainer(ABC): + ''' + Base class for detached rlhf trainers. + 'detach' means that the experience maker is detached compared to a normal Trainer. + Please set name attribute during init: + >>> trainer = DetachedTrainer.options(..., name = "xxx", ...).remote() + So an ExperienceMakerHolder can reach the detached_replay_buffer by Actor's name. + Args: + detached_strategy (DetachedStrategy): the strategy to use for training + detached_replay_buffer_ref (ObjectRef[DetachedReplayBuffer]): the replay buffer to use for training + experience_batch_size (int, defaults to 8): the batch size to use for experience generation + max_epochs (int, defaults to 1): the number of epochs of training process + data_loader_pin_memory (bool, defaults to True): whether to pin memory for data loader + callbacks (List[Callback], defaults to []): the callbacks to call during training process + generate_kwargs (dict, optional): the kwargs to use while model generating + ''' + + def __init__(self, + experience_maker_holder_name_list: List[str], + train_batch_size: int = 8, + buffer_limit: int = 0, + buffer_cpu_offload: bool = True, + experience_batch_size: int = 8, + max_epochs: int = 1, + dataloader_pin_memory: bool = True, + callbacks: List[Callback] = [], + **generate_kwargs) -> None: + super().__init__() + self.detached_replay_buffer = DetachedReplayBuffer(train_batch_size, limit=buffer_limit, cpu_offload=buffer_cpu_offload) + self.experience_batch_size = experience_batch_size + self.max_epochs = max_epochs + self.dataloader_pin_memory = dataloader_pin_memory + self.callbacks = callbacks + self.generate_kwargs = generate_kwargs + self.target_holder_name_list = experience_maker_holder_name_list + self.target_holder_list = [] + + def update_target_holder_list(self, experience_maker_holder_name_list): + self.target_holder_name_list = experience_maker_holder_name_list + self.target_holder_list = [] + for name in self.target_holder_name_list: + self.target_holder_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"])) + + @abstractmethod + def _update_remote_makers(self): + pass + + @abstractmethod + def training_step(self, experience: Experience) -> Dict[str, Any]: + pass + + def _learn(self): + pbar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0()) + for _ in pbar: + if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: + print("[trainer] sampling exp") + experience = self._buffer_sample() + if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: + print("[trainer] training step") + metrics = self.training_step(experience) + if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: + print("[trainer] step over") + pbar.set_postfix(metrics) + + def fit(self, num_episodes: int = 50000, max_timesteps: int = 500, update_timesteps: int = 5000) -> None: + self._on_fit_start() + for episode in range(num_episodes): + self._on_episode_start(episode) + for timestep in tqdm(range(max_timesteps // update_timesteps), + desc=f'Episode [{episode+1}/{num_episodes}]', + disable=not is_rank_0()): + self._learn() + self._update_remote_makers() + self._on_episode_end(episode) + self._on_fit_end() + + @ray.method(concurrency_group="buffer_length") + def buffer_get_length(self): + # called by ExperienceMakerHolder + if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: + print("[trainer] telling length") + return self.detached_replay_buffer.get_length() + + @ray.method(concurrency_group="buffer_append") + def buffer_append(self, experience: Experience): + # called by ExperienceMakerHolder + if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: + # print(f"[trainer] receiving exp. Current buffer length: {self.detached_replay_buffer.get_length()}") + print(f"[trainer] receiving exp.") + self.detached_replay_buffer.append(experience) + + @ray.method(concurrency_group="buffer_sample") + def _buffer_sample(self): + return self.detached_replay_buffer.sample() + + def _on_fit_start(self) -> None: + for callback in self.callbacks: + callback.on_fit_start() + + def _on_fit_end(self) -> None: + for callback in self.callbacks: + callback.on_fit_end() + + def _on_episode_start(self, episode: int) -> None: + for callback in self.callbacks: + callback.on_episode_start(episode) + + def _on_episode_end(self, episode: int) -> None: + for callback in self.callbacks: + callback.on_episode_end(episode) diff --git a/applications/Chat/coati/ray/src/detached_trainer_ppo.py b/applications/Chat/coati/ray/src/detached_trainer_ppo.py new file mode 100644 index 0000000000000000000000000000000000000000..838e82d07f4addb8924e1f808c2a47fed363c1f8 --- /dev/null +++ b/applications/Chat/coati/ray/src/detached_trainer_ppo.py @@ -0,0 +1,192 @@ +from typing import Any, Callable, Dict, List, Optional +import torch +from torch.optim import Adam + +from coati.experience_maker import Experience, NaiveExperienceMaker +from coati.models.base import Actor, Critic +from coati.models.generation_utils import update_model_kwargs_fn +from coati.models.loss import PolicyLoss, ValueLoss +from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy, Strategy +from coati.trainer.callbacks import Callback + +from colossalai.nn.optimizer import HybridAdam + +import ray + + +from .utils import is_rank_0, get_cuda_actor_critic_from_args, get_strategy_from_args, set_dist_env +from .detached_trainer_base import DetachedTrainer + + +@ray.remote(concurrency_groups={"buffer_length": 1, "buffer_append":1, "buffer_sample":1,"model_io": 1, "compute": 1}) +class DetachedPPOTrainer(DetachedTrainer): + ''' + Detached Trainer for PPO algorithm + Args: + strategy (Strategy): the strategy to use for training + model (str) : for actor / critic init + pretrained (str) : for actor / critic init + lora_rank (int) : for actor / critic init + train_batch_size (int, defaults to 8): the batch size to use for training + train_batch_size (int, defaults to 8): the batch size to use for training + buffer_limit (int, defaults to 0): the max_size limitation of replay buffer + buffer_cpu_offload (bool, defaults to True): whether to offload replay buffer to cpu + eps_clip (float, defaults to 0.2): the clip coefficient of policy loss + value_clip (float, defaults to 0.4): the clip coefficient of value loss + experience_batch_size (int, defaults to 8): the batch size to use for experience generation + max_epochs (int, defaults to 1): the number of epochs of training process + dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader + callbacks (List[Callback], defaults to []): the callbacks to call during training process + generate_kwargs (dict, optional): the kwargs to use while model generating + ''' + + def __init__(self, + experience_maker_holder_name_list: List[str], + strategy: str, + model: str, + env_info: Dict[str, str] = None, + pretrained: str = None, + lora_rank: int = 0, + train_batch_size: int = 8, + buffer_limit: int = 0, + buffer_cpu_offload: bool = True, + eps_clip: float = 0.2, + value_clip: float = 0.4, + experience_batch_size: int = 8, + max_epochs: int = 10, + dataloader_pin_memory: bool = True, + callbacks: List[Callback] = [], + **generate_kwargs) -> None: + # set environment variables + if env_info: + set_dist_env(env_info=env_info) + # configure strategy + self.strategy = get_strategy_from_args(strategy) + # configure models, loss and optimizers + with self.strategy.model_init_context(): + self.actor, self.critic = get_cuda_actor_critic_from_args(model, pretrained, lora_rank) + + if strategy != 'colossalai_gemini': + self.actor.to(torch.float16).to(torch.cuda.current_device()) + self.critic.to(torch.float16).to(torch.cuda.current_device()) + + if strategy.startswith('colossalai'): + self.actor_optim = HybridAdam(self.actor.parameters(), lr=5e-6) + self.critic_optim = HybridAdam(self.critic.parameters(), lr=5e-6) + else: + self.actor_optim = Adam(self.actor.parameters(), lr=5e-6) + self.critic_optim = Adam(self.critic.parameters(), lr=5e-6) + + (self.actor, self.actor_optim), (self.critic, self.critic_optim) = \ + self.strategy.prepare((self.actor, self.actor_optim), (self.critic, self.critic_optim)) + generate_kwargs = _set_default_generate_kwargs(self.strategy, generate_kwargs, self.actor) + + self.actor_loss_fn = PolicyLoss(eps_clip) + self.critic_loss_fn = ValueLoss(value_clip) + + super().__init__(experience_maker_holder_name_list, + train_batch_size=train_batch_size, + buffer_limit=buffer_limit, + buffer_cpu_offload=buffer_cpu_offload, + experience_batch_size=experience_batch_size, + max_epochs=max_epochs, + dataloader_pin_memory=dataloader_pin_memory, + callbacks=callbacks, + **generate_kwargs) + + @ray.method(concurrency_group="model_io") + def _update_remote_makers(self): + # TODO: balance duties + if is_rank_0(): + self.update_target_holder_list(self.target_holder_name_list) + for target_holder in self.target_holder_list: + # TODO: reduce malloc + with torch.no_grad(): + ray.get(target_holder.update_experience_maker.remote(self._get_unwrapped_actor(), self._get_unwrapped_critic())) + + @ray.method(concurrency_group="model_io") + def initialize_remote_makers(self): + # TODO: balance duties + if is_rank_0(): + self.update_target_holder_list(self.target_holder_name_list) + for target_holder in self.target_holder_list: + # TODO: reduce malloc + with torch.no_grad(): + ray.get(target_holder.initialize_experience_maker.remote(self._get_unwrapped_actor(), self._get_unwrapped_critic())) + + @ray.method(concurrency_group="compute") + def training_step(self, experience: Experience) -> Dict[str, float]: + self.actor.train() + self.critic.train() + + experience.to_device(torch.cuda.current_device()) + num_actions = experience.action_mask.size(1) + action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask) + actor_loss = self.actor_loss_fn(action_log_probs, + experience.action_log_probs, + experience.advantages, + action_mask=experience.action_mask) + self.strategy.backward(actor_loss, self.actor, self.actor_optim) + self.strategy.optimizer_step(self.actor_optim) + self.actor_optim.zero_grad() + + values = self.critic(experience.sequences, + action_mask=experience.action_mask, + attention_mask=experience.attention_mask) + critic_loss = self.critic_loss_fn(values, + experience.values, + experience.reward, + action_mask=experience.action_mask) + + self.strategy.backward(critic_loss, self.critic, self.critic_optim) + self.strategy.optimizer_step(self.critic_optim) + self.critic_optim.zero_grad() + return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()} + + def strategy_save_actor(self, path: str, only_rank0: bool = False) -> None: + self.strategy.save_model(self.actor, path, only_rank0) + + def strategy_save_critic(self, path: str, only_rank0: bool = False) -> None: + self.strategy.save_model(self.critic, path, only_rank0) + + def strategy_save_actor_optim(self, path: str, only_rank0: bool = False) -> None: + self.strategy.save_optimizer(self.actor_optim, path, only_rank0) + + def strategy_save_critic_optim(self, path: str, only_rank0: bool = False) -> None: + self.strategy.save_optimizer(self.critic_optim, path, only_rank0) + + def _get_unwrapped_actor(self): + if False: + pass + elif isinstance(self.strategy, ColossalAIStrategy): + ret = Actor(self.strategy._unwrap_model(self.actor)) + return ret + elif isinstance(self.strategy, DDPStrategy): + return Actor(self.strategy._unwrap_actor(self.actor)) + elif isinstance(self.strategy, NaiveStrategy): + return self.actor + + def _get_unwrapped_critic(self): + if False: + pass + elif isinstance(self.strategy, ColossalAIStrategy): + ret = self.strategy._unwrap_model(self.critic) + return ret + elif isinstance(self.strategy, DDPStrategy): + return self.critic.module + elif isinstance(self.strategy, NaiveStrategy): + return self.critic + + +def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None: + origin_model = strategy._unwrap_actor(actor) + new_kwargs = {**generate_kwargs} + # use huggingface models method directly + if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'): + new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation + + if 'update_model_kwargs_fn' not in generate_kwargs: + new_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn + + return new_kwargs + \ No newline at end of file diff --git a/applications/Chat/coati/ray/src/experience_maker_holder.py b/applications/Chat/coati/ray/src/experience_maker_holder.py new file mode 100644 index 0000000000000000000000000000000000000000..94e4a3d537a57733ad8dbf510e940ef6c9129f6d --- /dev/null +++ b/applications/Chat/coati/ray/src/experience_maker_holder.py @@ -0,0 +1,172 @@ +import torch +from typing import Any, Callable, Dict, List, Optional, Union +import ray +from ray.exceptions import GetTimeoutError +from torch import Tensor +import torch.nn as nn +from coati.models.base import Actor, Critic, RewardModel +from coati.trainer.strategies.sampler import DistributedSampler +from coati.trainer.strategies import Strategy +from coati.experience_maker import NaiveExperienceMaker, Experience, ExperienceMaker + +from copy import deepcopy +from threading import Lock +import time +import os + + +from .utils import is_rank_0, get_strategy_from_args, set_dist_env + + +@ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1}) +class ExperienceMakerHolder: + ''' + Args: + detached_trainer_name_list: str list to get ray actor handleskkk + strategy: + experience_batch_size: batch size of generated experience + kl_coef: the coefficient of kl divergence loss + ''' + + def __init__(self, + detached_trainer_name_list: List[str], + strategy: str, + env_info: Dict[str, str] = None, + experience_batch_size: int = 8, + kl_coef: float = 0.1, + **generate_kwargs): + # set environment variables + if env_info: + set_dist_env(env_info=env_info) + self.target_trainer_list = [] + for name in detached_trainer_name_list: + self.target_trainer_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"])) + self.strategy_str = strategy + self.strategy = get_strategy_from_args(strategy) + self.experience_batch_size = experience_batch_size + self.kl_coef = kl_coef + self.generate_kwargs = generate_kwargs + # Need a trainer to give an actor and a critic via initialize_experience_maker(...) + actor, critic, reward_model, initial_model = None, None, None, None + self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, self.kl_coef) + self._model_visit_lock = Lock() + self.fully_initialized = False + if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: + print('[maker] Waiting for INIT') + + def _get_ready(self): + while not self.fully_initialized: + time.sleep(1.0) + + def update_target_trainer_list(self, detached_trainer_name_list): + self.target_trainer_list = [] + for name in detached_trainer_name_list: + self.target_trainer_list.append(ray.get_actor(name)) + + # copy from ../trainer/base.py + @ray.method(concurrency_group="compute") + def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience: + self._get_ready() + if isinstance(inputs, Tensor): + return self.experience_maker.make_experience(inputs, **self.generate_kwargs) + elif isinstance(inputs, dict): + return self.experience_maker.make_experience(**inputs, **self.generate_kwargs) + else: + raise ValueError(f'Unsupported input type "{type(inputs)}"') + + @ray.method(concurrency_group="experience_io") + def _send_experience(self, experience): + ''' + ignore it + + # choose a trainer that has the least experience batch in its detached_replay_buffer + chosen_trainer = None + min_length = None + if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: + print("[maker] choosing target trainer") + while chosen_trainer is None: + for target_trainer in self.target_trainer_list: + try: + temp_length = ray.get(target_trainer.buffer_get_length.remote(), timeout=0.1) + if min_length is None: + min_length = temp_length + chosen_trainer = target_trainer + else: + if temp_length < min_length: + min_length = temp_length + chosen_trainer = target_trainer + except GetTimeoutError: + pass + + if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: + print(f"[maker] sending exp to {chosen_trainer}") + chosen_trainer.buffer_append.remote(experience) + ''' + # + if not hasattr(self, "_target_idx"): + self._target_idx = 0 + chosen_trainer = self.target_trainer_list[self._target_idx] + if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: + print(f"[maker] sending exp to {chosen_trainer}") + chosen_trainer.buffer_append.remote(experience) + self._target_idx = (self._target_idx + 1) % len(self.target_trainer_list) + + def workingloop(self, dataset, tokenizer: Optional[Callable[[Any], dict]] = None, times=5000 * 50000): + self._get_ready() + sampler = self.strategy.setup_sampler(dataset) + for _ in range(times): + rand_prompts = sampler.sample(self.experience_batch_size) + if tokenizer is not None: + inputs = tokenizer(rand_prompts) + else: + inputs = rand_prompts + self._model_visit_lock.acquire() + experience = self._make_experience(inputs=inputs) + self._model_visit_lock.release() + self._send_experience(experience=experience) + + @ray.method(concurrency_group="model_io") + def initialize_experience_maker(self, init_actor: Actor, init_critic: Critic): + ''' + called by trainer. Only once. + ''' + # TODO: reduce malloc + if self.fully_initialized: + return + if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: + print('[maker] INIT') + with torch.no_grad(): + with self.strategy.model_init_context(): + actor = init_actor + critic = init_critic + initial_model = deepcopy(actor) + reward_model = RewardModel(deepcopy(critic.model), + deepcopy(critic.value_head)).to(torch.cuda.current_device()) + if self.strategy_str != 'colossalai_gemini': + actor.to(torch.float16).to(torch.cuda.current_device()) + critic.to(torch.float16).to(torch.cuda.current_device()) + initial_model.to(torch.float16).to(torch.cuda.current_device()) + reward_model.to(torch.float16).to(torch.cuda.current_device()) + + self.experience_maker.actor = self.strategy.prepare(actor) + self.experience_maker.critic = self.strategy.prepare(critic) + self.experience_maker.initial_model = self.strategy.prepare(initial_model) + self.experience_maker.reward_model = self.strategy.prepare(reward_model) + self.fully_initialized = True + + @ray.method(concurrency_group="model_io") + def update_experience_maker(self, new_actor: Actor, new_critic: Critic): + ''' + called by trainer + ''' + # TODO: reduce malloc + self._model_visit_lock.acquire() + with torch.no_grad(): + if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: + print("[maker] UPDATE ") + if self.strategy_str != 'colossalai_gemini': + new_actor.to(torch.float16).to(torch.cuda.current_device()) + new_critic.to(torch.float16).to(torch.cuda.current_device()) + self.experience_maker.actor = self.strategy.prepare(new_actor) + self.experience_maker.critic = self.strategy.prepare(new_critic) + self._model_visit_lock.release() diff --git a/applications/Chat/coati/ray/src/pipeline_strategy.py b/applications/Chat/coati/ray/src/pipeline_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..1780839c62ee3477ce84cd412ccf7ac60ab57357 --- /dev/null +++ b/applications/Chat/coati/ray/src/pipeline_strategy.py @@ -0,0 +1,105 @@ +# WIP + + +from coati.trainer.strategies import Strategy +from coati.trainer.strategies import NaiveStrategy +from coati.models.base import Actor, RewardModel, Critic + +import numpy as np +import torch +from torch._C._distributed_rpc import _is_current_rpc_agent_set + +import colossalai +from colossalai.pipeline.pipeline_process_group import ppg +from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine +from colossalai.fx import ColoTracer +from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass +from colossalai.pipeline.middleware.adaptor import get_fx_topology + + +import os +from functools import partial +import random + +rpc_is_initialized = _is_current_rpc_agent_set + +class PipelineModel(torch.nn.Module): + ''' + Actor has 2 kinds of jobs: forward and generate. + better to just pipelinize the inner model + ''' + def __init__(self, + model: torch.nn.Module, + stage_num: int, + num_microbatches: int, + data_kwargs = None, + ): + super().__init__() + # create partition module + def create_partition_module(pp_rank:int, stage_num: int, model, data_kwargs): + model.eval() + tracer = ColoTracer() + meta_args = {k: v.to('meta') for k, v in data_kwargs.items()} + graph = tracer.trace(root=model, meta_args=meta_args) + gm = torch.fx.GraphModule(model, graph, model.__class__.__name__) + annotated_model = balanced_split_pass(gm, stage_num) + top_module, split_submodules = split_with_split_nodes_pass(annotated_model, merge_output=True) + topo = get_fx_topology(top_module) + for submodule in split_submodules: + if isinstance(submodule, torch.fx.GraphModule): + setattr(submodule, '_topo', topo) + return split_submodules[pp_rank + 1] + + def partition(model, data_kwargs: dict, pp_rank: int, chunk: int, stage_num: int): + partition = create_partition_module(pp_rank, stage_num, model, data_kwargs) + return partition + self.inference_engine = OneFOneBPipelineEngine( + partition_fn=partial(partition, model, data_kwargs), + stage_num=stage_num, + num_microbatches=num_microbatches, + device='cuda', + ) + + def forward(self, + **model_inputs): + return self.inference_engine.forward_backward(**model_inputs, forward_only=True) + + + +class PPStrategy(NaiveStrategy): + """ + Strategy for Pipeline inference (inference only!) + + master node only + """ + def __init__( + self, + seed: int = 42 + ): + self.seed = seed + super().__init__() + + + def setup_distributed(self) -> None: + colossalai.launch_from_torch({}, seed=self.seed) + ppg.set_global_info(rank = int(os.environ['RANK']), + world_size=int(os.environ['WORLD_SIZE']), + dp_degree=1, + tp_degree=1, + num_worker_threads=128, + device="cuda") + + def model_init_context(self): + return super().model_init_context() + + def setup_model(self, model: torch.nn.Module) -> torch.nn.Module: + if isinstance(model, Actor) or \ + isinstance(model, RewardModel) or \ + isinstance(model, Critic): + model.model = PipelineModel(model.model) + + def set_seed(self, seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + diff --git a/applications/Chat/coati/ray/src/utils.py b/applications/Chat/coati/ray/src/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c750879b6d187b90bef14c15ddf50cd9ca33f2a5 --- /dev/null +++ b/applications/Chat/coati/ray/src/utils.py @@ -0,0 +1,48 @@ +import torch.distributed as dist +from typing import Any, Callable, Dict, List, Optional +from coati.models.bloom import BLOOMActor, BLOOMCritic +from coati.models.gpt import GPTActor, GPTCritic +from coati.models.opt import OPTActor, OPTCritic +from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy +import torch +import os + +def is_rank_0() -> bool: + return not dist.is_initialized() or dist.get_rank() == 0 + + +def get_cuda_actor_critic_from_args(model: str, pretrained: str = None, lora_rank=0): + if model == 'gpt2': + actor = GPTActor(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device()) + critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device()) + elif model == 'bloom': + actor = BLOOMActor(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device()) + critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device()) + elif model == 'opt': + actor = OPTActor(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device()) + critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device()) + else: + raise ValueError(f'Unsupported model "{model}"') + return actor, critic + + +def get_strategy_from_args(strategy: str): + if strategy == 'naive': + strategy_ = NaiveStrategy() + elif strategy == 'ddp': + strategy_ = DDPStrategy() + elif strategy == 'colossalai_gemini': + strategy_ = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5) + elif strategy == 'colossalai_zero2': + strategy_ = ColossalAIStrategy(stage=2, placement_policy='cuda') + else: + raise ValueError(f'Unsupported strategy "{strategy}"') + return strategy_ + + +def set_dist_env(env_info: Dict[str, str]): + os.environ["RANK"] = env_info['rank'] + os.environ["LOCAL_RANK"] = env_info['local_rank'] + os.environ["WORLD_SIZE"] = env_info['world_size'] + os.environ['MASTER_PORT'] = env_info['master_port'] + os.environ['MASTER_ADDR'] = env_info['master_addr'] diff --git a/applications/Chat/coati/replay_buffer/__init__.py b/applications/Chat/coati/replay_buffer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1ebf60382913ead7247197a6ae7b021ceb7e5d71 --- /dev/null +++ b/applications/Chat/coati/replay_buffer/__init__.py @@ -0,0 +1,4 @@ +from .base import ReplayBuffer +from .naive import NaiveReplayBuffer + +__all__ = ['ReplayBuffer', 'NaiveReplayBuffer'] diff --git a/applications/Chat/coati/replay_buffer/base.py b/applications/Chat/coati/replay_buffer/base.py new file mode 100644 index 0000000000000000000000000000000000000000..4c3812461a10120c358d4ddbccdffde4eff7fdfa --- /dev/null +++ b/applications/Chat/coati/replay_buffer/base.py @@ -0,0 +1,43 @@ +from abc import ABC, abstractmethod +from typing import Any + +from coati.experience_maker.base import Experience + + +class ReplayBuffer(ABC): + """Replay buffer base class. It stores experience. + + Args: + sample_batch_size (int): Batch size when sampling. + limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0. + """ + + def __init__(self, sample_batch_size: int, limit: int = 0) -> None: + super().__init__() + self.sample_batch_size = sample_batch_size + # limit <= 0 means unlimited + self.limit = limit + + @abstractmethod + def append(self, experience: Experience) -> None: + pass + + @abstractmethod + def clear(self) -> None: + pass + + @abstractmethod + def sample(self) -> Experience: + pass + + @abstractmethod + def __len__(self) -> int: + pass + + @abstractmethod + def __getitem__(self, idx: int) -> Any: + pass + + @abstractmethod + def collate_fn(self, batch: Any) -> Experience: + pass diff --git a/applications/Chat/coati/replay_buffer/naive.py b/applications/Chat/coati/replay_buffer/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..938f500643c96c7370314bf6e60d3e38d765bc12 --- /dev/null +++ b/applications/Chat/coati/replay_buffer/naive.py @@ -0,0 +1,57 @@ +import random +from typing import List + +import torch +from coati.experience_maker.base import Experience + +from .base import ReplayBuffer +from .utils import BufferItem, make_experience_batch, split_experience_batch + + +class NaiveReplayBuffer(ReplayBuffer): + """Naive replay buffer class. It stores experience. + + Args: + sample_batch_size (int): Batch size when sampling. + limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0. + cpu_offload (bool, optional): Whether to offload experience to cpu when sampling. Defaults to True. + """ + + def __init__(self, sample_batch_size: int, limit: int = 0, cpu_offload: bool = True) -> None: + super().__init__(sample_batch_size, limit) + self.cpu_offload = cpu_offload + self.target_device = torch.device(f'cuda:{torch.cuda.current_device()}') + # TODO(ver217): add prefetch + self.items: List[BufferItem] = [] + + @torch.no_grad() + def append(self, experience: Experience) -> None: + if self.cpu_offload: + experience.to_device(torch.device('cpu')) + items = split_experience_batch(experience) + self.items.extend(items) + if self.limit > 0: + samples_to_remove = len(self.items) - self.limit + if samples_to_remove > 0: + self.items = self.items[samples_to_remove:] + + def clear(self) -> None: + self.items.clear() + + @torch.no_grad() + def sample(self) -> Experience: + items = random.sample(self.items, self.sample_batch_size) + experience = make_experience_batch(items) + if self.cpu_offload: + experience.to_device(self.target_device) + return experience + + def __len__(self) -> int: + return len(self.items) + + def __getitem__(self, idx: int) -> BufferItem: + return self.items[idx] + + def collate_fn(self, batch) -> Experience: + experience = make_experience_batch(batch) + return experience diff --git a/applications/Chat/coati/replay_buffer/utils.py b/applications/Chat/coati/replay_buffer/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6ad0db2c3b609e0a3cfa4b1126550ebb44dec6a2 --- /dev/null +++ b/applications/Chat/coati/replay_buffer/utils.py @@ -0,0 +1,73 @@ +from dataclasses import dataclass +from typing import List, Optional + +import torch +import torch.nn.functional as F +from coati.experience_maker.base import Experience + + +@dataclass +class BufferItem: + """BufferItem is an item of experience data. + + Shapes of each tensor: + sequences: (S) + action_log_probs: (A) + values: (1) + reward: (1) + advantages: (1) + attention_mask: (S) + action_mask: (A) + + "A" is the number of actions. + """ + sequences: torch.Tensor + action_log_probs: torch.Tensor + values: torch.Tensor + reward: torch.Tensor + advantages: torch.Tensor + attention_mask: Optional[torch.LongTensor] + action_mask: Optional[torch.BoolTensor] + + +def split_experience_batch(experience: Experience) -> List[BufferItem]: + batch_size = experience.sequences.size(0) + batch_kwargs = [{} for _ in range(batch_size)] + keys = ('sequences', 'action_log_probs', 'values', 'reward', 'advantages', 'attention_mask', 'action_mask') + for key in keys: + value = getattr(experience, key) + if isinstance(value, torch.Tensor): + vals = torch.unbind(value) + else: + # None + vals = [value for _ in range(batch_size)] + assert batch_size == len(vals) + for i, v in enumerate(vals): + batch_kwargs[i][key] = v + items = [BufferItem(**kwargs) for kwargs in batch_kwargs] + return items + + +def zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> torch.Tensor: + assert side in ('left', 'right') + max_len = max(seq.size(0) for seq in sequences) + padded_sequences = [] + for seq in sequences: + pad_len = max_len - seq.size(0) + padding = (pad_len, 0) if side == 'left' else (0, pad_len) + padded_sequences.append(F.pad(seq, padding)) + return torch.stack(padded_sequences, dim=0) + + +def make_experience_batch(items: List[BufferItem]) -> Experience: + kwargs = {} + to_pad_keys = set(('action_log_probs', 'action_mask')) + keys = ('sequences', 'action_log_probs', 'values', 'reward', 'advantages', 'attention_mask', 'action_mask') + for key in keys: + vals = [getattr(item, key) for item in items] + if key in to_pad_keys: + batch_data = zero_pad_sequences(vals) + else: + batch_data = torch.stack(vals, dim=0) + kwargs[key] = batch_data + return Experience(**kwargs) diff --git a/applications/Chat/coati/trainer/__init__.py b/applications/Chat/coati/trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..525b57bf21d351e6522102a8d39ff5eb1aec5250 --- /dev/null +++ b/applications/Chat/coati/trainer/__init__.py @@ -0,0 +1,6 @@ +from .base import Trainer +from .ppo import PPOTrainer +from .rm import RewardModelTrainer +from .sft import SFTTrainer + +__all__ = ['Trainer', 'PPOTrainer', 'RewardModelTrainer', 'SFTTrainer'] diff --git a/applications/Chat/coati/trainer/base.py b/applications/Chat/coati/trainer/base.py new file mode 100644 index 0000000000000000000000000000000000000000..ac3a878be88430f94af1bf19bcde4aaf1d0ec7f2 --- /dev/null +++ b/applications/Chat/coati/trainer/base.py @@ -0,0 +1,75 @@ +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from coati.experience_maker import Experience + +from .callbacks import Callback +from .strategies import Strategy + + +class Trainer(ABC): + """ + Base class for rlhf trainers. + + Args: + strategy (Strategy):the strategy to use for training + max_epochs (int, defaults to 1): the number of epochs of training process + dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader + callbacks (List[Callback], defaults to []): the callbacks to call during training process + generate_kwargs (dict, optional): the kwargs to use while model generating + """ + + def __init__(self, + strategy: Strategy, + max_epochs: int = 1, + dataloader_pin_memory: bool = True, + callbacks: List[Callback] = [], + **generate_kwargs) -> None: + super().__init__() + self.strategy = strategy + self.max_epochs = max_epochs + self.generate_kwargs = generate_kwargs + self.dataloader_pin_memory = dataloader_pin_memory + self.callbacks = callbacks + + # TODO(ver217): maybe simplify these code using context + def _on_fit_start(self) -> None: + for callback in self.callbacks: + callback.on_fit_start() + + def _on_fit_end(self) -> None: + for callback in self.callbacks: + callback.on_fit_end() + + def _on_episode_start(self, episode: int) -> None: + for callback in self.callbacks: + callback.on_episode_start(episode) + + def _on_episode_end(self, episode: int) -> None: + for callback in self.callbacks: + callback.on_episode_end(episode) + + def _on_make_experience_start(self) -> None: + for callback in self.callbacks: + callback.on_make_experience_start() + + def _on_make_experience_end(self, experience: Experience) -> None: + for callback in self.callbacks: + callback.on_make_experience_end(experience) + + def _on_learn_epoch_start(self, epoch: int) -> None: + for callback in self.callbacks: + callback.on_learn_epoch_start(epoch) + + def _on_learn_epoch_end(self, epoch: int) -> None: + for callback in self.callbacks: + callback.on_learn_epoch_end(epoch) + + def _on_learn_batch_start(self) -> None: + for callback in self.callbacks: + callback.on_learn_batch_start() + + def _on_learn_batch_end(self, metrics: dict, experience: Experience) -> None: + for callback in self.callbacks: + callback.on_learn_batch_end(metrics, experience) diff --git a/applications/Chat/coati/trainer/callbacks/__init__.py b/applications/Chat/coati/trainer/callbacks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9ed0ee6f764002fd085f0ecdf957f2dba4d576c8 --- /dev/null +++ b/applications/Chat/coati/trainer/callbacks/__init__.py @@ -0,0 +1,5 @@ +from .base import Callback +from .performance_evaluator import PerformanceEvaluator +from .save_checkpoint import SaveCheckpoint + +__all__ = ['Callback', 'PerformanceEvaluator', 'SaveCheckpoint'] diff --git a/applications/Chat/coati/trainer/callbacks/base.py b/applications/Chat/coati/trainer/callbacks/base.py new file mode 100644 index 0000000000000000000000000000000000000000..f5616048855b26ceac836e812d1ecdee8035f025 --- /dev/null +++ b/applications/Chat/coati/trainer/callbacks/base.py @@ -0,0 +1,39 @@ +from abc import ABC + +from coati.experience_maker import Experience + + +class Callback(ABC): + """ + Base callback class. It defines the interface for callbacks. + """ + + def on_fit_start(self) -> None: + pass + + def on_fit_end(self) -> None: + pass + + def on_episode_start(self, episode: int) -> None: + pass + + def on_episode_end(self, episode: int) -> None: + pass + + def on_make_experience_start(self) -> None: + pass + + def on_make_experience_end(self, experience: Experience) -> None: + pass + + def on_learn_epoch_start(self, epoch: int) -> None: + pass + + def on_learn_epoch_end(self, epoch: int) -> None: + pass + + def on_learn_batch_start(self) -> None: + pass + + def on_learn_batch_end(self, metrics: dict, experience: Experience) -> None: + pass diff --git a/applications/Chat/coati/trainer/callbacks/performance_evaluator.py b/applications/Chat/coati/trainer/callbacks/performance_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..925455444597095d67330698cffaca151f8c5e0b --- /dev/null +++ b/applications/Chat/coati/trainer/callbacks/performance_evaluator.py @@ -0,0 +1,183 @@ +from time import time +from typing import Optional + +import torch +import torch.distributed as dist +from coati.experience_maker import Experience + +from .base import Callback + + +def get_world_size() -> int: + if dist.is_initialized(): + return dist.get_world_size() + return 1 + + +def print_rank_0(*args, **kwargs) -> None: + if not dist.is_initialized() or dist.get_rank() == 0: + print(*args, **kwargs) + + +def divide(x: float, y: float) -> float: + if y == 0: + return float('inf') + elif y == float('inf'): + return float('nan') + return x / y + + +@torch.no_grad() +def all_reduce_mean(x: float, world_size: int) -> float: + if world_size == 1: + return x + tensor = torch.tensor([x], device=torch.cuda.current_device()) + dist.all_reduce(tensor) + tensor = tensor / world_size + return tensor.item() + + +class Timer: + + def __init__(self) -> None: + self.start_time: Optional[float] = None + self.duration: float = 0. + + def start(self) -> None: + self.start_time = time() + + def end(self) -> None: + assert self.start_time is not None + self.duration += time() - self.start_time + self.start_time = None + + def reset(self) -> None: + self.duration = 0. + + +class PerformanceEvaluator(Callback): + """ + Callback for valuate the performance of the model. + Args: + actor_num_params: The number of parameters of the actor model. + critic_num_params: The number of parameters of the critic model. + initial_model_num_params: The number of parameters of the initial model. + reward_model_num_params: The number of parameters of the reward model. + enable_grad_checkpoint: Whether to enable gradient checkpointing. + ignore_episodes: The number of episodes to ignore when calculating the performance. + """ + + def __init__(self, + actor_num_params: int, + critic_num_params: int, + initial_model_num_params: int, + reward_model_num_params: int, + enable_grad_checkpoint: bool = False, + ignore_episodes: int = 0) -> None: + super().__init__() + self.world_size = get_world_size() + self.actor_num_params = actor_num_params + self.critic_num_params = critic_num_params + self.initial_model_num_params = initial_model_num_params + self.reward_model_num_params = reward_model_num_params + self.enable_grad_checkpoint = enable_grad_checkpoint + self.ignore_episodes = ignore_episodes + self.disable: bool = False + + self.overall_timer = Timer() + self.make_experience_timer = Timer() + self.learn_timer = Timer() + self.make_experience_num_samples: int = 0 + self.make_experience_flop: int = 0 + self.learn_num_samples: int = 0 + self.learn_flop: int = 0 + + def on_episode_start(self, episode: int) -> None: + self.disable = self.ignore_episodes > 0 and episode < self.ignore_episodes + if self.disable: + return + self.overall_timer.start() + + def on_episode_end(self, episode: int) -> None: + if self.disable: + return + self.overall_timer.end() + + def on_make_experience_start(self) -> None: + if self.disable: + return + self.make_experience_timer.start() + + def on_make_experience_end(self, experience: Experience) -> None: + if self.disable: + return + self.make_experience_timer.end() + + batch_size, seq_len = experience.sequences.shape + + self.make_experience_num_samples += batch_size + + # actor generate + num_actions = experience.action_mask.size(1) + input_len = seq_len - num_actions + total_seq_len = (input_len + seq_len - 1) * num_actions / 2 + self.make_experience_flop += self.actor_num_params * batch_size * total_seq_len * 2 + # actor forward + self.make_experience_flop += self.actor_num_params * batch_size * seq_len * 2 + # critic forward + self.make_experience_flop += self.critic_num_params * batch_size * seq_len * 2 + # initial model forward + self.make_experience_flop += self.initial_model_num_params * batch_size * seq_len * 2 + # reward model forward + self.make_experience_flop += self.reward_model_num_params * batch_size * seq_len * 2 + + def on_learn_batch_start(self) -> None: + if self.disable: + return + self.learn_timer.start() + + def on_learn_batch_end(self, metrics: dict, experience: Experience) -> None: + if self.disable: + return + self.learn_timer.end() + + batch_size, seq_len = experience.sequences.shape + + self.learn_num_samples += batch_size + + # actor forward-backward, 3 means forward(1) + backward(2) + self.learn_flop += self.actor_num_params * batch_size * seq_len * 2 * (3 + int(self.enable_grad_checkpoint)) + # critic forward-backward + self.learn_flop += self.critic_num_params * batch_size * seq_len * 2 * (3 + int(self.enable_grad_checkpoint)) + + def on_fit_end(self) -> None: + avg_make_experience_duration = all_reduce_mean(self.make_experience_timer.duration, self.world_size) + avg_learn_duration = all_reduce_mean(self.learn_timer.duration, self.world_size) + avg_overall_duration = all_reduce_mean(self.overall_timer.duration, self.world_size) + + avg_make_experience_throughput = self.make_experience_num_samples * \ + self.world_size / (avg_make_experience_duration + 1e-12) + avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12) + + avg_learn_throughput = self.learn_num_samples * self.world_size / (avg_learn_duration + 1e-12) + avg_learn_tflops = self.learn_flop / 1e12 / (avg_learn_duration + 1e-12) + + num_effective_samples = min(self.learn_num_samples, self.make_experience_num_samples) * self.world_size + + avg_overall_throughput = num_effective_samples / (avg_overall_duration + 1e-12) + + overall_time_per_sample = divide(1, avg_overall_throughput) + make_experience_time_per_sample = divide(avg_make_experience_duration, num_effective_samples) + learn_time_per_sample = divide(avg_learn_duration, num_effective_samples) + + print_rank_0( + f'Performance summary:\n' + + f'Generate {self.make_experience_num_samples * self.world_size} samples, throughput: {avg_make_experience_throughput:.2f} samples/s, TFLOPS per GPU: {avg_make_experience_tflops:.2f}\n' + + + f'Train {self.learn_num_samples * self.world_size} samples, throughput: {avg_learn_throughput:.2f} samples/s, TFLOPS per GPU: {avg_learn_tflops:.2f}\n' + + f'Overall throughput: {avg_overall_throughput:.2f} samples/s\n' + + f'Overall time per sample: {overall_time_per_sample:.2f} s\n' + + f'Make experience time per sample: {make_experience_time_per_sample:.2f} s, {make_experience_time_per_sample/overall_time_per_sample*100:.2f}%\n' + + + f'Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%' + ) diff --git a/applications/Chat/coati/trainer/callbacks/save_checkpoint.py b/applications/Chat/coati/trainer/callbacks/save_checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..d2dcc0dd4c65f0aae86a87b188cc2801fa0e5281 --- /dev/null +++ b/applications/Chat/coati/trainer/callbacks/save_checkpoint.py @@ -0,0 +1,75 @@ +import os + +import torch.distributed as dist +from coati.trainer.strategies import ColossalAIStrategy, Strategy +from coati.trainer.utils import is_rank_0 +from torch import nn +from torch.optim import Optimizer + +from .base import Callback + + +class SaveCheckpoint(Callback): + """ + The callback for saving checkpoint for coati. + + Only support saving actor and critic model. + A typical architecture of the saved checkpoint would be: + - checkpoint + - episode_x + - actor.pt + - actor-optim-rank-0.pt + - actor-optim-rank-1.pt + - critic.pt + - critic-optim-rank-0.pt + - critic-optim-rank-1.pt + - ... + + Args: + path(str): the base path you want to save checkpoint, the checkpoint would be saved at `path/checkpoint` + interval(int): the interval episode of saving checkpoint + strategy(Strategy): the strategy used to train + actor(nn.Module): the actor model + critic(nn.Module): the critic model + actor_optim(Optimizer): the optimizer of actor + critic_optim(Optimizer): the optimizer of critic + + """ + + def __init__(self, + path: str, + interval: int, + strategy: Strategy, + actor: nn.Module = None, + critic: nn.Module = None, + actor_optim: Optimizer = None, + critic_optim: Optimizer = None) -> None: + super().__init__() + self.path = os.path.join(path, 'checkpoint') + self.interval = interval + self.strategy = strategy + self.model_dict = {'actor': [actor, actor_optim], 'critic': [critic, critic_optim]} + + def on_episode_end(self, episode: int) -> None: + if (episode + 1) % self.interval != 0: + return + base_path = os.path.join(self.path, f'episode_{episode}') + if not os.path.exists(base_path): + os.makedirs(base_path) + + for model in self.model_dict.keys(): + + # save model + if self.model_dict[model][0] is None: + # saving only optimizer states is meaningless, so it would be skipped + continue + model_path = os.path.join(base_path, f'{model}.pt') + self.strategy.save_model(model=self.model_dict[model][0], path=model_path, only_rank0=True) + + # save optimizer + if self.model_dict[model][1] is None: + continue + only_rank0 = not isinstance(self.strategy, ColossalAIStrategy) + rank = 0 if is_rank_0() else dist.get_rank() + optim_path = os.path.join(base_path, f'{model}-optim-rank-{rank}.pt') + self.strategy.save_optimizer(optimizer=self.model_dict[model][1], path=optim_path, only_rank0=only_rank0) diff --git a/applications/Chat/coati/trainer/ppo.py b/applications/Chat/coati/trainer/ppo.py new file mode 100644 index 0000000000000000000000000000000000000000..fe5ae48d9c2f05af14235a0106c43e74959e4c11 --- /dev/null +++ b/applications/Chat/coati/trainer/ppo.py @@ -0,0 +1,213 @@ +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +import torch.nn as nn +from coati.experience_maker import Experience, NaiveExperienceMaker +from coati.models.base import Actor, Critic +from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss +from coati.replay_buffer import NaiveReplayBuffer +from torch import Tensor +from torch.optim import Optimizer +from torch.utils.data import DistributedSampler +from tqdm import tqdm +from transformers.tokenization_utils_base import PreTrainedTokenizerBase + +from colossalai.utils import get_current_device + +from .base import Trainer +from .callbacks import Callback +from .strategies import Strategy +from .utils import is_rank_0, to_device + + +class PPOTrainer(Trainer): + """ + Trainer for PPO algorithm. + + Args: + strategy (Strategy): the strategy to use for training + actor (Actor): the actor model in ppo algorithm + critic (Critic): the critic model in ppo algorithm + reward_model (nn.Module): the reward model in rlhf algorithm to make reward of sentences + initial_model (Actor): the initial model in rlhf algorithm to generate reference logics to limit the update of actor + actor_optim (Optimizer): the optimizer to use for actor model + critic_optim (Optimizer): the optimizer to use for critic model + kl_coef (float, defaults to 0.1): the coefficient of kl divergence loss + train_batch_size (int, defaults to 8): the batch size to use for training + buffer_limit (int, defaults to 0): the max_size limitation of replay buffer + buffer_cpu_offload (bool, defaults to True): whether to offload replay buffer to cpu + eps_clip (float, defaults to 0.2): the clip coefficient of policy loss + vf_coef (float, defaults to 1.0): the coefficient of value loss + ptx_coef (float, defaults to 0.9): the coefficient of ptx loss + value_clip (float, defaults to 0.4): the clip coefficient of value loss + max_epochs (int, defaults to 1): the number of epochs of training process + sample_replay_buffer (bool, defaults to False): whether to sample from replay buffer + dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader + offload_inference_models (bool, defaults to True): whether to offload inference models to cpu during training process + callbacks (List[Callback], defaults to []): the callbacks to call during training process + generate_kwargs (dict, optional): the kwargs to use while model generating + """ + + def __init__(self, + strategy: Strategy, + actor: Actor, + critic: Critic, + reward_model: nn.Module, + initial_model: Actor, + actor_optim: Optimizer, + critic_optim: Optimizer, + kl_coef: float = 0.1, + ptx_coef: float = 0.9, + train_batch_size: int = 8, + buffer_limit: int = 0, + buffer_cpu_offload: bool = True, + eps_clip: float = 0.2, + vf_coef: float = 1.0, + value_clip: float = 0.4, + max_epochs: int = 1, + sample_replay_buffer: bool = False, + dataloader_pin_memory: bool = True, + offload_inference_models: bool = True, + callbacks: List[Callback] = [], + **generate_kwargs) -> None: + experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef) + replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload) + generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor) + super().__init__(strategy, max_epochs, dataloader_pin_memory, callbacks, **generate_kwargs) + + self.experience_maker = experience_maker + self.replay_buffer = replay_buffer + self.sample_replay_buffer = sample_replay_buffer + self.offload_inference_models = offload_inference_models + + self.actor = actor + self.critic = critic + + self.actor_loss_fn = PolicyLoss(eps_clip) + self.critic_loss_fn = ValueLoss(value_clip) + self.vf_coef = vf_coef + self.ptx_loss_fn = GPTLMLoss() + self.ptx_coef = ptx_coef + self.actor_optim = actor_optim + self.critic_optim = critic_optim + + self.device = get_current_device() + + def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience: + if isinstance(inputs, Tensor): + return self.experience_maker.make_experience(inputs, **self.generate_kwargs) + elif isinstance(inputs, dict): + return self.experience_maker.make_experience(**inputs, **self.generate_kwargs) + else: + raise ValueError(f'Unsupported input type "{type(inputs)}"') + + def _learn(self): + # replay buffer may be empty at first, we should rebuild at each training + if not self.sample_replay_buffer: + dataloader = self.strategy.setup_dataloader(self.replay_buffer, self.dataloader_pin_memory) + if self.sample_replay_buffer: + pbar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0()) + for _ in pbar: + experience = self.replay_buffer.sample() + experience.to_device(self.device) + metrics = self.training_step(experience) + pbar.set_postfix(metrics) + else: + for epoch in range(self.max_epochs): + self._on_learn_epoch_start(epoch) + if isinstance(dataloader.sampler, DistributedSampler): + dataloader.sampler.set_epoch(epoch) + pbar = tqdm(dataloader, desc=f'Train epoch [{epoch+1}/{self.max_epochs}]', disable=not is_rank_0()) + for experience in pbar: + self._on_learn_batch_start() + experience.to_device(self.device) + metrics = self.training_step(experience) + self._on_learn_batch_end(metrics, experience) + pbar.set_postfix(metrics) + self._on_learn_epoch_end(epoch) + + def fit(self, + prompt_dataloader, + pretrain_dataloader, + num_episodes: int = 50000, + max_timesteps: int = 500, + update_timesteps: int = 5000) -> None: + time = 0 + self.pretrain_dataloader = pretrain_dataloader + self.prompt_dataloader = prompt_dataloader + self._on_fit_start() + for episode in range(num_episodes): + self._on_episode_start(episode) + for timestep in tqdm(range(max_timesteps), + desc=f'Episode [{episode+1}/{num_episodes}]', + disable=not is_rank_0()): + time += 1 + prompts = next(iter(self.prompt_dataloader)) + self._on_make_experience_start() + if self.offload_inference_models: + # TODO(ver217): this may be controlled by strategy if they are prepared by strategy + self.experience_maker.initial_model.to(self.device) + self.experience_maker.reward_model.to(self.device) + experience = self._make_experience(prompts) + self._on_make_experience_end(experience) + self.replay_buffer.append(experience) + if time % update_timesteps == 0: + if self.offload_inference_models: + self.experience_maker.initial_model.to('cpu') + self.experience_maker.reward_model.to('cpu') + self._learn() + self.replay_buffer.clear() + self._on_episode_end(episode) + self._on_fit_end() + + def training_step(self, experience: Experience) -> Dict[str, float]: + self.actor.train() + self.critic.train() + # policy loss + num_actions = experience.action_mask.size(1) + action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask) + actor_loss = self.actor_loss_fn(action_log_probs, + experience.action_log_probs, + experience.advantages, + action_mask=experience.action_mask) + + # ptx loss + if self.ptx_coef != 0: + batch = next(iter(self.pretrain_dataloader)) + batch = to_device(batch, self.device) + ptx_log_probs = self.actor.get_base_model()(batch['input_ids'], + attention_mask=batch['attention_mask'])['logits'] + ptx_loss = self.ptx_loss_fn(ptx_log_probs, batch['labels']) + actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef) + + self.strategy.backward(actor_loss, self.actor, self.actor_optim) + self.strategy.optimizer_step(self.actor_optim) + self.actor_optim.zero_grad() + + # value loss + values = self.critic(experience.sequences, + action_mask=experience.action_mask, + attention_mask=experience.attention_mask) + critic_loss = self.critic_loss_fn(values, + experience.values, + experience.reward, + action_mask=experience.action_mask) + critic_loss = critic_loss * self.vf_coef + self.strategy.backward(critic_loss, self.critic, self.critic_optim) + self.strategy.optimizer_step(self.critic_optim) + self.critic_optim.zero_grad() + + return {'reward': experience.reward.mean().item()} + + +def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None: + origin_model = strategy.unwrap_model(actor) + new_kwargs = {**generate_kwargs} + # use huggingface models method directly + if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'): + new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation + + if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(origin_model, '_update_model_kwargs_for_generation'): + new_kwargs['update_model_kwargs_fn'] = origin_model._update_model_kwargs_for_generation + + return new_kwargs diff --git a/applications/Chat/coati/trainer/rm.py b/applications/Chat/coati/trainer/rm.py new file mode 100644 index 0000000000000000000000000000000000000000..cdae5108ab00e3506449d8c837f62b2d594d8d31 --- /dev/null +++ b/applications/Chat/coati/trainer/rm.py @@ -0,0 +1,123 @@ +from datetime import datetime +from typing import List, Optional + +import pandas as pd +import torch +import torch.distributed as dist +from torch.optim import Optimizer, lr_scheduler +from torch.utils.data import DataLoader, Dataset, DistributedSampler +from tqdm import tqdm +from transformers.tokenization_utils_base import PreTrainedTokenizerBase + +from .base import Trainer +from .callbacks import Callback +from .strategies import Strategy +from .utils import is_rank_0 + + +class RewardModelTrainer(Trainer): + """ + Trainer to use while training reward model. + + Args: + model (torch.nn.Module): the model to train + strategy (Strategy): the strategy to use for training + optim(Optimizer): the optimizer to use for training + loss_fn (callable): the loss function to use for training + train_dataloader (DataLoader): the dataloader to use for training + valid_dataloader (DataLoader): the dataloader to use for validation + eval_dataloader (DataLoader): the dataloader to use for evaluation + batch_size (int, defaults to 1): the batch size while training + max_epochs (int, defaults to 2): the number of epochs to train + callbacks (List[Callback], defaults to []): the callbacks to call during training process + """ + + def __init__( + self, + model, + strategy: Strategy, + optim: Optimizer, + loss_fn, + train_dataloader: DataLoader, + valid_dataloader: DataLoader, + eval_dataloader: DataLoader, + max_epochs: int = 1, + callbacks: List[Callback] = [], + ) -> None: + super().__init__(strategy, max_epochs, callbacks=callbacks) + + self.train_dataloader = train_dataloader + self.valid_dataloader = valid_dataloader + self.eval_dataloader = eval_dataloader + + self.model = model + self.loss_fn = loss_fn + self.optimizer = optim + self.scheduler = lr_scheduler.CosineAnnealingLR(self.optimizer, self.train_dataloader.__len__() // 100) + + def eval_acc(self, dataloader): + dist = 0 + on = 0 + cnt = 0 + self.model.eval() + with torch.no_grad(): + for chosen_ids, c_mask, reject_ids, r_mask in dataloader: + chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device()) + c_mask = c_mask.squeeze(1).to(torch.cuda.current_device()) + reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device()) + r_mask = r_mask.squeeze(1).to(torch.cuda.current_device()) + chosen_reward = self.model(chosen_ids, attention_mask=c_mask) + reject_reward = self.model(reject_ids, attention_mask=r_mask) + for i in range(len(chosen_reward)): + cnt += 1 + if chosen_reward[i] > reject_reward[i]: + on += 1 + dist += (chosen_reward - reject_reward).mean().item() + dist_mean = dist / len(dataloader) + acc = on / cnt + self.model.train() + return dist_mean, acc + + def fit(self): + time = datetime.now() + epoch_bar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0()) + for epoch in range(self.max_epochs): + step_bar = tqdm(range(self.train_dataloader.__len__()), + desc='Train step of epoch %d' % epoch, + disable=not is_rank_0()) + # train + self.model.train() + cnt = 0 + acc = 0 + dist = 0 + for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader: + chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device()) + c_mask = c_mask.squeeze(1).to(torch.cuda.current_device()) + reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device()) + r_mask = r_mask.squeeze(1).to(torch.cuda.current_device()) + chosen_reward = self.model(chosen_ids, attention_mask=c_mask) + reject_reward = self.model(reject_ids, attention_mask=r_mask) + loss = self.loss_fn(chosen_reward, reject_reward) + self.strategy.backward(loss, self.model, self.optimizer) + self.strategy.optimizer_step(self.optimizer) + self.optimizer.zero_grad() + cnt += 1 + if cnt == 100: + self.scheduler.step() + dist, acc = self.eval_acc(self.valid_dataloader) + cnt = 0 + if is_rank_0(): + log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]], + columns=['step', 'loss', 'dist', 'acc']) + log.to_csv('log_%s.csv' % time, mode='a', header=False, index=False) + step_bar.update() + step_bar.set_postfix({'dist': dist, 'acc': acc}) + + # eval + dist, acc = self.eval_acc(self.eval_dataloader) + if is_rank_0(): + log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]], columns=['step', 'loss', 'dist', 'acc']) + log.to_csv('log.csv', mode='a', header=False, index=False) + epoch_bar.update() + step_bar.set_postfix({'dist': dist, 'acc': acc}) + step_bar.close() diff --git a/applications/Chat/coati/trainer/sft.py b/applications/Chat/coati/trainer/sft.py new file mode 100644 index 0000000000000000000000000000000000000000..63fde53956ccd387296cdfde6068f93d76d4fd3f --- /dev/null +++ b/applications/Chat/coati/trainer/sft.py @@ -0,0 +1,135 @@ +import math +import time +from typing import List, Optional + +import torch +import torch.distributed as dist +import wandb +from torch.optim import Optimizer +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers.tokenization_utils_base import PreTrainedTokenizerBase +from transformers.trainer import get_scheduler + +from .base import Trainer +from .callbacks import Callback +from .strategies import ColossalAIStrategy, Strategy +from .utils import is_rank_0, to_device + + +class SFTTrainer(Trainer): + """ + Trainer to use while training reward model. + + Args: + model (torch.nn.Module): the model to train + strategy (Strategy): the strategy to use for training + optim(Optimizer): the optimizer to use for training + train_dataloader: the dataloader to use for training + eval_dataloader: the dataloader to use for evaluation + batch_size (int, defaults to 1): the batch size while training + max_epochs (int, defaults to 2): the number of epochs to train + callbacks (List[Callback], defaults to []): the callbacks to call during training process + optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer + """ + + def __init__( + self, + model, + strategy: Strategy, + optim: Optimizer, + train_dataloader: DataLoader, + eval_dataloader: DataLoader = None, + max_epochs: int = 2, + accumulation_steps: int = 8, + callbacks: List[Callback] = [], + ) -> None: + if accumulation_steps > 1 and isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3: + raise ValueError("Accumulation steps are not supported in stage 3 of ColossalAI") + super().__init__(strategy, max_epochs, callbacks=callbacks) + self.train_dataloader = train_dataloader + self.eval_dataloader = eval_dataloader + self.model = model + self.optimizer = optim + + self.accumulation_steps = accumulation_steps + num_update_steps_per_epoch = len(train_dataloader) // self.accumulation_steps + max_steps = math.ceil(self.max_epochs * num_update_steps_per_epoch) + + self.scheduler = get_scheduler("cosine", + self.optimizer, + num_warmup_steps=math.ceil(max_steps * 0.03), + num_training_steps=max_steps) + + def fit(self, logger, use_wandb: bool = False): + if use_wandb: + wandb.init(project="Coati", name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) + wandb.watch(self.model) + total_loss = 0 + # epoch_bar = tqdm(range(self.epochs), desc='Epochs', disable=not is_rank_0()) + step_bar = tqdm(range(len(self.train_dataloader) // self.accumulation_steps * self.max_epochs), + desc=f'steps', + disable=not is_rank_0()) + for epoch in range(self.max_epochs): + + # process_bar = tqdm(range(len(self.train_dataloader)), desc=f'Train process for{epoch}', disable=not is_rank_0()) + # train + self.model.train() + for batch_id, batch in enumerate(self.train_dataloader): + + batch = to_device(batch, torch.cuda.current_device()) + outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]) + + loss = outputs.loss + + if loss >= 2.5 and is_rank_0(): + logger.warning(f"batch_id:{batch_id}, abnormal loss: {loss}") + + loss = loss / self.accumulation_steps + + self.strategy.backward(loss, self.model, self.optimizer) + + total_loss += loss.item() + + # gradient accumulation + if (batch_id + 1) % self.accumulation_steps == 0: + self.strategy.optimizer_step(self.optimizer) + self.optimizer.zero_grad() + self.scheduler.step() + if is_rank_0() and use_wandb: + wandb.log({ + "loss": total_loss / self.accumulation_steps, + "lr": self.scheduler.get_last_lr()[0], + "epoch": epoch, + "batch_id": batch_id + }) + total_loss = 0 + step_bar.update() + + # if batch_id % log_interval == 0: + # logger.info(f'Train Epoch {epoch}/{self.epochs} Batch {batch_id} Rank {dist.get_rank()} loss {loss.item()}') + # wandb.log({"loss": loss.item()}) + + # process_bar.update() + + # eval + if self.eval_dataloader is not None: + self.model.eval() + with torch.no_grad(): + loss_sum = 0 + num_seen = 0 + for batch in self.eval_dataloader: + batch = to_device(batch, torch.cuda.current_device()) + outputs = self.model(batch["input_ids"], + attention_mask=batch["attention_mask"], + labels=batch["labels"]) + loss = outputs.loss + + loss_sum += loss.item() + num_seen += batch["input_ids"].size(0) + + loss_mean = loss_sum / num_seen + if dist.get_rank() == 0: + logger.info(f'Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}') + + # epoch_bar.update() diff --git a/applications/Chat/coati/trainer/strategies/__init__.py b/applications/Chat/coati/trainer/strategies/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f258c9b8a87324d28d03a22af94f485dcb416c24 --- /dev/null +++ b/applications/Chat/coati/trainer/strategies/__init__.py @@ -0,0 +1,6 @@ +from .base import Strategy +from .colossalai import ColossalAIStrategy +from .ddp import DDPStrategy +from .naive import NaiveStrategy + +__all__ = ['Strategy', 'NaiveStrategy', 'DDPStrategy', 'ColossalAIStrategy'] diff --git a/applications/Chat/coati/trainer/strategies/base.py b/applications/Chat/coati/trainer/strategies/base.py new file mode 100644 index 0000000000000000000000000000000000000000..b1452869179ebffdbbb05a42e64777ba3e20ca83 --- /dev/null +++ b/applications/Chat/coati/trainer/strategies/base.py @@ -0,0 +1,132 @@ +from abc import ABC, abstractmethod +from contextlib import nullcontext +from typing import Any, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from coati.models.base import Actor, get_base_model +from coati.replay_buffer import ReplayBuffer +from torch.optim import Optimizer +from torch.utils.data import DataLoader +from transformers.tokenization_utils_base import PreTrainedTokenizerBase + +from .sampler import DistributedSampler + +ModelOptimPair = Tuple[nn.Module, Optimizer] +ModelOrModelOptimPair = Union[nn.Module, ModelOptimPair] + + +class Strategy(ABC): + """ + Base class for training strategies. + """ + + def __init__(self) -> None: + super().__init__() + self.setup_distributed() + + @abstractmethod + def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: Optimizer, **kwargs) -> None: + pass + + @abstractmethod + def optimizer_step(self, optimizer: Optimizer, **kwargs) -> None: + pass + + @abstractmethod + def setup_distributed(self) -> None: + pass + + @abstractmethod + def setup_model(self, model: nn.Module) -> nn.Module: + pass + + @abstractmethod + def setup_optimizer(self, optimizer: Optimizer, model: nn.Module) -> Optimizer: + pass + + @abstractmethod + def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader: + pass + + def model_init_context(self): + return nullcontext() + + def prepare( + self, *models_or_model_optim_pairs: ModelOrModelOptimPair + ) -> Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]: + """Prepare models or model-optimizer-pairs based on each strategy. + + Example:: + >>> # when fine-tuning actor and critic + >>> (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare((actor, actor_optim), (critic, critic_optim), reward_model, initial_model) + >>> # or when training reward model + >>> (reward_model, reward_model_optim) = strategy.prepare((reward_model, reward_model_optim)) + >>> # or just inference + >>> actor, critic = strategy.prepare(actor, critic) + + Returns: + Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]: Models or model-optimizer-pairs in the original order. + """ + + def prepare_model(model: nn.Module): + if isinstance(model, Actor): + return Actor(self.setup_model(model.get_base_model())) + return self.setup_model(model) + + rets = [] + for arg in models_or_model_optim_pairs: + if isinstance(arg, tuple): + assert len(arg) == 2, f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"' + model, optimizer = arg + model = prepare_model(model) + optimizer = self.setup_optimizer(optimizer, get_base_model(model)) + rets.append((model, optimizer)) + elif isinstance(arg, nn.Module): + rets.append(prepare_model(arg)) + else: + raise RuntimeError(f'Expect model or (model, optimizer) pair, got {type(arg)}') + + if len(rets) == 1: + return rets[0] + return rets + + @staticmethod + def unwrap_model(model: nn.Module) -> nn.Module: + """Get the unwrapped model from a wrapped model. Useful for getting original huggingface model. + For Actor, it will unwrap `actor.model`. + + Args: + model (nn.Module): the model to unwrap + + Returns: + nn.Module: the original model (usually a huggingface model) + """ + return get_base_model(model) + + @abstractmethod + def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None: + pass + + @abstractmethod + def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None: + pass + + @abstractmethod + def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None: + pass + + @abstractmethod + def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = None) -> None: + pass + + def setup_sampler(self, dataset) -> DistributedSampler: + return DistributedSampler(dataset, 1, 0) + + @abstractmethod + def save_pretrained(self, + model: nn.Module, + path: str, + only_rank0: bool = True, + tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: + pass diff --git a/applications/Chat/coati/trainer/strategies/colossalai.py b/applications/Chat/coati/trainer/strategies/colossalai.py new file mode 100644 index 0000000000000000000000000000000000000000..8aa302c77eeec2efa5fb7e3e789b172e2f9578a2 --- /dev/null +++ b/applications/Chat/coati/trainer/strategies/colossalai.py @@ -0,0 +1,188 @@ +import warnings +from typing import Optional, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.optim as optim +from coati.models.base import get_base_model +from torch.optim import Optimizer +from transformers.tokenization_utils_base import PreTrainedTokenizerBase + +import colossalai +from colossalai.logging import get_dist_logger +from colossalai.nn.optimizer import CPUAdam, HybridAdam +from colossalai.tensor import ProcessGroup, ShardSpec +from colossalai.utils import get_current_device +from colossalai.zero import ColoInitContext, ZeroDDP, zero_model_wrapper, zero_optim_wrapper + +from .ddp import DDPStrategy + +logger = get_dist_logger(__name__) + + +class ColossalAIStrategy(DDPStrategy): + """ + The strategy for training with ColossalAI. + + Args: + stage(int): The stage to use in ZeRO. Choose in (1, 2, 3) + precision(str): The precision to use. Choose in ('fp32', 'fp16'). Stage 3 only supports fp16. + seed(int): The seed for the random number generator. + shard_init(bool): Whether to shard the model parameters during initialization. Only for ZeRO-3. + This is not compativle with `from_pretrained()`. We temporarily disable this and will support it in the future. + placement_policy(str): The placement policy for gemini. Choose in ('cpu', 'cuda') + If it is “cpu”, parameters, gradients and optimizer states will be offloaded to CPU, + If it is “cuda”, they will not be offloaded, which means max CUDA memory will be used. It is the fastest. + pin_memory(bool): Whether to pin the memory for the data loader. Only for ZeRO-3. + force_outputs_fp32(bool): Whether to force the outputs to be fp32. Only for ZeRO-3. + search_range_mb(int): The search range in MB for the chunk size. Only for ZeRO-3. + hidden_dim(optional, int): The hidden dimension for the gemini. Only for ZeRO-3. + min_chunk_size_mb(float): The minimum chunk size in MB. Only for ZeRO-3. + gpu_margin_mem_ratio(float): The margin memory ratio for the GPU. Only for ZeRO-3. + reduce_bugket_size(int): The reduce bucket size in bytes. Only for ZeRO-1 and ZeRO-2. + overlap_communication(bool): Whether to overlap communication and computation. Only for ZeRO-1 and ZeRO-2. + initial_scale(float): The initial scale for the optimizer. + growth_factor(float): The growth factor for the optimizer. + backoff_factor(float): The backoff factor for the optimizer. + growth_interval(int): The growth interval for the optimizer. + hysteresis(int): The hysteresis for the optimizer. + min_scale(float): The minimum scale for the optimizer. + max_scale(float): The maximum scale for the optimizer. + max_norm(float): The maximum norm for the optimizer. + norm_type(float): The norm type for the optimizer. + + """ + + def __init__( + self, + stage: int = 3, + precision: str = 'fp16', + seed: int = 42, + shard_init: bool = False, # only for stage 3 + placement_policy: str = 'cuda', + pin_memory: bool = True, # only for stage 3 + force_outputs_fp32: bool = False, # only for stage 3 + scatter_after_inference: bool = False, # only for stage 3 + search_range_mb: int = 32, # only for stage 3 + hidden_dim: Optional[int] = None, # only for stage 3 + min_chunk_size_mb: float = 32, # only for stage 3 + gpu_margin_mem_ratio: float = 0.0, # only for stage 3 + reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2 + overlap_communication: bool = True, # only for stage 1&2 + initial_scale: float = 2**16, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + min_scale: float = 1, + max_scale: float = 2**32, + max_norm: float = 0.0, + norm_type: float = 2.0) -> None: + super().__init__(seed) + assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"' + assert precision in ('fp32', 'fp16'), f'Unsupported precision "{precision}"' + self.stage = stage + # TODO(ver217): support shard_init when using from_pretrained() + if shard_init: + warnings.warn( + f'Shard init is not supported model.from_pretrained() yet. Please load weights after strategy.prepare()' + ) + if stage == 3 and precision == 'fp32': + warnings.warn(f'Stage 3 only supports fp16. Precision is set to fp16.') + precision = 'fp16' + self.precision = precision + self.shard_init = shard_init + self.gemini_config = dict(device=get_current_device(), + placement_policy=placement_policy, + pin_memory=pin_memory, + force_outputs_fp32=force_outputs_fp32, + strict_ddp_mode=shard_init, + search_range_mb=search_range_mb, + hidden_dim=hidden_dim, + min_chunk_size_mb=min_chunk_size_mb, + scatter_after_inference=scatter_after_inference) + if stage == 3: + self.zero_optim_config = dict(gpu_margin_mem_ratio=gpu_margin_mem_ratio) + else: + self.zero_optim_config = dict(reduce_bucket_size=reduce_bucket_size, + overlap_communication=overlap_communication, + cpu_offload=(placement_policy == 'cpu')) + self.optim_kwargs = dict(initial_scale=initial_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + min_scale=min_scale, + max_scale=max_scale, + max_norm=max_norm, + norm_type=norm_type) + + def setup_distributed(self) -> None: + colossalai.launch_from_torch({}, seed=self.seed) + + def model_init_context(self): + if self.stage == 3: + world_size = dist.get_world_size() + shard_pg = ProcessGroup(tp_degree=world_size) if self.shard_init else None + default_dist_spec = ShardSpec([-1], [world_size]) if self.shard_init else None + return ColoInitContext(device=get_current_device(), + dtype=torch.half, + default_pg=shard_pg, + default_dist_spec=default_dist_spec) + return super().model_init_context() + + def setup_model(self, model: nn.Module) -> nn.Module: + + model = zero_model_wrapper(model, zero_stage=self.stage, gemini_config=self.gemini_config) + + if self.stage != 3 and self.precision == 'fp16': + model = model.half().cuda() + return model + + def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer: + assert isinstance(optimizer, (CPUAdam, HybridAdam)), f'Unsupported optimizer {type(optimizer)}' + return zero_optim_wrapper(model, optimizer, optim_config=self.zero_optim_config, **self.optim_kwargs) + + def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: optim.Optimizer, **kwargs) -> None: + optimizer.backward(loss) + + def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None: + optimizer.step() + + def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None: + if only_rank0 and dist.get_rank() != 0 and self.stage != 3: + return + base_model = get_base_model(model) + if self.stage == 3: + assert isinstance(base_model, ZeroDDP) + # for stage 3, state_dict() method should be called on every rank + state_dict = base_model.state_dict(only_rank_0=only_rank0) + else: + # only_rank0 is false or rank == 0 + state_dict = base_model.state_dict() + if only_rank0 and dist.get_rank() != 0: + return + torch.save(state_dict, path) + + def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None: + if only_rank0: + raise RuntimeError( + f'Optimizer states are sharded when using ColossalAIStrategy. Only rank0 is not supported.') + torch.save(optimizer.state_dict(), path) + + def unwrap_model(self, model: nn.Module) -> nn.Module: + base_model: Union[nn.Module, ZeroDDP] = get_base_model(model) + if self.stage == 3: + assert isinstance(base_model, ZeroDDP) + return base_model.module + return base_model + + def save_pretrained(self, + model: nn.Module, + path: str, + only_rank0: bool = True, + tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: + if self.stage == 3: + raise RuntimeError('ColossalAI strategy with stage-3 does not support save_pretrained() now') + super().save_pretrained(model, path, only_rank0, tokenizer) diff --git a/applications/Chat/coati/trainer/strategies/ddp.py b/applications/Chat/coati/trainer/strategies/ddp.py new file mode 100644 index 0000000000000000000000000000000000000000..7910b57878f86c5a8ed2781ab570e8302c3c2567 --- /dev/null +++ b/applications/Chat/coati/trainer/strategies/ddp.py @@ -0,0 +1,93 @@ +import os +import random +from typing import Optional + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +from coati.replay_buffer import ReplayBuffer +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer +from torch.utils.data import DataLoader +from transformers.tokenization_utils_base import PreTrainedTokenizerBase + +from .naive import NaiveStrategy +from .sampler import DistributedSampler + + +class DDPStrategy(NaiveStrategy): + """ + Strategy for distributed training using torch.distributed. + """ + + def __init__(self, seed: int = 42) -> None: + self.seed = seed + super().__init__() + + def setup_distributed(self) -> None: + try: + rank = int(os.environ['RANK']) + local_rank = int(os.environ['LOCAL_RANK']) + world_size = int(os.environ['WORLD_SIZE']) + host = os.environ['MASTER_ADDR'] + port = int(os.environ['MASTER_PORT']) + except KeyError as e: + raise RuntimeError( + f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch" + ) + dist.init_process_group('nccl', init_method=f'tcp://[{host}]:{port}', world_size=world_size, rank=rank) + self.set_seed(self.seed) + torch.cuda.set_device(local_rank) + + def set_seed(self, seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + def setup_model(self, model: nn.Module) -> nn.Module: + device = torch.cuda.current_device() + return DDP(model, device_ids=[device]) + + def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader: + # DDP only mode, replay buffers on each rank are different. + # sampler = DistributedSampler(replay_buffer, + # num_replicas=dist.get_world_size(), + # rank=dist.get_rank(), + # shuffle=True, + # seed=self.seed, + # drop_last=True) + return DataLoader( + replay_buffer, + batch_size=replay_buffer.sample_batch_size, + # sampler=sampler, + shuffle=True, + drop_last=True, + pin_memory=pin_memory, + collate_fn=replay_buffer.collate_fn) + + def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None: + if only_rank0 and dist.get_rank() != 0: + return + super().save_model(model, path, only_rank0) + + def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None: + if only_rank0 and dist.get_rank() != 0: + return + super().save_optimizer(optimizer, path, only_rank0) + + def setup_sampler(self, dataset) -> DistributedSampler: + return DistributedSampler(dataset, dist.get_world_size(), dist.get_rank()) + + def unwrap_model(self, model: nn.Module) -> nn.Module: + base_model: DDP = super().unwrap_model(model) + return base_model.module + + def save_pretrained(self, + model: nn.Module, + path: str, + only_rank0: bool = True, + tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: + if only_rank0 and dist.get_rank() != 0: + return + super().save_pretrained(model, path, only_rank0, tokenizer) diff --git a/applications/Chat/coati/trainer/strategies/naive.py b/applications/Chat/coati/trainer/strategies/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..4d94026ce9320ef754632680d5b44de6baae4862 --- /dev/null +++ b/applications/Chat/coati/trainer/strategies/naive.py @@ -0,0 +1,70 @@ +from typing import Any, Optional + +import torch +import torch.nn as nn +import torch.optim as optim +from coati.models.base import get_base_model +from coati.replay_buffer import ReplayBuffer +from torch.optim import Optimizer +from torch.utils.data import DataLoader +from transformers.modeling_utils import PreTrainedModel +from transformers.tokenization_utils_base import PreTrainedTokenizerBase + +from .base import Strategy + + +class NaiveStrategy(Strategy): + """ + Strategy for single GPU. No parallelism is used. + """ + + def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: optim.Optimizer, **kwargs) -> None: + loss.backward() + + def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None: + optimizer.step() + + def setup_distributed(self) -> None: + pass + + def setup_model(self, model: nn.Module) -> nn.Module: + return model + + def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer: + return optimizer + + def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader: + return DataLoader(replay_buffer, + batch_size=replay_buffer.sample_batch_size, + shuffle=True, + drop_last=True, + pin_memory=pin_memory, + collate_fn=replay_buffer.collate_fn) + + def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None: + base_model = get_base_model(model) + state_dict = base_model.state_dict() + torch.save(state_dict, path) + + def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None: + base_model = get_base_model(model) + state_dict = torch.load(path, map_location=map_location) + base_model.load_state_dict(state_dict, strict=strict) + + def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None: + torch.save(optimizer.state_dict(), path) + + def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = None) -> None: + state_dict = torch.load(path, map_location=map_location) + optimizer.load_state_dict(state_dict) + + def save_pretrained(self, + model: nn.Module, + path: str, + only_rank0: bool = True, + tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: + unwrapped_model = self.unwrap_model(model) + assert isinstance(unwrapped_model, PreTrainedModel) + unwrapped_model.save_pretrained(path) + if tokenizer is not None: + tokenizer.save_pretrained(path) diff --git a/applications/Chat/coati/trainer/strategies/sampler.py b/applications/Chat/coati/trainer/strategies/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..d726fa640fa201b8bdec5c7601cc2895c4357316 --- /dev/null +++ b/applications/Chat/coati/trainer/strategies/sampler.py @@ -0,0 +1,32 @@ +import math + +import numpy as np + + +class DistributedSampler: + + def __init__(self, dataset, num_replicas: int, rank: int) -> None: + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + + if len(self.dataset) % self.num_replicas != 0: + self.num_samples = math.ceil( + (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type] + ) + else: + self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) + + self.total_size = self.num_samples * self.num_replicas + + indices = list(range(len(self.dataset))) + indices = indices[:self.total_size] + assert len(indices) == self.total_size + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + self.indices = indices + + def sample(self, batch_size: int) -> list: + sampled_indices = np.random.choice(self.indices, batch_size, replace=False) + return [self.dataset[idx] for idx in sampled_indices] diff --git a/applications/Chat/coati/trainer/utils.py b/applications/Chat/coati/trainer/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9cccb5c9260395859611c55207f7dbaf817ef76f --- /dev/null +++ b/applications/Chat/coati/trainer/utils.py @@ -0,0 +1,19 @@ +from typing import Any + +import torch +import torch.distributed as dist +from torch.utils._pytree import tree_map + + +def is_rank_0() -> bool: + return not dist.is_initialized() or dist.get_rank() == 0 + + +def to_device(x: Any, device: torch.device) -> Any: + + def _to(t: Any): + if isinstance(t, torch.Tensor): + return t.to(device) + return t + + return tree_map(_to, x) diff --git a/applications/Chat/coati/utils/__init__.py b/applications/Chat/coati/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..112b82b9706444013013777d12798b4f6b62e52a --- /dev/null +++ b/applications/Chat/coati/utils/__init__.py @@ -0,0 +1,3 @@ +from .tokenizer_utils import prepare_llama_tokenizer_and_embedding, smart_tokenizer_and_embedding_resize + +__all__ = ['smart_tokenizer_and_embedding_resize', 'prepare_llama_tokenizer_and_embedding'] \ No newline at end of file diff --git a/applications/Chat/coati/utils/tokenizer_utils.py b/applications/Chat/coati/utils/tokenizer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e0d96cfc8be2711397d23bcfed725bd0ba10a2bf --- /dev/null +++ b/applications/Chat/coati/utils/tokenizer_utils.py @@ -0,0 +1,73 @@ +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict + +import transformers + +DEFAULT_PAD_TOKEN = "[PAD]" +DEFAULT_EOS_TOKEN = "" +DEFAULT_BOS_TOKEN = "" +DEFAULT_UNK_TOKEN = "" + + +def prepare_llama_tokenizer_and_embedding( + tokenizer: transformers.PreTrainedTokenizer, + model: transformers.PreTrainedModel, + special_tokens_dict: Dict = dict(pad_token=DEFAULT_PAD_TOKEN), +): + """prepare llama tokenizer and embedding. + + """ + + if tokenizer.pad_token is None: + smart_tokenizer_and_embedding_resize( + special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN), + tokenizer=tokenizer, + model=model, + ) + + tokenizer.add_special_tokens({ + "eos_token": DEFAULT_EOS_TOKEN, + "bos_token": DEFAULT_BOS_TOKEN, + "unk_token": DEFAULT_UNK_TOKEN, + }) + + return tokenizer + + +def smart_tokenizer_and_embedding_resize( + tokenizer: transformers.PreTrainedTokenizer, + model: transformers.PreTrainedModel, + special_tokens_dict: Dict = dict(pad_token=DEFAULT_PAD_TOKEN), +): + """Resize tokenizer and embedding. + + Note: This is the unoptimized version that may make your embedding size not be divisible by 64. + """ + + if tokenizer.pad_token is None: + num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + + model.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings = model.get_input_embeddings().weight.data + output_embeddings = model.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg diff --git a/applications/Chat/evaluate/README.md b/applications/Chat/evaluate/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7ace4bfe6d1871bb97e3e522dc35128d692a7cb7 --- /dev/null +++ b/applications/Chat/evaluate/README.md @@ -0,0 +1,182 @@ +# Evaluation + +In this directory, we introduce how you can evaluate your model with GPT-4. + +## Evaluation Pipeline + +The whole evaluation process undergoes the following three steps: +1. Prepare the questions following the internal data structure in the data format section (described below). +2. Generate answers from different models: + * Generate answers using GPT-3.5: [`generate_gpt35_answers.py`](generate_gpt35_answers.py). + * Generate answers using your own models: [`generate_answers.py`](generate_answers.py). +3. Evaluate models using GPT-4: [`evaluate.py`](evaluate.py). + +### Generate Answers +#### Generate Answers Using GPT-3.5 +You can provide your own OpenAI key to generate answers from GPT-3.5 using [`generate_gpt35_answers.py`](./generate_gpt35_answers.py). + +An example script is provided as follows: +```shell +python generate_gpt35_answers.py \ + --dataset "path to the question dataset" \ + --answer_path "path to answer folder" \ + --num_workers 4 \ + --openai_key "your openai key" \ + --max_tokens 512 \ +``` + +#### Generate Answers Using our Own Model +You can also generate answers using your own models. The generation process is divided into two stages: +1. Generate answers using multiple GPUs (optional) with batch processing: [`generate_answers.py`](./generate_answers.py). +2. Merge multiple shards and output a single file: [`merge.py`](./merge.py). + +An example script is given as follows: + +```shell +device_number=number of your devices +model_name="name of your model" +model_path="path to your model" +dataset="path to the question dataset" +answer_path="path to save the model answers" + +torchrun --standalone --nproc_per_node=$device_number generate_answers.py \ + --model 'llama' \ + --strategy ddp \ + --model_path $model_path \ + --model_name $model_name \ + --dataset $dataset \ + --batch_size 8 \ + --max_datasets_size 80 \ + --answer_path $answer_path \ + --max_length 512 + +python merge.py \ + --model_name $model_name \ + --shards $device_number \ + --answer_path $answer_path \ + +for (( i=0; i scores[1]: + worse_count += 1 + worse_file.append(review_jsons[idx]) + elif scores[0] < scores[1]: + better_count += 1 + better_file.append(review_jsons[idx]) + else: + tie_count += 1 + tie_file.append(review_jsons[idx]) + ans1_score += scores[0] + ans2_score += scores[1] + + output_review_file.append(review_jsons[idx]) + + better_file.sort(key=lambda x: x['id']) + worse_file.sort(key=lambda x: x['id']) + tie_file.sort(key=lambda x: x['id']) + invalid_file.sort(key=lambda x: x['id']) + output_review_file.sort(key=lambda x: x['id']) + + name1 = os.path.basename(args.answer_file_list[0]).split("_answers")[0] + name2 = os.path.basename(args.answer_file_list[1]).split("_answers")[0] + prefix = f"{name1}_vs_{name2}" + + jdump(better_file, os.path.join( + args.output_folder, prefix, f"{prefix}_better.json")) + jdump(worse_file, os.path.join( + args.output_folder, prefix, f"{prefix}_worse.json")) + jdump(tie_file, os.path.join( + args.output_folder, prefix, f"{prefix}_tie.json")) + jdump(invalid_file, os.path.join( + args.output_folder, prefix, f"{prefix}_invalid.json")) + jdump(output_review_file, os.path.join( + args.output_folder, prefix, f"{prefix}_review.json")) + + if os.path.exists(os.path.join(args.output_folder, "results.json")): + results = jload(os.path.join(args.output_folder, "results.json")) + else: + results = {} + results[prefix] = {'model': [name1, name2], 'better': better_count, 'worse': worse_count, 'tie': tie_count, 'win_rate': better_count / + (len(reviews)-invalid_count), 'score': [ans1_score/(len(reviews)-invalid_count), ans2_score/(len(reviews)-invalid_count)]} + jdump(results, os.path.join(args.output_folder, "results.json")) + + logger.info(f' Total {invalid_count} invalid score pair(s).') + logger.info(f' Model {name2} has {better_count} better answer(s).') + logger.info(f' Model {name2} has {worse_count} worse answer(s).') + logger.info(f' {tie_count} answer(s) play(s) to a tie.') + logger.info( + f' Win rate of model {name2}: {better_count/(len(reviews)-invalid_count):.2f}') + logger.info( + f' Model {name1} average score: {ans1_score/(len(reviews)-invalid_count):.2f}') + logger.info( + f' Model {name2} average score: {ans2_score/(len(reviews)-invalid_count):.2f}') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='Model evaluation.') + parser.add_argument('--answer_file_list', nargs='+', default=[]) + parser.add_argument('--prompt_file') + parser.add_argument('--reviewer_file') + parser.add_argument('--output_folder', type=str, default="./output") + parser.add_argument('--openai_key', type=str, default=None) + parser.add_argument('--model', type=str, default="gpt-4") + parser.add_argument('--num_workers', type=int, default=8) + parser.add_argument('--max_tokens', type=int, default=512, + help='maximum number of tokens produced in the output') + args = parser.parse_args() + + if args.openai_key is not None: + os.environ["OPENAI_API_KEY"] = args.openai_key + openai.api_key = os.getenv("OPENAI_API_KEY") + + evaluate(args) diff --git a/applications/Chat/evaluate/evaluate.sh b/applications/Chat/evaluate/evaluate.sh new file mode 100755 index 0000000000000000000000000000000000000000..c51aa941019e55e38f57a459f88dc6ae85264c0e --- /dev/null +++ b/applications/Chat/evaluate/evaluate.sh @@ -0,0 +1,9 @@ +python evaluate.py \ + --answer_file_list "path to answers of model 1" "path to answers of model 2" \ + --prompt_file "path to prompt file" \ + --reviewer_file "path to reviewer file" \ + --output_folder "path to output folder" \ + --openai_key "your openai key" \ + --model "gpt-4" \ + --num_workers 8 \ + --max_tokens 512 \ diff --git a/applications/Chat/evaluate/generate_answers.py b/applications/Chat/evaluate/generate_answers.py new file mode 100644 index 0000000000000000000000000000000000000000..fbebf5c5e6f6835575ba6d9733604924d9e91deb --- /dev/null +++ b/applications/Chat/evaluate/generate_answers.py @@ -0,0 +1,173 @@ +import argparse +import os +import random +import copy +import math +from tqdm import tqdm + +import torch +import torch.distributed as dist +import transformers + +from coati.models.bloom import BLOOMActor +from coati.models.gpt import GPTActor +from coati.models.opt import OPTActor +from coati.models.roberta import RoBERTaActor +from coati.models.llama import LlamaActor +from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy +from transformers import AutoTokenizer, RobertaTokenizer +from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer + +from colossalai.logging import get_dist_logger + +from utils import jload, jdump, is_rank_0 + + +logger = get_dist_logger() + +PROMPT_DICT = { + "prompt_input": + ("Below is an instruction that describes a task, paired with an input that provides further context. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"), + "prompt_no_input": ("Below is an instruction that describes a task. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n### Response:"), +} + + +def generate(args): + # torch.cuda.set_per_process_memory_fraction(0.4) + if args.strategy == 'naive': + strategy = NaiveStrategy() + elif args.strategy == 'ddp': + strategy = DDPStrategy() + elif args.strategy == 'colossalai_gemini': + strategy = ColossalAIStrategy(stage=3, placement_policy='cuda') + elif args.strategy == 'colossalai_zero2': + strategy = ColossalAIStrategy(stage=2, placement_policy='cuda') + elif args.strategy == 'colossalai_zero2_cpu': + strategy = ColossalAIStrategy(stage=2, placement_policy='cpu') + else: + raise ValueError(f'Unsupported strategy "{args.strategy}"') + + world_size = dist.get_world_size() + rank = dist.get_rank() + + with strategy.model_init_context(): + if args.model == 'gpt2': + actor = GPTActor(pretrained=args.model_path).to( + torch.cuda.current_device()) + elif args.model == 'bloom': + actor = BLOOMActor(pretrained=args.model_path).to( + torch.cuda.current_device()) + elif args.model == 'opt': + actor = OPTActor(pretrained=args.model_path).to( + torch.cuda.current_device()) + elif args.model == 'roberta': + actor = RoBERTaActor(pretrained=args.model_path).to( + torch.cuda.current_device()) + elif args.model == 'llama': + actor = LlamaActor(pretrained=args.model_path).to( + torch.float16).to(torch.cuda.current_device()) + else: + raise ValueError(f'Unsupported model "{args.model}"') + + if args.model == 'gpt2': + tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + tokenizer.pad_token = tokenizer.eos_token + elif args.model == 'bloom': + tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom-560m') + tokenizer.pad_token = tokenizer.eos_token + elif args.model == 'opt': + tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m') + elif args.model == 'roberta': + tokenizer = RobertaTokenizer.from_pretrained("roberta-base") + elif args.model == 'llama': + tokenizer = AutoTokenizer.from_pretrained(args.model_path, + padding_side="right", + use_fast=False, + ) + tokenizer.eos_token = '<\s>' + else: + raise ValueError(f'Unsupported model "{args.model}"') + + questions = [] + if args.max_datasets_size is not None: + questions = random.sample(jload(args.dataset), args.max_datasets_size) + if is_rank_0(): + logger.info( + f"Limiting dataset to {args.max_datasets_size} examples.") + questions = questions[rank:args.max_datasets_size:world_size] + + answers = copy.deepcopy(questions) + + prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"] + sources = [ + prompt_input.format_map(example) if example.get( + "input", "") != "" else prompt_no_input.format_map(example) + for example in questions + ] + + if is_rank_0(): + logger.info("Tokenizing inputs... This may take some time...") + + input_ids_list = [] + + for string in sources: + input_ids = tokenizer.encode(string, return_tensors='pt').squeeze(0) + input_ids_list.append(input_ids) + + bar = tqdm(range(math.ceil(len(input_ids_list)/args.batch_size)), + desc=f'steps', disable=not is_rank_0()) + + actor.eval() + with torch.no_grad(): + for i in range(0, len(input_ids_list), args.batch_size): + batch = input_ids_list[i:i+args.batch_size] + batch = [i.flip(dims=[0]) for i in batch] + batch = torch.nn.utils.rnn.pad_sequence(batch, + batch_first=True, + padding_value=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0).to(torch.cuda.current_device()) + batch = batch.flip(dims=[1]) + attention_mask = batch.ne(tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0) + + outputs = actor.model.generate(batch, attention_mask=attention_mask, + max_length=args.max_length, + do_sample=True, + top_k=50, + top_p=0.95, + num_return_sequences=1) + + outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) + for j in range(batch.size(0)): + answers[i + + j]['output'] = outputs[j].split("### Response:")[1].strip() + + bar.update() + + jdump(answers, os.path.join(args.answer_path, + f'{args.model_name}_answers_rank{rank}.json')) + + if is_rank_0(): + logger.info( + f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--strategy', + choices=['naive', 'ddp', 'colossalai_gemini', + 'colossalai_zero2', 'colossalai_zero2_cpu'], + default='naive') + parser.add_argument('--model', default='gpt2', + choices=['gpt2', 'bloom', 'opt', 'roberta', 'llama']) + parser.add_argument('--model_path', type=str, default=None) + parser.add_argument('--model_name', type=str, default='model') + parser.add_argument('--dataset', type=str, default=None) + parser.add_argument('--batch_size', type=int, default=1) + parser.add_argument('--max_datasets_size', type=int, default=None) + parser.add_argument('--answer_path', type=str, default="answer") + parser.add_argument('--max_length', type=int, default=1024) + args = parser.parse_args() + generate(args) diff --git a/applications/Chat/evaluate/generate_answers.sh b/applications/Chat/evaluate/generate_answers.sh new file mode 100755 index 0000000000000000000000000000000000000000..36881f5f4f292885a153e4558eb1240823c14bb5 --- /dev/null +++ b/applications/Chat/evaluate/generate_answers.sh @@ -0,0 +1,25 @@ +device_number=number of your devices +model_name="name of your model" +model_path="path to your model" +dataset="path to the question dataset" +answer_path="path to save the model answers" + +torchrun --standalone --nproc_per_node=$device_number generate_answers.py \ + --model 'llama' \ + --strategy ddp \ + --model_path $model_path \ + --model_name $model_name \ + --dataset $dataset \ + --batch_size 8 \ + --max_datasets_size 80 \ + --answer_path $answer_path \ + --max_length 512 + +python merge.py \ + --model_name $model_name \ + --shards $device_number \ + --answer_path $answer_path \ + +for (( i=0; i bool: + return not dist.is_initialized() or dist.get_rank() == 0 + +def _make_w_io_base(f, mode: str): + if not isinstance(f, io.IOBase): + f_dirname = os.path.dirname(f) + if f_dirname != "": + os.makedirs(f_dirname, exist_ok=True) + f = open(f, mode=mode) + return f + +def _make_r_io_base(f, mode: str): + if not isinstance(f, io.IOBase): + f = open(f, mode=mode) + return f + +def jdump(obj, f, mode="w", indent=4, default=str): + """Dump a str or dictionary to a file in json format. + Args: + obj: An object to be written. + f: A string path to the location on disk. + mode: Mode for opening the file. + indent: Indent for storing json dictionaries. + default: A function to handle non-serializable entries; defaults to `str`. + """ + f = _make_w_io_base(f, mode) + if isinstance(obj, (dict, list)): + json.dump(obj, f, indent=indent, default=default) + elif isinstance(obj, str): + f.write(obj) + else: + raise ValueError(f"Unexpected type: {type(obj)}") + f.close() + +def jload(f, mode="r"): + """Load a .json file into a dictionary.""" + f = _make_r_io_base(f, mode) + jdict = json.load(f) + f.close() + return jdict + +def get_json_list(file_path): + with open(file_path, 'r') as f: + json_list = [] + for line in f: + json_list.append(json.loads(line)) + return json_list diff --git a/applications/Chat/examples/README.md b/applications/Chat/examples/README.md new file mode 100644 index 0000000000000000000000000000000000000000..561ace2205ed559846560e9352a13852eddd44bd --- /dev/null +++ b/applications/Chat/examples/README.md @@ -0,0 +1,302 @@ +# Examples + +## Table of Contents + +- [Examples](#examples) + - [Table of Contents](#table-of-contents) + - [Install requirements](#install-requirements) + - [Supervised datasets collection](#supervised-datasets-collection) + - [Stage1 - Supervised instructs tuning](#stage1---supervised-instructs-tuning) + - [Arg List](#arg-list) + - [Stage2 - Training reward model](#stage2---training-reward-model) + - [Features and tricks in RM training](#features-and-tricks-in-rm-training) + - [Experiment result](#experiment-result) + - [Arg List](#arg-list-1) + - [Stage3 - Training model using prompts with RL](#stage3---training-model-using-prompts-with-rl) + - [Arg List](#arg-list-2) + - [Inference example - After Stage3](#inference-example---after-stage3) + - [Attention](#attention) + - [data](#data) + - [Support Model](#support-model) + - [GPT](#gpt) + - [BLOOM](#bloom) + - [OPT](#opt) + - [LLaMA](#llama) + - [Add your own models](#add-your-own-models) + - [Actor model](#actor-model) + - [LM model](#lm-model) + - [Reward model](#reward-model) + - [Critic model](#critic-model) + + +--- +## Install requirements + +```shell +pip install -r requirements.txt +``` + +## Supervised datasets collection + +We collected 104K bilingual dataset of Chinese and English, and you can find the datasets in this repo +[InstructionWild](https://github.com/XueFuzhao/InstructionWild). + +The following pic shows how we collected the data. +

+ +

+ +## Stage1 - Supervised instructs tuning + +Stage1 is supervised instructs fine-tuning, which uses the datasets mentioned earlier to fine-tune the model. + +You can run the `examples/train_sft.sh` to start a supervised instructs fine-tuning. + +You can also use the following cmd to start a supervised instructs fine-tuning with your own settings. +``` +torchrun --standalone --nproc_per_node=4 train_sft.py \ + --pretrain "/path/to/LLaMa-7B/" \ + --model 'llama' \ + --strategy colossalai_zero2 \ + --log_interval 10 \ + --save_path /path/to/Coati-7B \ + --dataset /path/to/data.json \ + --batch_size 4 \ + --accumulation_steps 8 \ + --lr 2e-5 \ + --max_datasets_size 512 \ + --max_epochs 1 \ + --grad_checkpoint +``` +### Arg List +- --strategy: the strategy using for training, choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], default='colossalai_zero2' +- --model: model type, choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom' +- --pretrain: pretrain model, type=str, default=None +- --max_datasets_size: the max size of dataset, type=int, default=None +- --save_path: path to save the model, type=str, default='output' +- --need_optim_ckpt: whether to save optim ckpt, type=bool, default=False +- --max_epochs: max epochs for training, type=int, default=3 +- --batch_size: batch size while training, type=int, default=4 +- --lora_rank: low-rank adaptation matrices rank, type=int, default=0 +- --log_interval: how many steps to log, type=int, default=100 +- --grad_checkpoint: enable gradient checkpointing, type=bool, default=False + +## Stage2 - Training reward model + +We train a reward model in stage 2, which obtains corresponding scores by manually ranking different outputs for the same prompt and supervises the training of the reward model. + +You can run the `examples/train_rm.sh` to start a reward model training. + +You can also use the following cmd to start training a reward model. +``` +torchrun --standalone --nproc_per_node=4 train_reward_model.py \ + --pretrain "/path/to/LLaMa-7B/" \ + --model 'llama' \ + --strategy colossalai_zero2 \ + --loss_fn 'log_exp'\ + --save_path 'rmstatic.pt' \ +``` +### Features and tricks in RM training +- We support [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf)and[rm-static](https://huggingface.co/datasets/Dahoas/rm-static) datasets. +- We support 2 kinds of loss_function named 'log_sig'(used by OpenAI) and 'log_exp'(used by Anthropic). +- We change the loss to valid_acc and pair_dist to monitor progress during training. +- We add special token to the end of the sequence to get better result. +- We use cosine-reducing lr-scheduler for RM training. +- We set value_head as 1 liner layer and initialize the weight of value_head using N(0,1/(d_model + 1)) distribution. +- We train a Bloom-560m reward model for 1 epoch and find the test acc of the model achieve the performance mentions in [Anthropics paper](https://arxiv.org/abs/2204.05862). + +### Experiment result +Model performance in [Anthropics paper](https://arxiv.org/abs/2204.05862): + +
image + +
Our training & test result of bloom-560m for 1 epoch: + +
image + +
We also train the reward model based on LLaMA-7B, which reaches the ACC of 72.06% after 1 epoch, performing almost the same as Anthropic's best RM. + +### Arg List +- --strategy: the strategy using for training, choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], default='colossalai_zero2' +- --model: model type, choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom' +- --pretrain: pretrain model, type=str, default=None +- --model_path: the path of rm model(if continue to train), type=str, default=None +- --save_path: path to save the model, type=str, default='output' +- --need_optim_ckpt: whether to save optim ckpt, type=bool, default=False +- --max_epochs: max epochs for training, type=int, default=3 +- --dataset: dataset name, type=str, choices=['Anthropic/hh-rlhf', 'Dahoas/rm-static'] +- --subset: subset of the dataset, type=str, default=None +- --batch_size: batch size while training, type=int, default=4 +- --lora_rank: low-rank adaptation matrices rank, type=int, default=0 +- --loss_func: which kind of loss function, choices=['log_sig', 'log_exp'] +- --max_len: max sentence length for generation, type=int, default=512 +- --test: whether is only testing, if it's true, the dataset will be small + +## Stage3 - Training model using prompts with RL + +Stage3 uses reinforcement learning algorithm, which is the most complex part of the training process, as shown below: + +

+ +

+ +You can run the `examples/train_prompts.sh` to start PPO training. +You can also use the cmd following to start PPO training. + +``` +torchrun --standalone --nproc_per_node=4 train_prompts.py \ + --pretrain "/path/to/LLaMa-7B/" \ + --model 'llama' \ + --strategy colossalai_zero2 \ + --prompt_dataset /path/to/your/prompt_dataset \ + --pretrain_dataset /path/to/your/pretrain_dataset \ + --rm_pretrain /your/pretrain/rm/defination \ + --rm_path /your/rm/model/path +``` + +Prompt dataset: the instruction dataset mentioned in the above figure which includes the instructions, e.g. you can use the [script](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples/example_data_reformat.py) to reformat [seed_prompts_ch.jsonl](https://github.com/XueFuzhao/InstructionWild/blob/main/data/seed_prompts_ch.jsonl) or [seed_prompts_en.jsonl](https://github.com/XueFuzhao/InstructionWild/blob/main/data/seed_prompts_en.jsonl) in InstructionWild. +Pretrain dataset: the pretrain dataset including the instruction and corresponding response, e.g. you can use the [InstructWild Data](https://github.com/XueFuzhao/InstructionWild/tree/main/data) in stage 1 supervised instructs tuning. + +### Arg List +- --strategy: the strategy using for training, choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], default='colossalai_zero2' +- --model: model type of actor, choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom' +- --pretrain: pretrain model, type=str, default=None +- --rm_model: reward model type, type=str, choices=['gpt2', 'bloom', 'opt', 'llama'], default=None +- --rm_pretrain: pretrain model for reward model, type=str, default=None +- --rm_path: the path of rm model, type=str, default=None +- --save_path: path to save the model, type=str, default='output' +- --prompt_dataset: path of the prompt dataset, type=str, default=None +- --pretrain_dataset: path of the ptx dataset, type=str, default=None +- --need_optim_ckpt: whether to save optim ckpt, type=bool, default=False +- --num_episodes: num of episodes for training, type=int, default=10 +- --max_epochs: max epochs for training in one episode, type=int, default=5 +- --max_timesteps: max episodes in one batch, type=int, default=10 +- --update_timesteps: timesteps to update, type=int, default=10 +- --train_batch_size: batch size while training, type=int, default=8 +- --ptx_batch_size: batch size to compute ptx loss, type=int, default=1 +- --experience_batch_size: batch size to make experience, type=int, default=8 +- --lora_rank: low-rank adaptation matrices rank, type=int, default=0 +- --kl_coef: kl_coef using for computing reward, type=float, default=0.1 +- --ptx_coef: ptx_coef using for computing policy loss, type=float, default=0.9 + +## Inference example - After Stage3 +We support different inference options, including int8 and int4 quantization. +For details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/inference). + + +## Attention +The examples are demos for the whole training process.You need to change the hyper-parameters to reach great performance. + +#### data +- [x] [rm-static](https://huggingface.co/datasets/Dahoas/rm-static) +- [x] [hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) +- [ ] [openai/summarize_from_feedback](https://huggingface.co/datasets/openai/summarize_from_feedback) +- [ ] [openai/webgpt_comparisons](https://huggingface.co/datasets/openai/webgpt_comparisons) +- [ ] [Dahoas/instruct-synthetic-prompt-responses](https://huggingface.co/datasets/Dahoas/instruct-synthetic-prompt-responses) + +## Support Model + +### GPT +- [x] GPT2-S (s) +- [x] GPT2-M (m) +- [x] GPT2-L (l) +- [x] GPT2-XL (xl) +- [x] GPT2-4B (4b) +- [ ] GPT2-6B (6b) + +### BLOOM +- [x] [BLOOM-560m](https://huggingface.co/bigscience/bloom-560m) +- [x] [BLOOM-1b1](https://huggingface.co/bigscience/bloom-1b1) +- [x] [BLOOM-3b](https://huggingface.co/bigscience/bloom-3b) +- [x] [BLOOM-7b](https://huggingface.co/bigscience/bloom-7b1) +- [ ] [BLOOM-175b](https://huggingface.co/bigscience/bloom) + +### OPT +- [x] [OPT-125M](https://huggingface.co/facebook/opt-125m) +- [x] [OPT-350M](https://huggingface.co/facebook/opt-350m) +- [x] [OPT-1.3B](https://huggingface.co/facebook/opt-1.3b) +- [x] [OPT-2.7B](https://huggingface.co/facebook/opt-2.7b) +- [x] [OPT-6.7B](https://huggingface.co/facebook/opt-6.7b) +- [ ] [OPT-13B](https://huggingface.co/facebook/opt-13b) +- [ ] [OPT-30B](https://huggingface.co/facebook/opt-30b) + +### [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) +- [x] LLaMA-7B +- [x] LLaMA-13B +- [ ] LLaMA-33B +- [ ] LLaMA-65B + +## Add your own models + +If you want to support your own model in Coati, please refer the pull request for RoBERTa support as an example --[[chatgpt] add pre-trained model RoBERTa for RLHF stage 2 & 3](https://github.com/hpcaitech/ColossalAI/pull/3223), and submit a PR to us. + +You should complete the implementation of four model classes, including Reward model, Critic model, LM model, Actor model + +here are some example code for a NewModel named `Coati`. +if it is supported in huggingface [transformers](https://github.com/huggingface/transformers), you can load it by `from_pretrained`, o +r you can build your own model by yourself. + +### Actor model +``` +from ..base import Actor +from transformers.models.coati import CoatiModel + +class CoatiActor(Actor): + + def __init__(self, + pretrained: Optional[str] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = 'none') -> None: + if pretrained is not None: + model = CoatiModel.from_pretrained(pretrained) + else: + model = build_model() # load your own model if it is not support in transformers + + super().__init__(model, lora_rank, lora_train_bias) +``` + +### Reward model +``` +from ..base import RewardModel +from transformers.models.coati import CoatiModel + +class CoatiRM(RewardModel): + + def __init__(self, + pretrained: Optional[str] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = 'none') -> None: + if pretrained is not None: + model = CoatiModel.from_pretrained(pretrained) + else: + model = build_model() # load your own model if it is not support in transformers + + value_head = nn.Linear(model.config.n_embd, 1) + value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.n_embd + 1)) + super().__init__(model, value_head, lora_rank, lora_train_bias) +``` + +### Critic model + +``` +from ..base import Critic +from transformers.models.coati import CoatiModel + +class CoatiCritic(Critic): + + def __init__(self, + pretrained: Optional[str] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = 'none') -> None: + if pretrained is not None: + model = CoatiModel.from_pretrained(pretrained) + else: + model = build_model() # load your own model if it is not support in transformers + + value_head = nn.Linear(model.config.n_embd, 1) + value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.n_embd + 1)) + super().__init__(model, value_head, lora_rank, lora_train_bias) +``` diff --git a/applications/Chat/examples/community/README.md b/applications/Chat/examples/community/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c9c645032288f5beccd6b45f2d53f6ee5677c4e7 --- /dev/null +++ b/applications/Chat/examples/community/README.md @@ -0,0 +1,24 @@ +# Community Examples +--- +We are thrilled to announce the latest updates to ColossalChat, an open-source solution for cloning ChatGPT with a complete RLHF (Reinforcement Learning with Human Feedback) pipeline. + +As Colossal-AI undergoes major updates, we are actively maintaining ColossalChat to stay aligned with the project's progress. With the introduction of Community-driven example, we aim to create a collaborative platform for developers to contribute exotic features built on top of ColossalChat. + +## Community Example + +Community-driven Examples is an initiative that allows users to contribute their own examples to the ColossalChat package, fostering a sense of community and making it easy for others to access and benefit from shared work. The primary goal with community-driven examples is to have a community-maintained collection of diverse and exotic functionalities built on top of the ColossalChat package, which is powered by the Colossal-AI project and its Coati module (ColossalAI Talking Intelligence). + +For more information about community pipelines, please have a look at this [issue](https://github.com/hpcaitech/ColossalAI/issues/3487). + +## Community Examples + +Community examples consist of both inference and training examples that have been added by the community. Please have a look at the following table to get an overview of all community examples. Click on the Code Example to get a copy-and-paste ready code example that you can try out. If a community doesn't work as expected, please open an issue and ping the author on it. + +| Example | Description | Code Example | Colab | Author | +|:---------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------:| +| Peft | Adding Peft support for SFT and Prompts model training | [Huggingface Peft](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples/community/peft) | - | [YY Lin](https://github.com/yynil) | +| Train prompts on Ray | A Ray based implementation of Train prompts example | [Huggingface Peft](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples/community/ray) | - | [MisterLin1995](https://github.com/MisterLin1995) | +|...|...|...|...|...| + +### How to get involved +To join our community-driven initiative, please visit the [ColossalChat GitHub repository](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples), review the provided information, and explore the codebase. To contribute, create a new issue outlining your proposed feature or enhancement, and our team will review and provide feedback. We look forward to collaborating with you on this exciting project! diff --git a/applications/Chat/examples/community/peft/README.md b/applications/Chat/examples/community/peft/README.md new file mode 100644 index 0000000000000000000000000000000000000000..eabb56fd8294ea89ad7632bd73ffece758e29e56 --- /dev/null +++ b/applications/Chat/examples/community/peft/README.md @@ -0,0 +1,24 @@ +# Add Peft support for SFT and Prompts model training + +The original implementation just adopts the loralib and merges the layers into the final model. The huggingface peft is a better lora model implementation and can be easily training and distributed. + +Since reward model is relative small, I just keep it as original one. I suggest train full model to get the proper reward/critic model. + +# Preliminary installation +Since the current pypi peft package(0.2) has some bugs, please install the peft package using source. +``` +git clone https://github.com/huggingface/peft +cd peft +pip install . +``` + +# Usage +For SFT training, just call train_peft_sft.py + +Its arguments are almost identical to train_sft.py instead adding a new eval_dataset if you have a eval_dataset file. The data file is just a plain datafile, please check the format in the easy_dataset.py. + +For stage-3 rlhf training, call train_peft_prompts.py. +Its arguments are almost idential to train_prompts.py. The only difference is that I use text files to indicate the prompt and pretrained data file. The models are included in easy_models.py. Currently only bloom models are tested, but technically gpt2/opt/llama should be supported. + +# Dataformat +Please refer the formats in test_sft.txt, test_prompts.txt, test_pretrained.txt. diff --git a/applications/Chat/examples/community/peft/easy_dataset.py b/applications/Chat/examples/community/peft/easy_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..24ea4f0a86186c4789062ab148c83c916ec86edc --- /dev/null +++ b/applications/Chat/examples/community/peft/easy_dataset.py @@ -0,0 +1,240 @@ +import copy +import json +from typing import Dict, Sequence + +import torch +from datasets import load_dataset +from torch.utils.data import Dataset +from tqdm import tqdm +from transformers import AutoTokenizer + +IGNORE_INDEX = -100 + + +def _tokenize_fn(strings: Sequence[str], tokenizer: AutoTokenizer, max_length: int = 512) -> Dict: + """Tokenize a list of strings.""" + tokenized_list = [ + tokenizer( + text, + return_tensors="pt", + padding="longest", + max_length=max_length, + truncation=True, + ) for text in strings + ] + input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] + input_ids_lens = labels_lens = [ + tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list + ] + return dict( + input_ids=input_ids, + labels=labels, + input_ids_lens=input_ids_lens, + labels_lens=labels_lens, + ) + + +def preprocess(sources: Sequence[str], targets: Sequence[str], tokenizer: AutoTokenizer, max_length: int = 512) -> Dict: + """Preprocess the data by tokenizing.""" + examples = [s + t for s, t in zip(sources, targets)] + examples_tokenized, sources_tokenized = [ + _tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources) + ] + input_ids = examples_tokenized["input_ids"] + labels = copy.deepcopy(input_ids) + for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]): + label[:source_len] = IGNORE_INDEX + return dict(input_ids=input_ids, labels=labels) + + +class EasySupervisedDataset(Dataset): + + def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 512) -> None: + super(EasySupervisedDataset, self).__init__() + with open(data_file, "r", encoding="UTF-8") as f: + all_lines = f.readlines() + #split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:" + sources, targets = [], [] + for line in all_lines: + if "回答:" in line: + sep_index = line.index("回答:") + sources.append(line[:sep_index + 3]) + targets.append(line[sep_index + 3:] + tokenizer.eos_token) + else: + sources.append(line) + targets.append("" + tokenizer.eos_token) + data_dict = preprocess(sources, targets, tokenizer, max_length) + + self.input_ids = data_dict["input_ids"] + self.labels = data_dict["labels"] + self.data_file = data_file + + def __len__(self): + return len(self.input_ids) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + return dict(input_ids=self.input_ids[i], labels=self.labels[i]) + + def __repr__(self): + return f"LawSupervisedDataset(data_file={self.data_file}, input_ids_len={len(self.input_ids)}, labels_len={len(self.labels)})" + + def __str__(self): + return f"LawSupervisedDataset(data_file={self.data_file}, input_ids_len={len(self.input_ids)}, labels_len={len(self.labels)})" + + +class EasyPromptsDataset(Dataset): + + def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 96) -> None: + super(EasyPromptsDataset, self).__init__() + with open(data_file, "r", encoding="UTF-8") as f: + all_lines = f.readlines() + all_lines = [line if "回答:" not in line else line[:line.index("回答:") + 3] for line in all_lines] + self.prompts = [ + tokenizer(line, return_tensors='pt', max_length=max_length, padding='max_length', + truncation=True)['input_ids'].to(torch.cuda.current_device()).squeeze(0) + for line in tqdm(all_lines) + ] + self.data_file = data_file + + def __len__(self): + return len(self.prompts) + + def __getitem__(self, idx): + return self.prompts[idx] + + def __repr__(self): + return f"LawPromptsDataset(data_file={self.data_file}, prompts_len={len(self.prompts)})" + + def __str__(self): + return f"LawPromptsDataset(data_file={self.data_file}, prompts_len={len(self.prompts)})" + + +class EasyRewardDataset(Dataset): + + def __init__(self, train_file: str, tokenizer: AutoTokenizer, special_token=None, max_length=512) -> None: + super(EasyRewardDataset, self).__init__() + self.chosen = [] + self.reject = [] + if special_token is None: + self.end_token = tokenizer.eos_token + else: + self.end_token = special_token + print(self.end_token) + #read all lines in the train_file to a list + with open(train_file, "r", encoding="UTF-8") as f: + all_lines = f.readlines() + for line in tqdm(all_lines): + data = json.loads(line) + prompt = "提问:" + data['prompt'] + " 回答:" + + chosen = prompt + data['chosen'] + self.end_token + chosen_token = tokenizer(chosen, + max_length=max_length, + padding="max_length", + truncation=True, + return_tensors="pt") + self.chosen.append({ + "input_ids": chosen_token['input_ids'], + "attention_mask": chosen_token['attention_mask'] + }) + + reject = prompt + data['rejected'] + self.end_token + reject_token = tokenizer(reject, + max_length=max_length, + padding="max_length", + truncation=True, + return_tensors="pt") + self.reject.append({ + "input_ids": reject_token['input_ids'], + "attention_mask": reject_token['attention_mask'] + }) + + def __len__(self): + length = len(self.chosen) + return length + + def __getitem__(self, idx): + return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][ + "input_ids"], self.reject[idx]["attention_mask"] + + #python representation of the object and the string representation of the object + def __repr__(self): + return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})" + + def __str__(self): + return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})" + + +''' +Easy SFT just accept a text file which can be read line by line. However the datasets will group texts together to max_length so LLM will learn the texts meaning better. +If individual lines are not related, just set is_group_texts to False. +''' + + +class EasySFTDataset(Dataset): + + def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length=512, is_group_texts=True) -> None: + super().__init__() + #read the data_file line by line + with open(data_file, "r", encoding="UTF-8") as f: + #encode the text data line by line and put raw python list input_ids only to raw_input_ids list + raw_input_ids = [] + for line in f: + encoded_ids = tokenizer.encode(line) + #if the encoded_ids is longer than max_length, then split it into several parts + if len(encoded_ids) > max_length: + for i in range(0, len(encoded_ids), max_length): + raw_input_ids.append(encoded_ids[i:i + max_length]) + else: + raw_input_ids.append(encoded_ids) + + grouped_inpup_ids = [] + current_input_ids = [] + attention_mask = [] + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + if is_group_texts: + for input_ids in raw_input_ids: + if len(current_input_ids) + len(input_ids) > max_length: + #pad the current_input_ids to max_length with tokenizer.pad_token_id + padded_length = max_length - len(current_input_ids) + current_input_ids.extend([tokenizer.pad_token_id] * padded_length) + grouped_inpup_ids.append(torch.tensor(current_input_ids, dtype=torch.long)) + attention_mask.append( + torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)) + current_input_ids = [] + else: + current_input_ids.extend(input_ids) + if len(current_input_ids) > 0: + padded_length = max_length - len(current_input_ids) + current_input_ids.extend([tokenizer.pad_token_id] * padded_length) + grouped_inpup_ids.append(torch.tensor(current_input_ids, dtype=torch.long)) + attention_mask.append( + torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)) + else: + #just append the raw_input_ids to max_length + for input_ids in raw_input_ids: + padded_length = max_length - len(input_ids) + input_ids.extend([tokenizer.pad_token_id] * padded_length) + attention_mask.append( + torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)) + grouped_inpup_ids.append(torch.tensor(input_ids, dtype=torch.long)) + self.input_ids = grouped_inpup_ids + self.labels = copy.deepcopy(self.input_ids) + self.file_name = data_file + self.attention_mask = attention_mask + + def __len__(self): + return len(self.input_ids) + + #get item from dataset + def __getitem__(self, idx): + return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx]) + + #generate the dataset description to be printed by print in python + def __repr__(self): + return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})" + + #generate the dataset description to be printed by print in python + def __str__(self): + return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})" diff --git a/applications/Chat/examples/community/peft/easy_models.py b/applications/Chat/examples/community/peft/easy_models.py new file mode 100644 index 0000000000000000000000000000000000000000..fe294868159dde227cae9757da41ee71b5778a25 --- /dev/null +++ b/applications/Chat/examples/community/peft/easy_models.py @@ -0,0 +1,96 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from coati.models.generation import generate +from coati.models.utils import log_probs_from_logits, masked_mean +from peft import PeftModel +from torch.nn.modules import Module +from transformers import BloomConfig, BloomForCausalLM + + +class Actor(Module): + """ + Actor model base class. + + Args: + model (nn.Module): Actor Model. + """ + + def __init__(self, model: nn.Module) -> None: + super().__init__() + self.model = model + + @torch.no_grad() + def generate( + self, + input_ids: torch.Tensor, + return_action_mask: bool = True, + **kwargs + ) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]: + sequences = generate(self.model, input_ids, **kwargs) + attention_mask = None + pad_token_id = kwargs.get('pad_token_id', None) + if pad_token_id is not None: + attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device) + if not return_action_mask: + return sequences, attention_mask, None + input_len = input_ids.size(1) + eos_token_id = kwargs.get('eos_token_id', None) + if eos_token_id is None: + action_mask = torch.ones_like(sequences, dtype=torch.bool) + else: + # left padding may be applied, only mask action + action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0 + action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input + action_mask[:, :input_len] = False + action_mask = action_mask[:, 1:] + return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):] + + def forward(self, + sequences: torch.LongTensor, + num_actions: int, + attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """Returns action log probs + """ + output = self.model(sequences, attention_mask=attention_mask) + logits = output['logits'] + log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) + return log_probs[:, -num_actions:] + + def get_base_model(self): + return self.model + + +class BLOOMActor(Actor): + """ + BLOOM Actor model. + + Args: + pretrained (str): Pretrained model name or path. + config (BloomConfig): Model config. + checkpoint (bool): Enable gradient checkpointing. + lora_rank (int): LoRA rank. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__(self, + pretrained: str = None, + config: Optional[BloomConfig] = None, + checkpoint: bool = False, + lora_path: str = None) -> None: + if pretrained is not None: + model = BloomForCausalLM.from_pretrained(pretrained) + elif config is not None: + model = BloomForCausalLM(config) + else: + model = BloomForCausalLM(BloomConfig()) + if lora_path is not None: + model = PeftModel.from_pretrained(model, lora_path) + if checkpoint: + model.gradient_checkpointing_enable() + super().__init__(model) + + def print_trainable_parameters(self): + self.get_base_model().print_trainable_parameters() diff --git a/applications/Chat/examples/community/peft/train_peft_prompts.py b/applications/Chat/examples/community/peft/train_peft_prompts.py new file mode 100644 index 0000000000000000000000000000000000000000..0e277021e917a7da8ce5c3df18b0165e22b9a0e2 --- /dev/null +++ b/applications/Chat/examples/community/peft/train_peft_prompts.py @@ -0,0 +1,228 @@ +import argparse + +import pandas as pd +import torch +import torch.distributed as dist +from coati.dataset import DataCollatorForSupervisedDataset, PromptDataset, SupervisedDataset +from coati.models.bloom import BLOOMRM, BLOOMCritic +from coati.models.gpt import GPTRM, GPTActor, GPTCritic +from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM +from coati.models.opt import OPTRM, OPTActor, OPTCritic +from coati.trainer import PPOTrainer +from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy +from coati.utils import prepare_llama_tokenizer_and_embedding +from easy_dataset import EasyPromptsDataset, EasySupervisedDataset +from easy_models import BLOOMActor +from peft import PeftModel +from torch.optim import Adam +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer + +from colossalai.nn.optimizer import HybridAdam + + +def main(args): + # configure strategy + if args.strategy == 'naive': + strategy = NaiveStrategy() + elif args.strategy == 'ddp': + strategy = DDPStrategy() + elif args.strategy == 'colossalai_gemini': + strategy = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5) + elif args.strategy == 'colossalai_zero2': + strategy = ColossalAIStrategy(stage=2, placement_policy='cpu') + else: + raise ValueError(f'Unsupported strategy "{args.strategy}"') + + if args.rm_path is not None: + state_dict = torch.load(args.rm_path, map_location='cpu') + + # configure model + if args.model == 'bloom': + # initial_model = BLOOMActor(pretrained=args.pretrain) + print('Using peft lora to load Bloom model as inital_model') + initial_model = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path) + print('Using peft lora to load Bloom model as initial_model (Done)') + else: + raise ValueError(f'Unsupported actor model "{args.model}"') + + if args.rm_model == None: + rm_model_name = args.model + else: + rm_model_name = args.rm_model + + if rm_model_name == 'gpt2': + reward_model = GPTRM(pretrained=args.rm_pretrain) + elif rm_model_name == 'bloom': + print("load bloom reward model ", args.rm_pretrain) + reward_model = BLOOMRM(pretrained=args.rm_pretrain) + elif rm_model_name == 'opt': + reward_model = OPTRM(pretrained=args.rm_pretrain) + elif rm_model_name == 'llama': + reward_model = LlamaRM(pretrained=args.rm_pretrain) + else: + raise ValueError(f'Unsupported reward model "{rm_model_name}"') + + if args.rm_path is not None: + print('Loading reward model from', args.rm_path) + reward_model.load_state_dict(state_dict) + + if args.strategy != 'colossalai_gemini': + initial_model.to(torch.float16).to(torch.cuda.current_device()) + reward_model.to(torch.float16).to(torch.cuda.current_device()) + + with strategy.model_init_context(): + if args.model == 'bloom': + # actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank) + print('Using peft lora to load Bloom model as Actor') + actor = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path) + print('Using peft lora to load Bloom model as Actor (Done)') + else: + raise ValueError(f'Unsupported actor model "{args.model}"') + + if rm_model_name == 'gpt2': + critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) + elif rm_model_name == 'bloom': + print("load bloom critic ", args.rm_pretrain, " lora_rank ", args.lora_rank, " use_action_mask ", True) + critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) + print("load bloom critic (Done) ") + elif rm_model_name == 'opt': + critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) + elif rm_model_name == 'llama': + critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) + else: + raise ValueError(f'Unsupported reward model "{rm_model_name}"') + + if args.rm_path is not None: + print('Loading reward model from', args.rm_path) + critic.load_state_dict(state_dict) + del state_dict + + if args.strategy != 'colossalai_gemini': + critic.to(torch.float16).to(torch.cuda.current_device()) + actor.to(torch.float16).to(torch.cuda.current_device()) + + # configure optimizer + if args.strategy.startswith('colossalai'): + actor_optim = HybridAdam(actor.parameters(), lr=1e-7) + critic_optim = HybridAdam(critic.parameters(), lr=1e-7) + else: + actor_optim = Adam(actor.parameters(), lr=1e-7) + critic_optim = Adam(critic.parameters(), lr=1e-7) + + # configure tokenizer + if args.model == 'gpt2': + tokenizer = GPT2Tokenizer.from_pretrained(args.rm_pretrain) + elif args.model == 'bloom': + tokenizer = BloomTokenizerFast.from_pretrained(args.rm_pretrain) + elif args.model == 'opt': + tokenizer = AutoTokenizer.from_pretrained(args.rm_pretrain) + elif args.model == 'llama': + tokenizer = LlamaTokenizer.from_pretrained(args.pretrain) + tokenizer.eos_token = '<\s>' + else: + raise ValueError(f'Unsupported model "{args.model}"') + + if args.model == 'llama': + tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, actor) + else: + tokenizer.pad_token = tokenizer.eos_token + + data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) + + prompt_dataset = EasyPromptsDataset(args.prompt_path, tokenizer) + if dist.is_initialized() and dist.get_world_size() > 1: + prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True) + else: + prompt_sampler = None + prompt_dataloader = DataLoader(prompt_dataset, + shuffle=(prompt_sampler is None), + sampler=prompt_sampler, + batch_size=args.train_batch_size) + + pretrain_dataset = EasySupervisedDataset(args.pretrain_dataset, tokenizer) + if dist.is_initialized() and dist.get_world_size() > 1: + pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True) + else: + pretrain_sampler = None + pretrain_dataloader = DataLoader(pretrain_dataset, + shuffle=(pretrain_sampler is None), + sampler=pretrain_sampler, + batch_size=args.ptx_batch_size, + collate_fn=data_collator) + + def tokenize_fn(texts): + # MUST padding to max length to ensure inputs of all ranks have the same length + # Different length may lead to hang when using gemini, as different generation steps + batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True) + return {k: v.to(torch.cuda.current_device()) for k, v in batch.items()} + + (actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim)) + + # configure trainer + trainer = PPOTrainer( + strategy, + actor, + critic, + reward_model, + initial_model, + actor_optim, + critic_optim, + kl_coef=args.kl_coef, + ptx_coef=args.ptx_coef, + max_epochs=args.max_epochs, + train_batch_size=args.train_batch_size, + experience_batch_size=args.experience_batch_size, + tokenizer=tokenize_fn, + max_length=512, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + ) + + trainer.fit(prompt_dataloader=prompt_dataloader, + pretrain_dataloader=pretrain_dataloader, + num_episodes=args.num_episodes, + max_timesteps=args.max_timesteps, + update_timesteps=args.update_timesteps) + + # save model checkpoint after fitting + trainer.save_model(args.save_path, only_rank0=True, tokenizer=tokenizer) + # save optimizer checkpoint on all ranks + if args.need_optim_ckpt: + strategy.save_optimizer(actor_optim, + 'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), + only_rank0=False) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--prompt_path', type=str, default=None, help='path to the prompt dataset') + parser.add_argument('--pretrain_dataset', type=str, default=None, help='path to the pretrained dataset') + parser.add_argument('--strategy', + choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], + default='naive', + help='strategy to use') + parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) + parser.add_argument('--pretrain', type=str, default=None) + parser.add_argument('--sft_lora_path', type=str, default=None) + parser.add_argument('--rm_model', default=None, choices=['gpt2', 'bloom', 'opt', 'llama']) + parser.add_argument('--rm_path', type=str, default=None) + parser.add_argument('--rm_pretrain', type=str, default=None) + parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts') + parser.add_argument('--need_optim_ckpt', type=bool, default=False) + parser.add_argument('--num_episodes', type=int, default=10) + parser.add_argument('--max_timesteps', type=int, default=10) + parser.add_argument('--update_timesteps', type=int, default=10) + parser.add_argument('--max_epochs', type=int, default=5) + parser.add_argument('--train_batch_size', type=int, default=2) + parser.add_argument('--ptx_batch_size', type=int, default=1) + parser.add_argument('--experience_batch_size', type=int, default=8) + parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") + parser.add_argument('--kl_coef', type=float, default=0.1) + parser.add_argument('--ptx_coef', type=float, default=0.9) + args = parser.parse_args() + main(args) diff --git a/applications/Chat/examples/community/peft/train_peft_sft.py b/applications/Chat/examples/community/peft/train_peft_sft.py new file mode 100644 index 0000000000000000000000000000000000000000..9bd0ebc12a836d6c699c90ce30538a74c65858f6 --- /dev/null +++ b/applications/Chat/examples/community/peft/train_peft_sft.py @@ -0,0 +1,190 @@ +import argparse +import os + +import loralib as lora +import torch +import torch.distributed as dist +from coati.dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset +from coati.models.base import RewardModel +from coati.models.bloom import BLOOMLM +from coati.models.gpt import GPTLM +from coati.models.llama import LlamaLM +from coati.models.opt import OPTLM +from coati.trainer import SFTTrainer +from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy +from coati.utils import prepare_llama_tokenizer_and_embedding +from datasets import load_dataset +from easy_dataset import EasyDataset +from peft import LoraConfig, PeftModel, TaskType, get_peft_model +from torch.optim import Adam +from torch.utils.data import DataLoader +from torch.utils.data.dataloader import default_collate +from torch.utils.data.distributed import DistributedSampler +from transformers import AutoModelForCausalLM, AutoTokenizer, BloomTokenizerFast +from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer + +from colossalai.logging import get_dist_logger +from colossalai.nn.optimizer import HybridAdam +from colossalai.tensor import ColoParameter + + +def train(args): + # configure strategy + if args.strategy == 'naive': + strategy = NaiveStrategy() + elif args.strategy == 'ddp': + strategy = DDPStrategy() + elif args.strategy == 'colossalai_gemini': + strategy = ColossalAIStrategy(stage=3, placement_policy='cuda') + elif args.strategy == 'colossalai_zero2': + strategy = ColossalAIStrategy(stage=2, placement_policy='cuda') + else: + raise ValueError(f'Unsupported strategy "{args.strategy}"') + + # configure model + with strategy.model_init_context(): + print('Warning: currently only bloom is tested, gpt2,llama and opt are not tested') + model = AutoModelForCausalLM.from_pretrained(args.pretrain).to(torch.cuda.current_device()) + #if the args.save_path exists and args.save_path+'/adapter_config.json' exists, we'll load the adapter_config.json + if os.path.exists(args.save_path) and os.path.exists(args.save_path+'/adapter_config.json') \ + and os.path.exists(args.save_path+'/adapter_model.bin'): + print("loading from saved peft model ", args.save_path) + model = PeftModel.from_pretrained(model, args.save_path) + else: + #we'll use peft lora library to do the lora + lora_rank = args.lora_rank if args.lora_rank > 0 else 32 + #config lora with rank of lora_rank + lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM, + inference_mode=False, + r=lora_rank, + lora_alpha=32, + lora_dropout=0.1) + model = get_peft_model(model, lora_config) + model.print_trainable_parameters() + + # configure tokenizer + if args.model == 'gpt2': + tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + tokenizer.pad_token = tokenizer.eos_token + elif args.model == 'bloom': + tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain) + tokenizer.pad_token = tokenizer.eos_token + elif args.model == 'opt': + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + elif args.model == 'llama': + tokenizer = AutoTokenizer.from_pretrained( + args.pretrain, + padding_side="right", + use_fast=False, + ) + tokenizer.eos_token = '<\s>' + else: + raise ValueError(f'Unsupported model "{args.model}"') + tokenizer.pad_token = tokenizer.eos_token + if args.model == 'llama': + tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, model) + + if args.strategy == 'colossalai_gemini': + # this is a hack to deal with the resized embedding + # to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatiblity + for name, param in model.named_parameters(): + if not isinstance(param, ColoParameter): + sub_module_name = '.'.join(name.split('.')[:-1]) + weight_name = name.split('.')[-1] + sub_module = model.get_submodule(sub_module_name) + setattr(sub_module, weight_name, ColoParameter(param)) + else: + tokenizer.pad_token = tokenizer.eos_token + + # configure optimizer + if args.strategy.startswith('colossalai'): + optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0) + else: + optim = Adam(model.parameters(), lr=args.lr) + + logger = get_dist_logger() + logger.set_level('WARNING') + + # configure dataset + law_dataset = EasyDataset(args.dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text) + train_dataset = law_dataset + print(train_dataset) + eval_dataset = None + if args.eval_dataset is not None: + eval_dataset = EasyDataset(args.eval_dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text) + data_collator = default_collate + if dist.is_initialized() and dist.get_world_size() > 1: + train_sampler = DistributedSampler(train_dataset, + shuffle=True, + seed=42, + drop_last=True, + rank=dist.get_rank(), + num_replicas=dist.get_world_size()) + if eval_dataset is not None: + eval_sampler = DistributedSampler(eval_dataset, + shuffle=False, + seed=42, + drop_last=False, + rank=dist.get_rank(), + num_replicas=dist.get_world_size()) + else: + train_sampler = None + eval_sampler = None + + train_dataloader = DataLoader(train_dataset, + shuffle=(train_sampler is None), + sampler=train_sampler, + batch_size=args.batch_size, + collate_fn=data_collator, + pin_memory=True) + if eval_dataset is not None: + eval_dataloader = DataLoader(eval_dataset, + shuffle=(eval_sampler is None), + sampler=eval_sampler, + batch_size=args.batch_size, + collate_fn=data_collator, + pin_memory=True) + else: + eval_dataloader = None + + trainer = SFTTrainer(model=model, + strategy=strategy, + optim=optim, + train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, + batch_size=args.batch_size, + max_epochs=args.max_epochs, + accumulation_steps=args.accumulation_steps) + + trainer.fit(logger=logger, log_interval=args.log_interval) + + # save model checkpoint after fitting on only rank0 + trainer.save_model(path=args.save_path, only_rank0=True, tokenizer=tokenizer) + # save optimizer checkpoint on all ranks + if args.need_optim_ckpt: + strategy.save_optimizer(trainer.optimizer, + 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), + only_rank0=False) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--strategy', + choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], + default='naive') + parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom') + parser.add_argument('--pretrain', type=str, default=None) + parser.add_argument('--dataset', type=str, default=None) + parser.add_argument('--eval_dataset', type=str, default=None) + parser.add_argument('--save_path', type=str, default='output') + parser.add_argument('--need_optim_ckpt', type=bool, default=False) + parser.add_argument('--max_epochs', type=int, default=3) + parser.add_argument('--batch_size', type=int, default=4) + parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") + parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log") + parser.add_argument('--lr', type=float, default=5e-6) + parser.add_argument('--accumulation_steps', type=int, default=8) + parser.add_argument('--enable_peft_lora', action='store_true', default=False) + parser.add_argument("--is_short_text", action='store_true', default=False) + args = parser.parse_args() + train(args) diff --git a/applications/Chat/examples/community/ray/README.md b/applications/Chat/examples/community/ray/README.md new file mode 100644 index 0000000000000000000000000000000000000000..64360bd73ddc8d5627594332f85a7f701c26e1b9 --- /dev/null +++ b/applications/Chat/examples/community/ray/README.md @@ -0,0 +1,17 @@ +# ColossalAI on Ray +## Abstract +This is an experimental effort to run ColossalAI Chat training on Ray +## How to use? +### 1. Setup Ray clusters +Please follow the official [Ray cluster setup instructions](https://docs.ray.io/en/latest/cluster/getting-started.html) to setup an cluster with GPU support. Record the cluster's api server endpoint, it should be something similar to http://your.head.node.addrees:8265 +### 2. Clone repo +Clone this project: +```shell +git clone https://github.com/hpcaitech/ColossalAI.git +``` +### 3. Submit the ray job +```shell +python applications/Chat/examples/community/ray/ray_job_script.py http://your.head.node.addrees:8265 +``` +### 4. View your job on the Ray Dashboard +Open your ray cluster dashboard http://your.head.node.addrees:8265 to view your submitted training job. diff --git a/applications/Chat/examples/community/ray/ray_job_script.py b/applications/Chat/examples/community/ray/ray_job_script.py new file mode 100644 index 0000000000000000000000000000000000000000..53f304d379fec54d82d3775552863e73f8dfcbc4 --- /dev/null +++ b/applications/Chat/examples/community/ray/ray_job_script.py @@ -0,0 +1,22 @@ +import sys + +from ray.job_submission import JobSubmissionClient + + +def main(api_server_endpoint="http://127.0.0.1:8265"): + client = JobSubmissionClient(api_server_endpoint) + client.submit_job( + entrypoint= + "python experimental/ray/train_prompts_on_ray.py --strategy colossalai_zero2 --prompt_csv_url https://huggingface.co/datasets/fka/awesome-chatgpt-prompts/resolve/main/prompts.csv", + runtime_env={ + "working_dir": + "applications/Chat", + "pip": [ + "torch==1.13.1", "transformers>=4.20.1", "datasets", "loralib", "colossalai>=0.2.4", "langchain", + "tokenizers", "fastapi", "sse_starlette", "wandb", "sentencepiece", "gpustat" + ] + }) + + +if __name__ == "__main__": + main(sys.argv[1]) diff --git a/applications/Chat/examples/community/ray/train_prompts_on_ray.py b/applications/Chat/examples/community/ray/train_prompts_on_ray.py new file mode 100644 index 0000000000000000000000000000000000000000..289330ad841516a8bbc17ce80cf022d21f30643a --- /dev/null +++ b/applications/Chat/examples/community/ray/train_prompts_on_ray.py @@ -0,0 +1,555 @@ +import argparse +import logging +import os +import socket +from copy import deepcopy +from typing import Type + +import ray +import torch +from coati.experience_maker.base import Experience +from coati.models.base import RewardModel +from coati.models.bloom import BLOOMActor, BLOOMCritic +from coati.models.gpt import GPTActor, GPTCritic +from coati.models.lora import LoRAModule +from coati.models.loss import PolicyLoss, ValueLoss +from coati.models.opt import OPTActor, OPTCritic +from coati.models.utils import compute_reward +from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy +from ray.util.placement_group import placement_group +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +from torch.optim import Adam +from transformers import AutoTokenizer, BloomTokenizerFast +from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer + +from colossalai.nn.optimizer import HybridAdam + + +class ExperienceCompositionRefs: + + def __init__(self, sequences_attention_mask_action_mask_ref: ray.ObjectRef, action_log_probs_ref: ray.ObjectRef, + base_action_log_probs_ref: ray.ObjectRef, value_ref: ray.ObjectRef, r_ref: ray.ObjectRef) -> None: + self.sequences_attention_mask_action_mask_ref = sequences_attention_mask_action_mask_ref + self.action_log_probs_ref = action_log_probs_ref + self.base_action_log_probs_ref = base_action_log_probs_ref + self.value_ref = value_ref + self.r_ref = r_ref + + +class ExperienceMaker: + + def __init__(self, kl_coef) -> None: + self.kl_coef = kl_coef + + @torch.no_grad() + def make_experience(self, experiment_computation_refs: ExperienceCompositionRefs): + sequences, attention_mask, action_mask = ray.get( + experiment_computation_refs.sequences_attention_mask_action_mask_ref) + action_log_probs = ray.get(experiment_computation_refs.action_log_probs_ref) + base_action_log_probs = ray.get(experiment_computation_refs.base_action_log_probs_ref) + r = ray.get(experiment_computation_refs.r_ref) + reward = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask) + value = ray.get(experiment_computation_refs.value_ref) + advantage = reward - value + if advantage.ndim == 1: + advantage = advantage.unsqueeze(-1) + experience = Experience(sequences, action_log_probs, value, reward, advantage, attention_mask, action_mask) + return experience + + +class DistributedTorchRayActor: + + def __init__(self, world_size, rank, local_rank, master_addr, master_port): + logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s', + level=logging.INFO, + datefmt='%Y-%m-%d %H:%M:%S') + self._model = None + self._world_size = world_size + self._rank = rank + self._local_rank = local_rank + self._master_addr = master_addr if master_addr else self._get_current_node_ip() + self._master_port = master_port if master_port else self._get_free_port() + os.environ["MASTER_ADDR"] = self._master_addr + os.environ["MASTER_PORT"] = str(self._master_port) + os.environ["WORLD_SIZE"] = str(self._world_size) + os.environ["RANK"] = str(self._rank) + os.environ["LOCAL_RANK"] = str(self._local_rank) + + @staticmethod + def _get_current_node_ip(): + return ray._private.services.get_node_ip_address() + + @staticmethod + def _get_free_port(): + with socket.socket() as sock: + sock.bind(('', 0)) + return sock.getsockname()[1] + + def get_master_addr_port(self): + return self._master_addr, self._master_port + + +class BasePPORole(DistributedTorchRayActor): + + def add_experience_maker(self, kl_coef: float = 0.1): + self._experience_maker = ExperienceMaker(kl_coef) + + def make_experience(self, experience_computation_ref: ExperienceCompositionRefs): + return self._experience_maker.make_experience(experience_computation_ref) + + def _init_strategy(self, strategy: str): + # configure strategy + if strategy == 'naive': + self._strategy = NaiveStrategy() + elif strategy == 'ddp': + self._strategy = DDPStrategy() + elif strategy == 'colossalai_gemini': + self._strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5) + elif strategy == 'colossalai_zero2': + self._strategy = ColossalAIStrategy(stage=2, placement_policy='cuda') + else: + raise ValueError(f'Unsupported strategy "{strategy}"') + + def _init_optimizer(self): + if isinstance(self._strategy, ColossalAIStrategy): + self._optimizer = HybridAdam(self._model.parameters(), lr=5e-6) + else: + self._optimizer = Adam(self._model.parameters(), lr=5e-6) + + def _prepare_model_with_strategy(self, has_optimizer: bool): + if has_optimizer: + self._init_optimizer() + (self._model, self._optimizer) = self._strategy.prepare((self._model, self._optimizer)) + else: + self._model = self._strategy.prepare(self._model) + + def _load_model_from_pretrained(self, model_class: Type[LoRAModule], pretrain: str): + raise NotImplementedError() + + def init_model_from_pretrained(self, + strategy: str, + model_class: Type[LoRAModule], + pretrain: str, + has_optimizer=False): + self._init_strategy(strategy) + self._load_model_from_pretrained(model_class, pretrain) + self._prepare_model_with_strategy(has_optimizer) + + def eval(self): + self._model.eval() + + +class TrainablePPORole(BasePPORole): + + def _load_model_from_pretrained(self, model_class, pretrain): + with self._strategy.model_init_context(): + self._model = model_class(pretrain).to(torch.cuda.current_device()) + + def _train(self): + self._model.train() + + def _training_step(self, experience: Experience): + raise NotImplementedError() + + def learn_on_experiences(self, experience_refs): + experiences = ray.get(experience_refs) + device = torch.cuda.current_device() + self._train() + for exp in experiences: + exp.to_device(device) + self._training_step(exp) + self.eval() + + +@ray.remote(num_gpus=1) +class RayPPOActor(TrainablePPORole): + + def set_loss_function(self, eps_clip: float): + self._actor_loss_fn = PolicyLoss(eps_clip) + + def load_tokenizer_from_pretrained(self, model_type: str, pretrained): + if model_type == 'gpt2': + self._model_tokenizer = GPT2Tokenizer.from_pretrained(pretrained) + self._model_tokenizer.pad_token = self._model_tokenizer.eos_token + elif model_type == 'bloom': + self._model_tokenizer = BloomTokenizerFast.from_pretrained(pretrained) + self._model_tokenizer.pad_token = self._model_tokenizer.eos_token + elif model_type == 'opt': + self._model_tokenizer = AutoTokenizer.from_pretrained(pretrained) + else: + raise ValueError(f'Unsupported model "{model_type}"') + + # Set tokenize function for sequence generation + def _text_input_tokenize_fn(texts): + batch = self._model_tokenizer(texts, return_tensors='pt', max_length=96, padding=True, truncation=True) + return {k: v.cuda() for k, v in batch.items()} + + self._sample_tokenize_function = _text_input_tokenize_fn + + def setup_generate_kwargs(self, generate_kwargs: dict): + from coati.trainer.ppo import _set_default_generate_kwargs + self._generate_kwargs = _set_default_generate_kwargs(self._strategy, generate_kwargs, self._model) + self._generate_kwargs['pad_token_id'] = self._model_tokenizer.pad_token_id + self._generate_kwargs['eos_token_id'] = self._model_tokenizer.eos_token_id + + def load_csv_prompt_file_from_url_to_sampler(self, prompt_url): + import pandas as pd + prompts = pd.read_csv(prompt_url)['prompt'] + self._sampler = self._strategy.setup_sampler(prompts) + + def _generate(self, input_ids, **generate_kwargs): + return self._model.generate(input_ids, return_action_mask=True, **generate_kwargs) + + def sample_prompts_and_make_sequence(self, experience_batch_size): + sampled_prompts = self._sampler.sample(experience_batch_size) + input_ids = self._sample_tokenize_function(sampled_prompts) + if isinstance(input_ids, dict): + return self._generate(**input_ids, **self._generate_kwargs) + else: + return self._generate(input_ids, **self._generate_kwargs) + + @torch.no_grad() + def calculate_action_log_probs(self, sequence_attention_action_mask): + sequences, attention_mask, action_mask = sequence_attention_action_mask + return self._model.forward(sequences, action_mask.size(1), attention_mask) + + def _training_step(self, experience): + num_actions = experience.action_mask.size(1) + action_log_probs = self._model(experience.sequences, num_actions, attention_mask=experience.attention_mask) + actor_loss = self._actor_loss_fn(action_log_probs, + experience.action_log_probs, + experience.advantages, + action_mask=experience.action_mask) + self._strategy.backward(actor_loss, self._model, self._optimizer) + self._strategy.optimizer_step(self._optimizer) + self._optimizer.zero_grad() + logging.info("actor_loss: {}".format(actor_loss)) + + def save_checkpoint(self, save_path, should_save_optimizer: bool): + if self._rank == 0: + # save model checkpoint only on rank 0 + self._strategy.save_model(self._model, save_path, only_rank0=True) + # save optimizer checkpoint on all ranks + if should_save_optimizer: + self._strategy.save_optimizer(self._optimizer, + 'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), + only_rank0=False) + + def generate_answer(self, prompt, max_length=30, num_return_sequences=5): + encoded_input = self._model_tokenizer(prompt, return_tensors='pt') + input_ids = {k: v.cuda() for k, v in encoded_input.items()} + sequence, _ = self._model.generate(**input_ids, + max_length=max_length, + return_action_mask=False, + num_return_sequences=num_return_sequences) + token_list = list(sequence.data[0]) + output = " ".join([self._model_tokenizer.decode(token) for token in token_list]) + return output + + +@ray.remote(num_gpus=1) +class RayPPOCritic(TrainablePPORole): + + def set_loss_function(self, value_clip: float): + self._critic_loss_fn = ValueLoss(value_clip) + + def _training_step(self, experience): + values = self._model(experience.sequences, + action_mask=experience.action_mask, + attention_mask=experience.attention_mask) + critic_loss = self._critic_loss_fn(values, + experience.values, + experience.reward, + action_mask=experience.action_mask) + self._strategy.backward(critic_loss, self._model, self._optimizer) + self._strategy.optimizer_step(self._optimizer) + self._optimizer.zero_grad() + logging.info("critic_loss: {}".format(critic_loss)) + + @torch.no_grad() + def calculate_value(self, sequence_attention_action_mask): + sequences, attention_mask, action_mask = sequence_attention_action_mask + return self._model(sequences, action_mask, attention_mask) + + +@ray.remote(num_gpus=1) +class RayPPORewardModel(BasePPORole): + + def _load_model_from_pretrained(self, model_class, pretrain): + with self._strategy.model_init_context(): + critic = model_class(pretrained=pretrain).to(torch.cuda.current_device()) + self._model = RewardModel(deepcopy(critic.model), + deepcopy(critic.value_head)).to(torch.cuda.current_device()) + + @torch.no_grad() + def calculate_r(self, sequence_attention_action_mask): + sequences, attention_mask, _ = sequence_attention_action_mask + return self._model(sequences, attention_mask) + + +@ray.remote(num_gpus=1) +class RayPPOInitialModel(BasePPORole): + + def _load_model_from_pretrained(self, model_class, pretrain): + with self._strategy.model_init_context(): + self._model = model_class(pretrain).to(torch.cuda.current_device()) + + @torch.no_grad() + def calculate_base_action_log_probs(self, sequence_attention_action_mask): + sequences, attention_mask, action_mask = sequence_attention_action_mask + return self._model(sequences, action_mask.size(1), attention_mask) + + +class PPORayActorGroup: + """ + A group of ray actors + Functions start with 'async' should return list of object refs + """ + + def __init__(self, num_nodes, num_gpus_per_node, ray_actor_type: Type[BasePPORole]) -> None: + self._num_nodes = num_nodes + self._num_gpus_per_node = num_gpus_per_node + self.ray_actor_type = ray_actor_type + self._initiate_actors() + + def _initiate_actors(self): + world_size = self._num_nodes * self._num_gpus_per_node + # Use placement group to lock resources for models of same type + pg = None + if self._num_gpus_per_node > 1: + bundles = [{"GPU": self._num_gpus_per_node, "CPU": self._num_gpus_per_node} for _ in range(self._num_nodes)] + pg = placement_group(bundles, strategy="STRICT_SPREAD") + ray.get(pg.ready()) + if pg: + master_actor = self.ray_actor_type.options(scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, placement_group_bundle_index=0)).remote(world_size, 0, 0, None, None) + else: + master_actor = self.ray_actor_type.options(num_gpus=1).remote(world_size, 0, 0, None, None) + self._actor_handlers = [master_actor] + + # Create worker actors + if world_size > 1: + master_addr, master_port = ray.get(master_actor.get_master_addr_port.remote()) + for rank in range(1, world_size): + local_rank = rank % self._num_gpus_per_node + if pg: + worker_actor = self.ray_actor_type.options(scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, placement_group_bundle_index=rank // self._num_gpus_per_node)).remote( + world_size, rank, local_rank, master_addr, master_port) + else: + worker_actor = self.ray_actor_type.options(num_gpus=1).remote(world_size, rank, local_rank, + master_addr, master_port) + self._actor_handlers.append(worker_actor) + + def async_init_model_from_pretrained(self, strategy: str, model_class: Type[LoRAModule], pretrain: str, + has_optimizer: bool): + return [ + actor.init_model_from_pretrained.remote(strategy, model_class, pretrain, has_optimizer) + for actor in self._actor_handlers + ] + + +class TrainableModelRayActorGroup(PPORayActorGroup): + + def async_learn_on_experiences(self, experience_refs): + num_actors = len(self._actor_handlers) + learn_result_refs = [] + for i in range(num_actors): + exp_refs_batch = experience_refs[i::num_actors] + learn_result_refs.append(self._actor_handlers[i].learn_on_experiences.remote(exp_refs_batch)) + return learn_result_refs + + +class PPOActorRayActorGroup(TrainableModelRayActorGroup): + + def __init__(self, num_nodes, num_gpus_per_node) -> None: + super().__init__(num_nodes, num_gpus_per_node, RayPPOActor) + + def async_prepare_for_sequence_generation(self, model: str, pretrain: str, generation_kwargs: dict): + refs = [] + for actor in self._actor_handlers: + refs.append(actor.load_tokenizer_from_pretrained.remote(model, pretrain)) + refs.append(actor.setup_generate_kwargs.remote(generation_kwargs)) + return refs + + def load_csv_prompt_file_from_url_to_sampler(self, csv_url): + ray.get([actor.load_csv_prompt_file_from_url_to_sampler.remote(csv_url) for actor in self._actor_handlers]) + + def async_sample_prompts_and_make_sequence(self, experience_batch_size): + return [actor.sample_prompts_and_make_sequence.remote(experience_batch_size) for actor in self._actor_handlers] + + def async_calculate_action_log_probs(self, sequences_attention_mask_action_mask_refs): + num_actors = len(self._actor_handlers) + action_log_probs_refs = [] + for i in range(len(sequences_attention_mask_action_mask_refs)): + action_log_probs_ref = self._actor_handlers[i % num_actors].calculate_action_log_probs.remote( + sequences_attention_mask_action_mask_refs[i]) + action_log_probs_refs.append(action_log_probs_ref) + return action_log_probs_refs + + def set_loss_function(self, eps_clip: float = 0.2): + ray.get([actor.set_loss_function.remote(eps_clip) for actor in self._actor_handlers]) + + def save_checkpoint(self, save_path, should_save_optimizer): + ray.get([actor.save_checkpoint.remote(save_path, should_save_optimizer) for actor in self._actor_handlers]) + + +class PPOCriticRayActorGroup(TrainableModelRayActorGroup): + + def __init__(self, num_nodes, num_gpus_per_node) -> None: + super().__init__(num_nodes, num_gpus_per_node, RayPPOCritic) + + def async_calculate_value(self, sequences_attention_mask_action_mask_refs): + num_actors = len(self._actor_handlers) + value_refs = [] + for i in range(len(sequences_attention_mask_action_mask_refs)): + value_ref = self._actor_handlers[i % num_actors].calculate_value.remote( + sequences_attention_mask_action_mask_refs[i]) + value_refs.append(value_ref) + return value_refs + + def set_loss_function(self, value_clip: float = 0.4): + ray.get([actor.set_loss_function.remote(value_clip) for actor in self._actor_handlers]) + + +class PPOInitialRayActorGroup(PPORayActorGroup): + + def __init__(self, num_nodes, num_gpus_per_node) -> None: + super().__init__(num_nodes, num_gpus_per_node, RayPPOInitialModel) + + def async_calculate_base_action_log_probs(self, sequences_attention_mask_action_mask_refs): + num_actors = len(self._actor_handlers) + base_action_log_probs_refs = [] + for i in range(len(sequences_attention_mask_action_mask_refs)): + base_action_log_probs_ref = self._actor_handlers[i % num_actors].calculate_base_action_log_probs.remote( + sequences_attention_mask_action_mask_refs[i]) + base_action_log_probs_refs.append(base_action_log_probs_ref) + return base_action_log_probs_refs + + +class PPORewardRayActorGroup(PPORayActorGroup): + + def __init__(self, num_nodes, num_gpus_per_node) -> None: + super().__init__(num_nodes, num_gpus_per_node, RayPPORewardModel) + + def async_calculate_r(self, sequences_attention_mask_action_mask_refs): + num_actors = len(self._actor_handlers) + r_refs = [] + for i in range(len(sequences_attention_mask_action_mask_refs)): + r_ref = self._actor_handlers[i % num_actors].calculate_r.remote( + sequences_attention_mask_action_mask_refs[i]) + r_refs.append(r_ref) + return r_refs + + +def main(args): + logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s', + level=logging.INFO, + datefmt='%Y-%m-%d %H:%M:%S') + if args.model == 'gpt2': + actor_model_class, critic_model_class = GPTActor, GPTCritic + elif args.model == 'bloom': + actor_model_class, critic_model_class = BLOOMActor, BLOOMCritic + elif args.model == 'opt': + actor_model_class, critic_model_class = OPTActor, OPTCritic + else: + raise ValueError(f'Unsupported model "{args.model}"') + + logging.info("Start creating actors") + # Initialize 4 models (actor, critic, initial_model and reward_model) + actor_group = PPOActorRayActorGroup(num_nodes=args.num_actor_nodes, num_gpus_per_node=args.num_gpus_per_node) + critic_group = PPOCriticRayActorGroup(num_nodes=args.num_critic_nodes, num_gpus_per_node=args.num_gpus_per_node) + initial_group = PPOInitialRayActorGroup(num_nodes=args.num_initial_nodes, num_gpus_per_node=args.num_gpus_per_node) + reward_group = PPORewardRayActorGroup(num_nodes=args.num_reward_nodes, num_gpus_per_node=args.num_gpus_per_node) + logging.info("Actors created") + + # Prepare model for training + generate_kwargs = {'max_length': 128, 'do_sample': True, 'temperature': 1.0, 'top_k': 50} + ray.get( + actor_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, True) + + critic_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, True) + + initial_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, False) + + reward_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, False) + + actor_group.async_prepare_for_sequence_generation(args.model, args.pretrain, generate_kwargs)) + logging.info("Models prepared for training") + + # Prepare models for training + actor_group.load_csv_prompt_file_from_url_to_sampler(args.prompt_csv_url) + actor_group.set_loss_function() + critic_group.set_loss_function() + # Training parameter + num_episodes = args.num_episodes + max_timesteps = args.max_timesteps + update_timesteps = args.update_timesteps + experience_batch_size = args.experience_batch_size + # Start training + logging.info("Training start") + # Set all models to eval and add experience maker + all_ray_actors = actor_group._actor_handlers + critic_group._actor_handlers + \ + initial_group._actor_handlers + reward_group._actor_handlers + num_ray_actors = len(all_ray_actors) + ray.get([ray_actor.eval.remote() for ray_actor in all_ray_actors]) + ray.get([ray_actor.add_experience_maker.remote() for ray_actor in all_ray_actors]) + # Used as a queue to coordinate experience making + experience_composition_refs = [] + time = 0 + for episode in range(num_episodes): + logging.info("episode {} started".format(episode)) + for _ in range(max_timesteps): + time += 1 + # Experience queueing stage + sequences_attention_mask_action_mask_refs = actor_group.async_sample_prompts_and_make_sequence( + experience_batch_size) + base_action_log_probs_refs = initial_group.async_calculate_base_action_log_probs( + sequences_attention_mask_action_mask_refs) + values_refs = critic_group.async_calculate_value(sequences_attention_mask_action_mask_refs) + r_refs = reward_group.async_calculate_r(sequences_attention_mask_action_mask_refs) + action_log_probs_refs = actor_group.async_calculate_action_log_probs( + sequences_attention_mask_action_mask_refs) + experience_composition_refs.extend([ + ExperienceCompositionRefs(sequences_attention_mask_action_mask_refs[i], action_log_probs_refs[i], + base_action_log_probs_refs[i], values_refs[i], r_refs[i]) + for i in range(len(sequences_attention_mask_action_mask_refs)) + ]) + # Learning stage + if time % update_timesteps == 0: + experience_refs = [] + # calculate experiences + for i, experience_composition_ref in enumerate(experience_composition_refs): + exp_composition_ref = experience_composition_ref + selected_ray_actor = all_ray_actors[i % num_ray_actors] + experience_refs.append(selected_ray_actor.make_experience.remote(exp_composition_ref)) + # backward + ray.get( + actor_group.async_learn_on_experiences(experience_refs) + + critic_group.async_learn_on_experiences(experience_refs)) + # clear refs queue + experience_composition_refs.clear() + logging.info("Training finished") + # Save checkpoint + actor_group.save_checkpoint(args.save_path, args.need_optim_ckpt) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--prompt_csv_url', type=str) + parser.add_argument('--strategy', + choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], + default='naive') + parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt']) + parser.add_argument('--pretrain', type=str, default='gpt2') + parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt') + parser.add_argument('--need_optim_ckpt', type=bool, default=False) + parser.add_argument('--num_episodes', type=int, default=10) + parser.add_argument('--max_timesteps', type=int, default=10) + parser.add_argument('--update_timesteps', type=int, default=10) + parser.add_argument('--train_batch_size', type=int, default=8) + parser.add_argument('--experience_batch_size', type=int, default=8) + parser.add_argument('--num_actor_nodes', type=int, help='num of nodes to use to host actor model', default=1) + parser.add_argument('--num_critic_nodes', type=int, help='num of nodes to use to host critic model', default=1) + parser.add_argument('--num_initial_nodes', type=int, help='num of nodes to use to host initial model', default=1) + parser.add_argument('--num_reward_nodes', type=int, help='num of nodes to use to host reward model', default=1) + parser.add_argument('--num_gpus_per_node', type=int, help='num of gpus on a ray node', default=1) + args = parser.parse_args() + ray.init() + main(args) diff --git a/applications/Chat/examples/example_data_reformat.py b/applications/Chat/examples/example_data_reformat.py new file mode 100644 index 0000000000000000000000000000000000000000..dc83b29b525b16ff322126b63042eb32f32ed21e --- /dev/null +++ b/applications/Chat/examples/example_data_reformat.py @@ -0,0 +1,12 @@ +jsonl_file = 'seed_prompts_xx.jsonl' # seed_prompts_en.jsonl or seed_prompts_ch.json from InstructionWild +reformat_file = 'prompts_xx.jsonl' # reformat jsonl file used as Prompt dataset in Stage3 + +data = '' +with open(jsonl_file, 'r', encoding="utf-8") as f1: + for jsonstr in f1.readlines(): + jsonstr = '\t' + jsonstr.strip('\n') + ',\n' + data = data + jsonstr + data = '[\n' + data + ']' + +with open(reformat_file, 'w') as f2: + f2.write(data) \ No newline at end of file diff --git a/applications/Chat/examples/inference.py b/applications/Chat/examples/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..ae59d91c1822825924e87401ee5f5064cf41fbb6 --- /dev/null +++ b/applications/Chat/examples/inference.py @@ -0,0 +1,64 @@ +import argparse + +import torch +from coati.models.bloom import BLOOMActor +from coati.models.gpt import GPTActor +from coati.models.opt import OPTActor +from coati.models.roberta import RoBERTaActor +from transformers import AutoTokenizer, RobertaTokenizer +from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer + + +def eval(args): + # configure model + if args.model == 'gpt2': + actor = GPTActor(pretrained=args.pretrain).to(torch.cuda.current_device()) + elif args.model == 'bloom': + actor = BLOOMActor(pretrained=args.pretrain).to(torch.cuda.current_device()) + elif args.model == 'opt': + actor = OPTActor(pretrained=args.pretrain).to(torch.cuda.current_device()) + elif args.model == 'roberta': + actor = RoBERTaActor(pretrained=args.pretrain).to(torch.cuda.current_device()) + else: + raise ValueError(f'Unsupported model "{args.model}"') + + state_dict = torch.load(args.model_path) + actor.model.load_state_dict(state_dict) + + # configure tokenizer + if args.model == 'gpt2': + tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + tokenizer.pad_token = tokenizer.eos_token + elif args.model == 'bloom': + tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom-560m') + tokenizer.pad_token = tokenizer.eos_token + elif args.model == 'opt': + tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m') + elif args.model == 'roberta': + tokenizer = RobertaTokenizer.from_pretrained("roberta-base") + else: + raise ValueError(f'Unsupported model "{args.model}"') + + actor.eval() + input = args.input + input_ids = tokenizer.encode(input, return_tensors='pt').to(torch.cuda.current_device()) + outputs = actor.generate(input_ids, + max_length=args.max_length, + do_sample=True, + top_k=50, + top_p=0.95, + num_return_sequences=1) + output = tokenizer.batch_decode(outputs[0], skip_special_tokens=True) + print(output) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'roberta']) + # We suggest to use the pretrained model from HuggingFace, use pretrain to configure model + parser.add_argument('--pretrain', type=str, default=None) + parser.add_argument('--model_path', type=str, default=None) + parser.add_argument('--input', type=str, default='Question: How are you ? Answer:') + parser.add_argument('--max_length', type=int, default=100) + args = parser.parse_args() + eval(args) diff --git a/applications/Chat/examples/requirements.txt b/applications/Chat/examples/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..40e6edc7ea7303c516ededa9ecd360f0445f957d --- /dev/null +++ b/applications/Chat/examples/requirements.txt @@ -0,0 +1,2 @@ +pandas>=1.4.1 +sentencepiece diff --git a/applications/Chat/examples/test_ci.sh b/applications/Chat/examples/test_ci.sh new file mode 100755 index 0000000000000000000000000000000000000000..2b049163c8012f0d7954a805be41990ecab1d910 --- /dev/null +++ b/applications/Chat/examples/test_ci.sh @@ -0,0 +1,126 @@ +#!/usr/bin/env bash + +set -xue + +if [ -z "$SFT_DATASET" ]; then + echo "Please set \$SFT_DATASET to the path to sft dataset." + exit 1 +fi + +if [ -z "$PROMPT_PATH" ]; then + echo "Please set \$PROMPT_PATH to the path to prompts csv." + exit 1 +fi + +if [ -z "$PRETRAIN_DATASET" ]; then + echo "Please set \$PRETRAIN_DATASET to the path to alpaca data." + exit 1 +fi + +BASE=$(realpath $(dirname $0)) + +export OMP_NUM_THREADS=8 + +# install requirements +pip install -r ${BASE}/requirements.txt + +wandb init -m offline + +# train sft +torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'bigscience/bloom-560m' \ + --model 'bloom' --strategy colossalai_zero2 --lora_rank 4\ + --dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \ + --save_path ${BASE}/output +rm -rf ${BASE}/output + +torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'gpt2' \ + --model 'gpt2' --strategy colossalai_zero2 \ + --dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \ + --save_path ${BASE}/output +rm -rf ${BASE}/output + +torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'facebook/opt-350m' \ + --model 'opt' --strategy colossalai_zero2 --lora_rank 4\ + --dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \ + --save_path ${BASE}/output +rm -rf ${BASE}/output + +torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'gpt2' \ + --model 'gpt2' --strategy ddp --lora_rank 4\ + --dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \ + --save_path ${BASE}/output + +#torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'facebook/opt-350m' \ +# --model 'opt' --strategy naive \ +# --dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \ +# --save_path ${BASE}/output + +rm -rf ${BASE}/output + +# train rm +torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \ + --pretrain 'facebook/opt-350m' --model 'opt' \ + --strategy colossalai_zero2 --loss_fn 'log_sig'\ + --dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \ + --test True --lora_rank 0 \ + --save_path ${BASE}/rm_ckpt_opt.pt + +torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \ + --pretrain 'gpt2' --model 'gpt2' \ + --strategy colossalai_zero2 --loss_fn 'log_exp' \ + --dataset 'Dahoas/rm-static' \ + --test True --lora_rank 0 \ + --save_path ${BASE}/rm_ckpt_gpt.pt + +torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \ + --pretrain 'gpt2' --model 'gpt2' \ + --strategy ddp --loss_fn 'log_exp' \ + --dataset 'Dahoas/rm-static' \ + --test True --lora_rank 4 \ + --save_path ${BASE}/rm_ckpt.pt +rm -rf ${BASE}/rm_ckpt.pt + +torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \ + --pretrain 'bigscience/bloom-560m' --model 'bloom' \ + --strategy colossalai_zero2 --loss_fn 'log_sig' \ + --dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \ + --test True --lora_rank 4 \ + --save_path ${BASE}/rm_ckpt.pt +rm -rf ${BASE}/rm_ckpt.pt + +torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \ + --pretrain 'microsoft/deberta-v3-large' --model 'deberta' \ + --strategy colossalai_zero2 --loss_fn 'log_sig' \ + --dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \ + --test True --lora_rank 4 \ + --save_path ${BASE}/rm_ckpt.pt +rm -rf ${BASE}/rm_ckpt.pt + +torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \ + --pretrain 'roberta-base' --model 'roberta' \ + --strategy colossalai_zero2 --loss_fn 'log_exp'\ + --dataset 'Anthropic/hh-rlhf' --subset 'harmless-base'\ + --test True --lora_rank 4 \ + --save_path ${BASE}/rm_ckpt.pt + +rm -rf ${BASE}/rm_ckpt.pt + +torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py --prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \ + --strategy colossalai_zero2 --num_episodes 1 --max_timesteps 2 \ + --update_timesteps 2 --max_epochs 1 --train_batch_size 2 \ + --pretrain 'facebook/opt-350m' --model opt \ + --rm_pretrain 'facebook/opt-350m' \ + --rm_path ${BASE}/rm_ckpt_opt.pt \ + --save_path ${BASE}/actor_checkpoint_prompts.pt +rm -rf ${BASE}/rm_ckpt_opt.pt + +torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py --prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \ + --strategy colossalai_zero2 --num_episodes 1 --max_timesteps 2 \ + --update_timesteps 2 --max_epochs 1 --train_batch_size 2 \ + --pretrain 'gpt2' --model gpt2 \ + --rm_pretrain 'gpt2' \ + --rm_path ${BASE}/rm_ckpt_gpt.pt \ + --save_path ${BASE}/actor_checkpoint_prompts.pt +rm -rf ${BASE}/rm_ckpt_gpt.pt + +rm -rf ${BASE}/actor_checkpoint_prompts.pt diff --git a/applications/Chat/examples/train_prompts.py b/applications/Chat/examples/train_prompts.py new file mode 100644 index 0000000000000000000000000000000000000000..a584991cd34e00cfb9bf2c97a285e642b25268c5 --- /dev/null +++ b/applications/Chat/examples/train_prompts.py @@ -0,0 +1,233 @@ +import argparse + +import pandas as pd +import torch +import torch.distributed as dist +from coati.dataset import DataCollatorForSupervisedDataset, PromptDataset, SupervisedDataset +from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic +from coati.models.gpt import GPTRM, GPTActor, GPTCritic +from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM +from coati.models.opt import OPTRM, OPTActor, OPTCritic +from coati.models.roberta import RoBERTaActor, RoBERTaCritic, RoBERTaRM +from coati.trainer import PPOTrainer +from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy +from coati.utils import prepare_llama_tokenizer_and_embedding +from torch.optim import Adam +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer, RobertaTokenizer + +from colossalai.nn.optimizer import HybridAdam + + +def main(args): + # configure strategy + if args.strategy == 'naive': + strategy = NaiveStrategy() + elif args.strategy == 'ddp': + strategy = DDPStrategy() + elif args.strategy == 'colossalai_gemini': + strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5) + elif args.strategy == 'colossalai_zero2': + strategy = ColossalAIStrategy(stage=2, placement_policy='cuda') + else: + raise ValueError(f'Unsupported strategy "{args.strategy}"') + + if args.rm_path is not None: + state_dict = torch.load(args.rm_path, map_location='cpu') + + # configure model + if args.model == 'gpt2': + initial_model = GPTActor(pretrained=args.pretrain) + elif args.model == 'bloom': + initial_model = BLOOMActor(pretrained=args.pretrain) + elif args.model == 'opt': + initial_model = OPTActor(pretrained=args.pretrain) + elif args.model == 'llama': + initial_model = LlamaActor(pretrained=args.pretrain) + elif args.model == 'roberta': + initial_model = RoBERTaActor(pretrained=args.pretrain) + else: + raise ValueError(f'Unsupported actor model "{args.model}"') + + if args.rm_model == None: + rm_model_name = args.model + else: + rm_model_name = args.rm_model + + if rm_model_name == 'gpt2': + reward_model = GPTRM(pretrained=args.rm_pretrain) + elif rm_model_name == 'bloom': + reward_model = BLOOMRM(pretrained=args.rm_pretrain) + elif rm_model_name == 'opt': + reward_model = OPTRM(pretrained=args.rm_pretrain) + elif rm_model_name == 'llama': + reward_model = LlamaRM(pretrained=args.rm_pretrain) + elif rm_model_name == 'roberta': + reward_model = RoBERTaRM(pretrained=args.rm_pretrain) + else: + raise ValueError(f'Unsupported reward model "{rm_model_name}"') + + if args.rm_path is not None: + reward_model.load_state_dict(state_dict) + + initial_model.to(torch.float16).to(torch.cuda.current_device()) + reward_model.to(torch.float16).to(torch.cuda.current_device()) + + with strategy.model_init_context(): + if args.model == 'gpt2': + actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank) + elif args.model == 'bloom': + actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank) + elif args.model == 'opt': + actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank) + elif args.model == 'llama': + actor = LlamaActor(pretrained=args.pretrain, lora_rank=args.lora_rank) + elif args.model == 'roberta': + actor = RoBERTaActor(pretrained=args.pretrain, lora_rank=args.lora_rank) + else: + raise ValueError(f'Unsupported actor model "{args.model}"') + + if rm_model_name == 'gpt2': + critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) + elif rm_model_name == 'bloom': + critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) + elif rm_model_name == 'opt': + critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) + elif rm_model_name == 'llama': + critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) + elif rm_model_name == 'roberta': + critic = RoBERTaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) + else: + raise ValueError(f'Unsupported reward model "{rm_model_name}"') + + if args.rm_path is not None: + critic.load_state_dict(state_dict) + del state_dict + + if args.strategy != 'colossalai_gemini': + critic.to(torch.float16).to(torch.cuda.current_device()) + actor.to(torch.float16).to(torch.cuda.current_device()) + + # configure optimizer + if args.strategy.startswith('colossalai'): + actor_optim = HybridAdam(actor.parameters(), lr=1e-7) + critic_optim = HybridAdam(critic.parameters(), lr=1e-7) + else: + actor_optim = Adam(actor.parameters(), lr=1e-7) + critic_optim = Adam(critic.parameters(), lr=1e-7) + + # configure tokenizer + if args.model == 'gpt2': + tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + elif args.model == 'bloom': + tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m') + elif args.model == 'opt': + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + elif args.model == 'llama': + tokenizer = LlamaTokenizer.from_pretrained(args.pretrain) + tokenizer.eos_token = '<\s>' + elif args.model == 'roberta': + tokenizer = RobertaTokenizer.from_pretrained("roberta-base") + else: + raise ValueError(f'Unsupported model "{args.model}"') + + if args.model == 'llama': + tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, actor) + else: + tokenizer.pad_token = tokenizer.eos_token + + data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) + + prompt_dataset = PromptDataset(tokenizer=tokenizer, data_path=args.prompt_dataset, max_datasets_size=16384) + if dist.is_initialized() and dist.get_world_size() > 1: + prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True) + else: + prompt_sampler = None + prompt_dataloader = DataLoader(prompt_dataset, + shuffle=(prompt_sampler is None), + sampler=prompt_sampler, + batch_size=args.experience_batch_size) + + pretrain_dataset = SupervisedDataset(tokenizer=tokenizer, + data_path=args.pretrain_dataset, + max_datasets_size=16384, + max_length=args.max_input_len) + if dist.is_initialized() and dist.get_world_size() > 1: + pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True) + else: + pretrain_sampler = None + pretrain_dataloader = DataLoader(pretrain_dataset, + shuffle=(pretrain_sampler is None), + sampler=pretrain_sampler, + batch_size=args.ptx_batch_size, + collate_fn=data_collator) + + (actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim)) + + # configure trainer + trainer = PPOTrainer( + strategy, + actor, + critic, + reward_model, + initial_model, + actor_optim, + critic_optim, + kl_coef=args.kl_coef, + ptx_coef=args.ptx_coef, + max_epochs=args.max_epochs, + train_batch_size=args.train_batch_size, + max_length=args.max_seq_len, + use_cache=True, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + ) + + trainer.fit(prompt_dataloader=prompt_dataloader, + pretrain_dataloader=pretrain_dataloader, + num_episodes=args.num_episodes, + max_timesteps=args.max_timesteps, + update_timesteps=args.update_timesteps) + + # save model checkpoint after fitting + strategy.save_model(actor, args.save_path, only_rank0=True) + # save optimizer checkpoint on all ranks + if args.need_optim_ckpt: + strategy.save_optimizer(actor_optim, + 'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), + only_rank0=False) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--prompt_dataset', type=str, default=None, help='path to the prompt dataset') + parser.add_argument('--pretrain_dataset', type=str, default=None, help='path to the pretrained dataset') + parser.add_argument('--strategy', + choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], + default='colossalai_zero2', + help='strategy to use') + parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama', 'roberta']) + parser.add_argument('--pretrain', type=str, default=None) + parser.add_argument('--rm_model', default=None, choices=['gpt2', 'bloom', 'opt', 'llama', 'roberta']) + parser.add_argument('--rm_path', type=str, default=None) + parser.add_argument('--rm_pretrain', type=str, default=None) + parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts') + parser.add_argument('--need_optim_ckpt', type=bool, default=False) + parser.add_argument('--num_episodes', type=int, default=10) + parser.add_argument('--max_timesteps', type=int, default=10) + parser.add_argument('--update_timesteps', type=int, default=10) + parser.add_argument('--max_epochs', type=int, default=5) + parser.add_argument('--train_batch_size', type=int, default=8) + parser.add_argument('--ptx_batch_size', type=int, default=1) + parser.add_argument('--experience_batch_size', type=int, default=8) + parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") + parser.add_argument('--kl_coef', type=float, default=0.1) + parser.add_argument('--ptx_coef', type=float, default=0.9) + parser.add_argument('--max_input_len', type=int, default=96) + parser.add_argument('--max_seq_len', type=int, default=128) + args = parser.parse_args() + main(args) diff --git a/applications/Chat/examples/train_prompts.sh b/applications/Chat/examples/train_prompts.sh new file mode 100755 index 0000000000000000000000000000000000000000..7f3b2636ca32862d03a260c44bfa4765f6f9990e --- /dev/null +++ b/applications/Chat/examples/train_prompts.sh @@ -0,0 +1,20 @@ +set_n_least_used_CUDA_VISIBLE_DEVICES() { + local n=${1:-"9999"} + echo "GPU Memory Usage:" + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \ + | tail -n +2 \ + | nl -v 0 \ + | tee /dev/tty \ + | sort -g -k 2 \ + | awk '{print $1}' \ + | head -n $n) + export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') + echo "Now CUDA_VISIBLE_DEVICES is set to:" + echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +} + +set_n_least_used_CUDA_VISIBLE_DEVICES 2 + +# torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy colossalai_zero2 + +torchrun --standalone --nproc_per_node=2 train_prompts.py --prompt_dataset /path/to/data.json --strategy colossalai_zero2 diff --git a/applications/Chat/examples/train_reward_model.py b/applications/Chat/examples/train_reward_model.py new file mode 100644 index 0000000000000000000000000000000000000000..48b12336fa6743714add52ee52c8a518c155f2ab --- /dev/null +++ b/applications/Chat/examples/train_reward_model.py @@ -0,0 +1,210 @@ +import argparse +from random import randint + +import loralib as lora +import torch +import torch.distributed as dist +from coati.dataset import HhRlhfDataset, RmStaticDataset +from coati.models import LogExpLoss, LogSigLoss +from coati.models.base import RewardModel +from coati.models.bloom import BLOOMRM +from coati.models.deberta import DebertaRM +from coati.models.gpt import GPTRM +from coati.models.llama import LlamaRM +from coati.models.opt import OPTRM +from coati.models.roberta import RoBERTaRM +from coati.trainer import RewardModelTrainer +from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy +from coati.utils import prepare_llama_tokenizer_and_embedding +from datasets import load_dataset +from torch.optim import Adam +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from transformers import AutoTokenizer, BloomTokenizerFast, DebertaV2Tokenizer, LlamaTokenizer, RobertaTokenizer +from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer + +from colossalai.nn.optimizer import HybridAdam + + +def train(args): + # configure strategy + if args.strategy == 'naive': + strategy = NaiveStrategy() + elif args.strategy == 'ddp': + strategy = DDPStrategy() + elif args.strategy == 'colossalai_gemini': + strategy = ColossalAIStrategy(stage=3, placement_policy='cuda') + elif args.strategy == 'colossalai_zero2': + strategy = ColossalAIStrategy(stage=2, placement_policy='cuda') + else: + raise ValueError(f'Unsupported strategy "{args.strategy}"') + + # configure model + with strategy.model_init_context(): + if args.model == 'bloom': + model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) + elif args.model == 'opt': + model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) + elif args.model == 'gpt2': + model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) + elif args.model == 'deberta': + model = DebertaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) + elif args.model == 'llama': + model = LlamaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) + elif args.model == 'roberta': + model = RoBERTaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) + else: + raise ValueError(f'Unsupported model "{args.model}"') + + if args.model_path is not None: + state_dict = torch.load(args.model_path) + model.load_state_dict(state_dict) + + model = model.to(torch.float16) + + # configure tokenizer + if args.model == 'gpt2': + tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + elif args.model == 'bloom': + tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m') + elif args.model == 'opt': + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + elif args.model == 'deberta': + tokenizer = DebertaV2Tokenizer.from_pretrained('microsoft/deberta-v3-large') + elif args.model == 'llama': + tokenizer = LlamaTokenizer.from_pretrained(args.pretrain) + elif args.model == 'roberta': + tokenizer = RobertaTokenizer.from_pretrained("roberta-base") + else: + raise ValueError(f'Unsupported model "{args.model}"') + max_len = args.max_len + + if args.model == 'llama': + tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, model) + else: + tokenizer.pad_token = tokenizer.eos_token + + # configure optimizer + if args.strategy.startswith('colossalai'): + optim = HybridAdam(model.parameters(), lr=5e-6) + else: + optim = Adam(model.parameters(), lr=5e-6) + + # configure loss function + if args.loss_fn == 'log_sig': + loss_fn = LogSigLoss() + elif args.loss_fn == 'log_exp': + loss_fn = LogExpLoss() + else: + raise ValueError(f'Unsupported loss function "{args.loss_fn}"') + + # prepare for data and dataset + if args.subset is not None: + data = load_dataset(args.dataset, data_dir=args.subset) + else: + data = load_dataset(args.dataset) + + if args.test: + train_data = data['train'].select(range(100)) + eval_data = data['test'].select(range(10)) + else: + train_data = data['train'] + eval_data = data['test'] + valid_data = data['test'].select((randint(0, len(eval_data) - 1) for _ in range(len(eval_data) // 5))) + + if args.dataset == 'Dahoas/rm-static': + train_dataset = RmStaticDataset(train_data, tokenizer, max_len) + valid_dataset = RmStaticDataset(valid_data, tokenizer, max_len) + eval_dataset = RmStaticDataset(eval_data, tokenizer, max_len) + elif args.dataset == 'Anthropic/hh-rlhf': + train_dataset = HhRlhfDataset(train_data, tokenizer, max_len) + valid_dataset = HhRlhfDataset(valid_data, tokenizer, max_len) + eval_dataset = HhRlhfDataset(eval_data, tokenizer, max_len) + else: + raise ValueError(f'Unsupported dataset "{args.dataset}"') + + if dist.is_initialized() and dist.get_world_size() > 1: + train_sampler = DistributedSampler(train_dataset, + shuffle=True, + seed=42, + drop_last=True, + rank=dist.get_rank(), + num_replicas=dist.get_world_size()) + valid_sampler = DistributedSampler(valid_dataset, + shuffle=True, + seed=42, + drop_last=True, + rank=dist.get_rank(), + num_replicas=dist.get_world_size()) + eval_sampler = DistributedSampler(eval_dataset, + shuffle=True, + seed=42, + drop_last=True, + rank=dist.get_rank(), + num_replicas=dist.get_world_size()) + else: + train_sampler = None + valid_sampler = None + eval_sampler = None + + train_dataloader = DataLoader(train_dataset, + shuffle=(train_sampler is None), + sampler=train_sampler, + batch_size=args.batch_size, + pin_memory=True) + + valid_dataloader = DataLoader(valid_dataset, + shuffle=(valid_sampler is None), + sampler=valid_sampler, + batch_size=args.batch_size, + pin_memory=True) + + eval_dataloader = DataLoader(eval_dataset, + shuffle=(eval_sampler is None), + sampler=eval_sampler, + batch_size=args.batch_size, + pin_memory=True) + + (model, optim) = strategy.prepare((model, optim)) + trainer = RewardModelTrainer(model=model, + strategy=strategy, + optim=optim, + loss_fn=loss_fn, + train_dataloader=train_dataloader, + valid_dataloader=valid_dataloader, + eval_dataloader=eval_dataloader, + max_epochs=args.max_epochs) + + trainer.fit() + # save model checkpoint after fitting on only rank0 + strategy.save_model(model, args.save_path, only_rank0=True) + # save optimizer checkpoint on all ranks + if args.need_optim_ckpt: + strategy.save_optimizer(trainer.optimizer, + 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), + only_rank0=False) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--strategy', + choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], + default='colossalai_zero2') + parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'deberta', 'llama', 'roberta'], default='bloom') + parser.add_argument('--pretrain', type=str, default=None) + parser.add_argument('--model_path', type=str, default=None) + parser.add_argument('--need_optim_ckpt', type=bool, default=False) + parser.add_argument('--dataset', + type=str, + choices=['Anthropic/hh-rlhf', 'Dahoas/rm-static'], + default='Dahoas/rm-static') + parser.add_argument('--subset', type=str, default=None) + parser.add_argument('--save_path', type=str, default='rm_ckpt') + parser.add_argument('--max_epochs', type=int, default=1) + parser.add_argument('--batch_size', type=int, default=1) + parser.add_argument('--max_len', type=int, default=512) + parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") + parser.add_argument('--loss_fn', type=str, default='log_sig', choices=['log_sig', 'log_exp']) + parser.add_argument('--test', type=bool, default=False) + args = parser.parse_args() + train(args) diff --git a/applications/Chat/examples/train_rm.sh b/applications/Chat/examples/train_rm.sh new file mode 100755 index 0000000000000000000000000000000000000000..80abe62d2a3fe9d70c0ab8be1b2e8e3b8afc5e03 --- /dev/null +++ b/applications/Chat/examples/train_rm.sh @@ -0,0 +1,24 @@ +set_n_least_used_CUDA_VISIBLE_DEVICES() { + local n=${1:-"9999"} + echo "GPU Memory Usage:" + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \ + | tail -n +2 \ + | nl -v 0 \ + | tee /dev/tty \ + | sort -g -k 2 \ + | awk '{print $1}' \ + | head -n $n) + export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') + echo "Now CUDA_VISIBLE_DEVICES is set to:" + echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +} + +set_n_least_used_CUDA_VISIBLE_DEVICES 2 + +torchrun --standalone --nproc_per_node=2 train_reward_model.py \ + --pretrain \ + --model 'bloom' \ + --strategy colossalai_zero2 \ + --loss_fn 'log_sig'\ + --save_path \ + --dataset 'Anthropic/hh-rlhf'\ diff --git a/applications/Chat/examples/train_sft.py b/applications/Chat/examples/train_sft.py new file mode 100644 index 0000000000000000000000000000000000000000..da499f068b17885ac468ecd1dcb9de49100f667c --- /dev/null +++ b/applications/Chat/examples/train_sft.py @@ -0,0 +1,196 @@ +import argparse +import os + +import loralib as lora +import torch +import torch.distributed as dist +from coati.dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset +from coati.models import convert_to_lora_module +from coati.trainer import SFTTrainer +from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy +from coati.utils import prepare_llama_tokenizer_and_embedding +from datasets import load_dataset +from torch.optim import Adam +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from transformers import AutoTokenizer, BloomConfig, BloomForCausalLM, BloomTokenizerFast, LlamaConfig, LlamaForCausalLM +from transformers.models.gpt2.configuration_gpt2 import GPT2Config +from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel +from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer +from transformers.models.opt.configuration_opt import OPTConfig +from transformers.models.opt.modeling_opt import OPTForCausalLM + +from colossalai.logging import get_dist_logger +from colossalai.nn.optimizer import HybridAdam +from colossalai.tensor import ColoParameter + + +def train(args): + # configure strategy + if args.strategy == 'naive': + strategy = NaiveStrategy() + elif args.strategy == 'ddp': + strategy = DDPStrategy() + elif args.strategy == 'colossalai_gemini': + raise NotImplementedError( + 'Gemini is not supported .from_pretrained() yet. We will update this after checkpoint io is ready.') + strategy = ColossalAIStrategy(stage=3, placement_policy='cuda') + elif args.strategy == 'colossalai_zero2': + strategy = ColossalAIStrategy(stage=2, placement_policy='cuda') + elif args.strategy == 'colossalai_zero2_cpu': + strategy = ColossalAIStrategy(stage=2, placement_policy='cpu') + else: + raise ValueError(f'Unsupported strategy "{args.strategy}"') + + # configure model + with strategy.model_init_context(): + if args.model == 'bloom': + model = convert_to_lora_module(BloomForCausalLM.from_pretrained(args.pretrain), + args.lora_rank).half().cuda() + elif args.model == 'opt': + model = convert_to_lora_module(OPTForCausalLM.from_pretrained(args.pretrain), args.lora_rank).half().cuda() + elif args.model == 'gpt2': + model = convert_to_lora_module(GPT2LMHeadModel.from_pretrained(args.pretrain), args.lora_rank).half().cuda() + elif args.model == 'llama': + model = convert_to_lora_module(LlamaForCausalLM.from_pretrained(args.pretrain), + args.lora_rank).half().cuda() + else: + raise ValueError(f'Unsupported model "{args.model}"') + if args.grad_checkpoint: + model.gradient_checkpointing_enable() + + # configure tokenizer + if args.model == 'gpt2': + tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + tokenizer.pad_token = tokenizer.eos_token + elif args.model == 'bloom': + tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain) + tokenizer.pad_token = tokenizer.eos_token + elif args.model == 'opt': + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + elif args.model == 'llama': + tokenizer = AutoTokenizer.from_pretrained( + args.pretrain, + padding_side="right", + use_fast=False, + ) + tokenizer.eos_token = '<\s>' + else: + raise ValueError(f'Unsupported model "{args.model}"') + tokenizer.pad_token = tokenizer.eos_token + max_len = args.max_len + if args.model == 'llama': + tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, model) + + if args.strategy == 'colossalai_gemini': + # this is a hack to deal with the resized embedding + # to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatiblity + for name, param in model.named_parameters(): + if not isinstance(param, ColoParameter): + sub_module_name = '.'.join(name.split('.')[:-1]) + weight_name = name.split('.')[-1] + sub_module = model.get_submodule(sub_module_name) + setattr(sub_module, weight_name, ColoParameter(param)) + else: + tokenizer.pad_token = tokenizer.eos_token + + # configure optimizer + if args.strategy.startswith('colossalai'): + optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0) + else: + optim = Adam(model.parameters(), lr=args.lr) + + logger = get_dist_logger() + + # configure dataset + if args.dataset == 'yizhongw/self_instruct': + train_data = load_dataset(args.dataset, 'super_natural_instructions', split='train') + eval_data = load_dataset(args.dataset, 'super_natural_instructions', split='test') + + train_dataset = SFTDataset(train_data, tokenizer, max_len) + eval_dataset = SFTDataset(eval_data, tokenizer, max_len) + + else: + train_dataset = SupervisedDataset(tokenizer=tokenizer, + data_path=args.dataset, + max_datasets_size=args.max_datasets_size, + max_length=max_len) + eval_dataset = None + data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) + + if dist.is_initialized() and dist.get_world_size() > 1: + train_sampler = DistributedSampler(train_dataset, + shuffle=True, + seed=42, + drop_last=True, + rank=dist.get_rank(), + num_replicas=dist.get_world_size()) + if eval_dataset is not None: + eval_sampler = DistributedSampler(eval_dataset, + shuffle=False, + seed=42, + drop_last=False, + rank=dist.get_rank(), + num_replicas=dist.get_world_size()) + else: + train_sampler = None + eval_sampler = None + + train_dataloader = DataLoader(train_dataset, + shuffle=(train_sampler is None), + sampler=train_sampler, + batch_size=args.batch_size, + collate_fn=data_collator, + pin_memory=True) + if eval_dataset is not None: + eval_dataloader = DataLoader(eval_dataset, + shuffle=(eval_sampler is None), + sampler=eval_sampler, + batch_size=args.batch_size, + collate_fn=data_collator, + pin_memory=True) + else: + eval_dataloader = None + + (model, optim) = strategy.prepare((model, optim)) + trainer = SFTTrainer(model=model, + strategy=strategy, + optim=optim, + train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, + max_epochs=args.max_epochs, + accumulation_steps=args.accumulation_steps) + + trainer.fit(logger=logger, use_wandb=args.use_wandb) + + # save model checkpoint after fitting on only rank0 + strategy.save_pretrained(model, path=args.save_path, only_rank0=True, tokenizer=tokenizer) + # save optimizer checkpoint on all ranks + if args.need_optim_ckpt: + strategy.save_optimizer(trainer.optimizer, + 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), + only_rank0=False) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--strategy', + choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_zero2_cpu'], + default='colossalai_zero2') + parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom') + parser.add_argument('--pretrain', type=str, default=None) + parser.add_argument('--dataset', type=str, default=None) + parser.add_argument('--max_datasets_size', type=int, default=None) + parser.add_argument('--save_path', type=str, default='output') + parser.add_argument('--need_optim_ckpt', type=bool, default=False) + parser.add_argument('--max_epochs', type=int, default=3) + parser.add_argument('--batch_size', type=int, default=4) + parser.add_argument('--max_len', type=int, default=512) + parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") + parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log") + parser.add_argument('--lr', type=float, default=5e-6) + parser.add_argument('--accumulation_steps', type=int, default=8) + parser.add_argument('--use_wandb', default=False, action='store_true') + parser.add_argument('--grad_checkpoint', default=False, action='store_true') + args = parser.parse_args() + train(args) diff --git a/applications/Chat/examples/train_sft.sh b/applications/Chat/examples/train_sft.sh new file mode 100755 index 0000000000000000000000000000000000000000..c880f85825a77a98ea49ce691bb5cf4fcabca857 --- /dev/null +++ b/applications/Chat/examples/train_sft.sh @@ -0,0 +1,12 @@ +torchrun --standalone --nproc_per_node=4 train_sft.py \ + --pretrain "/path/to/LLaMa-7B/" \ + --model 'llama' \ + --strategy colossalai_zero2 \ + --log_interval 10 \ + --save_path /path/to/Coati-7B \ + --dataset /path/to/data.json \ + --batch_size 4 \ + --accumulation_steps 8 \ + --lr 2e-5 \ + --max_datasets_size 512 \ + --max_epochs 1 \ diff --git a/applications/Chat/inference/README.md b/applications/Chat/inference/README.md new file mode 100644 index 0000000000000000000000000000000000000000..434677c98fa58f7050098671c6243c6f70d023a4 --- /dev/null +++ b/applications/Chat/inference/README.md @@ -0,0 +1,118 @@ +# Inference + +We provide an online inference server and a benchmark. We aim to run inference on single GPU, so quantization is essential when using large models. + +We support 8-bit quantization (RTN), which is powered by [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) and [transformers](https://github.com/huggingface/transformers). And 4-bit quantization (GPTQ), which is powered by [gptq](https://github.com/IST-DASLab/gptq) and [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa). We also support FP16 inference. + +We only support LLaMA family models now. + +## Choosing precision (quantization) + +**FP16**: Fastest, best output quality, highest memory usage + +**8-bit**: Slow, easier setup (originally supported by transformers), lower output quality (due to RTN), **recommended for first-timers** + +**4-bit**: Faster, lowest memory usage, higher output quality (due to GPTQ), but more difficult setup + +## Hardware requirements for LLaMA + +Tha data is from [LLaMA Int8 4bit ChatBot Guide v2](https://rentry.org/llama-tard-v2). + +### 8-bit + +| Model | Min GPU RAM | Recommended GPU RAM | Min RAM/Swap | Card examples | +| :---: | :---: | :---: | :---: | :---: | +| LLaMA-7B | 9.2GB | 10GB | 24GB | 3060 12GB, RTX 3080 10GB, RTX 3090 | +| LLaMA-13B | 16.3GB | 20GB | 32GB | RTX 3090 Ti, RTX 4090 | +| LLaMA-30B | 36GB | 40GB | 64GB | A6000 48GB, A100 40GB | +| LLaMA-65B | 74GB | 80GB | 128GB | A100 80GB | + +### 4-bit + +| Model | Min GPU RAM | Recommended GPU RAM | Min RAM/Swap | Card examples | +| :---: | :---: | :---: | :---: | :---: | +| LLaMA-7B | 3.5GB | 6GB | 16GB | RTX 1660, 2060, AMD 5700xt, RTX 3050, 3060 | +| LLaMA-13B | 6.5GB | 10GB | 32GB | AMD 6900xt, RTX 2060 12GB, 3060 12GB, 3080, A2000 | +| LLaMA-30B | 15.8GB | 20GB | 64GB | RTX 3080 20GB, A4500, A5000, 3090, 4090, 6000, Tesla V100 | +| LLaMA-65B | 31.2GB | 40GB | 128GB | A100 40GB, 2x3090, 2x4090, A40, RTX A6000, 8000, Titan Ada | + +## General setup + +```shell +pip install -r requirements.txt +``` + +## 8-bit setup + +8-bit quantization is originally supported by the latest [transformers](https://github.com/huggingface/transformers). Please install it from source. + +Please ensure you have downloaded HF-format model weights of LLaMA models. + +Usage: + +```python +import torch +from transformers import LlamaForCausalLM + +USE_8BIT = True # use 8-bit quantization; otherwise, use fp16 + +model = LlamaForCausalLM.from_pretrained( + "pretrained/path", + load_in_8bit=USE_8BIT, + torch_dtype=torch.float16, + device_map="auto", + ) +if not USE_8BIT: + model.half() # use fp16 +model.eval() +``` + +**Troubleshooting**: if you get error indicating your CUDA-related libraries not found when loading 8-bit model, you can check whether your `LD_LIBRARY_PATH` is correct. + +E.g. you can set `export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH`. + +## 4-bit setup + +Please ensure you have downloaded HF-format model weights of LLaMA models first. + +Then you can follow [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa). This lib provides efficient CUDA kernels and weight convertion script. + +After installing this lib, we may convert the original HF-format LLaMA model weights to 4-bit version. + +```shell +CUDA_VISIBLE_DEVICES=0 python llama.py /path/to/pretrained/llama-7b c4 --wbits 4 --groupsize 128 --save llama7b-4bit.pt +``` + +Run this command in your cloned `GPTQ-for-LLaMa` directory, then you will get a 4-bit weight file `llama7b-4bit-128g.pt`. + +**Troubleshooting**: if you get error about `position_ids`, you can checkout to commit `50287c3b9ae4a3b66f6b5127c643ec39b769b155`(`GPTQ-for-LLaMa` repo). + +## Online inference server + +In this directory: + +```shell +export CUDA_VISIBLE_DEVICES=0 +# fp16, will listen on 0.0.0.0:7070 by default +python server.py /path/to/pretrained +# 8-bit, will listen on localhost:8080 +python server.py /path/to/pretrained --quant 8bit --http_host localhost --http_port 8080 +# 4-bit +python server.py /path/to/pretrained --quant 4bit --gptq_checkpoint /path/to/llama7b-4bit-128g.pt --gptq_group_size 128 +``` + +## Benchmark + +In this directory: + +```shell +export CUDA_VISIBLE_DEVICES=0 +# fp16 +python benchmark.py /path/to/pretrained +# 8-bit +python benchmark.py /path/to/pretrained --quant 8bit +# 4-bit +python benchmark.py /path/to/pretrained --quant 4bit --gptq_checkpoint /path/to/llama7b-4bit-128g.pt --gptq_group_size 128 +``` + +This benchmark will record throughput and peak CUDA memory usage. diff --git a/applications/Chat/inference/benchmark.py b/applications/Chat/inference/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..59cd1eeea2aa841ae805d91adf279f668fdb3dd0 --- /dev/null +++ b/applications/Chat/inference/benchmark.py @@ -0,0 +1,132 @@ +# Adapted from https://github.com/tloen/alpaca-lora/blob/main/generate.py + +import argparse +from time import time + +import torch +from llama_gptq import load_quant +from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM + + +def generate_prompt(instruction, input=None): + if input: + return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. + +### Instruction: +{instruction} + +### Input: +{input} + +### Response:""" + else: + return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. + +### Instruction: +{instruction} + +### Response:""" + + +@torch.no_grad() +def evaluate( + model, + tokenizer, + instruction, + input=None, + temperature=0.1, + top_p=0.75, + top_k=40, + num_beams=4, + max_new_tokens=128, + **kwargs, +): + prompt = generate_prompt(instruction, input) + inputs = tokenizer(prompt, return_tensors="pt") + input_ids = inputs["input_ids"].cuda() + generation_config = GenerationConfig( + temperature=temperature, + top_p=top_p, + top_k=top_k, + num_beams=num_beams, + **kwargs, + ) + generation_output = model.generate( + input_ids=input_ids, + generation_config=generation_config, + return_dict_in_generate=True, + output_scores=True, + max_new_tokens=max_new_tokens, + do_sample=True, + ) + s = generation_output.sequences[0] + output = tokenizer.decode(s) + n_new_tokens = s.size(0) - input_ids.size(1) + return output.split("### Response:")[1].strip(), n_new_tokens + + +instructions = [ + "Tell me about alpacas.", + "Tell me about the president of Mexico in 2019.", + "Tell me about the king of France in 2019.", + "List all Canadian provinces in alphabetical order.", + "Write a Python program that prints the first 10 Fibonacci numbers.", + "Write a program that prints the numbers from 1 to 100. But for multiples of three print 'Fizz' instead of the number and for the multiples of five print 'Buzz'. For numbers which are multiples of both three and five print 'FizzBuzz'.", + "Tell me five words that rhyme with 'shock'.", + "Translate the sentence 'I have no mouth but I must scream' into Spanish.", + "Count up from 1 to 500.", + # === + "How to play support in legends of league", + "Write a Python program that calculate Fibonacci numbers.", +] +inst = [instructions[0]] * 4 + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + 'pretrained', + help='Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.') + parser.add_argument('--quant', + choices=['8bit', '4bit'], + default=None, + help='Quantization mode. Default: None (no quantization, fp16).') + parser.add_argument( + '--gptq_checkpoint', + default=None, + help='Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.') + parser.add_argument('--gptq_group_size', + type=int, + default=128, + help='Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.') + args = parser.parse_args() + + if args.quant == '4bit': + assert args.gptq_checkpoint is not None, 'Please specify a GPTQ checkpoint.' + + tokenizer = AutoTokenizer.from_pretrained(args.pretrained) + + if args.quant == '4bit': + model = load_quant(args.pretrained, args.gptq_checkpoint, 4, args.gptq_group_size) + model.cuda() + else: + model = LlamaForCausalLM.from_pretrained( + args.pretrained, + load_in_8bit=(args.quant == '8bit'), + torch_dtype=torch.float16, + device_map="auto", + ) + if args.quant != '8bit': + model.half() # seems to fix bugs for some users. + model.eval() + + total_tokens = 0 + start = time() + for instruction in instructions: + print(f"Instruction: {instruction}") + resp, tokens = evaluate(model, tokenizer, instruction, temparature=0.2, num_beams=1) + total_tokens += tokens + print(f"Response: {resp}") + print('\n----------------------------\n') + duration = time() - start + print(f'Total time: {duration:.3f} s, {total_tokens/duration:.3f} tokens/s') + print(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB') diff --git a/applications/Chat/inference/llama_gptq/__init__.py b/applications/Chat/inference/llama_gptq/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..51c8d6316290fe2fcef7d972803017c830d3e1b4 --- /dev/null +++ b/applications/Chat/inference/llama_gptq/__init__.py @@ -0,0 +1,5 @@ +from .loader import load_quant + +__all__ = [ + 'load_quant', +] diff --git a/applications/Chat/inference/llama_gptq/loader.py b/applications/Chat/inference/llama_gptq/loader.py new file mode 100644 index 0000000000000000000000000000000000000000..a5c6ac7d1589aa1873918b9c8b02edcfe13ed59f --- /dev/null +++ b/applications/Chat/inference/llama_gptq/loader.py @@ -0,0 +1,41 @@ +import torch +import torch.nn as nn +import transformers +from transformers import LlamaConfig, LlamaForCausalLM + +from .model_utils import find_layers +from .quant import make_quant + + +def load_quant(pretrained: str, checkpoint: str, wbits: int, groupsize: int): + config = LlamaConfig.from_pretrained(pretrained) + + def noop(*args, **kwargs): + pass + + torch.nn.init.kaiming_uniform_ = noop + torch.nn.init.uniform_ = noop + torch.nn.init.normal_ = noop + + torch.set_default_dtype(torch.half) + transformers.modeling_utils._init_weights = False + torch.set_default_dtype(torch.half) + model = LlamaForCausalLM(config) + torch.set_default_dtype(torch.float) + model = model.eval() + layers = find_layers(model) + for name in ['lm_head']: + if name in layers: + del layers[name] + make_quant(model, layers, wbits, groupsize) + + print(f'Loading model with {wbits} bits...') + if checkpoint.endswith('.safetensors'): + from safetensors.torch import load_file as safe_load + model.load_state_dict(safe_load(checkpoint)) + else: + model.load_state_dict(torch.load(checkpoint)) + model.seqlen = 2048 + print('Done.') + + return model diff --git a/applications/Chat/inference/llama_gptq/model_utils.py b/applications/Chat/inference/llama_gptq/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..62db171abb52cb88799a8b73d608f2617208cefe --- /dev/null +++ b/applications/Chat/inference/llama_gptq/model_utils.py @@ -0,0 +1,13 @@ +# copied from https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/past/modelutils.py + +import torch +import torch.nn as nn + + +def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): + if type(module) in layers: + return {name: module} + res = {} + for name1, child in module.named_children(): + res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1)) + return res diff --git a/applications/Chat/inference/llama_gptq/quant.py b/applications/Chat/inference/llama_gptq/quant.py new file mode 100644 index 0000000000000000000000000000000000000000..f7d5b7ce4bd8217bf246abbef0736c78be3869a6 --- /dev/null +++ b/applications/Chat/inference/llama_gptq/quant.py @@ -0,0 +1,283 @@ +# copied from https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/past/quant.py + +import math + +import numpy as np +import torch +import torch.nn as nn + + +def quantize(x, scale, zero, maxq): + q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) + return scale * (q - zero) + + +class Quantizer(nn.Module): + + def __init__(self, shape=1): + super(Quantizer, self).__init__() + self.register_buffer('maxq', torch.tensor(0)) + self.register_buffer('scale', torch.zeros(shape)) + self.register_buffer('zero', torch.zeros(shape)) + + def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8): + self.maxq = torch.tensor(2**bits - 1) + self.perchannel = perchannel + self.sym = sym + self.mse = mse + self.norm = norm + self.grid = grid + self.maxshrink = maxshrink + + def find_params(self, x, weight=False): + dev = x.device + self.maxq = self.maxq.to(dev) + + shape = x.shape + if self.perchannel: + if weight: + x = x.flatten(1) + else: + if len(shape) == 4: + x = x.permute([1, 0, 2, 3]) + x = x.flatten(1) + if len(shape) == 3: + x = x.reshape((-1, shape[-1])).t() + if len(shape) == 2: + x = x.t() + else: + x = x.flatten().unsqueeze(0) + + tmp = torch.zeros(x.shape[0], device=dev) + xmin = torch.minimum(x.min(1)[0], tmp) + xmax = torch.maximum(x.max(1)[0], tmp) + + if self.sym: + xmax = torch.maximum(torch.abs(xmin), xmax) + tmp = xmin < 0 + if torch.any(tmp): + xmin[tmp] = -xmax[tmp] + tmp = (xmin == 0) & (xmax == 0) + xmin[tmp] = -1 + xmax[tmp] = +1 + + self.scale = (xmax - xmin) / self.maxq + if self.sym: + self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) + else: + self.zero = torch.round(-xmin / self.scale) + + if self.mse: + best = torch.full([x.shape[0]], float('inf'), device=dev) + for i in range(int(self.maxshrink * self.grid)): + p = 1 - i / self.grid + xmin1 = p * xmin + xmax1 = p * xmax + scale1 = (xmax1 - xmin1) / self.maxq + zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero + q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) + q -= x + q.abs_() + q.pow_(self.norm) + err = torch.sum(q, 1) + tmp = err < best + if torch.any(tmp): + best[tmp] = err[tmp] + self.scale[tmp] = scale1[tmp] + self.zero[tmp] = zero1[tmp] + if not self.perchannel: + if weight: + tmp = shape[0] + else: + tmp = shape[1] if len(shape) != 3 else shape[2] + self.scale = self.scale.repeat(tmp) + self.zero = self.zero.repeat(tmp) + + if weight: + shape = [-1] + [1] * (len(shape) - 1) + self.scale = self.scale.reshape(shape) + self.zero = self.zero.reshape(shape) + return + if len(shape) == 4: + self.scale = self.scale.reshape((1, -1, 1, 1)) + self.zero = self.zero.reshape((1, -1, 1, 1)) + if len(shape) == 3: + self.scale = self.scale.reshape((1, 1, -1)) + self.zero = self.zero.reshape((1, 1, -1)) + if len(shape) == 2: + self.scale = self.scale.unsqueeze(0) + self.zero = self.zero.unsqueeze(0) + + def quantize(self, x): + if self.ready(): + return quantize(x, self.scale, self.zero, self.maxq) + return x + + def enabled(self): + return self.maxq > 0 + + def ready(self): + return torch.all(self.scale != 0) + + +try: + import quant_cuda +except: + print('CUDA extension not installed.') + +# Assumes layer is perfectly divisible into 256 * 256 blocks + + +class QuantLinear(nn.Module): + + def __init__(self, bits, groupsize, infeatures, outfeatures): + super().__init__() + if bits not in [2, 3, 4, 8]: + raise NotImplementedError("Only 2,3,4,8 bits are supported.") + self.infeatures = infeatures + self.outfeatures = outfeatures + self.bits = bits + if groupsize != -1 and groupsize < 32 and groupsize != int(math.pow(2, int(math.log2(groupsize)))): + raise NotImplementedError("groupsize supports powers of 2 greater than 32. (e.g. : 32,64,128,etc)") + groupsize = groupsize if groupsize != -1 else infeatures + self.groupsize = groupsize + self.register_buffer( + 'qzeros', torch.zeros((math.ceil(infeatures / groupsize), outfeatures // 256 * (bits * 8)), + dtype=torch.int)) + self.register_buffer('scales', torch.zeros((math.ceil(infeatures / groupsize), outfeatures))) + self.register_buffer('bias', torch.zeros(outfeatures)) + self.register_buffer('qweight', torch.zeros((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int)) + self._initialized_quant_state = False + + def pack(self, linear, scales, zeros): + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + self.scales = scales.clone() + if linear.bias is not None: + self.bias = linear.bias.clone() + + intweight = [] + for idx in range(self.infeatures): + g_idx = idx // self.groupsize + intweight.append( + torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx]) / self.scales[g_idx]).to(torch.int)[:, + None]) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.numpy().astype(np.uint32) + qweight = np.zeros((intweight.shape[0] // 256 * (self.bits * 8), intweight.shape[1]), dtype=np.uint32) + i = 0 + row = 0 + while row < qweight.shape[0]: + if self.bits in [2, 4, 8]: + for j in range(i, i + (32 // self.bits)): + qweight[row] |= intweight[j] << (self.bits * (j - i)) + i += 32 // self.bits + row += 1 + elif self.bits == 3: + for j in range(i, i + 10): + qweight[row] |= intweight[j] << (3 * (j - i)) + i += 10 + qweight[row] |= intweight[i] << 30 + row += 1 + qweight[row] |= (intweight[i] >> 2) & 1 + i += 1 + for j in range(i, i + 10): + qweight[row] |= intweight[j] << (3 * (j - i) + 1) + i += 10 + qweight[row] |= intweight[i] << 31 + row += 1 + qweight[row] |= (intweight[i] >> 1) & 0x3 + i += 1 + for j in range(i, i + 10): + qweight[row] |= intweight[j] << (3 * (j - i) + 2) + i += 10 + row += 1 + else: + raise NotImplementedError("Only 2,3,4,8 bits are supported.") + + qweight = qweight.astype(np.int32) + self.qweight = torch.from_numpy(qweight) + + zeros -= 1 + zeros = zeros.numpy().astype(np.uint32) + qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 256 * (self.bits * 8)), dtype=np.uint32) + i = 0 + col = 0 + while col < qzeros.shape[1]: + if self.bits in [2, 4, 8]: + for j in range(i, i + (32 // self.bits)): + qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) + i += 32 // self.bits + col += 1 + elif self.bits == 3: + for j in range(i, i + 10): + qzeros[:, col] |= zeros[:, j] << (3 * (j - i)) + i += 10 + qzeros[:, col] |= zeros[:, i] << 30 + col += 1 + qzeros[:, col] |= (zeros[:, i] >> 2) & 1 + i += 1 + for j in range(i, i + 10): + qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 1) + i += 10 + qzeros[:, col] |= zeros[:, i] << 31 + col += 1 + qzeros[:, col] |= (zeros[:, i] >> 1) & 0x3 + i += 1 + for j in range(i, i + 10): + qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 2) + i += 10 + col += 1 + else: + raise NotImplementedError("Only 2,3,4,8 bits are supported.") + + qzeros = qzeros.astype(np.int32) + self.qzeros = torch.from_numpy(qzeros) + + def forward(self, x): + intermediate_dtype = torch.float32 + + if not self._initialized_quant_state: + # Do we even have a bias? Check for at least one non-zero element. + if self.bias is not None and bool(torch.any(self.bias != 0)): + # Then make sure it's the right type. + self.bias.data = self.bias.data.to(intermediate_dtype) + else: + self.bias = None + + outshape = list(x.shape) + outshape[-1] = self.outfeatures + x = x.reshape(-1, x.shape[-1]) + if self.bias is None: + y = torch.zeros(x.shape[0], outshape[-1], dtype=intermediate_dtype, device=x.device) + else: + y = self.bias.clone().repeat(x.shape[0], 1) + + output_dtype = x.dtype + x = x.to(intermediate_dtype) + if self.bits == 2: + quant_cuda.vecquant2matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize) + elif self.bits == 3: + quant_cuda.vecquant3matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize) + elif self.bits == 4: + quant_cuda.vecquant4matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize) + elif self.bits == 8: + quant_cuda.vecquant8matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize) + else: + raise NotImplementedError("Only 2,3,4,8 bits are supported.") + y = y.to(output_dtype) + return y.reshape(outshape) + + +def make_quant(module, names, bits, groupsize, name=''): + if isinstance(module, QuantLinear): + return + for attr in dir(module): + tmp = getattr(module, attr) + name1 = name + '.' + attr if name != '' else attr + if name1 in names: + setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features)) + for name1, child in module.named_children(): + make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1) diff --git a/applications/Chat/inference/locustfile.py b/applications/Chat/inference/locustfile.py new file mode 100644 index 0000000000000000000000000000000000000000..51cdc68125bba42d29a91f03285847e0bde27ea8 --- /dev/null +++ b/applications/Chat/inference/locustfile.py @@ -0,0 +1,27 @@ +from json import JSONDecodeError + +from locust import HttpUser, task + +samples = [[ + dict( + instruction='Who is the best player in the history of NBA?', + response= + 'The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1' + ), + dict(instruction='continue this talk', response=''), +], [ + dict(instruction='Who is the best player in the history of NBA?', response=''), +]] + + +class GenerationUser(HttpUser): + + @task + def generate(self): + for sample in samples: + data = {'max_new_tokens': 64, 'history': sample} + with self.client.post('/generate', json=data, catch_response=True) as response: + if response.status_code in (200, 406): + response.success() + else: + response.failure('Response wrong') diff --git a/applications/Chat/inference/requirements.txt b/applications/Chat/inference/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..511fe1a4f1f339b23f1a11162925703b94e15013 --- /dev/null +++ b/applications/Chat/inference/requirements.txt @@ -0,0 +1,13 @@ +fastapi +locust +numpy +pydantic +safetensors +slowapi +sse_starlette +torch +uvicorn +git+https://github.com/huggingface/transformers +accelerate +bitsandbytes +jieba \ No newline at end of file diff --git a/applications/Chat/inference/server.py b/applications/Chat/inference/server.py new file mode 100644 index 0000000000000000000000000000000000000000..b4627299397e6949576318de3938ee9c19aa390c --- /dev/null +++ b/applications/Chat/inference/server.py @@ -0,0 +1,178 @@ +import argparse +import os +from threading import Lock +from typing import Dict, Generator, List, Optional + +import torch +import uvicorn +from fastapi import FastAPI, HTTPException, Request +from fastapi.middleware.cors import CORSMiddleware +from llama_gptq import load_quant +from pydantic import BaseModel, Field +from slowapi import Limiter, _rate_limit_exceeded_handler +from slowapi.errors import RateLimitExceeded +from slowapi.util import get_remote_address +from sse_starlette.sse import EventSourceResponse +from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM +from utils import ChatPromptProcessor, Dialogue, LockedIterator, sample_streamingly, update_model_kwargs_fn, load_json + +CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.' +MAX_LEN = 512 +running_lock = Lock() + + +class GenerationTaskReq(BaseModel): + max_new_tokens: int = Field(gt=0, le=512, example=64) + history: List[Dialogue] = Field(min_items=1) + top_k: Optional[int] = Field(default=None, gt=0, example=50) + top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5) + temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7) + repetition_penalty: Optional[float] = Field(default=None, gt=1.0, example=1.2) + + +limiter = Limiter(key_func=get_remote_address) +app = FastAPI() +app.state.limiter = limiter +app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) + +# set CORS +origin_spec_from_env = os.environ.get('CORS_ORIGIN', None) + +if origin_spec_from_env is not None: + # allow CORS from the specified origins + origins = os.environ['CORS_ORIGIN'].split(',') +else: + # allow CORS from all origins + origins = ["*"] + +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +def generate_streamingly(prompt, max_new_tokens, top_k, top_p, temperature): + inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()} + #TODO(ver217): streaming generation does not support repetition_penalty now + model_kwargs = { + 'max_generate_tokens': max_new_tokens, + 'early_stopping': True, + 'top_k': top_k, + 'top_p': top_p, + 'temperature': temperature, + 'prepare_inputs_fn': model.prepare_inputs_for_generation, + 'update_model_kwargs_fn': update_model_kwargs_fn, + } + is_first_word = True + generator = LockedIterator(sample_streamingly(model, **inputs, **model_kwargs), running_lock) + for output in generator: + output = output.cpu() + tokens = tokenizer.convert_ids_to_tokens(output, skip_special_tokens=True) + current_sub_tokens = [] + for token in tokens: + if token in tokenizer.all_special_tokens: + continue + current_sub_tokens.append(token) + if current_sub_tokens: + out_string = tokenizer.sp_model.decode(current_sub_tokens) + if is_first_word: + out_string = out_string.lstrip() + is_first_word = False + elif current_sub_tokens[0].startswith('▁'): + # whitespace will be ignored by the frontend + out_string = ' ' + out_string + yield out_string + + +async def event_generator(request: Request, generator: Generator): + while True: + if await request.is_disconnected(): + break + try: + yield {'event': 'generate', 'data': next(generator)} + except StopIteration: + yield {'event': 'end', 'data': ''} + break + + +@app.post('/generate/stream') +@limiter.limit('1/second') +def generate(data: GenerationTaskReq, request: Request): + prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens) + event_source = event_generator( + request, generate_streamingly(prompt, data.max_new_tokens, data.top_k, data.top_p, data.temperature)) + return EventSourceResponse(event_source) + + +@app.post('/generate') +@limiter.limit('1/second') +def generate_no_stream(data: GenerationTaskReq, request: Request): + prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens) + if prompt_processor.has_censored_words(prompt): + return prompt_processor.SAFE_RESPONSE + inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()} + with running_lock: + output = model.generate(**inputs, **data.dict(exclude={'history'})) + output = output.cpu() + prompt_len = inputs['input_ids'].size(1) + response = output[0, prompt_len:] + out_string = tokenizer.decode(response, skip_special_tokens=True) + out_string = prompt_processor.postprocess_output(out_string) + if prompt_processor.has_censored_words(out_string): + return prompt_processor.SAFE_RESPONSE + return out_string + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + 'pretrained', + help='Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.') + parser.add_argument('--quant', + choices=['8bit', '4bit'], + default=None, + help='Quantization mode. Default: None (no quantization, fp16).') + parser.add_argument( + '--gptq_checkpoint', + default=None, + help='Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.') + parser.add_argument('--gptq_group_size', + type=int, + default=128, + help='Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.') + parser.add_argument('--http_host', default='0.0.0.0') + parser.add_argument('--http_port', type=int, default=7070) + parser.add_argument('--profanity_file', default=None, help='Path to profanity words list. It should be a JSON file containing a list of words.') + args = parser.parse_args() + + if args.quant == '4bit': + assert args.gptq_checkpoint is not None, 'Please specify a GPTQ checkpoint.' + + tokenizer = AutoTokenizer.from_pretrained(args.pretrained) + + if args.profanity_file is not None: + censored_words = load_json(args.profanity_file) + else: + censored_words = [] + prompt_processor = ChatPromptProcessor(tokenizer, CONTEXT, MAX_LEN, censored_words=censored_words) + + if args.quant == '4bit': + model = load_quant(args.pretrained, args.gptq_checkpoint, 4, args.gptq_group_size) + model.cuda() + else: + model = LlamaForCausalLM.from_pretrained( + args.pretrained, + load_in_8bit=(args.quant == '8bit'), + torch_dtype=torch.float16, + device_map="auto", + ) + if args.quant != '8bit': + model.half() # seems to fix bugs for some users. + model.eval() + + config = uvicorn.Config(app, host=args.http_host, port=args.http_port) + server = uvicorn.Server(config=config) + server.run() diff --git a/applications/Chat/inference/tests/test_chat_prompt.py b/applications/Chat/inference/tests/test_chat_prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..f5737ebe8c097d73073bb21195341b378e7fc2f1 --- /dev/null +++ b/applications/Chat/inference/tests/test_chat_prompt.py @@ -0,0 +1,56 @@ +import os + +from transformers import AutoTokenizer +from utils import ChatPromptProcessor, Dialogue + +CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.' +tokenizer = AutoTokenizer.from_pretrained(os.environ['PRETRAINED_PATH']) + +samples = [ + ([ + Dialogue( + instruction='Who is the best player in the history of NBA?', + response= + 'The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1' + ), + Dialogue(instruction='continue this talk', response=''), + ], 128, + 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\nThe best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1\n\n### Instruction:\ncontinue this talk\n\n### Response:\n' + ), + ([ + Dialogue( + instruction='Who is the best player in the history of NBA?', + response= + 'The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1' + ), + Dialogue(instruction='continue this talk', response=''), + ], 200, + 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this talk\n\n### Response:\n' + ), + ([ + Dialogue( + instruction='Who is the best player in the history of NBA?', + response= + 'The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1' + ), + Dialogue(instruction='continue this talk', response=''), + ], 211, + 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this\n\n### Response:\n' + ), + ([ + Dialogue(instruction='Who is the best player in the history of NBA?', response=''), + ], 128, + 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\n' + ), +] + + +def test_chat_prompt_processor(): + processor = ChatPromptProcessor(tokenizer, CONTEXT, 256) + for history, max_new_tokens, result in samples: + prompt = processor.preprocess_prompt(history, max_new_tokens) + assert prompt == result + + +if __name__ == '__main__': + test_chat_prompt_processor() diff --git a/applications/Chat/inference/utils.py b/applications/Chat/inference/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..37944be70a3bf9631f0a995bd1753d71d6c8b5aa --- /dev/null +++ b/applications/Chat/inference/utils.py @@ -0,0 +1,200 @@ +import re +from threading import Lock +from typing import Any, Callable, Generator, List, Optional +import json +import jieba + +import torch +import torch.distributed as dist +import torch.nn as nn +from pydantic import BaseModel, Field + +try: + from transformers.generation_logits_process import ( + LogitsProcessorList, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, + ) +except ImportError: + from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper + + +def prepare_logits_processor(top_k: Optional[int] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None) -> LogitsProcessorList: + processor_list = LogitsProcessorList() + if temperature is not None and temperature != 1.0: + processor_list.append(TemperatureLogitsWarper(temperature)) + if top_k is not None and top_k != 0: + processor_list.append(TopKLogitsWarper(top_k)) + if top_p is not None and top_p < 1.0: + processor_list.append(TopPLogitsWarper(top_p)) + return processor_list + + +def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool: + if dist.is_initialized() and dist.get_world_size() > 1: + # consider DP + unfinished_sequences = unfinished_sequences.clone() + dist.all_reduce(unfinished_sequences) + return unfinished_sequences.max() == 0 + + +def sample_streamingly(model: nn.Module, + input_ids: torch.Tensor, + max_generate_tokens: int, + early_stopping: bool = False, + eos_token_id: Optional[int] = None, + pad_token_id: Optional[int] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, + update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None, + **model_kwargs) -> Generator: + + logits_processor = prepare_logits_processor(top_k, top_p, temperature) + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + + for _ in range(max_generate_tokens): + model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else { + 'input_ids': input_ids + } + outputs = model(**model_inputs) + + next_token_logits = outputs['logits'][:, -1, :] + # pre-process distribution + next_token_logits = logits_processor(input_ids, next_token_logits) + # sample + probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float) + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + + # finished sentences should have their next token be a padding token + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + yield next_tokens + + # update generated ids, model inputs for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + if update_model_kwargs_fn is not None: + model_kwargs = update_model_kwargs_fn(outputs, **model_kwargs) + + # if eos_token was found in one sentence, set sentence to finished + if eos_token_id is not None: + unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) + + # stop when each sentence is finished if early_stopping=True + if early_stopping and _is_sequence_finished(unfinished_sequences): + break + + +def update_model_kwargs_fn(outputs: dict, **model_kwargs) -> dict: + if "past_key_values" in outputs: + model_kwargs["past"] = outputs["past_key_values"] + else: + model_kwargs["past"] = None + + # update token_type_ids with last value + if "token_type_ids" in model_kwargs: + token_type_ids = model_kwargs["token_type_ids"] + model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) + + # update attention mask + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1) + + return model_kwargs + + +class Dialogue(BaseModel): + instruction: str = Field(min_length=1, example='Count up from 1 to 500.') + response: str = Field(example='') + + +def _format_dialogue(instruction: str, response: str = ''): + return f'\n\n### Instruction:\n{instruction}\n\n### Response:\n{response}' + + +STOP_PAT = re.compile(r'(###|instruction:).*', flags=(re.I | re.S)) + + +class ChatPromptProcessor: + SAFE_RESPONSE = 'The input/response contains inappropriate content, please rephrase your prompt.' + + def __init__(self, tokenizer, context: str, max_len: int = 2048, censored_words: List[str]=[]): + self.tokenizer = tokenizer + self.context = context + self.max_len = max_len + self.censored_words = set([word.lower() for word in censored_words]) + # These will be initialized after the first call of preprocess_prompt() + self.context_len: Optional[int] = None + self.dialogue_placeholder_len: Optional[int] = None + + def preprocess_prompt(self, history: List[Dialogue], max_new_tokens: int) -> str: + if self.context_len is None: + self.context_len = len(self.tokenizer(self.context)['input_ids']) + if self.dialogue_placeholder_len is None: + self.dialogue_placeholder_len = len( + self.tokenizer(_format_dialogue(''), add_special_tokens=False)['input_ids']) + prompt = self.context + # the last dialogue must be in the prompt + last_dialogue = history.pop() + # the response of the last dialogue is empty + assert last_dialogue.response == '' + if len(self.tokenizer(_format_dialogue(last_dialogue.instruction), add_special_tokens=False) + ['input_ids']) + max_new_tokens + self.context_len >= self.max_len: + # to avoid truncate placeholder, apply truncate to the original instruction + instruction_truncated = self.tokenizer(last_dialogue.instruction, + add_special_tokens=False, + truncation=True, + max_length=(self.max_len - max_new_tokens - self.context_len - + self.dialogue_placeholder_len))['input_ids'] + instruction_truncated = self.tokenizer.decode(instruction_truncated).lstrip() + prompt += _format_dialogue(instruction_truncated) + return prompt + + res_len = self.max_len - max_new_tokens - len(self.tokenizer(prompt)['input_ids']) + + rows = [] + for dialogue in history[::-1]: + text = _format_dialogue(dialogue.instruction, dialogue.response) + cur_len = len(self.tokenizer(text, add_special_tokens=False)['input_ids']) + if res_len - cur_len < 0: + break + res_len -= cur_len + rows.insert(0, text) + prompt += ''.join(rows) + _format_dialogue(last_dialogue.instruction) + return prompt + + def postprocess_output(self, output: str) -> str: + output = STOP_PAT.sub('', output) + return output.strip() + + def has_censored_words(self, text: str) -> bool: + if len(self.censored_words) == 0: + return False + intersection = set(jieba.cut(text.lower())) & self.censored_words + return len(intersection) > 0 + +class LockedIterator: + + def __init__(self, it, lock: Lock) -> None: + self.lock = lock + self.it = iter(it) + + def __iter__(self): + return self + + def __next__(self): + with self.lock: + return next(self.it) + +def load_json(path: str): + with open(path) as f: + return json.load(f) \ No newline at end of file diff --git a/applications/Chat/pytest.ini b/applications/Chat/pytest.ini new file mode 100644 index 0000000000000000000000000000000000000000..01e5cd217c5d9830551735c7d9e6d4087e038391 --- /dev/null +++ b/applications/Chat/pytest.ini @@ -0,0 +1,6 @@ +[pytest] +markers = + cpu: tests which can run on CPU + gpu: tests which requires a single GPU + dist: tests which are run in a multi-GPU or multi-machine environment + experiment: tests for experimental features diff --git a/applications/Chat/requirements-test.txt b/applications/Chat/requirements-test.txt new file mode 100644 index 0000000000000000000000000000000000000000..e079f8a6038dd2dc8512967540f96ee0de172067 --- /dev/null +++ b/applications/Chat/requirements-test.txt @@ -0,0 +1 @@ +pytest diff --git a/applications/Chat/requirements.txt b/applications/Chat/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..af7ff67861eb73573489e9ae46f1a2d29eaa02b3 --- /dev/null +++ b/applications/Chat/requirements.txt @@ -0,0 +1,13 @@ +transformers>=4.20.1 +tqdm +datasets +loralib +colossalai>=0.2.4 +torch<2.0.0, >=1.12.1 +langchain +tokenizers +fastapi +sse_starlette +wandb +sentencepiece +gpustat diff --git a/applications/Chat/setup.py b/applications/Chat/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..a285a6dff4bf9cfe6905494de83a0baacfd795cb --- /dev/null +++ b/applications/Chat/setup.py @@ -0,0 +1,41 @@ +from setuptools import find_packages, setup + + +def fetch_requirements(path): + with open(path, 'r') as fd: + return [r.strip() for r in fd.readlines()] + + +def fetch_readme(): + with open('README.md', encoding='utf-8') as f: + return f.read() + + +def fetch_version(): + with open('version.txt', 'r') as f: + return f.read().strip() + + +setup( + name='coati', + version=fetch_version(), + packages=find_packages(exclude=( + 'tests', + 'benchmarks', + '*.egg-info', + )), + description='Colossal-AI Talking Intelligence', + long_description=fetch_readme(), + long_description_content_type='text/markdown', + license='Apache Software License 2.0', + url='https://github.com/hpcaitech/Coati', + install_requires=fetch_requirements('requirements.txt'), + python_requires='>=3.6', + classifiers=[ + 'Programming Language :: Python :: 3', + 'License :: OSI Approved :: Apache Software License', + 'Environment :: GPU :: NVIDIA CUDA', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + 'Topic :: System :: Distributed Computing', + ], +) diff --git a/applications/Chat/tests/__init__.py b/applications/Chat/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/applications/Chat/tests/test_checkpoint.py b/applications/Chat/tests/test_checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..4c05a343169905dc03d91b1f4c39530ec8834848 --- /dev/null +++ b/applications/Chat/tests/test_checkpoint.py @@ -0,0 +1,94 @@ +import os +import tempfile +from contextlib import nullcontext + +import pytest +import torch +import torch.distributed as dist +from coati.models.gpt import GPTActor +from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy +from transformers.models.gpt2.configuration_gpt2 import GPT2Config + +from colossalai.nn.optimizer import HybridAdam +from colossalai.testing import rerun_if_address_is_in_use, spawn + +GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4) + + +def get_data(batch_size: int, seq_len: int = 10) -> dict: + input_ids = torch.randint(0, 50257, (batch_size, seq_len), device='cuda') + attention_mask = torch.ones_like(input_ids) + return dict(input_ids=input_ids, attention_mask=attention_mask) + + +def run_test_checkpoint(strategy): + BATCH_SIZE = 2 + + if strategy == 'ddp': + strategy = DDPStrategy() + elif strategy == 'colossalai_gemini': + strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5) + elif strategy == 'colossalai_zero2': + strategy = ColossalAIStrategy(stage=2, placement_policy='cuda') + else: + raise ValueError(f'Unsupported strategy "{strategy}"') + + with strategy.model_init_context(): + actor = GPTActor(config=GPT_CONFIG).cuda() + + actor_optim = HybridAdam(actor.parameters()) + + actor, actor_optim = strategy.prepare((actor, actor_optim)) + + def run_step(): + data = get_data(BATCH_SIZE) + action_mask = torch.ones_like(data['attention_mask'], dtype=torch.bool) + action_log_probs = actor(data['input_ids'], action_mask.size(1), data['attention_mask']) + loss = action_log_probs.sum() + strategy.backward(loss, actor, actor_optim) + strategy.optimizer_step(actor_optim) + + run_step() + + ctx = tempfile.TemporaryDirectory() if dist.get_rank() == 0 else nullcontext() + + with ctx as dirname: + rank0_dirname = [dirname] + dist.broadcast_object_list(rank0_dirname) + rank0_dirname = rank0_dirname[0] + + model_path = os.path.join(rank0_dirname, 'model.pt') + optim_path = os.path.join(rank0_dirname, f'optim-r{dist.get_rank()}.pt') + + strategy.save_model(actor, model_path, only_rank0=True) + strategy.save_optimizer(actor_optim, optim_path, only_rank0=False) + + dist.barrier() + + strategy.load_model(actor, model_path, strict=False) + strategy.load_optimizer(actor_optim, optim_path) + + dist.barrier() + + run_step() + + +def run_dist(rank, world_size, port, strategy): + os.environ['RANK'] = str(rank) + os.environ['LOCAL_RANK'] = str(rank) + os.environ['WORLD_SIZE'] = str(world_size) + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = str(port) + run_test_checkpoint(strategy) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [2]) +@pytest.mark.parametrize('strategy', ['ddp', 'colossalai_zero2', 'colossalai_gemini']) +@rerun_if_address_is_in_use() +def test_checkpoint(world_size, strategy): + spawn(run_dist, world_size, strategy=strategy) + + +if __name__ == '__main__': + test_checkpoint(2, 'colossalai_zero2') diff --git a/applications/Chat/tests/test_data.py b/applications/Chat/tests/test_data.py new file mode 100644 index 0000000000000000000000000000000000000000..2e4d4ceac05fa603b98e4ad1c9b098e221345e83 --- /dev/null +++ b/applications/Chat/tests/test_data.py @@ -0,0 +1,118 @@ +import os +from copy import deepcopy + +import pytest +import torch +import torch.distributed as dist +from coati.experience_maker import NaiveExperienceMaker +from coati.models.base import RewardModel +from coati.models.gpt import GPTActor, GPTCritic +from coati.replay_buffer import NaiveReplayBuffer +from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy +from transformers.models.gpt2.configuration_gpt2 import GPT2Config + +from colossalai.testing import rerun_if_address_is_in_use, spawn + +GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4) + + +def get_data(batch_size: int, seq_len: int = 10) -> dict: + input_ids = torch.randint(0, 50257, (batch_size, seq_len), device='cuda') + attention_mask = torch.ones_like(input_ids) + return dict(input_ids=input_ids, attention_mask=attention_mask) + + +def gather_and_equal(tensor: torch.Tensor) -> bool: + world_size = dist.get_world_size() + outputs = [torch.empty_like(tensor) for _ in range(world_size)] + dist.all_gather(outputs, tensor.contiguous()) + for t in outputs[1:]: + if not torch.equal(outputs[0], t): + return False + return True + + +def run_test_data(strategy): + EXPERINCE_BATCH_SIZE = 4 + SAMPLE_BATCH_SIZE = 2 + + if strategy == 'ddp': + strategy = DDPStrategy() + elif strategy == 'colossalai': + strategy = ColossalAIStrategy(placement_policy='cuda') + else: + raise ValueError(f'Unsupported strategy "{strategy}"') + + actor = GPTActor(config=GPT_CONFIG).cuda() + critic = GPTCritic(config=GPT_CONFIG).cuda() + + initial_model = deepcopy(actor) + reward_model = RewardModel(deepcopy(critic.model)).cuda() + + experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model) + replay_buffer = NaiveReplayBuffer(SAMPLE_BATCH_SIZE, cpu_offload=False) + + # experience of all ranks should be the same + for _ in range(2): + data = get_data(EXPERINCE_BATCH_SIZE) + assert gather_and_equal(data['input_ids']) + assert gather_and_equal(data['attention_mask']) + experience = experience_maker.make_experience(**data, + do_sample=True, + max_length=16, + eos_token_id=50256, + pad_token_id=50256) + assert gather_and_equal(experience.sequences) + assert gather_and_equal(experience.action_log_probs) + assert gather_and_equal(experience.values) + assert gather_and_equal(experience.reward) + assert gather_and_equal(experience.advantages) + assert gather_and_equal(experience.action_mask) + assert gather_and_equal(experience.attention_mask) + replay_buffer.append(experience) + + # replay buffer's data should be the same + buffer_size = torch.tensor([len(replay_buffer)], device='cuda') + assert gather_and_equal(buffer_size) + for item in replay_buffer.items: + assert gather_and_equal(item.sequences) + assert gather_and_equal(item.action_log_probs) + assert gather_and_equal(item.values) + assert gather_and_equal(item.reward) + assert gather_and_equal(item.advantages) + assert gather_and_equal(item.action_mask) + assert gather_and_equal(item.attention_mask) + + # dataloader of each rank should have the same size and different batch + dataloader = strategy.setup_dataloader(replay_buffer) + dataloader_size = torch.tensor([len(dataloader)], device='cuda') + assert gather_and_equal(dataloader_size) + for experience in dataloader: + assert not gather_and_equal(experience.sequences) + assert not gather_and_equal(experience.action_log_probs) + assert not gather_and_equal(experience.values) + assert not gather_and_equal(experience.reward) + assert not gather_and_equal(experience.advantages) + # action mask and attention mask may be same + + +def run_dist(rank, world_size, port, strategy): + os.environ['RANK'] = str(rank) + os.environ['LOCAL_RANK'] = str(rank) + os.environ['WORLD_SIZE'] = str(world_size) + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = str(port) + run_test_data(strategy) + + +@pytest.mark.skip +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [2]) +@pytest.mark.parametrize('strategy', ['ddp', 'colossalai']) +@rerun_if_address_is_in_use() +def test_data(world_size, strategy): + spawn(run_dist, world_size, strategy=strategy) + + +if __name__ == '__main__': + test_data(2, 'colossalai') diff --git a/applications/Chat/version.txt b/applications/Chat/version.txt new file mode 100644 index 0000000000000000000000000000000000000000..3eefcb9dd5b38e2c1dc061052455dd97bcd51e6c --- /dev/null +++ b/applications/Chat/version.txt @@ -0,0 +1 @@ +1.0.0 diff --git a/applications/README.md b/applications/README.md new file mode 100644 index 0000000000000000000000000000000000000000..cd0435aae199aade3981c1af6502cf42d0d8578e --- /dev/null +++ b/applications/README.md @@ -0,0 +1,12 @@ +# Applications + +This directory contains the applications that are powered by Colossal-AI. + +The list of applications include: + +- [X] [Chatbot](./Chat/README.md) +- [X] [FastFold](https://github.com/hpcaitech/FastFold): Optimizing AlphaFold (Biomedicine) Training and Inference on GPU Clusters + +> Please note that the `Chatbot` application is migrated from the original `ChatGPT` folder. + +You can find more example code for base models and functions in the [Examples](https://github.com/hpcaitech/ColossalAI/tree/main/examples) directory. diff --git a/colossalai/_C/__init__.py b/colossalai/_C/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/colossalai/__init__.py b/colossalai/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f859161f78108e4c8b5cfba89e12026fee28da43 --- /dev/null +++ b/colossalai/__init__.py @@ -0,0 +1,17 @@ +from .initialize import ( + get_default_parser, + initialize, + launch, + launch_from_openmpi, + launch_from_slurm, + launch_from_torch, +) + +try: + # .version will be created by setup.py + from .version import __version__ +except ModuleNotFoundError: + # this will only happen if the user did not run `pip install` + # and directly set PYTHONPATH to use Colossal-AI which is a bad practice + __version__ = '0.0.0' + print('please install Colossal-AI from https://www.colossalai.org/download or from source') diff --git a/colossalai/_analyzer/README.md b/colossalai/_analyzer/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c5c55eddd325a2f94b221e78945c51b6b14c21ae --- /dev/null +++ b/colossalai/_analyzer/README.md @@ -0,0 +1,306 @@ +# Analyzer + +# Overview +The Analyzer is a collection of static graph utils including Colossal-AI FX. Features include: +- MetaTensor -- enabling: + - Ahead-of-time Profiling + - Shape Propagation + - Ideal Flop Counter +- symbolic_trace() + - Robust Control-flow Tracing / Recompile + - Robust Activation Checkpoint Tracing / CodeGen + - Easy-to-define Bias-Addition Split +- symbolic_profile() + - Support ``MetaTensorMode``, where all Tensor operations are executed symbolically. + - Shape Inference Across Device and Unified ``MetaInfo`` + - Ideal Flop Counter https://dev-discuss.pytorch.org/t/the-ideal-pytorch-flop-counter-with-torch-dispatch/505 + +# Quickstart +## Analyzer.FX +**Reference:** + + https://pytorch.org/docs/stable/fx.html [[paper](https://arxiv.org/pdf/2112.08429)] + + +torch.FX is a toolkit for developers to use to transform nn.Module instances. FX consists of three main components: a symbolic tracer, an intermediate representation, and Python code generation. FX.Tracer hacks _\_\_torch_function\_\__ and use a Proxy object to propagate through any forward function of torch.nn.Module. +![image](https://user-images.githubusercontent.com/78588128/212531495-bbb934dd-dbbb-4578-8869-6171973f7dd8.png) +ColossalAI FX is modified from torch.FX, with the extra capability of ahead-of-time profiling enabled by the subclass of ``MetaTensor``. + +### Analyzer.FX.symbolic_trace() +A drawback of the original torch.FX implementation is that it is poor at handling control flow. All control flow is not PyTorch native operands and requires actual instances that specify the branches to execute on. For example, + +```python +class MyModule(nn.Module): + def forward(self, x): + if x.dim() == 3: + return x * 2 + 1 + else: + return x - 5 +``` + +The above function has the computation graph of + +![image](https://user-images.githubusercontent.com/78588128/212532631-dba30734-577b-4418-8dc9-004d7983abc5.png) + +However, since Proxy does not have concrete data, applying ``x.dim()`` will return nothing. In the context of the auto-parallel system, at least the control-flow dependencies for tensor shape should be removed, since any searched strategy could only auto-parallelize a specific computation graph with the same tensor shape. It is native to attach concrete data onto a Proxy, and propagate them through control flow. + +![image](https://user-images.githubusercontent.com/78588128/212533403-1b620986-1c3a-420a-87c6-d08c9702135d.png) + + +With ``MetaTensor``, the computation during shape propagation can be virtualized. This speeds up tracing by avoiding allocating actual memory on devices. + +#### Remarks +There is no free lunch for PyTorch to unify all operands in both its repo and other repos in its eco-system. For example, the einops library currently has no intention to support torch.FX (See https://github.com/arogozhnikov/einops/issues/188). To support different PyTorch-based libraries without modifying source code, good practices can be to allow users to register their implementation to substitute the functions not supported by torch.FX, or to avoid entering incompatible submodules. + +### Analyzer.FX.symbolic_profile() + +``symbolic_profile`` is another important feature of Colossal-AI's auto-parallel system. Profiling DNN can be costly, as you need to allocate memory and execute on real devices. However, since the profiling requirements for auto-parallel is enough if we can detect when and where the intermediate activations (i.e. Tensor) are generated, we can profile the whole procedure without actually executing it. ``symbolic_profile``, as its name infers, profiles the whole network with symbolic information only. + +```python +with MetaTensorMode(): + model = MyModule().cuda() + sample = torch.rand(100, 3, 224, 224).cuda() +meta_args = dict( + x = sample, +) +gm = symbolic_trace(model, meta_args=meta_args) +gm = symbolic_profile(gm, sample) +``` + +``symbolic_profile`` is enabled by ``ShapeProp`` and ``GraphProfile``. + +#### ShapeProp +Both Tensor Parallel and Activation Checkpoint solvers need to know the shape information ahead of time. Unlike PyTorch's implementation, this ``ShapeProp`` can be executed under MetaTensorMode. With this, all the preparation for auto-parallel solvers can be done in milliseconds. + +Meanwhile, it is easy to keep track of the memory usage of each node when doing shape propagation. However, the drawbacks of FX is that not every ``call_function`` saves its input for backward, and different tensor that flows within one FX.Graph can actually have the same layout. This raises problems for fine-grained profiling. + +![image](https://user-images.githubusercontent.com/78588128/215312957-7eb6cbc3-61b2-49cf-95a4-6b859149eb8d.png) + +To address this problem, I came up with a simulated environment enabled by ``torch.autograd.graph.saved_tensor_hooks`` and fake ``data_ptr`` (check ``_subclasses/meta_tensor.py`` for more details of ``data_ptr`` updates). + +```python +class sim_env(saved_tensors_hooks): + """ + A simulation of memory allocation and deallocation in the forward pass + using ``saved_tensor_hooks``. + + Attributes: + ctx (Dict[int, torch.Tensor]): A dictionary that maps the + data pointer of a tensor to the tensor itself. This is used + to track the memory allocation and deallocation. + + param_ctx (Dict[int, torch.Tensor]): A dictionary that maps the + data pointer of all model parameters to the parameter itself. + This avoids overestimating the memory usage of the intermediate activations. + """ + + def __init__(self, module: Optional[torch.nn.Module] = None): + super().__init__(self.pack_hook, self.unpack_hook) + self.ctx = {} + self.param_ctx = {param.data_ptr(): param for param in module.parameters()} + self.buffer_ctx = {buffer.data_ptr(): buffer for buffer in module.buffers()} if module else {} + + def pack_hook(self, tensor: torch.Tensor): + if tensor.data_ptr() not in self.param_ctx and tensor.data_ptr() not in self.buffer_ctx: + self.ctx[tensor.data_ptr()] = tensor + return tensor + + def unpack_hook(self, tensor): + return tensor +``` +The ``ctx`` variable will keep track of all saved tensors with a unique identifier. It is likely that ``nn.Parameter`` is also counted in the ``ctx``, which is not desired. To avoid this, we can use ``param_ctx`` to keep track of all parameters in the model. The ``buffer_ctx`` is used to keep track of all buffers in the model. The ``local_ctx`` that is attached to each ``Node`` marks the memory usage of the stage to which the node belongs. With simple ``intersect``, ``union`` and ``subtract`` operations, we can get any memory-related information. For non-profileable nodes, you might add your customized profile rules to simulate the memory allocation. If a ``Graph`` is modified with some non-PyTorch functions, such as fused operands, you can register the shape propagation rule with the decorator. + +```python +@register_shape_impl(fuse_conv_bn) +def fuse_conv_bn_shape_impl(*args, **kwargs): + # infer output shape here + return torch.empty(output_shape, device=output_device) +``` + +An important notice is that ``ShapeProp`` will attach additional information to the graph, which will be exactly the input of ``Profiler``. + +#### GraphProfiler +``GraphProfiler`` executes at the node level, and profiles both forward and backward within one node. For example, ``FlopProfiler`` will profile the forward and backward FLOPs of a node, and ``CommunicationProfiler`` will profile the forward and backward communication cost of a node. The ``GraphProfiler`` will attach the profiling results to the ``Node``. These procedures are decoupled for better extensibility. + +To provide a general insight of the profiled results, you can set ``verbose=True`` to print the summary as well. +```python +model = tm.resnet18() +sample = torch.rand(100, 3, 224, 224) +meta_args = dict(x=sample) +gm = symbolic_trace(model, meta_args=meta_args) +gm = symbolic_profile(gm, sample, verbose=True) + +============================================================ Results ===================================================================== + Op type Op Accumulate size Incremental size Output size Temp size Param size Backward size Fwd FLOPs Bwd FLOPs +------------- ---------------------------------------------- ----------------- ------------------ ------------- ----------- ------------ --------------- ------------- ------------- + placeholder x 4.59 Mb 0 b 4.59 Mb 0 b 0 b 0 b 0 FLOPs 0 FLOPs + call_module conv_proj 4.59 Mb 0 b 0 b 4.59 Mb 2.25 Mb 4.59 Mb 924.84 MFLOPs 924.84 MFLOPs + call_method reshape 4.59 Mb 0 b 0 b 4.59 Mb 0 b 4.59 Mb 0 FLOPs 0 FLOPs + call_method permute 4.59 Mb 0 b 0 b 4.59 Mb 0 b 4.59 Mb 0 FLOPs 0 FLOPs + get_attr class_token 4.59 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs + call_method expand 4.59 Mb 0 b 0 b 24.00 Kb 3.00 Kb 0 b 0 FLOPs 6.14 kFLOPs +call_function cat 4.59 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs + get_attr encoder_pos_embedding 4.59 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs +call_function add 9.21 Mb 4.62 Mb 4.62 Mb 0 b 591.00 Kb 4.62 Mb 1.21 MFLOPs 1.21 MFLOPs + call_module encoder_dropout 9.21 Mb 0 b 4.62 Mb 0 b 0 b 4.62 Mb 0 FLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_0_ln_1 9.22 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs + call_module encoder_layers_encoder_layer_0_self_attention 46.52 Mb 37.30 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs +call_function getitem 46.52 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs +call_function getitem_1 46.52 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_0_dropout 46.52 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs +call_function add_1 51.14 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_0_ln_2 51.15 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs + call_module encoder_layers_encoder_layer_0_mlp_0 74.24 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs + call_module encoder_layers_encoder_layer_0_mlp_1 92.71 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs + call_module encoder_layers_encoder_layer_0_mlp_2 92.71 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_0_mlp_3 92.71 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs + call_module encoder_layers_encoder_layer_0_mlp_4 92.71 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs +call_function add_2 97.32 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_1_ln_1 101.95 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs + call_module encoder_layers_encoder_layer_1_self_attention 134.63 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs +call_function getitem_2 134.63 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs +call_function getitem_3 134.63 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_1_dropout 134.63 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs +call_function add_3 139.25 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_1_ln_2 139.26 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs + call_module encoder_layers_encoder_layer_1_mlp_0 162.35 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs + call_module encoder_layers_encoder_layer_1_mlp_1 180.82 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs + call_module encoder_layers_encoder_layer_1_mlp_2 180.82 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_1_mlp_3 180.82 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs + call_module encoder_layers_encoder_layer_1_mlp_4 180.82 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs +call_function add_4 185.43 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_2_ln_1 190.06 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs + call_module encoder_layers_encoder_layer_2_self_attention 222.74 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs +call_function getitem_4 222.74 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs +call_function getitem_5 222.74 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_2_dropout 222.74 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs +call_function add_5 227.36 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_2_ln_2 227.37 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs + call_module encoder_layers_encoder_layer_2_mlp_0 250.46 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs + call_module encoder_layers_encoder_layer_2_mlp_1 268.93 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs + call_module encoder_layers_encoder_layer_2_mlp_2 268.93 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_2_mlp_3 268.93 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs + call_module encoder_layers_encoder_layer_2_mlp_4 268.93 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs +call_function add_6 273.54 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_3_ln_1 278.17 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs + call_module encoder_layers_encoder_layer_3_self_attention 310.86 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs +call_function getitem_6 310.86 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs +call_function getitem_7 310.86 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_3_dropout 310.86 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs +call_function add_7 315.47 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_3_ln_2 315.48 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs + call_module encoder_layers_encoder_layer_3_mlp_0 338.57 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs + call_module encoder_layers_encoder_layer_3_mlp_1 357.04 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs + call_module encoder_layers_encoder_layer_3_mlp_2 357.04 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_3_mlp_3 357.04 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs + call_module encoder_layers_encoder_layer_3_mlp_4 357.04 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs +call_function add_8 361.66 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_4_ln_1 366.29 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs + call_module encoder_layers_encoder_layer_4_self_attention 398.97 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs +call_function getitem_8 398.97 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs +call_function getitem_9 398.97 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_4_dropout 398.97 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs +call_function add_9 403.58 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_4_ln_2 403.60 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs + call_module encoder_layers_encoder_layer_4_mlp_0 426.68 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs + call_module encoder_layers_encoder_layer_4_mlp_1 445.15 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs + call_module encoder_layers_encoder_layer_4_mlp_2 445.15 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_4_mlp_3 445.15 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs + call_module encoder_layers_encoder_layer_4_mlp_4 445.15 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs +call_function add_10 449.77 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_5_ln_1 454.40 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs + call_module encoder_layers_encoder_layer_5_self_attention 487.08 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs +call_function getitem_10 487.08 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs +call_function getitem_11 487.08 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_5_dropout 487.08 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs +call_function add_11 491.70 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_5_ln_2 491.71 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs + call_module encoder_layers_encoder_layer_5_mlp_0 514.79 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs + call_module encoder_layers_encoder_layer_5_mlp_1 533.26 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs + call_module encoder_layers_encoder_layer_5_mlp_2 533.26 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_5_mlp_3 533.26 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs + call_module encoder_layers_encoder_layer_5_mlp_4 533.26 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs +call_function add_12 537.88 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_6_ln_1 542.51 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs + call_module encoder_layers_encoder_layer_6_self_attention 575.19 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs +call_function getitem_12 575.19 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs +call_function getitem_13 575.19 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_6_dropout 575.19 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs +call_function add_13 579.81 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_6_ln_2 579.82 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs + call_module encoder_layers_encoder_layer_6_mlp_0 602.90 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs + call_module encoder_layers_encoder_layer_6_mlp_1 621.37 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs + call_module encoder_layers_encoder_layer_6_mlp_2 621.37 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_6_mlp_3 621.37 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs + call_module encoder_layers_encoder_layer_6_mlp_4 621.37 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs +call_function add_14 625.99 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_7_ln_1 630.62 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs + call_module encoder_layers_encoder_layer_7_self_attention 663.30 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs +call_function getitem_14 663.30 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs +call_function getitem_15 663.30 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_7_dropout 663.30 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs +call_function add_15 667.92 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_7_ln_2 667.93 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs + call_module encoder_layers_encoder_layer_7_mlp_0 691.02 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs + call_module encoder_layers_encoder_layer_7_mlp_1 709.48 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs + call_module encoder_layers_encoder_layer_7_mlp_2 709.48 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_7_mlp_3 709.48 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs + call_module encoder_layers_encoder_layer_7_mlp_4 709.48 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs +call_function add_16 714.10 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_8_ln_1 718.73 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs + call_module encoder_layers_encoder_layer_8_self_attention 751.41 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs +call_function getitem_16 751.41 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs +call_function getitem_17 751.41 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_8_dropout 751.41 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs +call_function add_17 756.03 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_8_ln_2 756.04 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs + call_module encoder_layers_encoder_layer_8_mlp_0 779.13 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs + call_module encoder_layers_encoder_layer_8_mlp_1 797.60 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs + call_module encoder_layers_encoder_layer_8_mlp_2 797.60 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_8_mlp_3 797.60 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs + call_module encoder_layers_encoder_layer_8_mlp_4 797.60 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs +call_function add_18 802.21 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_9_ln_1 806.84 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs + call_module encoder_layers_encoder_layer_9_self_attention 839.52 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs +call_function getitem_18 839.52 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs +call_function getitem_19 839.52 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_9_dropout 839.52 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs +call_function add_19 844.14 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_9_ln_2 844.15 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs + call_module encoder_layers_encoder_layer_9_mlp_0 867.24 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs + call_module encoder_layers_encoder_layer_9_mlp_1 885.71 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs + call_module encoder_layers_encoder_layer_9_mlp_2 885.71 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_9_mlp_3 885.71 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs + call_module encoder_layers_encoder_layer_9_mlp_4 885.71 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs +call_function add_20 890.32 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_10_ln_1 894.95 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs + call_module encoder_layers_encoder_layer_10_self_attention 927.63 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs +call_function getitem_20 927.63 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs +call_function getitem_21 927.63 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_10_dropout 927.63 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs +call_function add_21 932.25 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_10_ln_2 932.26 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs + call_module encoder_layers_encoder_layer_10_mlp_0 955.35 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs + call_module encoder_layers_encoder_layer_10_mlp_1 973.82 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs + call_module encoder_layers_encoder_layer_10_mlp_2 973.82 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_10_mlp_3 973.82 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs + call_module encoder_layers_encoder_layer_10_mlp_4 973.82 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs +call_function add_22 978.44 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_11_ln_1 983.06 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs + call_module encoder_layers_encoder_layer_11_self_attention 1015.75 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs +call_function getitem_22 1015.75 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs +call_function getitem_23 1015.75 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_11_dropout 1015.75 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs +call_function add_23 1020.36 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_11_ln_2 1020.38 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs + call_module encoder_layers_encoder_layer_11_mlp_0 1.02 Gb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs + call_module encoder_layers_encoder_layer_11_mlp_1 1.04 Gb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs + call_module encoder_layers_encoder_layer_11_mlp_2 1.04 Gb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs + call_module encoder_layers_encoder_layer_11_mlp_3 1.04 Gb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs + call_module encoder_layers_encoder_layer_11_mlp_4 1.04 Gb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs +call_function add_24 1.04 Gb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs + call_module encoder_ln 1.04 Gb 36.31 Kb 24.00 Kb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs +call_function getitem_24 1.04 Gb 0 b 24.00 Kb 0 b 0 b 4.62 Mb 0 FLOPs 0 FLOPs + call_module heads_head 1.04 Gb 0 b 0 b 31.25 Kb 2.93 Mb 24.00 Kb 6.14 MFLOPs 12.30 MFLOPs + output output 1.04 Gb 0 b 0 b 31.25 Kb 0 b 31.25 Kb 0 FLOPs 0 FLOPs +``` diff --git a/colossalai/_analyzer/_subclasses/__init__.py b/colossalai/_analyzer/_subclasses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8464fed25edfbd598c43c5c587cb2f3de7cb2e09 --- /dev/null +++ b/colossalai/_analyzer/_subclasses/__init__.py @@ -0,0 +1,4 @@ +from ._meta_registration import * +from ._monkey_patch import * +from .flop_tensor import flop_count, flop_mapping +from .meta_tensor import MetaTensor, MetaTensorMode diff --git a/colossalai/_analyzer/_subclasses/_meta_registration.py b/colossalai/_analyzer/_subclasses/_meta_registration.py new file mode 100644 index 0000000000000000000000000000000000000000..4049be79c70fc1d9c33807d74c2a00fe17c05acd --- /dev/null +++ b/colossalai/_analyzer/_subclasses/_meta_registration.py @@ -0,0 +1,468 @@ +# meta patch from https://github.com/pytorch/pytorch/blob/master/torch/_meta_registrations.py +# should be activated for PyTorch version 1.12.0 and below +# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml +# for more meta_registrations + +from typing import Callable, List, Optional, Tuple, Union + +import torch +from packaging import version +from torch.utils._pytree import tree_map + +aten = torch.ops.aten + +try: + meta_lib = torch.library.Library("aten", "IMPL", "Meta") +except AttributeError: + meta_lib = None + +meta_table = {} + +orig_empty = torch.empty +orig_empty_strided = torch.empty_strided +orig_empty_like = torch.empty_like + + +def new(*args, **kwargs): + return orig_empty(*args, **kwargs, device=torch.device('meta')) + + +def new_strided(*args, **kwargs): + return orig_empty_strided(*args, **kwargs, device=torch.device('meta')) + + +def new_like(*args, **kwargs): + return orig_empty_like(*args, **kwargs, device=torch.device('meta')) + + +def register_meta(op, register_dispatcher=True): + + def wrapper(f): + + def add_func(op): + meta_table[op] = f + if register_dispatcher: + name = (op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__) + try: + meta_lib.impl(name, f) + except: + pass + + tree_map(add_func, op) + return f + + return wrapper + + +if version.parse(torch.__version__) >= version.parse('1.12.0'): + # ============================== Convolutions ====================================== + # https://github.com/pytorch/pytorch/pull/79834 + @register_meta(aten.convolution.default) + def meta_conv( + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: List[int], + padding: List[int], + dilation: List[int], + is_transposed: bool, + output_padding: List[int], + groups: int, + ): + + def _formula(ln: int, p: int, d: int, k: int, s: int) -> int: + """ + Formula to apply to calculate the length of some dimension of the output + See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html + Args: + ln: length of the dimension + p: padding in that dim + d: dilation in that dim + k: kernel size in that dim + s: stride in that dim + Returns: + The output length + """ + return (ln + 2 * p - d * (k - 1) - 1) // s + 1 + + def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int: + """ + Formula to apply to calculate the length of some dimension of the output + if transposed convolution is used. + See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html + Args: + ln: length of the dimension + p: padding in that dim + d: dilation in that dim + k: kernel size in that dim + s: stride in that dim + op: output padding in that dim + Returns: + The output length + """ + return (ln - 1) * s - 2 * p + d * (k - 1) + op + 1 + + def calc_conv_nd_return_shape( + dims: torch.Size, + kernel_size: torch.Size, + stride: Union[List[int], int], + padding: Union[List[int], int], + dilation: Union[List[int], int], + output_padding: Optional[Union[List[int], int]] = None, + ): + ret_shape = [] + if isinstance(stride, int): + stride = [stride] * len(dims) + elif len(stride) == 1: + stride = [stride[0]] * len(dims) + + if isinstance(padding, int): + padding = [padding] * len(dims) + elif len(padding) == 1: + padding = [padding[0]] * len(dims) + + if isinstance(dilation, int): + dilation = [dilation] * len(dims) + elif len(dilation) == 1: + dilation = [dilation[0]] * len(dims) + + output_padding_list: Optional[List[int]] = None + if output_padding: + if isinstance(output_padding, int): + output_padding_list = [output_padding] * len(dims) + elif len(output_padding) == 1: + output_padding_list = [output_padding[0]] * len(dims) + else: + output_padding_list = output_padding + + for i in range(len(dims)): + # If output_padding is present, we are dealing with a transposed convolution + if output_padding_list: + ret_shape.append( + _formula_transposed( + dims[i], + padding[i], + dilation[i], + kernel_size[i], + stride[i], + output_padding_list[i], + )) + else: + ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i])) + return ret_shape + + def pick_memory_format(): + if input_tensor.is_contiguous(memory_format=torch.channels_last): + return torch.channels_last + elif input_tensor.is_contiguous(memory_format=torch.contiguous_format): + return torch.contiguous_format + elif input_tensor.is_contiguous(memory_format=torch.preserve_format): + return torch.preserve_format + + kernel_size = weight.shape[2:] + dims = input_tensor.shape[2:] + if is_transposed: + out_channels = groups * weight.shape[1] + + shape_out = calc_conv_nd_return_shape( + dims, + kernel_size, + stride, + padding, + dilation, + output_padding, + ) + + else: + out_channels = weight.shape[0] + if weight.shape[1] != input_tensor.shape[1] / groups: + raise RuntimeError("Invalid channel dimensions") + shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation) + out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out)) + mem_fmt = pick_memory_format() + out = out.to(memory_format=mem_fmt) # type: ignore[call-overload] + return out + + @register_meta(aten._convolution.default) + def meta__conv(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int], + padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int, + *extra_args): + out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups) + return out + + @register_meta(aten.convolution_backward.default) + def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride, + padding, dilation, transposed, output_padding, groups, output_mask): + return new_like(input), new_like(weight), new((bias_sizes)) + + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp + @register_meta(aten._adaptive_avg_pool2d_backward.default) + def meta_adaptive_avg_pool2d_backward( + grad_output: torch.Tensor, + input: torch.Tensor, + ): + return new_like(input) + + # ================================ RNN ============================================= + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp + @register_meta(aten._cudnn_rnn.default) + def meta_cuda_rnn( + input, + weight, + weight_stride0, + weight_buf, + hx, + cx, + mode, + hidden_size, + proj_size, + num_layers, + batch_first, + dropout, + train, + bidirectional, + batch_sizes, + dropout_state, + ): + + is_input_packed = len(batch_sizes) != 0 + if is_input_packed: + seq_length = len(batch_sizes) + mini_batch = batch_sizes[0] + batch_sizes_sum = input.shape[0] + else: + seq_length = input.shape[1] if batch_first else input.shape[0] + mini_batch = input.shape[0] if batch_first else input.shape[1] + batch_sizes_sum = -1 + + num_directions = 2 if bidirectional else 1 + out_size = proj_size if proj_size != 0 else hidden_size + if is_input_packed: + out_shape = [batch_sizes_sum, out_size * num_directions] + else: + out_shape = ([mini_batch, seq_length, out_size * + num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions]) + output = input.new_empty(out_shape) + + cell_shape = [num_layers * num_directions, mini_batch, hidden_size] + cy = new(0) if cx is None else cx.new_empty(cell_shape) + + hy = hx.new_empty([num_layers * num_directions, mini_batch, out_size]) + + # TODO: Query cudnnGetRNNTrainingReserveSize (expose to python) + reserve_shape = 0 if train else 0 + reserve = input.new_empty(reserve_shape, dtype=torch.uint8) + + return output, hy, cy, reserve, weight_buf + + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp + @register_meta(aten._cudnn_rnn_backward.default) + def meta_cudnn_rnn_backward(input: torch.Tensor, + weight: torch.Tensor, + weight_stride0: int, + hx: torch.Tensor, + cx: Optional[torch.Tensor] = None, + *args, + **kwargs): + return new_like(input), new_like(weight), new_like(hx), new_like(cx) if cx is not None else new( + ()) # (grad_input, grad_weight, grad_hx, grad_cx) + + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp + # ============================== Activations ======================================= + _unregistered_ewise = [ + aten.relu.default, + aten.prelu.default, + aten.hardswish.default, + aten.hardtanh.default, + aten.hardswish_backward.default, + aten.hardtanh_backward.default, + ] + + if version.parse(torch.__version__) < version.parse('2.0.0'): + _unregistered_ewise += [ + aten.prelu_backward.default, + ] + + @register_meta(_unregistered_ewise) + def meta_unregistered_ewise(input: torch.Tensor, *args): + return new_like(input) + + # ============================== Normalization ===================================== + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp + @register_meta(aten.native_batch_norm.default) + def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps): + n_input = input.size(1) + return new_like(input), new((n_input)), new((n_input)) + + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp + @register_meta(aten.native_batch_norm_backward.default) + def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, + save_mean, save_invstd, train, eps, output_mask): + return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta) + + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp + @register_meta(aten.cudnn_batch_norm.default) + def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps): + n_input = input.size(1) + return new_like(input), new((n_input)), new((n_input)), new( + (0), dtype=torch.uint8) # (output, running_mean, running_var, reserve) + + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp + # NB: CuDNN only implements the backward algorithm for batchnorm + # in training mode (evaluation mode batchnorm has a different algorithm), + # which is why this doesn't accept a 'training' parameter. + @register_meta(aten.cudnn_batch_norm_backward.default) + def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, + save_mean, save_invstd, eps, reserve): + return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta) + + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp + @register_meta(aten.native_layer_norm.default) + def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps): + bs, n_input = input.size(0), input.size(1) + return new_like(input), new((bs, n_input, 1)), new((bs, n_input, 1)) # (output, running_mean, running_var) + + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp + @register_meta(aten.native_layer_norm_backward.default) + def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias, + grad_input_mask): + return new_like(input), new_like(weight), new_like(bias) # (dX, dgamma, dbeta) + + # ================================== Misc ========================================== + # Maybe incorrect + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Im2Col.cpp + @register_meta(aten.im2col.default) + def meta_im2col(input: torch.Tensor, kernel_size, dilation, padding, stride): + return new_like(input) + + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml + @register_meta(aten.roll.default) + def meta_roll(input: torch.Tensor, shifts, dims): + return input + + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Scalar.cpp + @register_meta(aten._local_scalar_dense.default) + def meta_local_scalar_dense(self: torch.Tensor): + return 0 + + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp + @register_meta(aten.where.self) + def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor): + result_type = torch.result_type(self, other) + return new_like(condition + self + other, dtype=result_type) + + # ============================== Embedding ========================================= + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp + + @register_meta(aten.embedding_dense_backward.default) + def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, + scale_grad_by_freq): + return new((num_weights, grad_output.size(-1)), dtype=grad_output.dtype, layout=grad_output.layout) + + # ============================== Dropout =========================================== + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp + @register_meta(aten.native_dropout.default) + def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False): + # notice that mask is bool + return new_like(input), new_like(input, dtype=torch.bool) # (output, mask) + + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp + @register_meta(aten.native_dropout_backward.default) + def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float): + return new_like(grad) # (grad_in) + + if version.parse(torch.__version__) < version.parse('1.13.0'): + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml + @register_meta(aten.eye.m_out) + def meta_eye(n: int, m: int, out: torch.Tensor): + return out + + @register_meta(aten.index.Tensor) + def meta_index_Tensor(self, indices): + assert indices, "at least one index must be provided" + # aten::index is the internal advanced indexing implementation + # checkIndexTensorTypes and expandTensors + result: List[Optional[torch.Tensor]] = [] + for i, index in enumerate(indices): + if index is not None: + assert index.dtype in [torch.long, torch.int8, torch.bool],\ + "tensors used as indices must be long, byte or bool tensors" + if index.dtype in [torch.int8, torch.bool]: + nonzero = index.nonzero() + k = len(result) + assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}" + for j in range(index.ndim): + assert index.shape[j] == self.shape[ + k + + j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}" + result.append(nonzero.select(1, j)) + else: + result.append(index) + else: + result.append(index) + indices = result + assert len( + indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})" + # expand_outplace + import torch._refs as refs + + indices = list(refs._maybe_broadcast(*indices)) + # add missing null tensors + while len(indices) < self.ndim: + indices.append(None) + + # hasContiguousSubspace + # true if all non-null tensors are adjacent + # See: + # https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing + # https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency + state = 0 + has_contiguous_subspace = False + for index in indices: + if state == 0: + if index is not None: + state = 1 + elif state == 1: + if index is None: + state = 2 + else: + if index is not None: + break + else: + has_contiguous_subspace = True + + # transposeToFront + # This is the logic that causes the newly inserted dimensions to show up + # at the beginning of the tensor, if they're not contiguous + if not has_contiguous_subspace: + dims = [] + transposed_indices = [] + for i, index in enumerate(indices): + if index is not None: + dims.append(i) + transposed_indices.append(index) + for i, index in enumerate(indices): + if index is None: + dims.append(i) + transposed_indices.append(index) + self = self.permute(dims) + indices = transposed_indices + + # AdvancedIndex::AdvancedIndex + # Now we can assume the indices have contiguous subspace + # This is simplified from AdvancedIndex which goes to more effort + # to put the input and indices in a form so that TensorIterator can + # take them. If we write a ref for this, probably that logic should + # get implemented + before_shape: List[int] = [] + after_shape: List[int] = [] + replacement_shape: List[int] = [] + for dim, index in enumerate(indices): + if index is None: + if replacement_shape: + after_shape.append(self.shape[dim]) + else: + before_shape.append(self.shape[dim]) + else: + replacement_shape = list(index.shape) + return self.new_empty(before_shape + replacement_shape + after_shape) diff --git a/colossalai/_analyzer/_subclasses/_monkey_patch.py b/colossalai/_analyzer/_subclasses/_monkey_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..b3ec98f0811f265e370c638369269f73e589dfd6 --- /dev/null +++ b/colossalai/_analyzer/_subclasses/_monkey_patch.py @@ -0,0 +1,93 @@ +import torch +import torch.distributed as dist +from packaging import version + +__all__ = [ + "_TorchFactoryMethod", + "_TorchOverrideableFactoryMethod", + "_TorchNonOverrideableFactoryMethod", + "_TensorPropertyMethod", + "_DistCommMethod", + "_AliasATen", + "_InplaceATen", + "_MaybeInplaceATen", +] + +_TorchOverrideableFactoryMethod = [ + "empty", + "eye", + "full", + "ones", + "rand", + "randn", + "zeros", +] + +_TorchNonOverrideableFactoryMethod = [ + "arange", + "finfo", + "linspace", + "logspace", + "randint", + "randperm", + "tensor", +] + +_TorchFactoryMethod = _TorchOverrideableFactoryMethod + _TorchNonOverrideableFactoryMethod + +_TensorPropertyMethod = ["dtype", "shape", "device", "requires_grad", "grad", "grad_fn", "data"] + +_DistCommMethod = [ + "all_gather", + "all_reduce", + "all_to_all", + "broadcast", + "gather", + "reduce", + "reduce_scatter", + "scatter", +] + +if version.parse(torch.__version__) >= version.parse('1.12.0'): + aten = torch.ops.aten + # TODO: dive deep here + # refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp + _AliasATen = [ + aten.detach.default, + aten.detach_.default, + aten.t.default, + aten.transpose.int, + aten.view.default, + aten._unsafe_view.default, + aten._reshape_alias.default, + ] + + _InplaceATen = [ + aten.add_.Tensor, + aten.add_.Scalar, + aten.sub_.Tensor, + aten.sub_.Scalar, + aten.mul_.Tensor, + aten.mul_.Scalar, + aten.div_.Tensor, + aten.div_.Scalar, + aten.pow_.Tensor, + aten.pow_.Scalar, + ] + + # use `MaybeInplace` because they call ``as_strided()`` or ``slice()`` + _MaybeInplaceATen = [ + aten.diagonal.default, + aten.expand.default, + aten.select.int, + aten.slice.Tensor, + aten.split.Tensor, + aten.squeeze.default, + aten.permute.default, + aten.unsqueeze.default, + aten.as_strided.default, + ] +else: + _AliasATen = [] + _InplaceATen = [] + _MaybeInplaceATen = [] diff --git a/colossalai/_analyzer/_subclasses/flop_tensor.py b/colossalai/_analyzer/_subclasses/flop_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..59991dc5091254631dd4c641c528b09aaf97c285 --- /dev/null +++ b/colossalai/_analyzer/_subclasses/flop_tensor.py @@ -0,0 +1,563 @@ +# adopted from https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/jit_handles.py +# ideas from https://pastebin.com/AkvAyJBw +# and https://dev-discuss.pytorch.org/t/the-ideal-pytorch-flop-counter-with-torch-dispatch/505 + +import operator +from collections import defaultdict +from contextlib import contextmanager +from enum import Enum, auto +from functools import partial, reduce +from numbers import Number +from typing import Any, Callable, List, Optional, Union + +import torch +from packaging import version +from torch.utils._pytree import tree_map + +from .meta_tensor import MetaTensor + +aten = torch.ops.aten + + +class Phase(Enum): + FWD = auto() + BWD = auto() + + +def normalize_tuple(x): + if not isinstance(x, tuple): + return (x,) + return x + + +def _format_flops(flop): + K = 1e3 + M = 1e6 + B = 1e9 + T = 1e12 + if flop < K: + return f'{flop:.2f}' + elif flop < M: + return f'{flop / K:.2f}K' + elif flop < B: + return f'{flop / M:.2f}M' + elif flop < T: + return f'{flop / B:.2f}B' + else: + return f'{flop / T:.2f}T' + + +def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose: bool = False, **kwargs) -> Number: + """ + Count the number of floating point operations in a model. + Ideas from https://pastebin.com/AkvAyJBw. + Args: + module (torch.nn.Module): A PyTorch model. + *args: Input arguments to the model. + verbose (bool): If True, print the number of flops for each module. + **kwargs: Input keyword arguments to the model. + Returns: + Number: The total number of floating point operations (FWD + BWD). + """ + maybe_inplace = (getattr(module, 'inplace', False) or kwargs.get('inplace', False) + or getattr(module, '__name__', None) in ('add_', 'mul_', 'div_', 'sub_')) + + class DummyModule(torch.nn.Module): + + def __init__(self, func): + super().__init__() + self.func = func + self.__name__ = func.__name__ + + def forward(self, *args, **kwargs): + return self.func(*args, **kwargs) + + total_flop_count = {Phase.FWD: 0, Phase.BWD: 0} + flop_counts = defaultdict(lambda: defaultdict(int)) + parents = ['Global'] + module = module if isinstance(module, torch.nn.Module) else DummyModule(module) + + class FlopTensor(MetaTensor): + _tensor: torch.Tensor + + def __repr__(self): + name = 'FlopParameter' if getattr(self, '_is_param', False) else 'FlopTensor' + if self.grad_fn: + return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})" + return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})" + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + + # no_dispatch is only needed if you use enable_python_mode. + # It prevents infinite recursion. + rs = super().__torch_dispatch__(func, types, args, kwargs) + + outs = normalize_tuple(rs) + + if func in flop_mapping: + nonlocal flop_counts, total_flop_count + flop_count = flop_mapping[func](args, outs) + for par in parents: + flop_counts[par][func.__name__] += flop_count + total_flop_count[cur_phase] += flop_count + + def wrap(x): + if isinstance(x, MetaTensor): + x = FlopTensor(x) + return x + + rs = tree_map(wrap, rs) + + return rs + + def is_autogradable(x): + return isinstance(x, torch.Tensor) and x.is_floating_point() + + def create_backwards_push(name): + + class PushState(torch.autograd.Function): + + @staticmethod + def forward(ctx, *args): + args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args) + if len(args) == 1: + return args[0] + return args + + @staticmethod + def backward(ctx, *grad_outs): + nonlocal parents + parents.append(name) + return grad_outs + + return PushState.apply + + def create_backwards_pop(name): + + class PopState(torch.autograd.Function): + + @staticmethod + def forward(ctx, *args): + args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args) + if len(args) == 1: + return args[0] + return args + + @staticmethod + def backward(ctx, *grad_outs): + nonlocal parents + assert (parents[-1] == name) + parents.pop() + return grad_outs + + return PopState.apply + + def enter_module(name): + + def f(module, inputs): + nonlocal parents + parents.append(name) + inputs = normalize_tuple(inputs) + out = create_backwards_pop(name)(*inputs) + return out + + return f + + def exit_module(name): + + def f(module, inputs, outputs): + nonlocal parents + assert (parents[-1] == name) + parents.pop() + outputs = normalize_tuple(outputs) + return create_backwards_push(name)(*outputs) + + return f + + @contextmanager + def instrument_module(mod): + registered = [] + for name, module in dict(mod.named_children()).items(): + registered.append(module.register_forward_pre_hook(enter_module(name))) + registered.append(module.register_forward_hook(exit_module(name))) + yield + for handle in registered: + handle.remove() + + def display_flops(): + for mod in flop_counts.keys(): + print(f"Module: ", mod) + for k, v in flop_counts[mod].items(): + print('\t', k, _format_flops(v)) + print() + + def detach_variables(r): + if isinstance(r, torch.Tensor): + requires_grad = r.requires_grad + r = r.detach() + r.requires_grad = requires_grad + return r + + def wrap(r): + if isinstance(r, torch.Tensor): + data_ptr_fn = getattr(r, '_tensor', r).data_ptr + r = FlopTensor(detach_variables(r)) + if maybe_inplace: + r = r + 0 + r._tensor.data_ptr = data_ptr_fn + return r + + with instrument_module(module): + cur_phase = Phase.FWD + rst = module(*tree_map(wrap, args), **tree_map(wrap, kwargs)) + rst = tuple(r for r in normalize_tuple(rst) if is_autogradable(r) and r.requires_grad) + cur_phase = Phase.BWD + + if rst: + grad = [torch.zeros_like(t) for t in rst] + torch.autograd.backward( + rst, + grad, + ) + + if verbose: + display_flops() + + return total_flop_count[Phase.FWD], total_flop_count[Phase.BWD] + + +def matmul_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: + """ + Count flops for matmul. + """ + # Inputs should be a list of length 2. + # Inputs contains the shapes of two matrices. + input_shapes = [v.shape for v in inputs] + assert len(input_shapes) == 2, input_shapes + + # There are three cases: 1) gemm, 2) gemv, 3) dot + if all(len(shape) == 2 for shape in input_shapes): + # gemm + assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes + elif all(len(shape) == 1 for shape in input_shapes): + # dot + assert input_shapes[0][0] == input_shapes[1][0], input_shapes + + # expand shape + input_shapes[0] = torch.Size([1, input_shapes[0][0]]) + input_shapes[1] = torch.Size([input_shapes[1][0], 1]) + else: + # gemv + if len(input_shapes[0]) == 1: + assert input_shapes[0][0] == input_shapes[1][-2], input_shapes + input_shapes.reverse() + else: + assert input_shapes[1][0] == input_shapes[0][-1], input_shapes + + # expand the shape of the vector to [batch size, 1] + input_shapes[-1] = torch.Size([input_shapes[-1][-1], 1]) + flops = reduce(operator.mul, input_shapes[0]) * input_shapes[-1][-1] + return flops + + +def addmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: + """ + Count flops for fully connected layers. + """ + # Count flop for nn.Linear + # inputs is a list of length 3. + input_shapes = [v.shape for v in inputs[1:3]] + # input_shapes[0]: [batch size, input feature dimension] + # input_shapes[1]: [input feature dimension, output feature dimension] + assert len(input_shapes[0]) == 2, input_shapes[0] + assert len(input_shapes[1]) == 2, input_shapes[1] + batch_size, input_dim = input_shapes[0] + output_dim = input_shapes[1][1] + flops = batch_size * input_dim * output_dim + return flops + + +def linear_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: + """ + Count flops for the aten::linear operator. + """ + # Inputs is a list of length 3; unlike aten::addmm, it is the first + # two elements that are relevant. + input_shapes = [v.shape for v in inputs[0:2]] + # input_shapes[0]: [dim0, dim1, ..., input_feature_dim] + # input_shapes[1]: [output_feature_dim, input_feature_dim] + assert input_shapes[0][-1] == input_shapes[1][-1] + flops = reduce(operator.mul, input_shapes[0]) * input_shapes[1][0] + return flops + + +def bmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: + """ + Count flops for the bmm operation. + """ + # Inputs should be a list of length 2. + # Inputs contains the shapes of two tensor. + assert len(inputs) == 2, len(inputs) + input_shapes = [v.shape for v in inputs] + n, c, t = input_shapes[0] + d = input_shapes[-1][-1] + flops = n * c * t * d + return flops + + +def conv_flop_count( + x_shape: List[int], + w_shape: List[int], + out_shape: List[int], + transposed: bool = False, +) -> Number: + """ + Count flops for convolution. Note only multiplication is + counted. Computation for addition and bias is ignored. + Flops for a transposed convolution are calculated as + flops = (x_shape[2:] * prod(w_shape) * batch_size). + Args: + x_shape (list(int)): The input shape before convolution. + w_shape (list(int)): The filter shape. + out_shape (list(int)): The output shape after convolution. + transposed (bool): is the convolution transposed + Returns: + int: the number of flops + """ + batch_size = x_shape[0] + conv_shape = (x_shape if transposed else out_shape)[2:] + flops = batch_size * reduce(operator.mul, w_shape) * reduce(operator.mul, conv_shape) + return flops + + +def conv_flop_jit(inputs: List[Any], outputs: List[Any]): + """ + Count flops for convolution. + """ + x, w = inputs[:2] + x_shape, w_shape, out_shape = (x.shape, w.shape, outputs[0].shape) + transposed = inputs[6] + + return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed) + + +def transpose_shape(shape): + return [shape[1], shape[0]] + list(shape[2:]) + + +def conv_backward_flop_jit(inputs: List[Any], outputs: List[Any]): + grad_out_shape, x_shape, w_shape = [i.shape for i in inputs[:3]] + output_mask = inputs[-1] + fwd_transposed = inputs[7] + flop_count = 0 + + if output_mask[0]: + grad_input_shape = outputs[0].shape + flop_count += conv_flop_count(grad_out_shape, w_shape, grad_input_shape, not fwd_transposed) + if output_mask[1]: + grad_weight_shape = outputs[1].shape + flop_count += conv_flop_count(transpose_shape(x_shape), grad_out_shape, grad_weight_shape, fwd_transposed) + + return flop_count + + +def norm_flop_counter(affine_arg_index: int, input_arg_index: int) -> Callable: + """ + Args: + affine_arg_index: index of the affine argument in inputs + """ + + def norm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: + """ + Count flops for norm layers. + """ + # Inputs[0] contains the shape of the input. + input_shape = inputs[input_arg_index].shape + + has_affine = inputs[affine_arg_index].shape is not None if hasattr(inputs[affine_arg_index], + 'shape') else inputs[affine_arg_index] + assert 2 <= len(input_shape) <= 5, input_shape + # 5 is just a rough estimate + flop = reduce(operator.mul, input_shape) * (5 if has_affine else 4) + return flop + + return norm_flop_jit + + +def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any], training: bool = None) -> Number: + if training is None: + training = inputs[-3] + assert isinstance(training, bool), "Signature of aten::batch_norm has changed!" + if training: + return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore + has_affine = inputs[1].shape is not None + input_shape = reduce(operator.mul, inputs[0].shape) + return input_shape * (2 if has_affine else 1) + + +def ewise_flop_counter(input_scale: float = 1, output_scale: float = 0) -> Callable: + """ + Count flops by + input_tensor.numel() * input_scale + output_tensor.numel() * output_scale + Args: + input_scale: scale of the input tensor (first argument) + output_scale: scale of the output tensor (first element in outputs) + """ + + def ewise_flop(inputs: List[Any], outputs: List[Any]) -> Number: + ret = 0 + if input_scale != 0: + shape = inputs[0].shape + ret += input_scale * reduce(operator.mul, shape) if shape else 0 + if output_scale != 0: + shape = outputs[0].shape + ret += output_scale * reduce(operator.mul, shape) if shape else 0 + return ret + + return ewise_flop + + +def zero_flop_jit(*args): + """ + Count flops for zero flop layers. + """ + return 0 + + +if version.parse(torch.__version__) >= version.parse('1.12.0'): + flop_mapping = { + # gemm + aten.mm.default: matmul_flop_jit, + aten.matmul.default: matmul_flop_jit, + aten.addmm.default: addmm_flop_jit, + aten.bmm.default: bmm_flop_jit, + + # convolution + aten.convolution.default: conv_flop_jit, + aten._convolution.default: conv_flop_jit, + aten.convolution_backward.default: conv_backward_flop_jit, + + # normalization + aten.native_batch_norm.default: batchnorm_flop_jit, + aten.native_batch_norm_backward.default: batchnorm_flop_jit, + aten.cudnn_batch_norm.default: batchnorm_flop_jit, + aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True), + aten.native_layer_norm.default: norm_flop_counter(2, 0), + aten.native_layer_norm_backward.default: norm_flop_counter(2, 0), + + # pooling + aten.avg_pool1d.default: ewise_flop_counter(1, 0), + aten.avg_pool2d.default: ewise_flop_counter(1, 0), + aten.avg_pool2d_backward.default: ewise_flop_counter(0, 1), + aten.avg_pool3d.default: ewise_flop_counter(1, 0), + aten.avg_pool3d_backward.default: ewise_flop_counter(0, 1), + aten.max_pool1d.default: ewise_flop_counter(1, 0), + aten.max_pool2d.default: ewise_flop_counter(1, 0), + aten.max_pool3d.default: ewise_flop_counter(1, 0), + aten.max_pool1d_with_indices.default: ewise_flop_counter(1, 0), + aten.max_pool2d_with_indices.default: ewise_flop_counter(1, 0), + aten.max_pool2d_with_indices_backward.default: ewise_flop_counter(0, 1), + aten.max_pool3d_with_indices.default: ewise_flop_counter(1, 0), + aten.max_pool3d_with_indices_backward.default: ewise_flop_counter(0, 1), + aten._adaptive_avg_pool2d.default: ewise_flop_counter(1, 0), + aten._adaptive_avg_pool2d_backward.default: ewise_flop_counter(0, 1), + aten._adaptive_avg_pool3d.default: ewise_flop_counter(1, 0), + aten._adaptive_avg_pool3d_backward.default: ewise_flop_counter(0, 1), + aten.embedding_dense_backward.default: ewise_flop_counter(0, 1), + aten.embedding.default: ewise_flop_counter(1, 0), + } + + ewise_flop_aten = [ + # basic op + aten.add.Tensor, + aten.add_.Tensor, + aten.div.Tensor, + aten.div_.Tensor, + aten.div.Scalar, + aten.div_.Scalar, + aten.mul.Tensor, + aten.mul.Scalar, + aten.mul_.Tensor, + aten.neg.default, + aten.pow.Tensor_Scalar, + aten.rsub.Scalar, + aten.sum.default, + aten.sum.dim_IntList, + aten.mean.dim, + + # activation op + aten.hardswish.default, + aten.hardswish_.default, + aten.hardswish_backward.default, + aten.hardtanh.default, + aten.hardtanh_.default, + aten.hardtanh_backward.default, + aten.hardsigmoid_backward.default, + aten.hardsigmoid.default, + aten.gelu.default, + aten.gelu_backward.default, + aten.silu.default, + aten.silu_.default, + aten.silu_backward.default, + aten.sigmoid.default, + aten.sigmoid_backward.default, + aten._softmax.default, + aten._softmax_backward_data.default, + aten.relu_.default, + aten.relu.default, + aten.tanh.default, + aten.tanh_backward.default, + aten.threshold_backward.default, + + # dropout + aten.native_dropout.default, + aten.native_dropout_backward.default, + + # distribution + aten.bernoulli_.float, + + # where + aten.where.self, + ] + for op in ewise_flop_aten: + flop_mapping[op] = ewise_flop_counter(1, 0) + + # fix-me: this will be removed in future + zero_flop_aten = [ + aten.as_strided.default, + aten.as_strided_.default, + aten.cat.default, + aten.clone.default, + aten.copy_.default, + aten.detach.default, + aten.expand.default, + aten.empty_like.default, + aten.new_empty.default, + aten.new_empty_strided.default, + aten.ones_like.default, + aten._reshape_alias.default, + aten.select.int, + aten.select_backward.default, + aten.squeeze.dim, + aten.slice.Tensor, + aten.slice_backward.default, + aten.split.Tensor, + aten.permute.default, + aten.t.default, + aten.transpose.int, + aten._to_copy.default, + aten.unsqueeze.default, + aten.unbind.int, + aten._unsafe_view.default, + aten.view.default, + aten.zero_.default, + aten.zeros_like.default, + ] + + for op in zero_flop_aten: + flop_mapping[op] = zero_flop_jit +else: + flop_mapping = {} + elementwise_flop_aten = {} + zero_flop_aten = {} diff --git a/colossalai/_analyzer/_subclasses/meta_tensor.py b/colossalai/_analyzer/_subclasses/meta_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..2bc212938ee08f1143b5dda354d9ec600dd64662 --- /dev/null +++ b/colossalai/_analyzer/_subclasses/meta_tensor.py @@ -0,0 +1,207 @@ +import uuid +from functools import partial + +import torch +import torch.distributed as dist +from torch.types import _bool, _device, _dtype +from torch.utils._pytree import tree_flatten, tree_map + +from ._monkey_patch import _AliasATen, _DistCommMethod, _InplaceATen, _MaybeInplaceATen, _TorchOverrideableFactoryMethod + +__all__ = ['MetaTensor', 'MetaTensorMode'] + + +def register_storage(r, data_ptr_fn=None): + if isinstance(r, torch.Tensor): + if data_ptr_fn is not None: + r.data_ptr = data_ptr_fn + elif not r.data_ptr(): + data_ptr = uuid.uuid1() + r.data_ptr = lambda: data_ptr + + +def _normalize_tuple(x): + if not isinstance(x, tuple): + return (x,) + return x + + +# a hack of inplace execution in PyTorch +def _assert_alias(func): + return func in (_AliasATen + _InplaceATen + _MaybeInplaceATen # TODO: check if should be this aggressive + ) + + +class MetaTensor(torch.Tensor): + """ + A wrapping tensor that hacks ``torch.autograd`` without patching more ``torch.ops.aten`` ops. + `device` is the device that ``MetaTensor`` is supposed to run on. Meta tensors give you the + ability to run PyTorch code without having to actually do computation through tensors + allocated on a `meta` device. Because the device is `meta`, meta tensors do not model + device propagation. ``MetaTensor`` extends its usage by carrying an additional `device` + which tracks devices that would have been used. + + Reference: + https://github.com/pytorch/pytorch/blob/master/torch/_subclasses/fake_tensor.py + """ + + _tensor: torch.Tensor + + @staticmethod + def __new__(cls, elem, device=None, data_ptr_fn=None): + requires_grad = elem.requires_grad + # Avoid multiple wrapping + while isinstance(elem, MetaTensor): + device = elem.device if device is None else device + elem = elem._tensor + + # The wrapping tensor (MetaTensor) shouldn't hold any + # memory for the class in question, but it should still + # advertise the same device as before + r = torch.Tensor._make_wrapper_subclass( + cls, + elem.size(), + strides=elem.stride(), + storage_offset=elem.storage_offset(), + dtype=elem.dtype, + layout=elem.layout, + device=device or (elem.device if elem.device.type != 'meta' else torch.device('cpu')), + requires_grad=requires_grad) # deceive the frontend for aten selections + r._tensor = elem + # ...the real tensor is held as an element on the tensor. + if not r._tensor.is_meta: + val = elem.data_ptr() + data_ptr_fn = lambda: val + r._tensor = r._tensor.to(torch.device('meta')) + + # only tensor not on `meta` should be copied to `meta` + register_storage(r._tensor, data_ptr_fn) + if isinstance(elem, torch.nn.Parameter): + r = torch.nn.Parameter(r) + return r + + def __repr__(self): + name = 'MetaParameter' if getattr(self, '_is_param', False) else 'MetaTensor' + if self.grad_fn: + return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})" + return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})" + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + device = None + + def unwrap(x): + nonlocal device + if isinstance(x, MetaTensor): + device = x.device + x = x._tensor + elif isinstance(x, torch.Tensor): + device = x.device + x = x.to(torch.device('meta')) + return x + + args = tree_map(unwrap, args) + kwargs = tree_map(unwrap, kwargs) + + if 'device' in kwargs: + device = kwargs['device'] + kwargs['device'] = torch.device('meta') + + # run aten for backend=CPU but actually on backend=Meta + # here we detect whether or not the execution generates a physical copy + # of the input tensor + ret = func(*args, **kwargs) + + if _assert_alias(func): + val = args[0].data_ptr() + tree_map(partial(register_storage, data_ptr_fn=lambda: val), _normalize_tuple(ret)) + + # Now, we want to continue propagating this tensor, so we rewrap Tensors in + # our custom tensor subclass + def wrap(x): + return MetaTensor(x, device=device) if isinstance(x, torch.Tensor) else x + + return tree_map(wrap, ret) + + def to(self, *args, **kwargs) -> torch.Tensor: + """An extension of `torch.Tensor.to()` to MetaTensor + Returns: + result (MetaTensor): MetaTensor + Usage: + >>> tensor = MetaTensor(torch.rand(10), device='cuda:100') + >>> tensor.to(torch.uint8) + MetaTensor(tensor(..., device='meta', size=(10,), dtype=torch.uint8), device='cuda:100') + >>> tensor.to(torch.device('cuda:42')) + MetaTensor(tensor(..., device='meta', size=(10,)), device='cuda:42') + >>> tensor.to('vulkan') + MetaTensor(tensor(..., device='meta', size=(10,)), device='vulkan') + """ + # this imitates c++ function in the way of @overload + device = None + + def replace(x): + nonlocal device + if isinstance(x, str) or isinstance(x, _device): + device = x + return torch.device('meta') + return x + + elem = self._tensor.to(*tree_map(replace, args), **tree_map(replace, kwargs)) + return MetaTensor(elem, device=device) + + def cpu(self, *args, **kwargs): + if self.device.type == 'cpu': + return self.to(*args, **kwargs) + return self.to(*args, device='cpu', **kwargs) + + def cuda(self, device=None, non_blocking=False): + if device is not None: + return self.to(device=device, non_blocking=non_blocking) + return self.to(device='cuda:0', non_blocking=non_blocking) + + def data_ptr(self): + return self._tensor.data_ptr() + + +class MetaTensorMode(object): + """ + A context manager that enables MetaTensor mode. + + Usage: + >>> with MetaTensorMode(): + >>> # all torch.xxx and torch.distributed.xxx will be replaced by patched functions + >>> # and the actual execution will be on torch.device('meta') + >>> a = torch.rand(100000, 100000) + >>> b = torch.rand(100000, 100000) + >>> c = torch.mm(a, b) + """ + + def __init__(self): + self.torch_overrides = {} # override torch.xxx + self.dist_overrides = {} # override torch.distributed.xxx + + def __enter__(self): + + def _dummy(*args, **kwargs): + pass + + def _new(*args, orig_new=torch.empty, **kwargs): + return MetaTensor(orig_new(*args, **{ + **kwargs, 'device': 'meta' + }), + device=kwargs.get('device', torch.device('cpu'))) + + for func in _TorchOverrideableFactoryMethod: + self.torch_overrides[func] = getattr(torch, func) + setattr(torch, func, partial(_new, orig_new=getattr(torch, func))) + + for func in _DistCommMethod: + self.dist_overrides[func] = getattr(dist, func) + setattr(dist, func, _dummy) + + def __exit__(self, exc_type, exc_value, traceback): + for func, func_impl in self.torch_overrides.items(): + setattr(torch, func, func_impl) + + for func, func_impl in self.dist_overrides.items(): + setattr(dist, func, func_impl) diff --git a/colossalai/_analyzer/envs.py b/colossalai/_analyzer/envs.py new file mode 100644 index 0000000000000000000000000000000000000000..b537747c57a89824854b1b32f729fdb76b53dbd5 --- /dev/null +++ b/colossalai/_analyzer/envs.py @@ -0,0 +1,7 @@ +from dataclasses import dataclass + + +@dataclass +class MeshConfig: + TFLOPS: float = 1.9e12 + BANDWIDTH = 1.2e9 diff --git a/colossalai/_analyzer/fx/__init__.py b/colossalai/_analyzer/fx/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aa01de0bbe6c8e990e6b8054f525985343afa050 --- /dev/null +++ b/colossalai/_analyzer/fx/__init__.py @@ -0,0 +1,3 @@ +from .node_util import MetaInfo +from .symbolic_profile import symbolic_profile +from .tracer.symbolic_trace import symbolic_trace diff --git a/colossalai/_analyzer/fx/codegen.py b/colossalai/_analyzer/fx/codegen.py new file mode 100644 index 0000000000000000000000000000000000000000..41d74f2e3719a0e56adc61a6c35684584a2d80c7 --- /dev/null +++ b/colossalai/_analyzer/fx/codegen.py @@ -0,0 +1,459 @@ +from typing import Any, Callable, Dict, Iterable, List, Tuple + +import torch + +try: + from torch.fx.graph import CodeGen +except: + pass +from torch.fx.graph import ( + PythonCode, + _custom_builtins, + _format_target, + _is_from_torch, + _Namespace, + _origin_type_map, + _register_custom_builtin, + inplace_methods, + magic_methods, +) +from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg + +import colossalai +from colossalai.fx._compatibility import compatibility + +_register_custom_builtin('colossalai', 'import colossalai', colossalai) + + +def _gen_ckpt_fn_def(label, free_vars: List[str]) -> str: + """ + Generate the checkpoint function definition + """ + return f"def checkpoint_{label}({', '.join(['self'] + free_vars)}):" + + +def _gen_ckpt_output(output_vars: List[str]) -> str: + """ + Generate the return statement for checkpoint region + """ + return f"return {', '.join(output_vars)}" + + +def _gen_ckpt_usage(label, input_vars, output_vars, use_reentrant=True): + """ + Generate the checkpoint function call code text + """ + outputs = ', '.join(output_vars) + inputs = ', '.join(input_vars) + return f'{outputs} = torch.utils.checkpoint.checkpoint(self.checkpoint_{label}, {inputs}, use_reentrant={use_reentrant})' + + +def _end_of_ckpt(node: Node, ckpt_level: int) -> bool: + """ + Check if the node could end the ckpt region at `ckpt_level` + """ + if len(node.meta['info'].activation_checkpoint) > ckpt_level: + return node.meta['info'].activation_checkpoint[ckpt_level] is not None + return True + + +def _find_input_and_output_nodes(nodes: List[Node]): + """ + Find the input and output node names which are not found in the given list of nodes. + """ + input_nodes = [] + output_nodes = [] + + # if a node has an input node which is not in the node list + # we treat that input node as the input of the checkpoint function + for node in nodes: + for input_node in node._input_nodes.keys(): + node_repr = repr(input_node) + if input_node not in nodes and node_repr not in input_nodes: + input_nodes.append(node_repr) + + # if a node has a user node which is not in the node list + # we treat that user node as the node receiving the current node output + for node in nodes: + for output_node in node.users.keys(): + node_repr = repr(node) + if output_node not in nodes and node_repr not in output_nodes: + output_nodes.append(node_repr) + + return input_nodes, output_nodes + + +def _find_nested_ckpt_regions(node_list: List[Node], ckpt_level: int = 0): + """ + Find the nested checkpoint regions given a list of consecutive nodes. The outputs + will be list of tuples, each tuple is in the form of (start_index, end_index). + """ + ckpt_regions = [] + start = -1 + end = -1 + current_region = None + + for idx, node in enumerate(node_list): + if len(node.meta['info'].activation_checkpoint) > ckpt_level: + act_ckpt_label = node.meta['info'].activation_checkpoint[ckpt_level] + + # this activation checkpoint label is not set yet + # meaning this is the first node of the activation ckpt region + if current_region is None: + current_region = act_ckpt_label + start = idx + + # if activation checkpoint has changed + # we restart the tracking + # e.g. node ckpt states = [ckpt1, ckpt2, ckpt2, ckpt2] + if act_ckpt_label != current_region: + assert start != -1 + ckpt_regions.append((start, idx - 1)) + current_region = act_ckpt_label + start = idx + end = -1 + + elif current_region is not None and _end_of_ckpt(node, ckpt_level): + # used to check the case below + # node ckpt states = [ckpt, ckpt, non-ckpt] + end = idx - 1 + assert start != -1 and end != -1 + ckpt_regions.append((start, end)) + start = end = -1 + current_region = None + + else: + pass + + if current_region is not None: + end = len(node_list) - 1 + ckpt_regions.append((start, end)) + return ckpt_regions + + +def emit_ckpt_func(body, + ckpt_func, + node_list: List[Node], + emit_node_func, + delete_unused_value_func, + ckpt_level=0, + in_ckpt=False): + """Emit ckpt function in nested way + + Args: + body: forward code - in recursive calls, this part will be checkpoint + functions code + ckpt_func: checkpoint functions code - in recursive calls, this part + will be a buffer + node_list (List[Node]): list of torch.fx.Node + emit_node_func: function to emit a node + delete_unused_value_func: function to delete unused value + level (int, optional): checkpoint level. Defaults to 0. + in_ckpt (bool, optional): indicates wether the func is in recursive + call. Defaults to False. + """ + inputs, outputs = _find_input_and_output_nodes(node_list) + + # label given by each layer, e.g. if you are currently at level (0, 1, 1) + # the label will be '0_1_1' + label = "_".join([str(idx) for idx in node_list[0].meta['info'].activation_checkpoint[:ckpt_level + 1]]) + ckpt_fn_def = _gen_ckpt_fn_def(label, inputs) + ckpt_func.append(f'{ckpt_fn_def}\n') + + # if there is more level to fetch + if ckpt_level + 1 < max(map(lambda node: len(node.meta['info'].activation_checkpoint), node_list)): + ckpt_regions = _find_nested_ckpt_regions(node_list, ckpt_level + 1) + start_idx = [item[0] for item in ckpt_regions] + end_idx = [item[1] for item in ckpt_regions] + + # use ckpt_func_buffer to store nested checkpoint functions + ckpt_func_buffer = [] + node_idx = 0 + while 1: + if node_idx >= len(node_list): + break + + if node_idx in start_idx: + ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1] + emit_ckpt_func(ckpt_func, ckpt_func_buffer, ckpt_node_list, emit_node_func, delete_unused_value_func, + ckpt_level + 1, True) + node_idx += len(ckpt_node_list) + + else: + node = node_list[node_idx] + emit_node_func(node, ckpt_func) + ckpt_func[-1] = ' ' + ckpt_func[-1] + delete_unused_value_func(node, ckpt_func) + node_idx += 1 + + ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') + ckpt_func += ckpt_func_buffer + + # last level + else: + for node in node_list: + emit_node_func(node, ckpt_func) + ckpt_func[-1] = ' ' + ckpt_func[-1] + delete_unused_value_func(node, ckpt_func) + + ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') + + usage = _gen_ckpt_usage(label, inputs, outputs, False) + '\n' + if in_ckpt: + usage = ' ' + usage + body.append(usage) + + +def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func): + """Emit code with nested activation checkpoint + When we detect some of the annotation is a , we will use + this function to emit the activation checkpoint codes. + + Args: + body: forward code + ckpt_func: checkpoint functions code + nodes: graph.nodes + emit_node_func: function to emit node + delete_unused_value_func: function to remove the unused value + """ + ckpt_regions = _find_nested_ckpt_regions(nodes, 0) + start_idx = [item[0] for item in ckpt_regions] + end_idx = [item[1] for item in ckpt_regions] + node_list = list(nodes) + + node_idx = 0 + while 1: + # break if we finish the processing all the nodes + if node_idx >= len(node_list): + break + + # process ckpt_regions + if node_idx in start_idx: + ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1] + emit_ckpt_func(body, ckpt_func, ckpt_node_list, emit_node_func, delete_unused_value_func) + node_idx += len(ckpt_node_list) + + # process node in forward function + else: + node = node_list[node_idx] + emit_node_func(node, body) + delete_unused_value_func(node, body) + node_idx += 1 + + +@compatibility(is_backward_compatible=True) +class ActivationCheckpointCodeGen(CodeGen): + + def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode: + free_vars: List[str] = [] + body: List[str] = [] + globals_: Dict[str, Any] = {} + wrapped_fns: Dict[str, None] = {} + + # Wrap string in list to pass by reference + maybe_return_annotation: List[str] = [''] + + def add_global(name_hint: str, obj: Any): + """Add an obj to be tracked as a global. + We call this for names that reference objects external to the + Graph, like functions or types. + Returns: the global name that should be used to reference 'obj' in generated source. + """ + if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device + # HACK: workaround for how torch custom ops are registered. We + # can't import them like normal modules so they must retain their + # fully qualified name. + return _get_qualified_name(obj) + + # normalize the name hint to get a proper identifier + global_name = namespace.create_name(name_hint, obj) + + if global_name in globals_: + assert globals_[global_name] is obj + return global_name + globals_[global_name] = obj + return global_name + + # Pre-fill the globals table with registered builtins. + for name, (_, obj) in _custom_builtins.items(): + add_global(name, obj) + + def type_repr(o: Any): + if o == (): + # Empty tuple is used for empty tuple type annotation Tuple[()] + return '()' + + typename = _type_repr(o) + + if hasattr(o, '__origin__'): + # This is a generic type, e.g. typing.List[torch.Tensor] + origin_type = _origin_type_map.get(o.__origin__, o.__origin__) + origin_typename = add_global(_type_repr(origin_type), origin_type) + + if hasattr(o, '__args__'): + # Assign global names for each of the inner type variables. + args = [type_repr(arg) for arg in o.__args__] + + if len(args) == 0: + # Bare type, such as `typing.Tuple` with no subscript + # This code-path used in Python < 3.9 + return origin_typename + + return f'{origin_typename}[{",".join(args)}]' + else: + # Bare type, such as `typing.Tuple` with no subscript + # This code-path used in Python 3.9+ + return origin_typename + + # Common case: this is a regular module name like 'foo.bar.baz' + return add_global(typename, o) + + def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str: + + def _get_repr(arg): + # Handle NamedTuples (if it has `_fields`) via add_global. + if isinstance(arg, tuple) and hasattr(arg, '_fields'): + qualified_name = _get_qualified_name(type(arg)) + global_name = add_global(qualified_name, type(arg)) + return f"{global_name}{repr(tuple(arg))}" + return repr(arg) + + args_s = ', '.join(_get_repr(a) for a in args) + kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items()) + if args_s and kwargs_s: + return f'{args_s}, {kwargs_s}' + return args_s or kwargs_s + + # Run through reverse nodes and record the first instance of a use + # of a given node. This represents the *last* use of the node in the + # execution order of the program, which we will use to free unused + # values + node_to_last_use: Dict[Node, Node] = {} + user_to_last_uses: Dict[Node, List[Node]] = {} + + def register_last_uses(n: Node, user: Node): + if n not in node_to_last_use: + node_to_last_use[n] = user + user_to_last_uses.setdefault(user, []).append(n) + + for node in reversed(nodes): + map_arg(node.args, lambda n: register_last_uses(n, node)) + map_arg(node.kwargs, lambda n: register_last_uses(n, node)) + + # NOTE: we add a variable to distinguish body and ckpt_func + def delete_unused_values(user: Node, body): + """ + Delete values after their last use. This ensures that values that are + not used in the remainder of the code are freed and the memory usage + of the code is optimal. + """ + if user.op == 'placeholder': + return + if user.op == 'output': + body.append('\n') + return + nodes_to_delete = user_to_last_uses.get(user, []) + if len(nodes_to_delete): + to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None']) + body.append(f'; {to_delete_str}\n') + else: + body.append('\n') + + # NOTE: we add a variable to distinguish body and ckpt_func + def emit_node(node: Node, body): + maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}' + if node.op == 'placeholder': + assert isinstance(node.target, str) + maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}' + free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}') + raw_name = node.target.replace('*', '') + if raw_name != repr(node): + body.append(f'{repr(node)} = {raw_name}\n') + return + elif node.op == 'call_method': + assert isinstance(node.target, str) + body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}' + f'({_format_args(node.args[1:], node.kwargs)})') + return + elif node.op == 'call_function': + assert callable(node.target) + # pretty print operators + if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods: + assert isinstance(node.args, tuple) + body.append(f'{repr(node)}{maybe_type_annotation} = ' + f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}') + return + + # pretty print inplace operators; required for jit.script to work properly + # not currently supported in normal FX graphs, but generated by torchdynamo + if node.target.__module__ == '_operator' and node.target.__name__ in inplace_methods: + body.append(f'{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; ' + f'{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}') + return + + qualified_name = _get_qualified_name(node.target) + global_name = add_global(qualified_name, node.target) + # special case for getattr: node.args could be 2-argument or 3-argument + # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value + if global_name == 'getattr' and \ + isinstance(node.args, tuple) and \ + isinstance(node.args[1], str) and \ + node.args[1].isidentifier() and \ + len(node.args) == 2: + body.append( + f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}') + return + body.append( + f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})') + if node.meta.get('is_wrapped', False): + wrapped_fns.setdefault(global_name) + return + elif node.op == 'call_module': + assert isinstance(node.target, str) + body.append(f'{repr(node)}{maybe_type_annotation} = ' + f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})') + return + elif node.op == 'get_attr': + assert isinstance(node.target, str) + body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}') + return + elif node.op == 'output': + if node.type is not None: + maybe_return_annotation[0] = f" -> {type_repr(node.type)}" + body.append(self.generate_output(node.args[0])) + return + raise NotImplementedError(f'node: {node.op} {node.target}') + + # Modified for activation checkpointing + ckpt_func = [] + emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values) + + if len(body) == 0: + # If the Graph has no non-placeholder nodes, no lines for the body + # have been emitted. To continue to have valid Python code, emit a + # single pass statement + body.append('pass\n') + + if len(wrapped_fns) > 0: + wrap_name = add_global('wrap', torch.fx.wrap) + wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns]) + else: + wrap_stmts = '' + + if self._body_transformer: + body = self._body_transformer(body) + + for name, value in self.additional_globals(): + add_global(name, value) + + prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) + prologue = ''.join(ckpt_func) + prologue + prologue = prologue + + code = ''.join(body) + code = '\n'.join(' ' + line for line in code.split('\n')) + fn_code = f""" +{wrap_stmts} +{prologue} +{code}""" + return PythonCode(fn_code, globals_) diff --git a/colossalai/_analyzer/fx/graph_module.py b/colossalai/_analyzer/fx/graph_module.py new file mode 100644 index 0000000000000000000000000000000000000000..1fdedd758c01f8fef9c0ba2e529b528527d8d527 --- /dev/null +++ b/colossalai/_analyzer/fx/graph_module.py @@ -0,0 +1,239 @@ +import linecache +import os +import sys +import traceback +import warnings +from pathlib import Path +from typing import Any, Dict, Optional, Union + +import torch +import torch.fx +import torch.nn as nn +from torch.fx.graph import PythonCode + +try: + from torch.fx.graph import _PyTreeCodeGen + SUPPORT_PT_CODEGEN = True +except ImportError: + SUPPORT_PT_CODEGEN = False + +from torch.fx.graph_module import _exec_with_source, _forward_from_src +from torch.nn.modules.module import _addindent + + +# This is a copy of torch.fx.graph_module._WrappedCall. +# It should be removed when we stop supporting torch < 1.12.0. +class _WrappedCall: + + def __init__(self, cls, cls_call): + self.cls = cls + self.cls_call = cls_call + + # Previously, if an error occurred when valid + # symbolically-traced code was run with an invalid input, the + # user would see the source of the error as coming from + # `File "`, where N is some number. We use + # this function to generate a more informative error message. We + # return the traceback itself, a message explaining that the + # error occurred in a traced Module's generated forward + # function, and five lines of context surrounding the faulty + # line + @staticmethod + def _generate_error_message(frame_summary: traceback.FrameSummary) -> str: + # auxiliary variables (for readability) + err_lineno = frame_summary.lineno + assert err_lineno is not None + line = frame_summary.line + assert line is not None + err_line_len = len(line) + all_src_lines = linecache.getlines(frame_summary.filename) + + # constituent substrings of the error message + tb_repr = traceback.format_exc() + custom_msg = ("Call using an FX-traced Module, " + f"line {err_lineno} of the traced Module's " + "generated forward function:") + before_err = "".join(all_src_lines[err_lineno - 2:err_lineno]) + marker = "~" * err_line_len + "~~~ <--- HERE" + err_and_after_err = "\n".join(all_src_lines[err_lineno:err_lineno + 2]) + + # joined message + return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err]) + + def __call__(self, obj, *args, **kwargs): + try: + if self.cls_call is not None: + return self.cls_call(obj, *args, **kwargs) + else: + return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] + except Exception as e: + assert e.__traceback__ + topmost_framesummary: traceback.FrameSummary = \ + traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1] # type: ignore[arg-type] + if "eval_with_key" in topmost_framesummary.filename: + print(_WrappedCall._generate_error_message(topmost_framesummary), file=sys.stderr) + raise e.with_traceback(None) + else: + raise e + + +class ColoGraphModule(torch.fx.GraphModule): + """ + ColoGraphGraphModule is an nn.Module generated from an fx.Graph. + ColoGraphmodule has a ``graph`` attribute, as well as ``code`` and ``forward`` + attributes generated from that ``graph``. + + The difference between ``ColoGraphModule`` and ``torch.fx.GraphModule`` is that + ``ColoGraphModule`` has a ``bind()`` function to bind customized functions + (i.e. activation checkpoint) to ``code`` of ``nn.Module``. If you want to use + specific features in Colossal-AI that are not supported by ``torch.fx.GraphModule``, + you can use ``ColoGraphModule`` instead. + + ``colossalai.fx.symbolic_trace()`` will return a ``ColoGraphModule`` as default. + + .. warning:: + + When ``graph`` is reassigned, ``code`` and ``forward`` will be automatically + regenerated. However, if you edit the contents of the ``graph`` without reassigning + the ``graph`` attribute itself, you must call ``recompile()`` to update the generated + code. + """ + + def __init__(self, + root: Union[torch.nn.Module, Dict[str, Any]], + graph: torch.fx.Graph, + class_name: str = 'GraphModule'): + super().__init__(root, graph, class_name) + + def bind(self, ckpt_def, globals): + """Bind function needed for correctly execute ``GraphModule.forward()`` + + We need to bind checkpoint functions to ``ColoGraphModule`` so that we could + correctly execute ``GraphModule.forward()`` + + Args: + ckpt_def (List[str]): definition before the forward function + globals (Dict[str, Any]): global variables + """ + + ckpt_code = "\n".join(ckpt_def) + globals_copy = globals.copy() + _exec_with_source(ckpt_code, globals_copy) + func_list = [func for func in globals_copy.keys() if "checkpoint" in func or "pack" in func] + for func in func_list: + tmp_func = globals_copy[func] + setattr(self, func, tmp_func.__get__(self, self.__class__)) + del globals_copy[func] + + def recompile(self) -> PythonCode: + """ + Recompile this GraphModule from its ``graph`` attribute. This should be + called after editing the contained ``graph``, otherwise the generated + code of this ``GraphModule`` will be out of date. + """ + if SUPPORT_PT_CODEGEN and isinstance(self._graph._codegen, _PyTreeCodeGen): + self._in_spec = self._graph._codegen.pytree_info.in_spec + self._out_spec = self._graph._codegen.pytree_info.out_spec + python_code = self._graph.python_code(root_module='self') + self._code = python_code.src + + # To split ckpt functions code and forward code + _code_list = self._code.split("\n") + _fwd_def = [item for item in _code_list if "def forward" in item][0] + _fwd_idx = _code_list.index(_fwd_def) + ckpt_def = _code_list[:_fwd_idx] + self._code = "\n".join(_code_list[_fwd_idx:]) + + self.bind(ckpt_def, python_code.globals) + + cls = type(self) + cls.forward = _forward_from_src(self._code, python_code.globals) + + # Determine whether this class explicitly defines a __call__ implementation + # to wrap. If it does, save it in order to have wrapped_call invoke it. + # If it does not, wrapped_call can use a dynamic call to super() instead. + # In most cases, super().__call__ should be torch.nn.Module.__call__. + # We do not want to hold a reference to Module.__call__ here; doing so will + # bypass patching of torch.nn.Module.__call__ done while symbolic tracing. + cls_call = cls.__call__ if "__call__" in vars(cls) else None + + if '_wrapped_call' not in vars(cls): + cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined] + + def call_wrapped(self, *args, **kwargs): + return self._wrapped_call(self, *args, **kwargs) + + cls.__call__ = call_wrapped + + # reset self._code to original src, otherwise to_folder will be wrong + self._code = python_code.src + return python_code + + def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModule"): + """Dumps out module to ``folder`` with ``module_name`` so that it can be + imported with ``from import `` + + Args: + + folder (Union[str, os.PathLike]): The folder to write the code out to + + module_name (str): Top-level name to use for the ``Module`` while + writing out the code + """ + folder = Path(folder) + Path(folder).mkdir(exist_ok=True) + torch.save(self.state_dict(), folder / 'state_dict.pt') + tab = " " * 4 + + # we add import colossalai here + model_str = f""" +import torch +from torch.nn import * +import colossalai + + +class {module_name}(torch.nn.Module): + def __init__(self): + super().__init__() +""" + + def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]: + safe_reprs = [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d] + if type(module) in safe_reprs: + return f"{module.__repr__()}" + else: + return None + + blobified_modules = [] + for module_name, module in self.named_children(): + module_str = _gen_model_repr(module_name, module) + if module_str is None: + module_file = folder / f'{module_name}.pt' + torch.save(module, module_file) + blobified_modules.append(module_name) + module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ') + module_str = f"torch.load(r'{module_file}') # {module_repr}" + model_str += f"{tab*2}self.{module_name} = {module_str}\n" + + for buffer_name, buffer in self._buffers.items(): + if buffer is None: + continue + model_str += f"{tab*2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n" + + for param_name, param in self._parameters.items(): + if param is None: + continue + model_str += f"{tab*2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n" + + model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n" + model_str += f"{_addindent(self.code, 4)}\n" + + module_file = folder / 'module.py' + module_file.write_text(model_str) + + init_file = folder / '__init__.py' + init_file.write_text('from .module import *') + + if len(blobified_modules) > 0: + warnings.warn("Was not able to save the following children modules as reprs -" + f"saved as pickled files instead: {blobified_modules}") diff --git a/colossalai/_analyzer/fx/node_util.py b/colossalai/_analyzer/fx/node_util.py new file mode 100644 index 0000000000000000000000000000000000000000..fbe8400a437eef944b6e8ac0a1126774797da031 --- /dev/null +++ b/colossalai/_analyzer/fx/node_util.py @@ -0,0 +1,211 @@ +from dataclasses import dataclass, field +from typing import Callable, ClassVar, Dict, List, Optional, Tuple, Union + +import torch +from torch.autograd.profiler_util import _format_memory, _format_time +from torch.fx import Graph, GraphModule, Node + +from colossalai._analyzer.envs import MeshConfig + + +def intersect(a, b): + return {k: a[k] for k in a if k in b} + + +def subtract(a, b): + return {k: a[k] for k in a if k not in b} + + +def union(a, b): + return {**a, **b} + + +def compute_size_in_bytes(elem: Union[torch.Tensor, Dict, List, Tuple, int]) -> int: + """Compute the size of a tensor or a collection of tensors in bytes. + + Args: + elem (torch.Tensor | Dict | List | Tuple | int): Arbitrary nested ``torch.Tensor`` data structure. + + Returns: + int: The size of the tensor or the collection of tensors in bytes. + """ + nbytes = 0 + if isinstance(elem, torch.Tensor): + if elem.is_quantized: + nbytes += elem.numel() * torch._empty_affine_quantized([], dtype=elem.dtype).element_size() + else: + nbytes += elem.numel() * torch.tensor([], dtype=elem.dtype).element_size() + elif isinstance(elem, dict): + value_list = [v for _, v in elem.items()] + nbytes += compute_size_in_bytes(value_list) + elif isinstance(elem, tuple) or isinstance(elem, list) or isinstance(elem, set): + for e in elem: + nbytes += compute_size_in_bytes(e) + return nbytes + + +@dataclass +class MetaInfo: + r""" + The base class to store all profiling and static graph analysis information + needed for auto-parallel system in Colossal-AI. + ============================================================================ + ------------------------------- + | FX.Node | <----- + [input/param] are ---> |[input/param] [grad_inp]| [grad_inp] contributes to the + placeholders (might be | | \__________ | | profiled peak memory in backward + saved for backward. | | \ | | pass. [grad_param] is calculated + | | \ | | separately. + | [interm] -------> [grad_int]| <----- + | | \_________ | | [grad_interm] marks the peak + | / \ \ | | memory in backward pass. + [x] is not counted ---> | [x] [interm] --> [grad_int]| <----- + in [interm] because | | \_____ | | + it is not saved for | | \ | | + backward. | [output] \ | | <----- [output] is potentially + ------------------------------- [input] for the next node. + ============================================================================ + + Accumulate Size = ALL_PREVIOUS_CTX U {Interm Size + Output Size} + Output Size = ([output] in global_ctx and not is_alias) + Temp Size = ([output] not in global_ctx and not is_alias) + Backward Size = ([grad_inp]) + + Usage: + >>> for node in graph.nodes: + >>> n_info = MetaInfo(node) # will create a new MetaInfo instance and store in node.meta['info'] + >>> # if not exist, otherwise return the existing one + >>> n_info.to_recompute = ... # set the to_recompute attribute + + Remarks: + This feature is experimental and all the entries are subject to change. + """ + + # reference + node: Node + + # directory + mod_dir: str = '' + + # ctx[data_ptr] = Tensor + # mark the storage for ctx.save_for_backward + global_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # globally shared + curr_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # global_ctx till this node + + # should be updated after each graph manipulation + # ============================== Update ==================================== + # parameter and buffer within ``Node`` + parameters: Dict[str, torch.nn.Parameter] = field(default_factory=lambda: {}) + buffers: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) + + inputs: Tuple[torch.Tensor] = () + outputs: Tuple[torch.Tensor] = () + is_alias: Tuple[bool] = () # whether the output is an alias of input + + # compute cost + fwd_flop: Optional[int] = 0 + bwd_flop: Optional[int] = 0 + + # communication cost (should be the size in bytes of communication) + fwd_comm: Optional[int] = 0 + bwd_comm: Optional[int] = 0 + + # should keep the same whenever manipulated + # ============================= Invariant ================================== + activation_checkpoint: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen + to_offload: Optional[bool] = False + sharding_spec: str = 'RR' + + def __new__(cls, node: Node, **kwargs): + orig_init = cls.__init__ + + # if initialized, return the existing one + # should disable the __init__ function + if node.meta.get('info', None) is not None: + + def _dummy(self, *args, **kwargs): + if getattr(self, '_is_init', False): + self._is_init = True + orig_init(self, *args, **kwargs) + cls.__init__ = orig_init + + cls.__init__ = _dummy + return node.meta['info'] + return super().__new__(cls) + + def __post_init__(self): + self.node.meta['info'] = self + + @property + def fwd_time(self, tflops: float = MeshConfig.TFLOPS, bandwidth: float = MeshConfig.BANDWIDTH): + return self.fwd_flop / tflops + self.fwd_comm / bandwidth + + @property + def bwd_time(self, tflops: float = MeshConfig.TFLOPS, bandwidth: float = MeshConfig.BANDWIDTH): + return self.bwd_flop / tflops + self.bwd_comm / bandwidth + + @property + def param_size(self): + return compute_size_in_bytes(self.parameters) + + @property + def buffer_size(self): + return compute_size_in_bytes(self.buffers) + + @property + def output_size(self): + """Used in CheckpointSolver""" + output_ctx = { + o.data_ptr(): o + for o, is_alias in zip(self.outputs, self.is_alias) + if not is_alias and isinstance(o, torch.Tensor) and not isinstance(o, torch.nn.Parameter) + } + return compute_size_in_bytes(intersect(self.global_ctx, output_ctx)) + + @property + def accumulate_size(self): + """Used in CheckpointSolver""" + output_ctx = { + o.data_ptr(): o + for o, is_alias in zip(self.outputs, self.is_alias) + if not is_alias and isinstance(o, torch.Tensor) and not isinstance(o, torch.nn.Parameter) + } + return compute_size_in_bytes(union(self.curr_ctx, intersect(self.global_ctx, output_ctx))) + + @property + def temp_size(self): + """Used in CheckpointSolver""" + output_ctx = { + o.data_ptr(): o + for o, is_alias in zip(self.outputs, self.is_alias) + if not is_alias and isinstance(o, torch.Tensor) and not isinstance(o, torch.nn.Parameter) + } + return compute_size_in_bytes(subtract(output_ctx, self.global_ctx)) + + @property + def backward_size(self): + """Used in CheckpointSolver""" + return compute_size_in_bytes(self.inputs) + + def __repr__(self): + s = f'Node {self.node.name}' + if self.parameters: + s += f'\n\thas parameter of size {_format_memory(self.param_size)}' + if self.buffers: + s += f'\n\thas buffer of size {_format_memory(self.buffer_size)}' + if self.output_size: + s += f'\n\thas output activation of size {_format_memory(self.output_size)}' + # if self.total_size: + # s += f'\n\thas total activation of size {_format_memory(self.total_size)}' + if self.temp_size: + s += f'\n\thas temp activation of size {_format_memory(self.temp_size)}' + if self.backward_size: + s += f'\n\thas backward activation of size {_format_memory(self.backward_size)}' + s += f'\n\tfwd_flop = {self.fwd_flop}'\ + f'\n\tbwd_flop = {self.bwd_flop}'\ + f'\n\tfwd_comm = {self.fwd_comm}'\ + f'\n\tbwd_comm = {self.bwd_comm}'\ + f'\n\tto_recompute = {self.to_recompute}'\ + f'\n\tto_offload = {self.to_offload}'\ + f'\n\tsharding_spec = {self.sharding_spec}' + return s diff --git a/colossalai/_analyzer/fx/passes/__init__.py b/colossalai/_analyzer/fx/passes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ae02d90a236c08be74e70165d4c34c13e2884ab2 --- /dev/null +++ b/colossalai/_analyzer/fx/passes/__init__.py @@ -0,0 +1,2 @@ +from .graph_profile import graph_profile_pass +from .shape_prop import ShapeProp, shape_prop_pass, sim_env diff --git a/colossalai/_analyzer/fx/passes/graph_profile.py b/colossalai/_analyzer/fx/passes/graph_profile.py new file mode 100644 index 0000000000000000000000000000000000000000..c3e760b31e96df2cd4f2de52c6309fefa9183417 --- /dev/null +++ b/colossalai/_analyzer/fx/passes/graph_profile.py @@ -0,0 +1,347 @@ +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union + +import torch +import torch.fx +from torch.autograd.profiler_util import _format_memory, _format_time +from torch.fx import GraphModule +from torch.fx.node import Argument, Node, Target + +from colossalai._analyzer._subclasses import flop_count +from colossalai._analyzer.fx.node_util import MetaInfo + + +def _format_flops(flops: float) -> str: + """Returns a formatted FLOP size string""" + if flops > 1e12: + return f'{flops / 1e12:.2f} TFLOPs' + elif flops > 1e9: + return f'{flops / 1e9:.2f} GFLOPs' + elif flops > 1e6: + return f'{flops / 1e6:.2f} MFLOPs' + elif flops > 1e3: + return f'{flops / 1e3:.2f} kFLOPs' + return f'{flops} FLOPs' + + +def _denormalize_tuple(t: Tuple[int, ...]) -> Tuple[int, ...]: + return t[0] if len(t) == 1 else t + + +def _normalize_tuple(x): + if not isinstance(x, tuple): + return (x,) + return x + + +def _current_device(module): + return next(module.parameters()).device + + +class GraphProfiler(torch.fx.Interpreter): + """ + Fetch shape argument from ``ShapeProp`` without re-executing + the ``GraphModule`` from scratch. + """ + _profileable = [ + 'call_function', + 'call_module', + 'call_method', + ] + + def __init__(self, module: GraphModule, garbage_collect_values: bool = True): + super().__init__(module, garbage_collect_values) + + def run(self, *args, initial_env: Optional[Dict[Node, Any]] = None, enable_io_processing: bool = True) -> Any: + """ + Run `module` via interpretation and return the result. + + Args: + *args: The arguments to the Module to run, in positional order + initial_env (Optional[Dict[Node, Any]]): An optional starting environment for execution. + This is a dict mapping `Node` to any value. This can be used, for example, to + pre-populate results for certain `Nodes` so as to do only partial evaluation within + the interpreter. + enable_io_processing (bool): If true, we process the inputs and outputs with graph's process_inputs and + process_outputs function first before using them. + + Returns: + Any: The value returned from executing the Module + """ + self.env = initial_env if initial_env else {} + + # Positional function args are consumed left-to-right by + # `placeholder` nodes. Use an iterator to keep track of + # position and extract those values. + if enable_io_processing: + args = self.module.graph.process_inputs(*args) + self.args_iter: Iterator[Any] = iter(args) + + for node in self.module.graph.nodes: + + self.run_node(node) # No need to store. + + if self.garbage_collect_values: + for to_delete in self.user_to_last_uses.get(node, []): + del self.env[to_delete] + + if node.op == 'output': + output_val = self.env[node] + return self.module.graph.process_outputs(output_val) if enable_io_processing else output_val + + def fetch_initial_env(self, device=None) -> Dict[Node, Any]: + """ + Fetch ``initial_env`` for execution. This is because ``ShapeProp`` + has already attached outputs of each ``Node`` to its ``MetaInfo``. + + Args: + device (torch.device): The device to place the execution, default to ``None`` + + Returns: + Dict[Node, Any]: The initial environment for execution + """ + initial_env = {} + for n in self.module.graph.nodes: + initial_env[n] = _denormalize_tuple(MetaInfo(n).outputs) + return initial_env + + def propagate(self, *args, device=None): + """ + Run `module` via interpretation and profile the execution + of each ``Node``. + + Args: + *args (Tensor): The sample input, not used + device (torch.device): The device to place the execution, default to ``None`` + + Returns: + Any: The value returned from executing the Module + """ + initial_env = self.fetch_initial_env(device) + + return self.run(initial_env=initial_env) + + def summary(self) -> str: + """ + Summarizes the profiled statistics of the `GraphModule` in + tabular format. Note that this API requires the ``tabulate`` module + to be installed. + + Returns: + str: The summary of the profiled statistics + """ + # https://github.com/pytorch/pytorch/blob/master/torch/fx/graph.py + try: + from tabulate import tabulate + except ImportError: + print("`summary` relies on the library `tabulate`, " + "which could not be found on this machine. Run `pip " + "install tabulate` to install the library.") + + # Build up a list of summary information for each node + node_summaries: List[List[Any]] = [] + last_n_info = None + + for node in self.module.graph.nodes: + node: Node + n_info = MetaInfo(node) + last_n_info = last_n_info or n_info + node_summaries.append([ + node.op, + str(node), + _format_memory(n_info.accumulate_size), + _format_memory(n_info.accumulate_size - last_n_info.accumulate_size), + _format_memory(n_info.output_size), + _format_memory(n_info.temp_size), + _format_memory(n_info.param_size), + _format_memory(n_info.backward_size), + _format_flops(n_info.fwd_flop), + _format_flops(n_info.bwd_flop), + ]) + last_n_info = n_info + + # Use the ``tabulate`` library to create a well-formatted table + # presenting our summary information + headers: List[str] = [ + 'Op type', + 'Op', + 'Accumulate size', + 'Incremental size', + 'Output size', + 'Temp size', + 'Param size', + 'Backward size', + 'Fwd FLOPs', + 'Bwd FLOPs', + ] + + return tabulate(node_summaries, headers=headers, stralign='right') + + +class CommunicationProfiler(GraphProfiler): + """ + TODO(lyl): Add this for all comm nodes + """ + + def __init__(self, module: GraphModule, garbage_collect_values: bool = True): + raise NotImplementedError() + + +class FlopProfiler(GraphProfiler): + """ + Execute an FX graph Node-by-Node and record the meta data of the result + into the corresponding node. + + Usage: + >>> model = MyModule() + >>> x = torch.rand(10, 10) + >>> gm = colossalai.fx.symbolic_trace(model, meta_args = {'x': x}}) + >>> shape_interp = ShapeProp(gm) # must do this first + >>> shape_interp.propagate(x) + >>> profiler = FlopProfiler(gm) + >>> profiler.propagate(x) + + Args: + module (GraphModule): The module to be executed + + Hints: + If you want to add a new flop count rule, you can first + check the existing files in ``../_subclasses/flop_tensor.py``. + If your flop count rules are incompatible with the existing + ones, you can do so by adding a new method to this class + with the ``@register_flop_count_impl`` decorator. The method + should take (*args, **kwargs) instance as its input and + generate flop count for both forward and backward as its + output. + + For example, if you want to add a flop count rule for + ``my_fn``, which is a hand-written operand not detected by + PyTorch, you can do so by adding a new method to this + class with the ``@register_flop_count_impl`` decorator: + + >>> @register_flop_count_impl(my_fn) + >>> def my_fn_flop_count_impl(*args, **kwargs): + >>> return 0, 0 + """ + _custom_flop_count_impl = {} + + def run_node(self, n: torch.fx.Node) -> Any: + """ + Run a specific node ``n`` and profile its execution time and memory usage. + Calls into call_function, call_method, and call_module only. + + Args: + n (Node): The Node to profile + + Returns: + Any: The output of the node + + Raises: + RuntimeError: If the node is not profileable. + """ + args, kwargs = self.fetch_args_kwargs_from_env(n) + n_info = MetaInfo(n) + + if n.op in self._profileable: + try: + ( + n_info.fwd_flop, + n_info.bwd_flop, + ) = getattr(self, n.op)(n.target, args, kwargs) + except Exception as e: + raise RuntimeError( + f'Error {str(e)} occurred when profiling node {n}, node.target = {n.target}. ' + f'Please refer to function\'s docstring to register the relevant profile_impl for this node!' + ) from e + + # retain the autograd graph + for param in self.module.parameters(): + param.grad = None + + return _denormalize_tuple(n_info.outputs) + + def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + """ + Execute a ``call_function`` node and return the profiling result. + Dispatch to ``_custom_flop_count_impl`` if ``call_function`` should be + profiled in a user-defined behavior. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return + flop_count (Tuple[int]): (fwd_flop, bwd_flop) + """ + assert not isinstance(target, str) + + # Dispatch the impl for profiling, default will be ``flop_count`` + if target in self._custom_flop_count_impl: + return self._custom_flop_count_impl[target](*args, **kwargs) + else: + return flop_count(target, *args, **kwargs) + + def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + """ + Execute a ``call_method`` node and return the profiling result. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return + flop_count (Tuple[int]): (fwd_flop, bwd_flop) + """ + # Execute the method and return the result + assert isinstance(target, str) + return flop_count(getattr(torch.Tensor, target), *args, **kwargs) + + def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + """ + Execute a ``call_module`` node and return the profiling result. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return + flop_count (Tuple[int]): (fwd_flop, bwd_flop) + """ + # Retrieve executed args and kwargs values from the environment + + # Execute the method and return the result + assert isinstance(target, str) + submod = self.fetch_attr(target) + return flop_count(submod, *args, **kwargs) + + +def graph_profile_pass(module: GraphModule, *args, verbose=False) -> GraphModule: + """ + Run ``module`` via interpretation and profile the execution + of each ``Node``. + + Args: + module (GraphModule): The GraphModule to profile + *args (Any): The sample input, not used + verbose (bool): Whether to print the profiling summary + + Returns: + GraphModule: The same GraphModule with profiling information + """ + for profiler_cls in (FlopProfiler, + # CommunicationProfiler, # TODO: add communication profiling + ): + profiler = profiler_cls(module) + profiler.propagate(*args, device=_current_device(module)) + + if verbose: + print(profiler.summary()) + return module diff --git a/colossalai/_analyzer/fx/passes/shape_prop.py b/colossalai/_analyzer/fx/passes/shape_prop.py new file mode 100644 index 0000000000000000000000000000000000000000..23e83013e02fd6a60d016f9b5c86a17d85cdaf81 --- /dev/null +++ b/colossalai/_analyzer/fx/passes/shape_prop.py @@ -0,0 +1,266 @@ +"""``torch.fx.ShapeProp``, but with ``MetaTensor``""" + +from typing import Any, Callable, Dict, Optional, Tuple, Union + +import torch +import torch.fx +from torch.autograd.graph import saved_tensors_hooks +from torch.utils._pytree import tree_map + +from colossalai._analyzer._subclasses import MetaTensor, MetaTensorMode +from colossalai._analyzer.fx.node_util import MetaInfo +from colossalai.fx._compatibility import compatibility + +Target = Union[Callable[..., Any], str] + + +class sim_env(saved_tensors_hooks): + """ + A simulation of memory allocation and deallocation in the forward pass + using ``saved_tensor_hooks``. + + Attributes: + ctx (Dict[int, torch.Tensor]): A dictionary that maps the + data pointer of a tensor to the tensor itself. This is used + to track the memory allocation and deallocation. + + param_ctx (Dict[int, torch.Tensor]): A dictionary that maps the + data pointer of all model parameters to the parameter itself. + This avoids overestimating the memory usage of the intermediate activations. + """ + + def __init__(self, module: Optional[torch.nn.Module] = None): + super().__init__(self.pack_hook, self.unpack_hook) + self.ctx = {} + self.param_ctx = {param.data_ptr(): param for param in module.parameters()} + self.buffer_ctx = {buffer.data_ptr(): buffer for buffer in module.buffers()} if module else {} + + def pack_hook(self, tensor: torch.Tensor): + if tensor.data_ptr() not in self.param_ctx and tensor.data_ptr() not in self.buffer_ctx: + self.ctx[tensor.data_ptr()] = tensor + return tensor + + def unpack_hook(self, tensor): + return tensor + + +def _normalize_tuple(x): + if not isinstance(x, tuple): + return (x,) + return x + + +def _current_device(module): + try: + return next(module.parameters()).device + except StopIteration: + return torch.device('cpu') + + +@compatibility(is_backward_compatible=False) +class ShapeProp(torch.fx.Interpreter): + """ + Execute an FX graph Node-by-Node and record the meta data of the result + into the corresponding node. + + Usage: + >>> model = MyModule() + >>> x = torch.rand(10, 10) + >>> gm = colossalai.fx.symbolic_trace(model, meta_args = {'x': x}) + >>> interp = ShapeProp(gm) + >>> interp.propagate(x) + + Args: + module (GraphModule): The module to be executed + + Hints: + If you want to add a new shape propagation rule, you can do so by + adding a new method to this class with the ``@register_shape_impl`` + decorator. The method should take (*args, **kwargs) instance as its + input and generate output. + + For example, if you want to add a shape propagation rule for + ``torch.nn.functional.linear``, you can do so by adding a new method + to this class with the ``@register_shape_impl`` decorator (Since the + ``MetaTensorMode`` is compatible with ``torch.nn.functional.linear``, + in practice you don't have to do as follows): + + >>> @register_shape_impl(torch.nn.functional.linear) + >>> def linear_shape_impl(*args, **kwargs): + >>> # do something here + >>> return torch.empty(output_shape, device=output_device) + """ + _custom_dispatch_func = {} + _mode = MetaTensorMode() + + def __init__(self, module: torch.fx.GraphModule, garbage_collect_values: bool = True): + super().__init__(module, garbage_collect_values) + self.global_hook = sim_env(module=self.module) + + def run_node(self, n: torch.fx.Node) -> Any: + """ + Run a specific node ``n`` and return the result. Attach + ( + ``inputs``, ``outputs``, ``parameters``, ``buffers`` + ) to ``n``. + + Args: + n (Node): The ``Node`` to execute + + Returns: + Any: The result of executing ``n`` + """ + args, kwargs = self.fetch_args_kwargs_from_env(n) + with self.global_hook: + r = getattr(self, n.op)(n.target, args, kwargs) + + def unwrap_fn(elem): + + def _convert_meta(t: torch.Tensor): + if t.device == 'meta': + return t + else: + return t.to('meta') + + if isinstance(elem, MetaTensor): + if getattr(self, '_is_param', False): + return torch.nn.Parameter(_convert_meta(elem._tensor)) + return _convert_meta(elem._tensor) + + elif isinstance(elem, torch.Tensor): + if isinstance(elem, torch.nn.Parameter): + return torch.nn.Parameter(_convert_meta(elem)) + return _convert_meta(elem) + + else: + return elem + + is_pure_tensor = lambda elem: isinstance(elem, MetaTensor) and not isinstance(elem, torch.nn.Parameter) + n_info = MetaInfo(n) + n_info.outputs = _normalize_tuple(r) + + if n.op == 'call_module': + submod = self.fetch_attr(n.target) + n_info.parameters.update({k: MetaTensor(v) for k, v in submod.named_parameters()}) + n_info.buffers.update({k: MetaTensor(v) for k, v in submod.named_buffers()}) + + else: + n_info.parameters.update({ + k.name: MetaTensor(v) + for k, v in zip(n.args, args) + if isinstance(k, torch.fx.Node) and isinstance(v, torch.nn.Parameter) + }) + n_info.parameters.update({k: MetaTensor(v) for k, v in kwargs.items() if isinstance(v, torch.nn.Parameter)}) + + n_info.inputs = tuple(v for v in args if is_pure_tensor(v)) + \ + tuple(v for v in kwargs.values() if is_pure_tensor(v)) + + # align with SPMD + if isinstance(r, (tuple, list)): + n._meta_data = tree_map(unwrap_fn, _normalize_tuple(r)) + else: + n._meta_data = unwrap_fn(r) + + n_info.global_ctx = self.global_hook.ctx + n_info.curr_ctx = self.global_hook.ctx.copy() + + crit = lambda x: x.data_ptr() in self.global_hook.ctx if isinstance(x, torch.Tensor) else False + n_info.is_alias = _normalize_tuple(tree_map(crit, n_info.outputs)) + return r + + def call_function(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any: + """ + Execute a ``call_function`` node and return the result. + If the target of ``Node`` is registered with ``@register_shape_impl``, + the registered function will be used to execute the node. This is common + if we insert some customized kernels. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return + Any: The value returned by the function invocation + """ + convert_to_param = False + if target in (torch.transpose, torch.reshape) and isinstance(args[0], torch.nn.parameter.Parameter): + convert_to_param = True + if target in self._custom_dispatch_func: + res = self._custom_dispatch_func[target](*args, **kwargs) + else: + res = super().call_function(target, args, kwargs) + if convert_to_param: + return torch.nn.Parameter(res) + else: + return res + + def call_method(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any: + """ + Execute a ``call_method`` node and return the result. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return + Any: The value returned by the method invocation + """ + # args[0] is the `self` object for this method call + self_obj, *args_tail = args + + target_method = getattr(self_obj.__class__, target) + + convert_to_parameter = False + if target_method in (torch.Tensor.view, torch.Tensor.transpose) and isinstance( + args[0], torch.nn.parameter.Parameter): + convert_to_parameter = True + # Execute the method and return the result + assert isinstance(target, str) + res = getattr(self_obj, target)(*args_tail, **kwargs) + if convert_to_parameter: + return torch.nn.Parameter(res) + else: + return res + + def propagate(self, *args, device=None): + """ + Run `module` via interpretation and return the result and record the + shape of each node. + Args: + *args (Tensor): The sample input. + Returns: + Any: The value returned from executing the Module + """ + + # wrap_fn = lambda elem: MetaTensor(elem, device=device) + def wrap_fn(elem, device=device): + if isinstance(elem, torch.Tensor): + return MetaTensor(elem, device=device) + else: + return elem + + with self._mode: + return super().run(*tree_map(wrap_fn, args)) + + +def shape_prop_pass(module: torch.fx.GraphModule, *args) -> torch.fx.GraphModule: + """ + Run ``module`` via interpretation and return the result and record the + shape of each ``Node``. + + Args: + module (GraphModule): The GraphModule to profile + *args (Any): The sample input + + Returns: + GraphModule: The same GraphModule with shape information + """ + + ShapeProp(module).propagate(*args, device=_current_device(module)) + return module diff --git a/colossalai/_analyzer/fx/symbolic_profile.py b/colossalai/_analyzer/fx/symbolic_profile.py new file mode 100644 index 0000000000000000000000000000000000000000..dd7f22c6c98a0d47c946935a991a1f31d1052734 --- /dev/null +++ b/colossalai/_analyzer/fx/symbolic_profile.py @@ -0,0 +1,40 @@ +import torch +import torch.fx +from torch.fx import GraphModule + +from .passes import ShapeProp, graph_profile_pass, shape_prop_pass +from .passes.graph_profile import FlopProfiler + + +def register_flop_count_impl(func): + + def wrapper(impl): + FlopProfiler._custom_flop_count_impl[func] = impl + return impl + + return wrapper + + +def register_shape_impl(func): + + def wrapper(impl): + ShapeProp._custom_dispatch_func[func] = impl + return impl + + return wrapper + + +def symbolic_profile(module: GraphModule, *args, verbose=False) -> GraphModule: + """Symbolically profile a model with sample inputs. + + Args: + module (GraphModule): The module to be profiled + args (Tuple): The sample inputs + verbose (bool): Whether to print the profiling result + + Returns: + GraphModule: The profiled module + """ + module = shape_prop_pass(module, *args) + module = graph_profile_pass(module, *args, verbose=verbose) + return module diff --git a/colossalai/_analyzer/fx/tracer/__init__.py b/colossalai/_analyzer/fx/tracer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6b1b2256aa44155883c71c10d1b2bebf511ff8d6 --- /dev/null +++ b/colossalai/_analyzer/fx/tracer/__init__.py @@ -0,0 +1,2 @@ +from .bias_addition import * +from .custom_leaf_module import * diff --git a/colossalai/_analyzer/fx/tracer/bias_addition.py b/colossalai/_analyzer/fx/tracer/bias_addition.py new file mode 100644 index 0000000000000000000000000000000000000000..1e75b47ca5b038aadb9c9bf0779bc3565d91bead --- /dev/null +++ b/colossalai/_analyzer/fx/tracer/bias_addition.py @@ -0,0 +1,154 @@ +""" +If FX.Graph is traced for auto-parallel module, some extra node will be added during +graph construction to deal with the compatibility between bias-addition and all-reduce. +""" + +import torch +import torch.nn.functional as F +from torch.nn.modules.utils import _pair, _single, _triple + +from .tracer import register_tracer_impl + +__all__ = [] + + +@register_tracer_impl(F.linear, name='_bias_addition_impl') +def linear_impl(input, weight, bias=None): + if bias is None: + return F.linear(input, weight) + else: + return F.linear(input, weight) + bias + + +@register_tracer_impl(F.conv1d, name='_bias_addition_impl') +def conv1d_impl(input, weight, bias=None, stride=_single(1), padding=_single(0), dilation=_single(1), groups=1): + if bias is None: + return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + else: + return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape( + (-1, 1)) + + +@register_tracer_impl(F.conv2d, name='_bias_addition_impl') +def conv2d_impl(input, weight, bias=None, stride=_pair(1), padding=_pair(0), dilation=_pair(1), groups=1): + if bias is None: + return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + else: + return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape( + (-1, 1, 1)) + + +@register_tracer_impl(F.conv3d, name='_bias_addition_impl') +def conv3d_impl(input, weight, bias=None, stride=_triple(1), padding=_triple(0), dilation=_triple(1), groups=1): + if bias is None: + return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + else: + return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape( + (-1, 1, 1, 1)) + + +@register_tracer_impl(F.conv_transpose1d, name='_bias_addition_impl') +def conv_transpose1d_impl(input, + weight, + bias=None, + stride=_single(1), + padding=_single(0), + output_padding=_single(0), + groups=1, + dilation=_single(1)): + if bias is None: + return F.conv_transpose1d(input, + weight, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation) + else: + return F.conv_transpose1d(input, + weight, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation) + bias.reshape((-1, 1)) + + +@register_tracer_impl(F.conv_transpose2d, name='_bias_addition_impl') +def conv_transpose2d_impl(input, + weight, + bias=None, + stride=_pair(1), + padding=_pair(0), + output_padding=_pair(0), + groups=1, + dilation=_pair(1)): + if bias is None: + return F.conv_transpose2d(input, + weight, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation) + else: + return F.conv_transpose2d(input, + weight, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation) + bias.reshape((-1, 1, 1)) + + +@register_tracer_impl(F.conv_transpose3d, name='_bias_addition_impl') +def conv_transpose3d_impl(input, + weight, + bias=None, + stride=_triple(1), + padding=_triple(0), + output_padding=_triple(0), + groups=1, + dilation=_triple(1)): + if bias is None: + return F.conv_transpose3d(input, + weight, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation) + else: + return F.conv_transpose3d(input, + weight, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation) + bias.reshape((-1, 1, 1, 1)) + + +@register_tracer_impl(torch.addmm, name='_bias_addition_impl') +@register_tracer_impl(torch.Tensor.addmm, name='_bias_addition_impl') +def addmm_impl(input, mat1, mat2, beta=1, alpha=1): + if alpha != 1 and beta != 1: + return F.linear(mat1, mat2.transpose(0, 1)) * alpha + input * beta + elif alpha != 1: + return F.linear(mat1, mat2.transpose(0, 1)) * alpha + input + elif beta != 1: + return F.linear(mat1, mat2.transpose(0, 1)) + input * beta + else: + return F.linear(mat1, mat2.transpose(0, 1)) + input + + +@register_tracer_impl(torch.addbmm, name='_bias_addition_impl') +@register_tracer_impl(torch.Tensor.addbmm, name='_bias_addition_impl') +def addbmm_impl(input, batch1, batch2, beta=1, alpha=1): + if alpha != 1 and beta != 1: + return torch.bmm(batch1, batch2.transpose(1, 2)) * alpha + input * beta + elif alpha != 1: + return torch.bmm(batch1, batch2.transpose(1, 2)) * alpha + input + elif beta != 1: + return torch.bmm(batch1, batch2.transpose(1, 2)) + input * beta + else: + return torch.bmm(batch1, batch2.transpose(1, 2)) + input diff --git a/colossalai/_analyzer/fx/tracer/custom_leaf_module.py b/colossalai/_analyzer/fx/tracer/custom_leaf_module.py new file mode 100644 index 0000000000000000000000000000000000000000..112c7c9637d20e395dbccaace063e3fa7657041f --- /dev/null +++ b/colossalai/_analyzer/fx/tracer/custom_leaf_module.py @@ -0,0 +1,29 @@ +import torch + +from .tracer import register_leaf_module, register_leaf_module_impl + +try: + import apex + register_leaf_module(apex.normalization.FusedLayerNorm) + register_leaf_module(apex.normalization.FusedRMSNorm) + register_leaf_module(apex.normalization.MixedFusedLayerNorm) + register_leaf_module(apex.normalization.MixedFusedRMSNorm) + + @register_leaf_module_impl(apex.normalization.FusedLayerNorm) + @register_leaf_module_impl(apex.normalization.FusedRMSNorm) + @register_leaf_module_impl(apex.normalization.MixedFusedLayerNorm) + @register_leaf_module_impl(apex.normalization.MixedFusedRMSNorm) + def torch_nn_normalize(self, input: torch.Tensor): + # check shape + if isinstance(self, torch.nn.BatchNorm1d): + assert input.dim() in [2, 3] + elif isinstance(self, torch.nn.BatchNorm2d): + assert input.dim() == 4 + elif isinstance(self, torch.nn.BatchNorm3d): + assert input.dim() == 5 + + # normalization maintain the same shape as the input + return input.clone() + +except (ImportError, AttributeError): + pass diff --git a/colossalai/_analyzer/fx/tracer/proxy.py b/colossalai/_analyzer/fx/tracer/proxy.py new file mode 100644 index 0000000000000000000000000000000000000000..ce379efdcf0d7c01d5541cb9ebffac1325fa97ef --- /dev/null +++ b/colossalai/_analyzer/fx/tracer/proxy.py @@ -0,0 +1,112 @@ +import operator +from typing import Any, Callable, Dict, Optional, Set, Union + +import torch +import torch.nn as nn +from torch.fx import Graph, Node, Proxy, Tracer +from torch.fx.graph import _Namespace +from torch.utils._pytree import tree_map + +from colossalai._analyzer._subclasses import MetaTensor + +Target = Union[Callable[..., Any], str] + + +class ColoProxy(Proxy): + _func_dispatch: Dict[Target, Callable[..., Any]] = {} + + def __init__(self, *args, data=None, **kwargs): + super().__init__(*args, **kwargs) + self._meta_data = data + + @property + def meta_data(self): + return self._meta_data + + @meta_data.setter + def meta_data(self, args): + wrap_fn = lambda x: MetaTensor(x) if isinstance(x, torch.Tensor) else x + self._meta_data = tree_map(wrap_fn, args) + + @classmethod + def __torch_function__(cls, orig_method, types, args=(), kwargs=None): + kwargs = {} if kwargs is None else kwargs + if orig_method in cls._func_dispatch: + impl = cls._func_dispatch.pop(orig_method) # avoid recursion + proxy = impl(*args, **kwargs) + cls._func_dispatch[orig_method] = impl + return proxy + else: + proxy = cls.from_torch_proxy(super().__torch_function__(orig_method, types, args, kwargs)) + unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p + if proxy.meta_data is None: + proxy.meta_data = orig_method(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) + return proxy + + @classmethod + def from_torch_proxy(cls, proxy: Proxy): + return cls(proxy.node, proxy.tracer) + + def __repr__(self): + return f"ColoProxy({self.node.name}, meta_data={self.meta_data})" + + def __len__(self): + return len(self.meta_data) + + def __int__(self): + return int(self.meta_data) + + def __index__(self): + try: + return int(self.meta_data) + except: + return torch.zeros(self.meta_data.shape, dtype=torch.bool).numpy().__index__() + + def __float__(self): + return float(self.meta_data) + + def __bool__(self): + return self.meta_data + + def __getattr__(self, k): + return ColoAttribute(self, k, getattr(self._meta_data, k, None)) + + def __setitem__(self, key, value): + proxy = self.tracer.create_proxy('call_function', operator.setitem, (self, key, value), {}) + proxy.meta_data = self._meta_data + return proxy + + def __contains__(self, key): + if self.node.op == "placeholder": + # this is used to handle like + # if x in kwargs + # we don't handle this case for now + return False + return super().__contains__(key) + + def __isinstancecheck__(self, type): + return isinstance(self.meta_data, type) + + +class ColoAttribute(ColoProxy): + + def __init__(self, root, attr: str, data=None): + self.root = root + self.attr = attr + self.tracer = root.tracer + self._meta_data = data + self._node: Optional[Node] = None + + @property + def node(self): + # the node for attributes is added lazily, since most will just be method calls + # which do not rely on the getitem call + if self._node is None: + self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node + return self._node + + def __call__(self, *args, **kwargs): + return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs) + + def __repr__(self): + return f"ColoAttribute({self.node.name}, attr={self.attr})" diff --git a/colossalai/_analyzer/fx/tracer/symbolic_trace.py b/colossalai/_analyzer/fx/tracer/symbolic_trace.py new file mode 100644 index 0000000000000000000000000000000000000000..2018863f6f5f50de7cc61eafb907a55034711993 --- /dev/null +++ b/colossalai/_analyzer/fx/tracer/symbolic_trace.py @@ -0,0 +1,157 @@ +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union + +import torch +from torch.fx import Tracer +from torch.utils._pytree import tree_map + +from colossalai._analyzer._subclasses import MetaTensor + +try: + from ..codegen import ActivationCheckpointCodeGen + SUPPORT_ACTIVATION = True +except: + SUPPORT_ACTIVATION = False +from ..graph_module import ColoGraphModule +from .tracer import ColoTracer + + +def _default_device(): + return torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') + + +def _current_device(module: torch.nn.Module): + try: + return next(module.parameters()).device + except: + return _default_device() + + +def symbolic_trace( + root: Union[torch.nn.Module, Callable[..., Any]], + concrete_args: Optional[Dict[str, Any]] = None, + meta_args: Optional[Dict[str, Any]] = None, + trace_act_ckpt: bool = False, + bias_addition_split: bool = False, +) -> ColoGraphModule: + """ + Traces a ``torch.nn.Module`` or a function and returns a ``GraphModule`` with ``Node``s and ``MetaInfo`` + attached to the ``Node``s. + + Can be used to trace the usage of ``torch.utils.checkpoint`` and the path of module + (https://github.com/pytorch/examples/blob/main/fx/module_tracer.py). + + This tracer is able to trace basic control flow and for loops. + + It will split the bias addition into two parts if ``bias_addition_split`` is set to be ``True``. + (See ./bias_addition.py for more details). + + Examples: + 1. Tracing a ``torch.nn.Module`` with control flow. + + .. code-block:: python + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(2, 2) + + def forward(self, x): + if x.size(0) > 1: + x = x.sum(dim=0) + return self.linear(x) + + traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)}) + + # traced code like: + # def forward(self, x): + # linear_1 = self.linear(x) + # return linear_1 + + traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(2, 2, 2)}) + + # traced code like: + # def forward(self, x): + # sum = x.sum(dim=0); x = None + # linear = self.linear(sum); sum = None + # return linear + + 2. Tracing a ``torch.nn.Module`` with ``torch.utils.checkpoint``. + + .. code-block:: python + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(2, 2) + + def forward(self, x): + def custom_forward(x): + return self.linear(x) + return torch.utils.checkpoint.checkpoint(custom_forward, x) + + traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)}, trace_act_ckpt=True) + + # traced code like: + # def checkpoint_0(self, x): + # linear = self.linear(x); x = None + # return linear + # + # def forward(self, x): + # linear = torch.utils.checkpoint.checkpoint(checkpoint_0, x); x = None + # return linear + + 3. Tracing a ``torch.nn.Module`` with ``bias_addition_split``. + + .. code-block:: python + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(2, 2, bias=True) + + def forward(self, x): + return self.linear(x) + + traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)}, bias_addition_split=True) + + # traced code like: + # def forward(self, x): + # linear_bias = self.linear.bias + # linear_weight = self.linear.weight + # linear = torch._C._nn.linear(x, linear_weight); x = linear_weight = None + # add = linear + linear_bias; linear = linear_bias = None + # return add + + Args: + root (Union[torch.nn.Module, Callable[..., Any]]): The ``torch.nn.Module`` or function to be traced. + concrete_args (Optional[Dict[str, Any]], optional): Concrete arguments to be passed to the ``root``. + Defaults to {}. + meta_args (Optional[Dict[str, Any]], optional): Meta arguments to be passed to the ``root``. Mostly used + for tracing control flow. Defaults to {}. + trace_act_ckpt (bool, optional): Whether to trace the usage of ``torch.utils.checkpoint``. + Defaults to False. + bias_addition_split (bool, optional): Whether to split the bias addition into two parts. Defaults to False. + + Returns: + ColoGraphModule: A traced ``GraphModule`` that is ready for activation checkpoint ``CodeGen``. + + Remarks: + This part of ``symbolic_trace()`` is maintained by Colossal-AI team. If you encountered + any unexpected error during tracing, feel free to raise an issue on Colossal-AI GitHub + repo. We welcome any feedback and contributions to enhance the extensibility of + Colossal-AI. + """ + if meta_args: + device, orig_device = _default_device(), _current_device(root) + wrap_fn = lambda elem: MetaTensor(elem, device=device) if isinstance(elem, torch.Tensor) else elem + graph = ColoTracer(trace_act_ckpt=trace_act_ckpt, + bias_addition_split=bias_addition_split).trace(root.to(device), + concrete_args=concrete_args, + meta_args=tree_map(wrap_fn, meta_args)) + if trace_act_ckpt and SUPPORT_ACTIVATION: + graph.set_codegen(ActivationCheckpointCodeGen()) + root.to(orig_device) + else: + graph = Tracer().trace(root, concrete_args=concrete_args) + name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ + return ColoGraphModule(root, graph, name) diff --git a/colossalai/_analyzer/fx/tracer/tracer.py b/colossalai/_analyzer/fx/tracer/tracer.py new file mode 100644 index 0000000000000000000000000000000000000000..6958a00a6a72af16bf6a9736a7c18411ff127b76 --- /dev/null +++ b/colossalai/_analyzer/fx/tracer/tracer.py @@ -0,0 +1,363 @@ +import functools +import inspect +from contextlib import contextmanager +from typing import Any, Callable, Dict, Iterable, Optional, Set, Tuple, Type, Union + +import torch +import torch.nn as nn +from torch.fx import Graph, Node, Proxy, Tracer +from torch.utils._pytree import tree_map + +from colossalai._analyzer._subclasses import _TensorPropertyMethod, _TorchFactoryMethod + +from ..node_util import MetaInfo +from .proxy import ColoProxy + +Target = Union[Callable[..., Any], str] + + +def _truncate_suffix(s: str): + import re + + # FIXME: don't know why but torch.fx always gets a suffix like '_1' in the name + return re.sub(r'_\d+$', '', s) + + +def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = '_custom_impl'): + + def wrapper(impl): + assert hasattr(ColoTracer, name), f"Cannot register {func.__name__} in ColoTracer.{name}" + getattr(ColoTracer, name)[func] = impl + return impl + + return wrapper + + +def register_leaf_module_impl(module: nn.Module): + + def wrapper(impl): + ColoTracer._custom_leaf_module_impl[module] = impl + return impl + + return wrapper + + +def register_leaf_module(module: nn.Module): + ColoTracer._custom_leaf_module.add(module) + + +def register_non_leaf_module(module: nn.Module): + ColoTracer._custom_non_leaf_module.add(module) + + +class ColoTracer(Tracer): + _custom_leaf_module: Set[Type[nn.Module]] = set() + _custom_leaf_module_impl: Dict[Type[nn.Module], Callable[..., Any]] = {} + _custom_non_leaf_module: Set[Type[nn.Module]] = set() + _custom_impl: Dict[Callable[..., Any], Callable[..., Any]] = {} + _bias_addition_impl: Dict[Callable[..., Any], Callable[..., Any]] = {} + _bias_addition_module = [ + torch.nn.Linear, + torch.nn.Conv1d, + torch.nn.Conv2d, + torch.nn.Conv3d, + torch.nn.ConvTranspose1d, + torch.nn.ConvTranspose2d, + torch.nn.ConvTranspose3d, + ] + + def __init__(self, trace_act_ckpt: bool = False, bias_addition_split: bool = False, *args, **kwargs): + super().__init__(*args, **kwargs) + self.disable_module_getattr = False + self.proxy_buffer_attributes = True + + # whether the tracer will record the usage of torch.utils.checkpoint + self.trace_act_ckpt = trace_act_ckpt + self.ckpt_regions = [] + self.ckpt_idx = 0 + + self.mod_dir = '' + + # whether the tracer should split the bias_add ops into two ops + self.bias_addition_split = bias_addition_split + + def is_leaf_module(self, m: nn.Module, module_qualified_name: str) -> bool: + # if bias-addiction split is enabled, and module has bias, then it is not a leaf module + # we will enter the module and split the bias-addition ops + if self.bias_addition_split and type(m) in self._bias_addition_module and m.bias is not None: + return False + # user can specify which modules are leaf modules and which are not + return (type(m) not in self._custom_non_leaf_module + and (type(m) in self._custom_leaf_module or super().is_leaf_module(m, module_qualified_name))) + + def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...], + kwargs: Dict[str, Any]) -> Any: + curr_dir = self.mod_dir + self.mod_dir = 'self.' + self.path_of_module(m) + rst = super().call_module(m, forward, args, kwargs) + self.mod_dir = curr_dir + return rst + + def proxy(self, node: Node) -> 'ColoProxy': + return ColoProxy(node, self) + + def create_proxy(self, + kind: str, + target: Target, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + name: Optional[str] = None, + type_expr: Optional[Any] = None, + proxy_factory_fn: Callable[[Node], 'Proxy'] = None): + + proxy: ColoProxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn) + unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p + if kind == 'placeholder': + proxy.meta_data = self.meta_args[target] if target in self.meta_args else self.concrete_args.get( + _truncate_suffix(target), None) + elif kind == 'get_attr': + self.disable_module_getattr = True + try: + attr_itr = self.root + atoms = target.split(".") + for atom in atoms: + attr_itr = getattr(attr_itr, atom) + proxy.meta_data = attr_itr + finally: + self.disable_module_getattr = False + elif kind == 'call_function': + proxy.meta_data = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) + elif kind == 'call_method': + self.disable_module_getattr = True + try: + if target == '__call__': + proxy.meta_data = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs)) + else: + if target not in _TensorPropertyMethod: + proxy._meta_data = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]), + **tree_map(unwrap_fn, kwargs)) + finally: + self.disable_module_getattr = False + elif kind == 'call_module': + mod = self.root.get_submodule(target) + self.disable_module_getattr = True + try: + args = tree_map(unwrap_fn, args) + kwargs = tree_map(unwrap_fn, kwargs) + if type(mod) in self._custom_leaf_module: + target = self._custom_leaf_module_impl[type(mod)] + proxy.meta_data = target(mod, *args, **kwargs) + else: + proxy.meta_data = mod.forward(*args, **kwargs) + finally: + self.disable_module_getattr = False + return proxy + + def create_node(self, *args, **kwargs) -> Node: + node = super().create_node(*args, **kwargs) + n_info = MetaInfo(node, mod_dir=self.mod_dir, activation_checkpoint=tuple(self.ckpt_regions)) + return node + + def trace(self, + root: torch.nn.Module, + concrete_args: Optional[Dict[str, torch.Tensor]] = None, + meta_args: Optional[Dict[str, torch.Tensor]] = None) -> Graph: + + if meta_args is None: + meta_args = {} + + if concrete_args is None: + concrete_args = {} + + # check concrete and meta args have valid names + sig = inspect.signature(root.forward) + sig_names = set(sig.parameters.keys()) + meta_arg_names = set(meta_args.keys()) + concrete_arg_names = set(concrete_args.keys()) + non_concrete_arg_names = sig_names - concrete_arg_names + # update concrete args with default values + for k, v in sig.parameters.items(): + if k in sig_names - meta_arg_names and \ + k not in concrete_args and \ + v.default is not inspect.Parameter.empty: + concrete_args[k] = v.default + + def _check_arg_name_valid(names: Iterable[str]): + for name in names: + if name not in sig_names: + raise ValueError(f"Argument {name} is not in the signature of {root.__class__.__name__}.forward") + + _check_arg_name_valid(meta_arg_names) + _check_arg_name_valid(concrete_arg_names) + + self.concrete_args = concrete_args + self.meta_args = meta_args + + with self._torch_factory_override(), self._tracer_override(), torch.no_grad(): + self.mod_dir = 'self' + self.graph = super().trace(root, concrete_args=concrete_args) + self.mod_dir = '' + self.graph.lint() + + for node in self.graph.nodes: + if node.op == "placeholder": + # Removing default values for inputs as the forward pass will fail with them. + if node.target in non_concrete_arg_names: + node.args = () + # Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor]. + # It cannot infer on the attributes and methods the input should have, and fails. + node.type = torch.Tensor + # It is a concrete arg so it is not used and should be removed. + else: + if hasattr(torch.fx._symbolic_trace, "_assert_is_none"): + # Newer versions of torch.fx emit an assert statement + # for concrete arguments; delete those before we delete + # the concrete arg. + to_delete = [] + for user in node.users: + if user.target == torch.fx._symbolic_trace._assert_is_none: + to_delete.append(user) + for user in to_delete: + self.graph.erase_node(user) + + self.graph.erase_node(node) + + # TODO: solves GraphModule creation. + # Without this, return type annotation "Tuple" is causing code execution failure. + if node.op == "output": + node.type = None + return self.graph + + @contextmanager + def _tracer_override(self): + # override the tracer to support custom modules and checkpointing + if self.trace_act_ckpt: + orig_ckpt_func_apply = torch.utils.checkpoint.CheckpointFunction.apply + orig_ckpt_func_without_reentrant = torch.utils.checkpoint._checkpoint_without_reentrant + + def checkpoint(run_function, preserve_rng_state=False, *args): + self.ckpt_regions.append(self.ckpt_idx) + out = run_function(*args) + self.ckpt_idx = self.ckpt_regions.pop(-1) + 1 + return out + + # override the checkpoint function + torch.utils.checkpoint.CheckpointFunction.apply = checkpoint + torch.utils.checkpoint._checkpoint_without_reentrant = checkpoint + + # override the custom functions + ColoProxy._func_dispatch.update({k: v for k, v in self._custom_impl.items()}) + + # override the bias addition functions + if self.bias_addition_split: + ColoProxy._func_dispatch.update({k: v for k, v in self._bias_addition_impl.items()}) + + yield + + if self.trace_act_ckpt: + # recover the checkpoint function upon exit + torch.utils.checkpoint.CheckpointFunction.apply = orig_ckpt_func_apply + torch.utils.checkpoint._checkpoint_reentrant = orig_ckpt_func_without_reentrant + + ColoProxy._func_dispatch = {} + + @contextmanager + def _torch_factory_override(self): + # override the torch factory functions to create a proxy when the method + # is called during ``symbolic_trace()``. + def wrap_factory_method(target): + + @functools.wraps(target) + def wrapper(*args, **kwargs): + is_proxy = any(isinstance(p, ColoProxy) for p in args) | any( + isinstance(p, ColoProxy) for p in kwargs.values()) + if is_proxy: + # if the arg is a proxy, then need to record this function called on this proxy + # e.g. torch.ones(size) where size is an input proxy + self.disable_module_getattr = True + try: + proxy = self.create_proxy('call_function', target, args, kwargs) + finally: + self.disable_module_getattr = False + return proxy + else: + return target(*args, **kwargs) + + return wrapper, target + + overrides = { + target: wrap_factory_method(getattr(torch, target)) + for target in _TorchFactoryMethod + if callable(getattr(torch, target)) + } + for name, (wrapper, orig) in overrides.items(): + setattr(torch, name, wrapper) + + yield + + # recover the torch factory functions upon exit + for name, (wrapper, orig) in overrides.items(): + setattr(torch, name, orig) + + def _post_check(self, non_concrete_arg_names: Set[str]): + # This is necessary because concrete args are added as input to the traced module since + # https://github.com/pytorch/pytorch/pull/55888. + for node in self.graph.nodes: + if node.op == "placeholder": + # Removing default values for inputs as the forward pass will fail with them. + if node.target in non_concrete_arg_names: + node.args = () + # Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor]. + # It cannot infer on the attributes and methods the input should have, and fails. + node.type = torch.Tensor + # It is a concrete arg so it is not used and should be removed. + else: + if hasattr(torch.fx._symbolic_trace, "_assert_is_none"): + # Newer versions of torch.fx emit an assert statement + # for concrete arguments; delete those before we delete + # the concrete arg. + to_delete = [] + for user in node.users: + if user.target == torch.fx._symbolic_trace._assert_is_none: + to_delete.append(user) + for user in to_delete: + self.graph.erase_node(user) + + self.graph.erase_node(node) + + if node.op == "output": + node.type = None + self.graph.lint() + + def getattr(self, attr, attr_val, parameter_proxy_cache): + return self._module_getattr(attr, attr_val, parameter_proxy_cache) + + def _module_getattr(self, attr, attr_val, parameter_proxy_cache): + if getattr(self, "disable_module_getattr", False): + return attr_val + + def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache): + for n, p in collection_to_search: + if attr_val is p: + if n not in parameter_proxy_cache: + kwargs = {} + if 'proxy_factory_fn' in inspect.signature(self.create_proxy).parameters: + kwargs['proxy_factory_fn'] = (None if not self.param_shapes_constant else + lambda node: ColoProxy(self, node, n, attr_val)) + val_proxy = self.create_proxy('get_attr', n, (), {}, **kwargs) # type: ignore[arg-type] + parameter_proxy_cache[n] = val_proxy + return parameter_proxy_cache[n] + return None + + if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor): + maybe_buffer_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_buffers(), parameter_proxy_cache) + if maybe_buffer_proxy is not None: + return maybe_buffer_proxy + + if isinstance(attr_val, torch.nn.Parameter): + maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(), + parameter_proxy_cache) + if maybe_parameter_proxy is not None: + return maybe_parameter_proxy + + return attr_val diff --git a/colossalai/amp/__init__.py b/colossalai/amp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..963215476b6b038b2aa33c124461387e47579d3c --- /dev/null +++ b/colossalai/amp/__init__.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch.nn as nn +from torch.nn.modules.loss import _Loss +from torch.optim import Optimizer + +from colossalai.context import Config + +from .amp_type import AMP_TYPE +from .apex_amp import convert_to_apex_amp +from .naive_amp import convert_to_naive_amp +from .torch_amp import convert_to_torch_amp + +__all__ = ['convert_to_amp', 'convert_to_naive_amp', 'convert_to_apex_amp', 'convert_to_torch_amp', 'AMP_TYPE'] + + +def convert_to_amp(model: nn.Module, optimizer: Optimizer, criterion: _Loss, mode: AMP_TYPE, amp_config: Config = None): + """A helper function to wrap training components with Torch AMP modules. + + Args: + param model (:class:`torch.nn.Module`): your model object. + optimizer (:class:`torch.optim.Optimizer`): your optimizer object. + criterion (:class:`torch.nn.modules.loss._Loss`): your loss function object. + mode (:class:`colossalai.amp.AMP_TYPE`): amp mode. + amp_config (Union[:class:`colossalai.context.Config`, dict]): configuration for different amp modes. + + Returns: + A tuple (model, optimizer, criterion). + + Note: + ``amp_config`` may vary from different mode you choose. You should check the corresponding amp mode + for more details about ``amp_config``. + For ``apex_amp``, please check + `apex_amp config `_. + For ``naive_amp``, please check + `naive_amp config `_. + For ``torch_amp``, please check + `torch_amp config `_. + """ + assert isinstance(mode, AMP_TYPE), \ + f'expected the argument mode be AMP_TYPE, but got {type(mode)}' + + if amp_config is None: + amp_config = Config() + + if mode == AMP_TYPE.TORCH: + model, optimizer, criterion = convert_to_torch_amp(model, optimizer, criterion, amp_config) + elif mode == AMP_TYPE.APEX: + model, optimizer = convert_to_apex_amp(model, optimizer, amp_config) + elif mode == AMP_TYPE.NAIVE: + model, optimizer = convert_to_naive_amp(model, optimizer, amp_config) + + return model, optimizer, criterion diff --git a/colossalai/amp/amp_type.py b/colossalai/amp/amp_type.py new file mode 100644 index 0000000000000000000000000000000000000000..6f322f866cfc813e66e54b0c1006d62ef949e96e --- /dev/null +++ b/colossalai/amp/amp_type.py @@ -0,0 +1,10 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from enum import Enum + + +class AMP_TYPE(Enum): + APEX = 'apex' + TORCH = 'torch' + NAIVE = 'naive' diff --git a/colossalai/amp/apex_amp/__init__.py b/colossalai/amp/apex_amp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..51b9b97dccce877783251fb3f61f08a87a6a7659 --- /dev/null +++ b/colossalai/amp/apex_amp/__init__.py @@ -0,0 +1,42 @@ +import torch.nn as nn +from torch.optim import Optimizer + +from .apex_amp import ApexAMPOptimizer + + +def convert_to_apex_amp(model: nn.Module, optimizer: Optimizer, amp_config): + r"""A helper function to wrap training components with Apex AMP modules + + Args: + model (:class:`torch.nn.Module`): your model object. + optimizer (:class:`torch.optim.Optimizer`): your optimizer object. + amp_config (Union[:class:`colossalai.context.Config`, dict]): configuration for initializing apex_amp. + + Returns: + Tuple: A tuple (model, optimizer). + + The ``amp_config`` should include parameters below: + :: + + enabled (bool, optional, default=True) + opt_level (str, optional, default="O1") + cast_model_type (``torch.dtype``, optional, default=None) + patch_torch_functions (bool, optional, default=None) + keep_batchnorm_fp32 (bool or str, optional, default=None + master_weights (bool, optional, default=None) + loss_scale (float or str, optional, default=None) + cast_model_outputs (torch.dtype, optional, default=None) + num_losses (int, optional, default=1) + verbosity (int, default=1) + min_loss_scale (float, default=None) + max_loss_scale (float, default=2.**24) + + More details about ``amp_config`` refer to `amp_config `_. + """ + import apex.amp as apex_amp + model, optimizer = apex_amp.initialize(model, optimizer, **amp_config) + optimizer = ApexAMPOptimizer(optimizer) + return model, optimizer + + +__all__ = ['convert_to_apex_amp', 'ApexAMPOptimizer'] diff --git a/colossalai/amp/apex_amp/apex_amp.py b/colossalai/amp/apex_amp/apex_amp.py new file mode 100644 index 0000000000000000000000000000000000000000..e6bdbe4520f92450e80e930c0a7c746881e10bba --- /dev/null +++ b/colossalai/amp/apex_amp/apex_amp.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch.nn as nn + +try: + import apex.amp as apex_amp +except ImportError: + pass + +from torch import Tensor + +from colossalai.nn.optimizer import ColossalaiOptimizer +from colossalai.utils import clip_grad_norm_fp32 + + +class ApexAMPOptimizer(ColossalaiOptimizer): + """ A wrapper class for APEX optimizer and it implements apex-specific backward and clip_grad_norm + methods + """ + + def backward(self, loss: Tensor): + """Backward pass to get all gradients + + Args: + loss (torch.Tensor): Loss computed by a loss function + """ + with apex_amp.scale_loss(loss, self.optim) as scaled_loss: + scaled_loss.backward() + + def clip_grad_norm(self, model: nn.Module, max_norm: float): + """Clip gradients by norm + + Args: + model (torch.nn.Module): Your model object + max_norm (float): The max norm value for gradient clipping + """ + if max_norm > 0: + clip_grad_norm_fp32(apex_amp.master_params(self.optim), max_norm) diff --git a/colossalai/amp/naive_amp/__init__.py b/colossalai/amp/naive_amp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5b2f71d3ced771c43d541843153c6b64613f69e1 --- /dev/null +++ b/colossalai/amp/naive_amp/__init__.py @@ -0,0 +1,60 @@ +import inspect + +import torch.nn as nn +from torch.optim import Optimizer + +from colossalai.utils import is_no_pp_or_last_stage + +from ._fp16_optimizer import FP16Optimizer +from .grad_scaler import ConstantGradScaler, DynamicGradScaler +from .naive_amp import NaiveAMPModel, NaiveAMPOptimizer + + +def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config): + """A helper function to wrap training components with naive AMP modules. In this mode, + we forcibly cast the model weights and inputs to FP16, and cast the model outputs to FP32 to calculate loss, + which is equivalent to Apex O3. + + Args: + model (:class:`torch.nn.Module`): your model object + optimizer (:class:`torch.optim.Optimizer`): your optimizer object + amp_config (:class:`colossalai.context.Config` or dict): configuration for naive mode amp. + + Returns: + Tuple: A tuple (model, optimizer) + + The ``amp_config`` should contain parameters below:: + + verbose (bool, optional): if set to `True`, will print debug info (Default: False). + clip_grad_norm (float, optional): clip gradients with this global L2 norm (Default 0). + Note that clipping is ignored if clip_grad == 0. + dynamic_grad_scale (bool): whether to use dynamic grad scaler. + """ + if isinstance(model, nn.ModuleList): + # interleaved pipeline + module_list = [] + for chunk, m in enumerate(model): + output_to_fp32 = is_no_pp_or_last_stage() and chunk == len(model) - 1 + module_list.append(NaiveAMPModel(m, output_to_fp32=output_to_fp32)) + model = nn.ModuleList(module_list) + else: + output_to_fp32 = is_no_pp_or_last_stage() + model = NaiveAMPModel(model, output_to_fp32=output_to_fp32) + + use_dynamic_grad_scaler = amp_config.pop('dynamic_grad_scale', True) + if use_dynamic_grad_scaler: + scaler_class = DynamicGradScaler + else: + scaler_class = ConstantGradScaler + + sig = inspect.signature(scaler_class.__init__) + kwargs = dict() + for param in sig.parameters.values(): + if param.name in amp_config: + kwargs[param.name] = amp_config.pop(param.name) + grad_scaler = scaler_class(**kwargs) + optimizer = NaiveAMPOptimizer(optimizer, grad_scaler, **amp_config) + return model, optimizer + + +__all__ = ['convert_to_naive_amp', 'NaiveAMPOptimizer', 'FP16Optimizer'] diff --git a/colossalai/amp/naive_amp/_fp16_optimizer.py b/colossalai/amp/naive_amp/_fp16_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..e4699f92b9444005086a2a625ed243f0fa49ec44 --- /dev/null +++ b/colossalai/amp/naive_amp/_fp16_optimizer.py @@ -0,0 +1,371 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.optim import Optimizer + +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.kernel.op_builder import FusedOptimBuilder +from colossalai.logging import get_dist_logger +from colossalai.utils import clip_grad_norm_fp32, copy_tensor_parallel_attributes, multi_tensor_applier + +from ._utils import has_inf_or_nan, zero_gard_by_list +from .grad_scaler import BaseGradScaler + +try: + from colossalai._C import fused_optim +except: + fused_optim = None + +__all__ = ['FP16Optimizer'] + + +def load_fused_optim(): + global fused_optim + + if fused_optim is None: + fused_optim = FusedOptimBuilder().load() + + +def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None): + """ + adapted from Megatron-LM (https://github.com/NVIDIA/Megatron-LM) + + Use multi-tensor-applier to copy values from one list to another. + We don't have a blfoat16 implementation so for now if the overflow_buf + is not provided, we default back to simple loop copy to be compatible + with bfloat16. + """ + if overflow_buf: + overflow_buf.fill_(0) + # Scaling with factor `1.0` is equivalent to copy. + global fused_optim + load_fused_optim() + multi_tensor_applier(fused_optim.multi_tensor_scale, overflow_buf, [this, that], 1.0) + else: + for this_, that_ in zip(this, that): + that_.copy_(this_) + + +class FP16Optimizer(Optimizer): + """Float16 optimizer for fp16 and bf16 data types. + + Args: + optimizer (torch.optim.Optimizer): base optimizer such as Adam or SGD + grad_scaler (BaseGradScaler): grad scaler for gradient chose in + ``constant_grad_scaler`` or ``dynamic_grad_scaler``. + clip_grad_norm (float, optional): clip gradients with this global L2 norm. Default 0. + Note that clipping is ignored if clip_grad == 0 + verbose (bool, optional): if set to `True`, will print debug info. Default False. + """ + + def __init__(self, + optimizer: Optimizer, + grad_scaler: BaseGradScaler, + verbose: bool = False, + clip_grad_norm=0, + dp_process_group: ProcessGroup = None, + mp_process_group: ProcessGroup = None): + # have a defaults for compatibility with pytorch optim + self._optimizer = optimizer + self._defaults = optimizer.defaults + + # fp16-related params + assert isinstance(grad_scaler, BaseGradScaler) + self._grad_scaler = grad_scaler + self._found_overflow = torch.cuda.FloatTensor([0.0]) + self._dummy_overflow_buf = torch.cuda.IntTensor([0]) + + # misc params + self._clip_grad_max_norm = clip_grad_norm + + # get process group + def _get_process_group(parallel_mode): + if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode): + return gpc.get_group(parallel_mode) + else: + return None + + if dp_process_group is None: + dp_process_group = _get_process_group(ParallelMode.DATA) + if mp_process_group is None: + mp_process_group = _get_process_group(ParallelMode.MODEL) + + self._dp_process_group = dp_process_group + self._mp_process_group = mp_process_group + + # we maintain three groups of parameters + # so that the model can have a mixture + # of fp16 and fp32 params + # fp16_param_groups: the fp16 params of the model + # fp32_master_param_groups: the fp32 params cast from the fp16 param of the model + # fp32_param_groups: the fp32 params of the model + # NOTE: + # 1. fp16_param_groups and fp32_master_param_groups have one-to-one correspondence + # 2. fp32_param_groups and fp16_param_groups are exclusive of each other + self._fp16_param_groups = [] + self._fp32_master_param_groups = [] + self._fp32_param_groups = [] + + # For all the groups in the original optimizer: + for param_group in self._optimizer.param_groups: + fp16_params = [] + fp32_master_params = [] + fp32_params = [] + # For all the parameters in this group: + for i, param in enumerate(param_group['params']): + if param.requires_grad: + # float16 params: + if param.type() in ['torch.cuda.HalfTensor']: + fp16_params.append(param) + + # Create a fp32 copy + fp32_param = param.detach().clone().float() + # Copy tensor model parallel attributes. + copy_tensor_parallel_attributes(param, fp32_param) + + # Replace the optimizer params with the new fp32 copy. + param_group['params'][i] = fp32_param + fp32_master_params.append(fp32_param) + + # Reset existing state dict key to the new main param. + if param in self._optimizer.state: + self._optimizer.state[fp32_param] = self._optimizer.state.pop(param) + + # fp32 params. + elif param.type() == 'torch.cuda.FloatTensor': + fp32_params.append(param) + else: + raise TypeError('Expected parameter of type torch.cuda.FloatTensor ' + f'or torch.cuda.HalfTensor, but got {param.type()}') + + self._fp16_param_groups.append(fp16_params) + self._fp32_master_param_groups.append(fp32_master_params) + self._fp32_param_groups.append(fp32_params) + + # Leverage state_dict() and load_state_dict() to + # recast preexisting per-param state tensors + self._optimizer.load_state_dict(self._optimizer.state_dict()) + + # log config + self._logger = get_dist_logger() + if verbose: + self._logger.info( + f"\n========= FP16 Optimizer Config =========\n" + f"Optimizer: {optimizer.__class__.__name__}\n" + f"clip_grad_norm = {clip_grad_norm}\n" + f"grad_scaler = {self._grad_scaler.__class__.__name__}" + f"==========================================", + ranks=[0]) + + @property + def max_norm(self): + """Returns the maximum norm of gradient clipping. + """ + return self._clip_grad_max_norm + + @property + def grad_scaler(self): + """Returns the gradient scaler. + + Returns: + :class:`BaseGradScaler`: gradient scaler. + """ + + return self._grad_scaler + + @property + def loss_scale(self): + """Returns the loss scale. + + Returns: + int: loss scale. + """ + return self._grad_scaler.scale + + @property + def optimizer(self): + """Returns the optimizer. + + Returns: + :class:`torch.optim.Optimizer`: the optimizer object wrapped. + """ + return self._optimizer + + @property + def defaults(self): + """Returns the default arguments of optimizer. + + Returns: + dict: optimizer arguments saved in defaults of the optimizer wrapped. + """ + return self._defaults + + def _check_overflow(self): + # clear previous overflow record + self._found_overflow.fill_(0.0) + + # check for overflow + for group in self._optimizer.param_groups: + for p in group['params']: + if p.grad is not None and has_inf_or_nan(p.grad): + self._found_overflow.fill_(1.0) + break + + # all-reduce across dp group + if self._dp_process_group: + dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_process_group) + + # all-reduce over model parallel group + if self._mp_process_group: + dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_process_group) + + return self._found_overflow.item() > 0 + + def zero_grad(self, set_to_none=True): + """Set gradient to zero. + + Args: + set_to_none (bool): Whether set the gradient to None. + """ + + # set_to_none = True can save some memory space + for param_group in self._optimizer.param_groups: + zero_gard_by_list(param_group['params'], set_to_none=set_to_none) + + def _get_fp32_param_groups_to_update(self): + return self._fp32_master_param_groups + self._fp32_param_groups + + def _unscale_grads(self): + for group in self._get_fp32_param_groups_to_update(): + for p in group: + if p.grad is not None: + p.grad.data.div_(self.loss_scale) + + def _assign_grad_to_fp32_master_param(self): + # This only needs to be done for the float16 group. + for fp16_param_group, fp32_master_param_group in zip(self._fp16_param_groups, self._fp32_master_param_groups): + for fp16_param, fp32_param in zip(fp16_param_group, fp32_master_param_group): + if fp16_param.grad is not None: + fp32_param.grad = fp16_param.grad.float() + # clear unneeded grad on fp16 param + fp16_param.grad = None + + def _update_fp16_param_from_fp32_param(self): + fp16_param_data = [] + fp32_master_param_data = [] + for fp16_group, fp32_group in zip(self._fp16_param_groups, self._fp32_master_param_groups): + for fp16_param, fp32_param in zip(fp16_group, fp32_group): + fp16_param_data.append(fp16_param.data) + fp32_master_param_data.append(fp32_param.data) + _multi_tensor_copy_this_to_that(this=fp32_master_param_data, + that=fp16_param_data, + overflow_buf=self._dummy_overflow_buf) + + def step(self): + """Update the model parameters. + """ + + # Copy gradients from model params to main params. + self._assign_grad_to_fp32_master_param() + self._unscale_grads() + + overflow = self._check_overflow() + self._grad_scaler.update(overflow) + if overflow: + self.zero_grad() + + # Clip the main gradients. + grad_norm = None + if self._clip_grad_max_norm > 0.0: + grad_norm = self.clip_grad_norm(self._clip_grad_max_norm) + + if not overflow: + # Step the optimizer. + self._optimizer.step() + + # Update params from main params. + self._update_fp16_param_from_fp32_param() + + # Successful update. + return True, grad_norm + else: + return False, None + + def backward(self, loss): + """Execute backward pass. + + Args: + loss (:class:`torch.Tensor`): the loss value. + """ + + scaled_loss = loss * self.grad_scaler.scale + scaled_loss.backward() + + def state_dict(self): + """Returns the states of the fp16 optimizer as a dict object. + """ + + state_dict = {} + state_dict['optimizer'] = self._optimizer.state_dict() + if self.grad_scaler: + state_dict['grad_scaler'] = self.grad_scaler.state_dict() + state_dict['fp32_master_param_groups'] = self._fp32_master_param_groups + return state_dict + + def load_state_dict(self, state_dict): + """Load the states of the fp16 optimizer from a dict object. + + Args: + state_dict (dict): the states of the fp16 optimizer + """ + + # Optimizer. + self._optimizer.load_state_dict(state_dict['optimizer']) + + # Grad scaler. + if 'grad_scaler' in state_dict: + self.grad_scaler.load_state_dict(state_dict['grad_scaler']) + + # Copy data for the main params. + if 'fp32_master_param_groups' in state_dict: + for current_group, ckpt_group in zip(self._fp32_master_param_groups, + state_dict['fp32_master_param_groups']): + for current_param, ckpt_param in zip(current_group, ckpt_group): + current_param.data.copy_(ckpt_param.data) + + def clip_grad_norm(self, clip_grad): + """Clip gradients by norm. + + Args: + clip_grad (float): the max norm for clipping + """ + params = [] + for param_group in self._optimizer.param_groups: + for param in param_group['params']: + params.append(param) + return clip_grad_norm_fp32(params, clip_grad) + + # Promote state so it can be retrieved or set via + # "optimizer_instance.state" + def _get_state(self): + return self._optimizer.state + + def _set_state(self, value): + self._optimizer.state = value + + state = property(_get_state, _set_state) + + # Promote param_groups so it can be retrieved or set via + # "optimizer_instance.param_groups" + # (for example, to adjust the learning rate) + def _get_param_groups(self): + return self._optimizer.param_groups + + def _set_param_groups(self, value): + self._optimizer.param_groups = value + + param_groups = property(_get_param_groups, _set_param_groups) diff --git a/colossalai/amp/naive_amp/_utils.py b/colossalai/amp/naive_amp/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7633705e19fbce24faec87f9691c834279f0d8ad --- /dev/null +++ b/colossalai/amp/naive_amp/_utils.py @@ -0,0 +1,49 @@ +from typing import List + +from torch import Tensor + + +def has_inf_or_nan(tensor): + """Check if tensor has inf or nan values. + + Args: + tensor (:class:`torch.Tensor`): a torch tensor object + + Returns: + bool: Whether the tensor has inf or nan. True for yes and False for no. + """ + try: + # if tensor is half, the .float() incurs an additional deep copy, but it's necessary if + # Pytorch's .sum() creates a one-element tensor of the same type as tensor + # (which is true for some recent version of pytorch). + tensor_sum = float(tensor.float().sum()) + # More efficient version that can be used if .sum() returns a Python scalar + # tensor_sum = float(tensor.sum()) + except RuntimeError as instance: + # We want to check if inst is actually an overflow exception. + # RuntimeError could come from a different error. + # If so, we still want the exception to propagate. + if "value cannot be converted" not in instance.args[0]: + raise + return True + else: + if tensor_sum == float('inf') or tensor_sum == -float('inf') or tensor_sum != tensor_sum: + return True + return False + + +def zero_gard_by_list(tensor_list: List[Tensor], set_to_none: bool = True) -> None: + """Clear the gradient of a list of tensors, + + Note: copied from torch.optim.optimizer. + """ + for param in tensor_list: + if param.grad is not None: + if set_to_none: + param.grad = None + else: + if param.grad.grad_fn is not None: + param.grad.detach_() + else: + param.grad.requires_grad_(False) + param.grad.zero_() diff --git a/colossalai/amp/naive_amp/grad_scaler/__init__.py b/colossalai/amp/naive_amp/grad_scaler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dc8499d877e13f8e0eb14317d2cf4a8d54dfcb2a --- /dev/null +++ b/colossalai/amp/naive_amp/grad_scaler/__init__.py @@ -0,0 +1,5 @@ +from .base_grad_scaler import BaseGradScaler +from .constant_grad_scaler import ConstantGradScaler +from .dynamic_grad_scaler import DynamicGradScaler + +__all__ = ['BaseGradScaler', 'ConstantGradScaler', 'DynamicGradScaler'] diff --git a/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py new file mode 100644 index 0000000000000000000000000000000000000000..0d84384a7f67c6a4521a86d34f71ff03b821c7be --- /dev/null +++ b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from abc import ABC, abstractmethod +from typing import Dict + +import torch +from torch import Tensor + +from colossalai.logging import get_dist_logger + +__all__ = ['BaseGradScaler'] + + +class BaseGradScaler(ABC): + """A base class for the gradient scaler. + + Args: + initial_scale (float): the initial loss scale + verbose (bool): whether to log messages + """ + + def __init__(self, initial_scale: float, verbose: bool): + assert initial_scale > 0 + self._scale = torch.cuda.FloatTensor([initial_scale]) + self._verbose = verbose + + if self._verbose: + self._logger = get_dist_logger() + + @property + def scale(self) -> Tensor: + """Returns the loss scale. + """ + + return self._scale + + @property + def inv_scale(self) -> Tensor: + """Returns the inverse of the loss scale. + """ + + return self._scale.double().reciprocal().float() + + def state_dict(self) -> Dict: + """Returns the states of the gradient scaler as a dict object. + """ + + state_dict = dict() + state_dict['scale'] = self.scale + return state_dict + + def load_state_dict(self, state_dict: Dict) -> None: + """Load the states of the gradient scaler from a dict object. + + Args: + state_dict (dict): the states of the gradient scaler + """ + + self._scale = state_dict['scale'] + + @abstractmethod + def update(self, overflow: bool) -> None: + """Update the loss scale. + + Args: + overflow (bool): whether overflow occurs + """ + + pass + + def log(self, message, *args, **kwargs): + """Log messages. + + Args: + message (str): the message to log + *args: positional arguments for :class:`colossalai.logging.DistributedLogger` + **kwargs: key-word arguments for :class:`colossalai.logging.DistributedLogger` + """ + + if self._verbose: + self._logger.info(message, *args, **kwargs) diff --git a/colossalai/amp/naive_amp/grad_scaler/constant_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/constant_grad_scaler.py new file mode 100644 index 0000000000000000000000000000000000000000..a2f518c5dd28261f98faadf9134721ef1fd67dc7 --- /dev/null +++ b/colossalai/amp/naive_amp/grad_scaler/constant_grad_scaler.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- +from .base_grad_scaler import BaseGradScaler + +__all__ = ['ConstantGradScaler'] + + +class ConstantGradScaler(BaseGradScaler): + """A gradient scaler which uses constant loss scale + + Args: + initial_scale (float): the initial loss scale + verbose (bool): whether to log messages + """ + + def __init__(self, initial_scale: int, verbose: bool): + super().__init__(initial_scale, verbose) + self.log(f"Constant Gradient Scaler is initialized with scale {self.scale}", ranks=[0]) + + def update(self, overflow: bool) -> None: + """Do nothing to keep the loss scale constant. + + Args: + overflow (bool): whether overflow occurs + """ + pass diff --git a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py new file mode 100644 index 0000000000000000000000000000000000000000..e899b9ca4c89fba16352ce736cb0abc4959e163b --- /dev/null +++ b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from typing import Optional + +import torch + +from .base_grad_scaler import BaseGradScaler + +__all__ = ['DynamicGradScaler'] + + +class DynamicGradScaler(BaseGradScaler): + """A gradient scaler which uses dynamic loss scale + + Args: + initial_scale (float): the initial loss scale, defaults to 2**16 + growth_factor (float): the multiplication factor for increasing loss scale, defaults to 2 + backoff_factor (float): the multiplication factor for decreasing loss scale, defaults to 0.5 + growth_interval (int): the number of steps to increase loss scale when no overflow occurs, defaults to 1000 + min_scale (float): the minimum loss scale, defaults to None + max_scale (float): the maximum loss scale, defaults to None + hysteresis (int): the number of overflows before decreasing loss scale, defaults to 2 + verbose (bool): whether to log messages, defaults to False + """ + + def __init__(self, + initial_scale: float = 2**16, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + min_scale: Optional[float] = None, + max_scale: Optional[float] = None, + hysteresis: int = 2, + verbose: bool = False): + super().__init__(initial_scale, verbose) + if min_scale: + self._min_scale = torch.cuda.FloatTensor([min_scale]) + else: + self._min_scale = None + + if max_scale: + self._max_scale = torch.cuda.FloatTensor([max_scale]) + else: + self._max_scale = None + + self._growth_factor = growth_factor + self._backoff_factor = backoff_factor + self._growth_interval = growth_interval + self._growth_step = 0 + self._hysteresis = hysteresis + self._hysteresis_step = 0 + self._sanity_checks() + + def _sanity_checks(self) -> None: + """Check if the arguments are correct. + """ + + if self._min_scale: + assert self._min_scale > 0, 'The minimum gradient scale cannot be zero or negative' + assert self._min_scale <= self._scale, 'The minimum gradient scale cannot be greater than the current scale' + if self._max_scale: + assert self._max_scale > 0, 'The maximum gradient scale cannot be zero or negative' + assert self._max_scale >= self._scale, 'The maximum gradient scale cannot be smaller than the current scale' + assert self._growth_factor > 1, 'The growth factor cannot be equal or smaller than 1' + assert 0 < self._backoff_factor < 1, 'The backoff factor must be between 0 and 1' + assert self._hysteresis >= 0, 'The hysteresis cannot be negative' + + def update(self, overflow: bool) -> None: + """Update the loss scale. + + Args: + overflow (bool): whether overflow occurs + """ + if overflow: + self._hysteresis_step += 1 + self._growth_step = 0 + + if self._hysteresis_step >= self._hysteresis: + self._backoff_scale() + self.log(f"Overflow occurs, the loss scale is adjusted to {self.scale.item()}", ranks=[0]) + else: + self._growth_step += 1 + if self._growth_step == self._growth_interval: + self._growth_step = 0 + self._hysteresis_step = 0 + self._grow_scale() + self.log( + f"No overflow for consecutive {self._growth_interval} steps, " + f"the loss scale is adjusted to {self.scale.item()}", + ranks=[0]) + + def _backoff_scale(self) -> None: + """Decrease the loss scale + """ + + self._scale = self._scale * self._backoff_factor + if self._min_scale: + self._scale = torch.max(self._scale, self._min_scale) + + def _grow_scale(self) -> None: + """Increase the loss scale + """ + + self._scale = self._scale * self._growth_factor + if self._max_scale: + self._scale = torch.min(self._scale, self._max_scale) + + def state_dict(self): + state_dict = dict() + state_dict['scale'] = self._scale + state_dict['growth_factor'] = self._growth_factor + state_dict['backoff_factor'] = self._backoff_factor + state_dict['hysteresis'] = self._hysteresis + return state_dict + + def load_state_dict(self, state_dict): + self._scale = state_dict['scale'].cuda(torch.cuda.current_device()) + self._growth_factor = state_dict['growth_factor'] + self._backoff_factor = state_dict['backoff_factor'] + self._hysteresis = state_dict['hysteresis'] diff --git a/colossalai/amp/naive_amp/naive_amp.py b/colossalai/amp/naive_amp/naive_amp.py new file mode 100644 index 0000000000000000000000000000000000000000..6a39d518d3f42716b800b7673fd128a4d6afe91b --- /dev/null +++ b/colossalai/amp/naive_amp/naive_amp.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from typing import Any + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch import Tensor +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +from torch.distributed import ReduceOp +from torch.optim import Optimizer + +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.nn.optimizer import ColossalaiOptimizer + +from ._fp16_optimizer import FP16Optimizer + + +class NaiveAMPOptimizer(ColossalaiOptimizer): + """A wrapper class for optimizer to cast all parameters to fp16 + + Args: + optim (torch.optim.Optimizer): A normal optimizer like Adam or SGD. + grad_scaler (BaseGradScaler): grad scaler for gradient chose in + ``constant_grad_scaler`` or ``dynamic_grad_scaler``. + clip_grad_norm (float, optional): clip gradients with this global L2 norm. Default 0. + verbose (bool, optional): if set to `True`, will print debug info. Default False. + + Note: + clipping is ignored if ``clip_grad_norm`` equals 0. + """ + + def __init__(self, optim: Optimizer, *args, **kwargs): + optim = FP16Optimizer(optim, *args, **kwargs) + super().__init__(optim) + + def backward(self, loss: Tensor): + self.optim.backward(loss) + + def step(self): + return self.optim.step() + + def clip_grad_norm(self, model: nn.Module, max_norm: float): + if self.optim.max_norm == max_norm: + return + raise RuntimeError("NaiveAMP optimizer has clipped gradients during optimizer.step(). " + "If you have supplied clip_grad_norm in the amp_config, " + "executing the method clip_grad_norm is not allowed.") + + +class NaiveAMPModel(nn.Module): + r"""A wrapper class for model to cast the model into fp16 and + automatically cast the input and output + + Args: + model (torch.nn.Module): torch.nn.Module to be wrapped. + output_to_fp32 (bool, optional): Whether cast output of this module into fp32. (Default: True) + parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this module. + (Default: ``ParallelMode.DATA``) + sync_buffer (bool, optional): whether to synchronize buffer. (Default: True) + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + + def __init__(self, + model: nn.Module, + output_to_fp32: bool = True, + parallel_mode: ParallelMode = ParallelMode.DATA, + sync_buffer: bool = True): + super().__init__() + self.model = model.half() + self._output_to_fp32 = output_to_fp32 + self._sync_buf = sync_buffer + + if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1: + self._process_group = gpc.get_group(parallel_mode) + self._world_size = gpc.get_world_size(parallel_mode) + else: + self._process_group = None + self._world_size = 1 + self._sync_buf = False + self._first_eval_run = False + + @property + def sync_buffer(self): + return self._sync_buf + + @sync_buffer.setter + def sync_buffer(self, state: bool): + self._sync_buf = state + + def _convert_to_fp16(self, input_: Any): + if isinstance(input_, Tensor) and input_.dtype == torch.float32: + input_ = input_.half() + return input_ + + def _convert_to_fp32(self, input_: Any): + if isinstance(input_, Tensor) and input_.dtype == torch.float16: + input_ = input_.float() + return input_ + + def _reduce_module_buffer(self): + """ + All-reduce the buffers (e.g. running stats of batch normalization) across + data parallel ranks so that all the ranks will produce consistent results + when given the same input + """ + buf_list = [] + + # find valid buffers + for buf in self.model.buffers(): + if buf is not None: + buf_list.append(buf) + + # reduce buffers across data parallel ranks + if buf_list: + coalesced_buf = _flatten_dense_tensors(buf_list) + coalesced_buf.div_(self._world_size) + dist.all_reduce(coalesced_buf, op=ReduceOp.SUM, group=self._process_group) + unflattened_buf_list = _unflatten_dense_tensors(coalesced_buf, buf_list) + for old, new in zip(buf_list, unflattened_buf_list): + old.copy_(new) + + def eval(self): + self.model.eval() + + # we only sync buffer in the first eval iteration + # so that future eval iterations can be done without communication + self._first_eval_run = True + + def forward(self, *args, **kwargs): + # reduce buffers after forward will lead to error + # as we cannot change the variables needed for gradient computation after forward + # so we sync buffer before forward + if (self.training or self._first_eval_run) and self._sync_buf: + with torch.no_grad(): + self._reduce_module_buffer() + + if self._first_eval_run: + self._first_eval_run = False + + if args: + args = [self._convert_to_fp16(arg) for arg in args] + if kwargs: + for k, v in kwargs.items(): + kwargs[k] = self._convert_to_fp16(v) + + out = self.model(*args, **kwargs) + + if self._output_to_fp32: + if isinstance(out, Tensor): + out = self._convert_to_fp32(out) + elif isinstance(out, (tuple, list)): + out = [self._convert_to_fp32(val) for val in out] + elif isinstance(out, dict): + out = {key: self._convert_to_fp32(val) for key, val in out.items()} + return out diff --git a/colossalai/amp/torch_amp/__init__.py b/colossalai/amp/torch_amp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..893cc890d68e423c643e6dc4bbf6343ff174a8d7 --- /dev/null +++ b/colossalai/amp/torch_amp/__init__.py @@ -0,0 +1,45 @@ +from typing import Optional + +import torch.nn as nn +from torch.nn.modules.loss import _Loss +from torch.optim import Optimizer + +from colossalai.context import Config + +from .torch_amp import TorchAMPLoss, TorchAMPModel, TorchAMPOptimizer + + +def convert_to_torch_amp(model: nn.Module, + optimizer: Optimizer, + criterion: Optional[_Loss] = None, + amp_config: Optional[Config] = None): + """A helper function to wrap training components with Pytorch AMP modules + + Args: + model (:class:`torch.nn.Module`): your model object. + optimizer (:class:`torch.optim.Optimizer`): your optimizer object + criterion (:class:`torch.nn.modules.loss._Loss`, optional): your loss function object + amp_config (:class:`colossalai.context.Config` or dict, optional): configuration for Pytorch AMP. + + The ``amp_config`` should include parameters below: + :: + + init_scale (float, optional, default=2.**16) + growth_factor (float, optional, default=2.0) + backoff_factor (float, optional, default=0.5) + growth_interval (int, optional, default=2000) + enabled (bool, optional, default=True) + + Returns: + A tuple (model, optimizer, criterion) + """ + model = TorchAMPModel(model) + if amp_config is None: + amp_config = dict() + optimizer = TorchAMPOptimizer(optimizer, **amp_config) + if criterion: + criterion = TorchAMPLoss(criterion) + return model, optimizer, criterion + + +__all__ = ['convert_to_torch_amp', 'TorchAMPModel', 'TorchAMPLoss', 'TorchAMPOptimizer'] diff --git a/colossalai/amp/torch_amp/_grad_scaler.py b/colossalai/amp/torch_amp/_grad_scaler.py new file mode 100644 index 0000000000000000000000000000000000000000..7b78998fb8c233f13f34fdf64df95bdfd1601ee6 --- /dev/null +++ b/colossalai/amp/torch_amp/_grad_scaler.py @@ -0,0 +1,571 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- +# modified from https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.py +# to support tensor parallel + +import warnings +from collections import abc, defaultdict +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.distributed as dist +from packaging import version +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc + + +class _MultiDeviceReplicator(object): + """ + Lazily serves copies of a tensor to requested devices. Copies are cached per-device. + """ + + def __init__(self, master_tensor: torch.Tensor) -> None: + assert master_tensor.is_cuda or master_tensor.device.type == 'xla' + self.master = master_tensor + self._per_device_tensors: Dict[torch.device, torch.Tensor] = {} + + def get(self, device) -> torch.Tensor: + retval = self._per_device_tensors.get(device, None) + if retval is None: + retval = self.master.to(device=device, non_blocking=True, copy=True) + self._per_device_tensors[device] = retval + return retval + + +# Defines default_factory for GradScaler's _per_optimizer_states defaultdict, +# as well as associated "enum" values. Prefers defining these at top level because +# - Lambdas can't be pickled, so we don't want to supply a lambda as the factory. +# - Defining READY, UNSCALED, STEPPED and _refresh_per_optimizer_state within GradScaler +# causes a circular reference, which we'd rather avoid. +class OptState(Enum): + READY = 0 + UNSCALED = 1 + STEPPED = 2 + + +def _refresh_per_optimizer_state(): + return {"stage": OptState.READY, "found_inf_per_device": {}} + + +class GradScaler(object): + _scale: Optional[torch.Tensor] + _grows_tracker: Optional[torch.Tensor] + _per_optimizer_states: Dict[int, Dict[str, Any]] + """ + An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling + conveniently. + + * ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor. + * ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``. + * ``scaler.update()`` updates ``scaler``'s scale factor. + + Example: + + # Creates a GradScaler once at the beginning of training. + scaler = GradScaler() + + for epoch in epochs: + for input, target in data: + optimizer.zero_grad() + output = model(input) + loss = loss_fn(output, target) + + # Scales loss. Calls backward() on scaled loss to create scaled gradients. + scaler.scale(loss).backward() + + # scaler.step() first unscales gradients of the optimizer's params. + # If gradients don't contain infs/NaNs, optimizer.step() is then called, + # otherwise, optimizer.step() is skipped. + scaler.step(optimizer) + + # Updates the scale for next iteration. + scaler.update() + + See the :ref:`Automatic Mixed Precision examples` for usage + (along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty, + and multiple losses/optimizers. + + ``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow, + a large scale factor should be used. However, ``float16`` values can "overflow" (become inf or NaN) if + the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used + without incurring inf or NaN gradient values. + ``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every + ``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`). + + * If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params + themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``. + + * If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual. + If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by + ``growth_factor``. + + The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its + value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these + iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations). + + Args: + init_scale (float, optional, default=2.**16): Initial scale factor. + growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during + :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations. + backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during + :meth:`update` if inf/NaN gradients occur in an iteration. + growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients + that must occur for the scale to be multiplied by ``growth_factor``. + enabled (bool, optional, default=True): If ``False``, disables gradient scaling. :meth:`step` simply + invokes the underlying ``optimizer.step()``, and other methods become no-ops. + """ + + def __init__(self, init_scale=2.**16, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True): + if enabled and not torch.cuda.is_available(): + warnings.warn("torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.") + self._enabled = False + else: + self._enabled = enabled + + # check version + torch_version = version.parse(torch.__version__) + assert torch_version.major == 1 + if torch_version.minor > 8: + self._higher_than_torch18 = True + else: + self._higher_than_torch18 = False + + if self._enabled: + assert growth_factor > 1.0, "The growth factor must be > 1.0." + assert backoff_factor < 1.0, "The backoff factor must be < 1.0." + + self._init_scale = init_scale + # self._scale will be lazily initialized during the first call to scale() + self._scale = None + self._growth_factor = growth_factor + self._backoff_factor = backoff_factor + self._growth_interval = growth_interval + self._init_growth_tracker = 0 + # self._growth_tracker will be lazily initialized during the first call to scale() + self._growth_tracker = None + self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) + + def _check_scale_growth_tracker(self, funcname) -> Tuple[torch.Tensor, torch.Tensor]: + fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration." + assert self._scale is not None, "Attempted {} but _scale is None. ".format(funcname) + fix + assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format(funcname) + fix + return (self._scale, self._growth_tracker) + + def _lazy_init_scale_growth_tracker(self, dev): + assert self._growth_tracker is None, "_growth_tracker initialized before _scale" + self._scale = torch.full((1,), self._init_scale, dtype=torch.float32, device=dev) + self._growth_tracker = torch.full((1,), self._init_growth_tracker, dtype=torch.int32, device=dev) + + def scale(self, outputs): + """ + Multiplies ('scales') a tensor or list of tensors by the scale factor. + + Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned + unmodified. + + Args: + outputs (Tensor or iterable of Tensors): Outputs to scale. + """ + if not self._enabled: + return outputs + + # Short-circuit for the common case. + if isinstance(outputs, torch.Tensor): + assert outputs.is_cuda or outputs.device.type == 'xla' + if self._scale is None: + self._lazy_init_scale_growth_tracker(outputs.device) + assert self._scale is not None + return outputs * self._scale.to(device=outputs.device, non_blocking=True) + + # Invoke the more complex machinery only if we're treating multiple outputs. + # holds a reference that can be overwritten by apply_scale + stash: List[_MultiDeviceReplicator] = [] + + def apply_scale(val): + if isinstance(val, torch.Tensor): + assert val.is_cuda or val.device.type == 'xla' + if len(stash) == 0: + if self._scale is None: + self._lazy_init_scale_growth_tracker(val.device) + assert self._scale is not None + stash.append(_MultiDeviceReplicator(self._scale)) + return val * stash[0].get(val.device) + elif isinstance(val, abc.Iterable): + iterable = map(apply_scale, val) + if isinstance(val, list) or isinstance(val, tuple): + return type(val)(iterable) + else: + return iterable + else: + raise ValueError("outputs must be a Tensor or an iterable of Tensors") + + return apply_scale(outputs) + + def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): + per_device_inv_scale = _MultiDeviceReplicator(inv_scale) + per_device_found_inf = _MultiDeviceReplicator(found_inf) + + # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype. + # There could be hundreds of grads, so we'd like to iterate through them just once. + # However, we don't know their devices or dtypes in advance. + + # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict + # Google says mypy struggles with defaultdicts type annotations. + per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated] + with torch.no_grad(): + for group in optimizer.param_groups: + for param in group["params"]: + if param.grad is None: + continue + if (not allow_fp16) and param.grad.dtype == torch.float16: + raise ValueError("Attempting to unscale FP16 gradients.") + if param.grad.is_sparse: + # is_coalesced() == False means the sparse grad has values with duplicate indices. + # coalesce() deduplicates indices and adds all values that have the same index. + # For scaled fp16 values, there's a good chance coalescing will cause overflow, + # so we should check the coalesced _values(). + if param.grad.dtype is torch.float16: + param.grad = param.grad.coalesce() + to_unscale = param.grad._values() + else: + to_unscale = param.grad + + # TODO: is there a way to split by device and dtype without appending in the inner loop? + per_device_and_dtype_grads[to_unscale.device][to_unscale.dtype].append(to_unscale) + + for device, per_dtype_grads in per_device_and_dtype_grads.items(): + for grads in per_dtype_grads.values(): + torch._amp_foreach_non_finite_check_and_unscale_(grads, per_device_found_inf.get(device), + per_device_inv_scale.get(device)) + # For tensor parallel paramters it should be all-reduced over tensor parallel process group + if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1: + vals = [val for val in per_device_found_inf._per_device_tensors.values()] + coalesced = _flatten_dense_tensors(vals) + dist.all_reduce(coalesced, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MODEL)) + for buf, synced in zip(vals, _unflatten_dense_tensors(coalesced, vals)): + buf.copy_(synced) + return per_device_found_inf._per_device_tensors + + def unscale_(self, optimizer): + """ + Divides ("unscales") the optimizer's gradient tensors by the scale factor. + + :meth:`unscale_` is optional, serving cases where you need to + :ref:`modify or inspect gradients` + between the backward pass(es) and :meth:`step`. + If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`. + + Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients:: + + ... + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) + scaler.step(optimizer) + scaler.update() + + Args: + optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled. + + .. note:: + :meth:`unscale_` does not incur a CPU-GPU sync. + + .. warning:: + :meth:`unscale_` should only be called once per optimizer per :meth:`step` call, + and only after all gradients for that optimizer's assigned parameters have been accumulated. + Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError. + + .. warning:: + :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute. + """ + if not self._enabled: + return + + self._check_scale_growth_tracker("unscale_") + + optimizer_state = self._per_optimizer_states[id(optimizer)] + + if optimizer_state["stage"] is OptState.UNSCALED: + raise RuntimeError("unscale_() has already been called on this optimizer since the last update().") + elif optimizer_state["stage"] is OptState.STEPPED: + raise RuntimeError("unscale_() is being called after step().") + + # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. + assert self._scale is not None + inv_scale = self._scale.double().reciprocal().float() + found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=self._scale.device) + + optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False) + optimizer_state["stage"] = OptState.UNSCALED + + def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs): + retval = None + if not sum(v.item() for v in optimizer_state["found_inf_per_device"].values()): + retval = optimizer.step(*args, **kwargs) + return retval + + def step(self, optimizer, *args, **kwargs): + """ + :meth:`step` carries out the following two operations: + + 1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer`` + earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs. + 2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled + gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params. + + ``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``. + + Returns the return value of ``optimizer.step(*args, **kwargs)``. + + Args: + optimizer (torch.optim.Optimizer): Optimizer that applies the gradients. + args: Any arguments. + kwargs: Any keyword arguments. + + .. warning:: + Closure use is not currently supported. + """ + if (not self._enabled): + return optimizer.step(*args, **kwargs) + + if "closure" in kwargs: + raise RuntimeError("Closure use is not currently supported if GradScaler is enabled.") + + self._check_scale_growth_tracker("step") + + optimizer_state = self._per_optimizer_states[id(optimizer)] + + if optimizer_state["stage"] is OptState.STEPPED: + raise RuntimeError("step() has already been called since the last update().") + + retval = None + + if (hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling): + # This optimizer has customized scale-handling logic, so we can call optimizer.step() directly. + # The contract with custom optimizers is that their step() should accept an additional, + # optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information: + # it can query its own state, invoke unscale_ on itself, etc + retval = optimizer.step(*args, **dict(kwargs, grad_scaler=self)) + optimizer_state["stage"] = OptState.STEPPED + return retval + + if optimizer_state["stage"] is OptState.READY: + self.unscale_(optimizer) + + assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer." + + retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs) + + optimizer_state["stage"] = OptState.STEPPED + + return retval + + def update(self, new_scale=None): + """ + Updates the scale factor. + + If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` + to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, + the scale is multiplied by ``growth_factor`` to increase it. + + Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not + used directly, it's used to fill GradScaler's internal scale tensor. So if + ``new_scale`` was a tensor, later in-place changes to that tensor will not further + affect the scale GradScaler uses internally.) + + Args: + new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor. + + .. warning:: + :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has + been invoked for all optimizers used this iteration. + """ + if not self._enabled: + return + + _scale, _growth_tracker = self._check_scale_growth_tracker("update") + + if new_scale is not None: + # Accept a new user-defined scale. + if isinstance(new_scale, float): + self._scale.fill_(new_scale) # type: ignore[union-attr] + else: + reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False." + # type: ignore[attr-defined] + assert isinstance(new_scale, torch.cuda.FloatTensor), reason + assert new_scale.numel() == 1, reason + assert new_scale.requires_grad is False, reason + self._scale.copy_(new_scale) # type: ignore[union-attr] + else: + # Consume shared inf/nan data collected from optimizers to update the scale. + # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. + found_infs = [ + found_inf.to(device=_scale.device, non_blocking=True) + for state in self._per_optimizer_states.values() + for found_inf in state["found_inf_per_device"].values() + ] + + assert len(found_infs) > 0, "No inf checks were recorded prior to update." + + found_inf_combined = found_infs[0] + if len(found_infs) > 1: + for i in range(1, len(found_infs)): + found_inf_combined += found_infs[i] + + if self._higher_than_torch18: + torch._amp_update_scale_(_scale, _growth_tracker, found_inf_combined, self._growth_factor, + self._backoff_factor, self._growth_interval) + else: + self._scale = torch._amp_update_scale(_growth_tracker, _scale, found_inf_combined, self._growth_factor, + self._backoff_factor, self._growth_interval) + + # To prepare for next iteration, clear the data collected from optimizers this iteration. + self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) + + def _get_scale_async(self): + return self._scale + + def get_scale(self): + """ + Returns a Python float containing the current scale, or 1.0 if scaling is disabled. + + .. warning:: + :meth:`get_scale` incurs a CPU-GPU sync. + """ + if self._enabled: + return self._init_scale if self._scale is None else self._get_scale_async().item() + else: + return 1.0 + + def get_growth_factor(self): + r""" + Returns a Python float containing the scale growth factor. + """ + return self._growth_factor + + def set_growth_factor(self, new_factor): + r""" + Args: + new_scale (float): Value to use as the new scale growth factor. + """ + self._growth_factor = new_factor + + def get_backoff_factor(self): + r""" + Returns a Python float containing the scale backoff factor. + """ + return self._backoff_factor + + def set_backoff_factor(self, new_factor): + r""" + Args: + new_scale (float): Value to use as the new scale backoff factor. + """ + self._backoff_factor = new_factor + + def get_growth_interval(self): + r""" + Returns a Python int containing the growth interval. + """ + return self._growth_interval + + def set_growth_interval(self, new_interval): + r""" + Args: + new_interval (int): Value to use as the new growth interval. + """ + self._growth_interval = new_interval + + def _get_growth_tracker(self): + if self._enabled: + return self._init_growth_tracker if self._growth_tracker is None else self._growth_tracker.item() + else: + return 0 + + def is_enabled(self): + r""" + Returns a bool indicating whether this instance is enabled. + """ + return self._enabled + + def state_dict(self): + r""" + Returns the state of the scaler as a :class:`dict`. It contains five entries: + + * ``"scale"`` - a Python float containing the current scale + * ``"growth_factor"`` - a Python float containing the current growth factor + * ``"backoff_factor"`` - a Python float containing the current backoff factor + * ``"growth_interval"`` - a Python int containing the current growth interval + * ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps. + + If this instance is not enabled, returns an empty dict. + + .. note:: + If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict` + should be called after :meth:`update`. + """ + return { + "scale": self.get_scale(), + "growth_factor": self._growth_factor, + "backoff_factor": self._backoff_factor, + "growth_interval": self._growth_interval, + "_growth_tracker": self._get_growth_tracker() + } if self._enabled else {} + + def load_state_dict(self, state_dict): + r""" + Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op. + + Args: + state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`. + """ + if not self._enabled: + return + + if len(state_dict) == 0: + raise RuntimeError("The source state dict is empty, possibly because it was saved " + "from a disabled instance of GradScaler.") + + self._init_scale = state_dict["scale"] + if self._scale is not None: + self._scale.fill_(state_dict["scale"]) + self._growth_factor = state_dict["growth_factor"] + self._backoff_factor = state_dict["backoff_factor"] + self._growth_interval = state_dict["growth_interval"] + self._init_growth_tracker = state_dict["_growth_tracker"] + if self._growth_tracker is not None: + self._growth_tracker.fill_(state_dict["_growth_tracker"]) + + def __getstate__(self): + state = self.__dict__.copy() + if self._enabled: + assert len(self._per_optimizer_states) == 0, "A GradScaler instance may only be pickled at the beginning "\ + "of an iteration, or at the end after scaler.update()." + # Pickling _scale and _growth_tracker Tensors directly triggers + # "warnings.warn("pickle support for Storage will be removed in 1.5..." + # so instead, we set the unpickled instance up to reinitialize them lazily. + state['_init_scale'] = self.get_scale() + state['_init_growth_tracker'] = self._get_growth_tracker() + state['_scale'] = None + state['_growth_tracker'] = None + return state + + def __setstate__(self, state): + self.__dict__.update(state) + + def _check_inf_per_device(self, optimizer): + _scale, _ = self._check_scale_growth_tracker("_check_inf_per_device") + + dummy_inv_scale = torch.full((1,), 1.0, dtype=torch.float32, device=_scale.device) + found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=_scale.device) + + self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \ + self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True) + + return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] + + def _found_inf_per_device(self, optimizer): + return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] diff --git a/colossalai/amp/torch_amp/torch_amp.py b/colossalai/amp/torch_amp/torch_amp.py new file mode 100644 index 0000000000000000000000000000000000000000..65718d77c2e00cdaf83ca8c27e9c26caed0d9362 --- /dev/null +++ b/colossalai/amp/torch_amp/torch_amp.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch.cuda.amp as torch_amp +import torch.nn as nn +from torch import Tensor +from torch.nn.modules.loss import _Loss +from torch.optim import Optimizer + +from colossalai.nn.optimizer import ColossalaiOptimizer +from colossalai.utils import clip_grad_norm_fp32 + +from ._grad_scaler import GradScaler + + +class TorchAMPOptimizer(ColossalaiOptimizer): + """A wrapper class which integrate Pytorch AMP with an optimizer + + Args: + optim (torch.optim.Optimizer): A normal optimizer like Adam or SGD. + init_scale (float, optional, default=2.**16): Initial scale factor. + growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during + :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations. + backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during + :meth:`update` if inf/NaN gradients occur in an iteration. + growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients + that must occur for the scale to be multiplied by ``growth_factor``. + enabled (bool, optional, default=True): If ``False``, disables gradient scaling. :meth:`step` simply + invokes the underlying ``optimizer.step()``, and other methods become no-ops. + """ + + def __init__(self, optim: Optimizer, *args, **kwargs): + super().__init__(optim) + self.scaler = GradScaler(*args, **kwargs) + + def backward(self, loss: Tensor): + """Backward with torch amp gradient scaler + + Args: + loss (torch.Tensor): Loss computed by a loss function + """ + self.scaler.scale(loss).backward() + + def step(self): + """Update the parameters of the model + """ + self.scaler.step(self.optim) + self.scaler.update() + + def clip_grad_norm(self, model: nn.Module, max_norm: float): + """Apply gradient clipping to the model parameters + + Args: + model (torch.nn.Module): Your model object + max_norm (float): Max norm value for gradient clipping + """ + if max_norm > 0.0: + self.scaler.unscale_(self.optim) + clip_grad_norm_fp32(model.parameters(), max_norm) + + +class TorchAMPModel(nn.Module): + """A wrapper class for a model object which executes forward with values automatically + cast to fp16 + + Args: + model (:class:`torch.nn.Module`): a torch model instance + """ + + def __init__(self, model: nn.Module) -> None: + super().__init__() + self.model = model + + @torch_amp.autocast() + def forward(self, *args, **kwargs): + """ + Execute forward under the torch amp context + """ + return self.model(*args, **kwargs) + + +class TorchAMPLoss(nn.Module): + """A wrapper class for a criterion object which computes the loss in mixed-precision context + + Args: + loss (torch.nn.modules.loss._Loss): A loss function object + """ + + def __init__(self, loss: _Loss): + super().__init__() + self.loss = loss + + @torch_amp.autocast() + def forward(self, *args, **kwargs): + """ + Execute forward under the torch amp context + """ + return self.loss(*args, **kwargs) diff --git a/colossalai/auto_parallel/README.md b/colossalai/auto_parallel/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8e47e1bb0b4a6e8e86c1e76d600d3dae3c8be251 --- /dev/null +++ b/colossalai/auto_parallel/README.md @@ -0,0 +1,23 @@ +# Colossal-AUTO + +## Challenges +Recently, large models have achieved the state of the art performances in various fields. In order to support large model training, we have to use distributed training techniques. However, finding an efficient distributed execution plan not only requires fine-grained model statistics, such as memory and computing overhead of each operator but also is a labor-intensive task even for an expert in the field of distributed training. + +## Our solution +To simplify the process of distributed training for foundational models, recent advancements in machine learning systems have led to the emergence of automatic parallel systems. We investigate and research a number of current automatic parallel systems( Tofu , Flexflow , Alpa ) and some auto activation checkpoint algorithms( Rotor , Sublinear ). Inspired from these advanced systems, we build an automatic parallel system upon PyTorch framework. The input of the system is the serial PyTorch code, and the output is a PyTorch program with an optimized distributed execution plan. It is worth emphasizing that the output is a regular PyTorch program, so it is compatible with runtime optimization methods, such as ZeRO-Offload and PatrickStar. + +## Key modules + +### Analyzer + +**Analyzer** is a static analysis system consisting of three parts: +A *symbolic profiler* for collecting computing and memory overhead related to static computation graph, a *cluster detector* for collecting hardware characteristics and detecting cluster topology and a *tensor layout manager* to find efficient tensor layout conversion path from different sharding spec and record conversion cost. + +### Solver + +**Solver** is designed to find the optimal execution plan for a given computation graph and cluster in two stages: +1) *Intra-op parallelism stage* is to find the plan with the minimum total execution time of all nodes with respect to the constraint of the memory budget. The optimaztion goal of intra-op parallelism solver is modified from Alpa 's intra-op parallelsim ILP solver. +2) *Activation checkpoint stage* is to search for the fastest execution plan that meets the memory budget on the computation graph after inserting the communication nodes by the intra-op parallelism stage. The algorithm to find optimial activation checkpoint is modified from Rotor . The reason we use two-stage optimization is that if the two tasks are formulated together, the solving time will be significantly increased, which will greatly affect the user experience of the system. On the contrary, solving in two hierarchical levels has many advantages. Firstly, compared with the computation graph with activation checkpointing, the original graph has fewer nodes, which can reduce the solving cost of intra-op parallelism solver. In addition, a more optimal solution can be found by adding the communication overhead into the activation checkpoint modeling. + +### Generator +**Generator** applies the searched execution plan to the computation graph and recompiles the computation graph to optimized PyTorch code. It has *a series compile pass* to insert a communication node or do the kernel substitution as the intra-op parallelism solver required. Additionally, we implement a *code generation* feature to recognize the annotation from the activation checkpoint solver and inject the activation checkpoint block following annotation instructions. diff --git a/colossalai/auto_parallel/__init__.py b/colossalai/auto_parallel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/colossalai/auto_parallel/checkpoint/__init__.py b/colossalai/auto_parallel/checkpoint/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..10ade417a238753af49c3780cd6693b362b7bbb4 --- /dev/null +++ b/colossalai/auto_parallel/checkpoint/__init__.py @@ -0,0 +1,3 @@ +from .ckpt_solver_base import CheckpointSolverBase +from .ckpt_solver_chen import CheckpointSolverChen +from .ckpt_solver_rotor import CheckpointSolverRotor diff --git a/colossalai/auto_parallel/checkpoint/build_c_ext.py b/colossalai/auto_parallel/checkpoint/build_c_ext.py new file mode 100644 index 0000000000000000000000000000000000000000..af4349865a7b8dd748e34458eb3d5aeeb359b599 --- /dev/null +++ b/colossalai/auto_parallel/checkpoint/build_c_ext.py @@ -0,0 +1,16 @@ +import os + +from setuptools import Extension, setup + +this_dir = os.path.dirname(os.path.abspath(__file__)) +ext_modules = [Extension( + 'rotorc', + sources=[os.path.join(this_dir, 'ckpt_solver_rotor.c')], +)] + +setup( + name='rotor c extension', + version='0.1', + description='rotor c extension for faster dp computing', + ext_modules=ext_modules, +) diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py new file mode 100644 index 0000000000000000000000000000000000000000..b388d00ac553726f577575d5d770b98dfb873f12 --- /dev/null +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py @@ -0,0 +1,195 @@ +from abc import ABC, abstractmethod +from copy import deepcopy +from typing import Any, List + +import torch +from torch.fx import Graph, Node + +from colossalai.auto_parallel.passes.runtime_apply_pass import ( + runtime_apply, + runtime_apply_for_iterable_object, + runtime_comm_spec_apply, +) +from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen + +__all___ = ['CheckpointSolverBase'] + + +def _copy_output(src: Graph, dst: Graph): + """Copy the output node from src to dst""" + for n_src, n_dst in zip(src.nodes, dst.nodes): + if n_src.op == 'output': + n_dst.meta = n_src.meta + + +def _get_param_size(module: torch.nn.Module): + """Get the size of the parameters in the module""" + return sum([p.numel() * torch.tensor([], dtype=p.dtype).element_size() for p in module.parameters()]) + + +class CheckpointSolverBase(ABC): + + def __init__( + self, + graph: Graph, + free_memory: float = -1.0, + requires_linearize: bool = False, + cnode: List[str] = None, + optim_multiplier: float = 1.0, + ): + """``CheckpointSolverBase`` class will integrate information provided by the components + and use an existing solver to find a possible optimal strategies combination for target + computing graph. + + Existing Solvers: + Chen's Greedy solver: https://arxiv.org/abs/1604.06174 (CheckpointSolverChen) + Rotor solver: https://hal.inria.fr/hal-02352969 (CheckpointSolverRotor) + + Args: + graph (Graph): The computing graph to be optimized. + free_memory (float): Memory constraint for the solution. + requires_linearize (bool): Whether the graph needs to be linearized. + cnode (List[str], optional): Common node List, should be the subset of input. Default to None. + optim_multiplier (float, optional): The multiplier of extra weight storage for the + ``torch.optim.Optimizer``. Default to 1.0. + + Warnings: + Meta information of the graph is required for any ``CheckpointSolver``. + """ + # super-dainiu: this graph is a temporary graph which can refer to + # the owning module, but we will return another deepcopy of it after + # the solver is executed. + self.graph = deepcopy(graph) + self.graph.owning_module = graph.owning_module + _copy_output(graph, self.graph) + self.graph.set_codegen(ActivationCheckpointCodeGen()) + + # check if has meta information + if any(len(node.meta) == 0 for node in self.graph.nodes): + raise RuntimeError( + "Nodes meta information hasn't been prepared! Please extract from graph before constructing the solver!" + ) + + # parameter memory = parameter size + optimizer extra weight storage + self.free_memory = free_memory - _get_param_size(self.graph.owning_module) * (optim_multiplier + 1) + self.cnode = cnode + self.requires_linearize = requires_linearize + if self.requires_linearize: + self.node_list = self._linearize_graph() + else: + self.node_list = self.get_node_list() + + @abstractmethod + def solve(self): + """Solve the checkpointing problem and return the solution. + """ + pass + + def get_node_list(self): + """Get the node list. + """ + return [[node] for node in self.graph.nodes] + + def _linearize_graph(self) -> List[List[Node]]: + """Linearizing the graph + + Args: + graph (Graph): The computing graph to be optimized. + + Returns: + List[List[Node]]: List of list, each inside list of Node presents + the actual 'node' in linearized manner. + + Remarks: + Do merge the inplace ops and shape-consistency ops into the previous node. + """ + + # Common nodes are type of nodes that could be seen as attributes and remain + # unchanged throughout the whole model, it will be used several times by + # different blocks of model, so that it is hard for us to linearize the graph + # when we encounter those kinds of nodes. We let users to annotate some of the + # input as common node, such as attention mask, and the followings are some of + # the ops that could actually be seen as common nodes. With our common node prop, + # we could find some of the "real" common nodes (e.g. the real attention mask + # used in BERT and GPT), the rule is simple, for node who's parents are all common + # nodes or it's op belongs to the following operations, we view this node as a + # newly born common node. + # List of target name that could be seen as common node + common_ops = ["getattr", "getitem", "size"] + + def _is_cop(target: Any) -> bool: + """Check if an op could be seen as common node + + Args: + target (Any): node target + + Returns: + bool + """ + + if isinstance(target, str): + return target in common_ops + else: + return target.__name__ in common_ops + + def _is_sink() -> bool: + """Check if we can free all dependencies + + Returns: + bool + """ + + def _is_inplace(n: Node): + """Get the inplace argument from ``torch.fx.Node`` + """ + inplace = False + if n.op == "call_function": + inplace = n.kwargs.get("inplace", False) + elif n.op == "call_module": + inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False) + return inplace + + def _is_shape_consistency(n: Node): + """Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``) + """ + return n.target in [runtime_apply, runtime_apply_for_iterable_object, runtime_comm_spec_apply] + + return not sum([v for _, v in deps.items()]) and not any(map(_is_inplace, n.users)) and not any( + map(_is_shape_consistency, n.users)) + + # make sure that item in cnode is valid + if self.cnode: + for name in self.cnode: + try: + assert next(node for node in self.graph.nodes if node.name == name).op == "placeholder", \ + f"Common node {name} is not an input of the model." + except StopIteration: + raise ValueError(f"Common node name {name} not in graph.") + + else: + self.cnode = [] + + deps = {} + node_list = [] + region = [] + + for n in self.graph.nodes: + if n.op != "placeholder" and n.op != "output": + for n_par in n.all_input_nodes: + if n_par.op != "placeholder" and n_par.name not in self.cnode: + deps[n_par] -= 1 + region.append(n) + + # if the node could free all dependencies in graph + # we could begin a new node + if _is_sink(): + node_list.append(region) + region = [] + + # propagate common node attr if possible + if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.cnode + ]) or _is_cop(n.target): + self.cnode.append(n.name) + else: + deps[n] = len([user for user in n.users if user.op != "output"]) + return node_list diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py new file mode 100644 index 0000000000000000000000000000000000000000..19b2ef5987c9ebc160078339b764741a71b34dbf --- /dev/null +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py @@ -0,0 +1,87 @@ +import math +from copy import deepcopy +from typing import List, Set, Tuple + +from torch.fx import Graph, Node + +from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp + +from .ckpt_solver_base import CheckpointSolverBase + +__all__ = ['CheckpointSolverChen'] + + +class CheckpointSolverChen(CheckpointSolverBase): + + def __init__(self, graph: Graph, cnode: List[str] = None, num_grids: int = 6): + """ + This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174. + Note that this algorithm targets at memory optimization only, using techniques in appendix A. + + Usage: + Assume that we have a ``GraphModule``, and we have already done the extractions + to the graph to retrieve all information needed, then we could use the following + code to find a solution using ``CheckpointSolverChen``: + >>> solver = CheckpointSolverChen(gm.graph) + >>> chen_graph = solver.solve() + >>> gm.graph = chen_graph # set the graph to a new graph + + Args: + graph (Graph): The computing graph to be optimized. + cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None. + num_grids (int, optional): Number of grids to search for b. Defaults to 6. + """ + super().__init__(graph, 0, 0, True, cnode) + self.num_grids = num_grids + + def solve(self) -> Graph: + """Solve the checkpointing problem using Algorithm 3. + + Returns: + graph (Graph): The optimized graph, should be a copy of the original graph. + """ + checkpointable_op = ['call_module', 'call_method', 'call_function', 'get_attr'] + ckpt = self.grid_search() + for i, seg in enumerate(ckpt): + for idx in range(*seg): + nodes = self.node_list[idx] + for n in nodes: + if n.op in checkpointable_op: + n.meta['activation_checkpoint'] = i + return deepcopy(self.graph) + + def run_chen_greedy(self, b: int = 0) -> Tuple[Set, int]: + """ + This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174. + """ + ckpt_intv = [] + temp = 0 + x = 0 + y = 0 + prev_idx = 2 + for idx, nodes in enumerate(self.node_list): + for n in nodes: + n: Node + temp += calculate_fwd_in(n) + calculate_fwd_tmp(n) + y = max(y, temp) + if temp > b and idx > prev_idx: + x += calculate_fwd_in(nodes[0]) + temp = 0 + ckpt_intv.append((prev_idx, idx + 1)) + prev_idx = idx + 1 + return ckpt_intv, math.floor(math.sqrt(x * y)) + + def grid_search(self) -> Set: + """ + Search ckpt strategy with b = 0, then run the allocation algorithm again with b = √xy. + Grid search over [√2/2 b, √2 b] for ``ckpt_opt`` over ``num_grids`` as in appendix A. + """ + _, b_approx = self.run_chen_greedy(0) + b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2)) + b_opt = math.inf + for b in range(b_min, b_max, (b_max - b_min) // self.num_grids): + ckpt_intv, b_approx = self.run_chen_greedy(b) + if b_approx < b_opt: + b_opt = b_approx + ckpt_opt = ckpt_intv + return ckpt_opt diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.c b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.c new file mode 100644 index 0000000000000000000000000000000000000000..8dad074bc894d31f4c6c20fabddcff06c57f23cd --- /dev/null +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.c @@ -0,0 +1,209 @@ +#define PY_SSIZE_T_CLEAN +#include + +/* +Rotor solver for checkpointing problem in C. We follow the modeling mentioned in +paper `Optimal checkpointing for heterogeneous chains: how to train deep neural +networks with limited memory` https://hal.inria.fr/hal-02352969. Some lines of +the code are adapted from https://gitlab.inria.fr/hiepacs/rotor. +*/ +long* PySequenceToLongArray(PyObject* pylist) { + if (!(pylist && PySequence_Check(pylist))) return NULL; + Py_ssize_t len = PySequence_Size(pylist); + long* result = (long*)calloc(len + 1, sizeof(long)); + for (Py_ssize_t i = 0; i < len; ++i) { + PyObject* item = PySequence_GetItem(pylist, i); + result[i] = PyLong_AsLong(item); + Py_DECREF(item); + } + result[len] = 0; + return result; +} + +double* PySequenceToDoubleArray(PyObject* pylist) { + if (!(pylist && PySequence_Check(pylist))) return NULL; + Py_ssize_t len = PySequence_Size(pylist); + double* result = (double*)calloc(len + 1, sizeof(double)); + for (Py_ssize_t i = 0; i < len; ++i) { + PyObject* item = PySequence_GetItem(pylist, i); + result[i] = PyFloat_AsDouble(item); + Py_DECREF(item); + } + result[len] = 0; + return result; +} + +long* getLongArray(PyObject* container, const char* attributeName) { + PyObject* sequence = PyObject_GetAttrString(container, attributeName); + long* result = PySequenceToLongArray(sequence); + Py_DECREF(sequence); + return result; +} + +double* getDoubleArray(PyObject* container, const char* attributeName) { + PyObject* sequence = PyObject_GetAttrString(container, attributeName); + double* result = PySequenceToDoubleArray(sequence); + Py_DECREF(sequence); + return result; +} + +static PyObject* computeTable(PyObject* self, PyObject* args) { + PyObject* chainParam; + int mmax; + + if (!PyArg_ParseTuple(args, "Oi", &chainParam, &mmax)) return NULL; + + double* ftime = getDoubleArray(chainParam, "ftime"); + if (!ftime) return NULL; + + double* btime = getDoubleArray(chainParam, "btime"); + if (!btime) return NULL; + + long* x = getLongArray(chainParam, "x"); + if (!x) return NULL; + + long* xbar = getLongArray(chainParam, "xbar"); + if (!xbar) return NULL; + + long* ftmp = getLongArray(chainParam, "btmp"); + if (!ftmp) return NULL; + + long* btmp = getLongArray(chainParam, "btmp"); + if (!btmp) return NULL; + + long chainLength = PyObject_Length(chainParam); + if (!chainLength) return NULL; + +#define COST_TABLE(m, i, l) \ + costTable[(m) * (chainLength + 1) * (chainLength + 1) + \ + (i) * (chainLength + 1) + (l)] + double* costTable = (double*)calloc( + (mmax + 1) * (chainLength + 1) * (chainLength + 1), sizeof(double)); + +#define BACK_PTR(m, i, l) \ + backPtr[(m) * (chainLength + 1) * (chainLength + 1) + \ + (i) * (chainLength + 1) + (l)] + long* backPtr = (long*)calloc( + (mmax + 1) * (chainLength + 1) * (chainLength + 1), sizeof(long)); + + for (long m = 0; m <= mmax; ++m) + for (long i = 0; i <= chainLength; ++i) { + if ((m >= x[i + 1] + xbar[i + 1] + btmp[i]) && + (m >= x[i + 1] + xbar[i + 1] + ftmp[i])) { + COST_TABLE(m, i, i) = ftime[i] + btime[i]; + } else { + COST_TABLE(m, i, i) = INFINITY; + } + } + + for (long m = 0; m <= mmax; ++m) { + for (long d = 1; d <= chainLength; ++d) { + for (long i = 0; i <= chainLength - d; ++i) { + long idx = i + d; + long mmin = x[idx + 1] + x[i + 1] + ftmp[i]; + if (idx > i + 1) { + long maxCostFWD = 0; + for (long j = i + 1; j < idx; j++) { + maxCostFWD = fmaxl(maxCostFWD, x[j] + x[j + 1] + ftmp[j]); + } + mmin = fmaxl(mmin, x[idx + 1] + maxCostFWD); + } + if ((m >= mmin)) { + long bestLeaf = -1; + double sumFw = 0; + double bestLeafCost = INFINITY; + for (long j = i + 1; j <= idx; ++j) { + sumFw += ftime[j - 1]; + if (m >= x[j]) { + double cost = sumFw + COST_TABLE(m - x[j], j, idx) + + COST_TABLE(m, i, j - 1); + if (cost < bestLeafCost) { + bestLeafCost = cost; + bestLeaf = j; + } + } + } + double chainCost = INFINITY; + if (m >= xbar[i + 1]) { + chainCost = + COST_TABLE(m, i, i) + COST_TABLE(m - xbar[i + 1], i + 1, idx); + } + if (bestLeafCost <= chainCost) { + COST_TABLE(m, i, idx) = bestLeafCost; + BACK_PTR(m, i, idx) = bestLeaf; + } else { + COST_TABLE(m, i, idx) = chainCost; + BACK_PTR(m, i, idx) = -1; + } + } else { + COST_TABLE(m, i, idx) = INFINITY; + } + } + } + } + + free(ftime); + free(btime); + free(x); + free(xbar); + free(ftmp); + free(btmp); + + PyObject* pyCostTable = PyList_New(mmax + 1); + PyObject* pyBackPtr = PyList_New(mmax + 1); + + // Convert the result into Python world + for (long m = 0; m <= mmax; ++m) { + PyObject* pyCostTable_m = PyList_New(chainLength + 1); + PyList_SET_ITEM(pyCostTable, m, pyCostTable_m); + PyObject* pyBackPtr_m = PyList_New(chainLength + 1); + PyList_SET_ITEM(pyBackPtr, m, pyBackPtr_m); + for (long i = 0; i <= chainLength; ++i) { + PyObject* pyCostTable_m_i = PyDict_New(); + PyList_SET_ITEM(pyCostTable_m, i, pyCostTable_m_i); + PyObject* pyBackPtr_m_i = PyDict_New(); + PyList_SET_ITEM(pyBackPtr_m, i, pyBackPtr_m_i); + for (long l = i; l <= chainLength; ++l) { + PyObject* pyVar_l = PyLong_FromLong(l); + PyObject* pyCostTable_m_i_l = PyFloat_FromDouble(COST_TABLE(m, i, l)); + PyDict_SetItem(pyCostTable_m_i, pyVar_l, pyCostTable_m_i_l); + Py_DECREF(pyCostTable_m_i_l); + PyObject* pyBackPtr_m_i_l; + if (BACK_PTR(m, i, l) < 0) { + pyBackPtr_m_i_l = Py_BuildValue("(O)", Py_True); + } else { + pyBackPtr_m_i_l = Py_BuildValue("(Ol)", Py_False, BACK_PTR(m, i, l)); + } + PyDict_SetItem(pyBackPtr_m_i, pyVar_l, pyBackPtr_m_i_l); + Py_DECREF(pyBackPtr_m_i_l); + Py_DECREF(pyVar_l); + } + } + } + + free(costTable); + free(backPtr); + + PyObject* result = PyTuple_Pack(2, pyCostTable, pyBackPtr); + Py_DECREF(pyCostTable); + Py_DECREF(pyBackPtr); + return result; +} + +static PyMethodDef rotorMethods[] = { + {"compute_table", computeTable, METH_VARARGS, + "Compute the optimal table with the rotor algorithm."}, + {NULL, NULL, 0, NULL} /* Sentinel */ +}; + +static struct PyModuleDef rotorModule = { + PyModuleDef_HEAD_INIT, "rotorc", /* name of module */ + "A simple implementation of dynamic programming algorithm rotor with C in " + "https://hal.inria.fr/hal-02352969. Some code are adapted from " + "https://gitlab.inria.fr/hiepacs/rotor.", /* module documentation, may be + NULL */ + -1, /* size of per-interpreter state of the module, + or -1 if the module keeps state in global variables. */ + rotorMethods}; + +PyMODINIT_FUNC PyInit_rotorc(void) { return PyModule_Create(&rotorModule); } diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py new file mode 100644 index 0000000000000000000000000000000000000000..21c3bf0da758bd061eaa9bcf08534e9a2df8d6cf --- /dev/null +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py @@ -0,0 +1,439 @@ +from copy import deepcopy +from typing import Any, Dict, List, Tuple + +from torch import Tensor +from torch.fx import Graph, Node + +from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply, runtime_comm_spec_apply +from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions +from colossalai.fx.profiler import ( + activation_size, + calculate_bwd_time, + calculate_fwd_out, + calculate_fwd_time, + calculate_fwd_tmp, +) +from colossalai.logging import get_dist_logger + +from .ckpt_solver_base import CheckpointSolverBase +from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Loss, Sequence + +__all__ = ['CheckpointSolverRotor'] + + +class CheckpointSolverRotor(CheckpointSolverBase): + + def __init__(self, + graph: Graph, + free_memory: float = -1, + cnode: List[str] = None, + memory_slots: int = 500, + optim_multiplier: float = 1.0): + """This is the simple implementation of dynamic programming algorithm rotor + in https://hal.inria.fr/hal-02352969. Some code are adapted from + https://gitlab.inria.fr/hiepacs/rotor. + + Usage: + Assume that we have a ``GraphModule``, and we have already done the extractions + to the graph to retrieve all information needed, then we could use the following + code to find a solution using ``CheckpointSolverRotor``: + >>> solver = CheckpointSolverRotor(gm.graph, free_memory=torch.cuda.mem_get_info(device=0)[0]) + >>> rotor_graph = solver.solve(force_python=True) # otherwise use C solver + >>> gm.graph = rotor_graph # set the graph to a new graph + + Args: + graph (Graph): The computing graph to be optimized. + free_memory (float, optional): Memory constraint for the solution, unit is byte. + Use ``torch.cuda.mem_get_info(device=0)[0]`` to estimate the free_memory. Defaults to -1. + cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None. + memory_slots (int, optional): Number of slots for discretizing memory budget. Defaults to 500. + optim_multiplier (float, optional): The multiplier of extra weight storage for the + ``torch.optim.Optimizer``. Default to 1.0. + """ + super().__init__(graph, free_memory, True, cnode, optim_multiplier) + self.memory_slots = memory_slots + + # construct chain + unit = self.free_memory // self.memory_slots + self.chain = self._construct_chain(self.graph, self.node_list) + self.chain.discretize_all(unit) + + self.cost_table = None + self.back_ptr = None + self.sequence = None + + def solve(self, force_python: bool = False, verbose: bool = False) -> Graph: + """Solve the checkpointing problem using rotor algorithm. + + Args: + force_python (bool, optional): Use Python version of solver, else use C version. Defaults to False. + verbose (bool, optional): Print verbose information. Defaults to False. + + Returns: + graph (Graph): The optimized graph, should be a copy of the original graph. + """ + chain = self.chain + + # compute cost table + if force_python: + self.cost_table, self.back_ptr = self._compute_table(chain, self.memory_slots) + else: + self.cost_table, self.back_ptr = self._compute_table_c(chain, self.memory_slots) + + if verbose: + self.print_chain() + + # backtrack + try: + self.sequence = self._backtrack(chain, 0, len(chain), self.memory_slots - chain.x[0], self.cost_table, + self.back_ptr) + self._annotate_from_sequence(self.sequence, self.node_list) + except ValueError as e: + # using logger to annonce that the solver is failed + logger = get_dist_logger() + logger.warning(f'Checkpoint solver failed: {e}') + raise ValueError + + if verbose: + self.print_sequence() + + return deepcopy(self.graph) + + def print_chain(self): + print('[input]', self.chain.x[0], self.chain.xbar[0], self.chain.ftmp[0], self.chain.btmp[0]) + for idx in range(len(self.node_list) - 1): + print(self.node_list[idx], self.chain.x[idx + 1], self.chain.xbar[idx + 1], self.chain.ftmp[idx], + self.chain.btmp[idx]) + print(f'Chain = {self.chain}') + + def print_sequence(self): + print(f'Sequence = {self.sequence}') + + @classmethod + def _construct_chain(cls, graph: Graph, node_list: List[List[Node]]) -> Chain: + input_tensors = cls._extract_input(graph) + ftime, btime, ftmp, btmp = list(), list(), list(), list() + xbar, x = [activation_size(input_tensors)], [activation_size(input_tensors)] + + for node in node_list: + node_info = cls._extract_node_info(node) + ftime.append(node_info[0]) + btime.append(node_info[1]) + x.append(node_info[2]) + xbar.append(node_info[3]) + ftmp.append(node_info[4]) + btmp.append(node_info[5]) + + # currently we view loss backward temp as zero + btime.append(0) + btmp.append(0) + + return Chain(ftime, btime, x, xbar, ftmp, btmp) + + @classmethod + def _extract_node_info(cls, node: List[Node]) -> Tuple[int, ...]: + """Extract node info from a list of nodes""" + xbar = 0 + ftime = 0 + btime = 0 + fwd_mem_peak = 0 + for n in node: + assert isinstance(n, Node), f'{n} is not a Node' + if n.target == runtime_apply or n.target == runtime_comm_spec_apply: + # in this case we need to calculate memory usage directly based on the statics that hooked in node.meta + xbar += n.meta['fwd_mem_out'] + fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp']) + else: + xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n) + fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'] + cls._extract_unused_output(n)) + + # minimum flop count is required + ftime += max(calculate_fwd_time(n), 1.0) + btime += max(calculate_bwd_time(n), 1.0) + + x = calculate_fwd_out(node[-1]) + xbar = max(x, xbar) + ftmp = fwd_mem_peak - xbar + btmp = cls._extract_btmp(node) + return ftime, btime, x, xbar, ftmp, btmp + + @staticmethod + def _extract_input(graph: Graph) -> Tuple[Tensor, ...]: + """Extract input tensors from a Graph""" + input_tensors = [] + for node in graph.nodes: + if node.op == 'placeholder': + input_tensors.append(node.meta['fwd_out']) + return input_tensors + + @staticmethod + def _extract_unused_output(node: Node) -> int: + """Extract unused output from `torch.fx.Node`""" + return activation_size(node.meta['fwd_out']) - calculate_fwd_out(node) + + @staticmethod + def _extract_btmp(node: List[Node]) -> int: + """Extract btmp from a list of nodes""" + + def _extract_deps_size(): + deps_size = 0 + for k, v in deps.items(): + k: Node + if v > 0: + deps_size += k.meta['bwd_mem_out'] + if v == float('-inf'): + deps_size -= calculate_fwd_tmp(k) + calculate_fwd_out(k) + + return deps_size + + btmp = 0 + deps = {} + for n in reversed(node): + deps[n] = len(n.all_input_nodes) + btmp = max(btmp, _extract_deps_size() + n.meta['bwd_mem_tmp']) + for child in n.users: + if child in deps: + deps[child] -= 1 + if deps[child] <= 0: + deps[child] = float('-inf') # free + return btmp + + @staticmethod + def _compute_table(chain: Chain, mmax: int) -> Tuple: + """Compute the table using dynamic programming. Returns the cost table and the backtracking pointer. + + Args: + chain (Chain): A basic linearized structure for solving the dynamic programming problem. + mmax (int): Maximum number of memory slots. + + Returns: + cost_table (List): cost_table[m][lhs][rhs] indicates the optimal cost of the subproblem from lhs to rhs + with m memory slots. + back_ptr (List): back_ptr[m][lhs][rhs] indicates the best operation at this point. It is (True,) if the optimal choice + is a chain checkpoint, it is (False, j) if the optimal choice is a leaf checkpoint of length j + """ + + ftime = chain.ftime + [0.0] + btime = chain.btime + x = chain.x + [0] + xbar = chain.xbar + [0] + ftmp = chain.ftmp + [0] + btmp = chain.btmp + [0] + + # Build table + cost_table = [[{} for _ in range(len(chain) + 1)] for _ in range(mmax + 1)] + back_ptr = [[{} for _ in range(len(chain) + 1)] for _ in range(mmax + 1)] + + # Initialize corner cases where length of sequence equals to 1, i.e. lhs == rhs + for m in range(mmax + 1): + for i in range(len(chain) + 1): + limit = max(x[i + 1] + xbar[i + 1] + ftmp[i], x[i + 1] + xbar[i + 1] + btmp[i]) + if m >= limit: + cost_table[m][i][i] = ftime[i] + btime[i] + else: + cost_table[m][i][i] = float("inf") + + # Compute tables + for m in range(mmax + 1): + for d in range(1, len(chain) + 1): + for i in range(len(chain) + 1 - d): + idx = i + d + mmin = x[idx + 1] + x[i + 1] + ftmp[i] + if idx > i + 1: + mmin = max(mmin, x[idx + 1] + max(x[j] + x[j + 1] + ftmp[j] for j in range(i + 1, idx))) + if m < mmin: + cost_table[m][i][idx] = float("inf") + else: + leaf_checkpoints = [(j, + sum(ftime[i:j]) + cost_table[m - x[j]][j][idx] + cost_table[m][i][j - 1]) + for j in range(i + 1, idx + 1) + if m >= x[j]] + if leaf_checkpoints: + best_leaf = min(leaf_checkpoints, key=lambda t: t[1]) + else: + best_leaf = None + if m >= xbar[i + 1]: + chain_checkpoint = cost_table[m][i][i] + cost_table[m - xbar[i + 1]][i + 1][idx] + else: + chain_checkpoint = float("inf") + if best_leaf and best_leaf[1] <= chain_checkpoint: + cost_table[m][i][idx] = best_leaf[1] + back_ptr[m][i][idx] = (False, best_leaf[0]) + else: + cost_table[m][i][idx] = chain_checkpoint + back_ptr[m][i][idx] = (True,) + return cost_table, back_ptr + + @staticmethod + def _compute_table_c(chain: Chain, mmax: int) -> Tuple: + try: + from .rotorc import compute_table + + # build module if module not found + except ModuleNotFoundError: + import os + import subprocess + import sys + logger = get_dist_logger() + logger.info("rotorc hasn't been built! Building library...", ranks=[0]) + this_dir = os.path.dirname(os.path.abspath(__file__)) + result = subprocess.Popen( + [ + f"{sys.executable}", f"{os.path.join(this_dir, 'build_c_ext.py')}", "build_ext", + f"--build-lib={this_dir}" + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + if result.wait() == 0: + logger.info("rotorc has been built!", ranks=[0]) + from .rotorc import compute_table + else: + logger.warning("rotorc built failed! Using python version!", ranks=[0]) + return CheckpointSolverRotor._compute_table(chain, mmax) + return compute_table(chain, mmax) + + @staticmethod + def _backtrack(chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[Any], + back_ptr: List[Any]) -> "Sequence": + """Backtrack the cost table and retrieve the optimal checkpointing strategy. + + Args: + chain (Chain): A basic linearized structure for solving the dynamic programming problem. + lhs (int): The left index of the interval to backtrack. + rhs (int): The right index of the interval to backtrack. + budget (int): The memory budget for processing this interval. + cost_table (List[Any]): See ``._compute_table()`` for definitions + back_ptr (List[Any]): See ``._compute_table()`` for definitions + + Raises: + ValueError: Can not process the chain. + + Returns: + sequence (Sequence): The sequence of executing nodes with checkpoints. + """ + if budget <= 0: + raise ValueError(f"Can not process a chain with negative memory {budget}") + elif cost_table[budget][lhs][rhs] == float("inf"): + raise ValueError(f"Can not process this chain from index {lhs} to {rhs} with memory {budget}") + + sequence = Sequence() + if rhs == lhs: + if lhs == len(chain): + sequence += [Loss()] + else: + sequence += [ForwardEnable(lhs), Backward(lhs)] + return sequence + + if back_ptr[budget][lhs][rhs][0]: + sequence += [ + ForwardEnable(lhs), + CheckpointSolverRotor._backtrack(chain, lhs + 1, rhs, budget - chain.xbar[lhs + 1], cost_table, + back_ptr), + Backward(lhs), + ] + else: + best_leaf = back_ptr[budget][lhs][rhs][1] + sequence += [ForwardCheck(lhs)] + sequence += [ForwardNograd(k) for k in range(lhs + 1, best_leaf)] + sequence += [ + CheckpointSolverRotor._backtrack(chain, best_leaf, rhs, budget - chain.x[best_leaf], cost_table, + back_ptr), + CheckpointSolverRotor._backtrack(chain, lhs, best_leaf - 1, budget, cost_table, back_ptr), + ] + return sequence + + @staticmethod + def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]): + """Annotate the nodes in the ``node_list`` with activation checkpoint from the sequence. + + Args: + sequence (Sequence): The sequence of executing nodes with activation checkpoint annotations. + node_list (List[List[Node]]): The list of nodes to annotate. + """ + op_list = sequence.list_operations() + loss_op = next(op for op in op_list if isinstance(op, Loss)) + fwd_list = op_list[:op_list.index(loss_op)] + bwd_list = op_list[op_list.index(loss_op) + 1:] + ckpt_idx = 0 + in_ckpt = False + ckpt_region = [] + + # forward annotation + for idx, op in enumerate(fwd_list, 0): + if in_ckpt: + if isinstance(op, ForwardNograd): + ckpt_region.append(idx) + + elif isinstance(op, ForwardEnable): + in_ckpt = False + for node_idx in ckpt_region: + for n in node_list[node_idx]: + n.meta['activation_checkpoint'] = [ckpt_idx] + + ckpt_idx += 1 + ckpt_region = [] + + elif isinstance(op, ForwardCheck): + for node_idx in ckpt_region: + for n in node_list[node_idx]: + n.meta['activation_checkpoint'] = [ckpt_idx] + + ckpt_idx += 1 + ckpt_region = [idx] + + else: + if isinstance(op, ForwardCheck): + in_ckpt = True + ckpt_region.append(idx) + + # annotate the backward if there is any nested activation checkpoint + in_recompute = False + for op in bwd_list: + if in_recompute: + if isinstance(op, ForwardNograd): + ckpt_region.append(op.index) + + elif isinstance(op, ForwardEnable): + for node_idx in ckpt_region: + for n in node_list[node_idx]: + n.meta['activation_checkpoint'].append(ckpt_idx) + + ckpt_idx += 1 + ckpt_region = [] + + elif isinstance(op, ForwardCheck): + for node_idx in ckpt_region: + for n in node_list[node_idx]: + n.meta['activation_checkpoint'].append(ckpt_idx) + + ckpt_idx += 1 + ckpt_region = [op.index] + + elif isinstance(op, Backward): + for node_idx in ckpt_region: + for n in node_list[node_idx]: + n.meta['activation_checkpoint'].append(ckpt_idx) + + in_recompute = False + + else: + if not isinstance(op, Backward): + in_recompute = True + ckpt_idx = 0 + ckpt_region = [] + if isinstance(op, ForwardCheck): + ckpt_region.append(op.index) + + # postprocess, make sure every activation checkpoint label in the + # same activation checkpoint region (level = 0) has the same length + op_list = [] + for node in node_list: + op_list += node + ckpt_regions = _find_nested_ckpt_regions(op_list) + for (start_idx, end_idx) in ckpt_regions: + nested_length = max( + len(op_list[idx].meta['activation_checkpoint']) for idx in range(start_idx, end_idx + 1)) + for idx in range(start_idx, end_idx + 1): + op_list[idx].meta['activation_checkpoint'] += [None] * (nested_length - + len(op_list[idx].meta['activation_checkpoint'])) diff --git a/colossalai/auto_parallel/checkpoint/operation.py b/colossalai/auto_parallel/checkpoint/operation.py new file mode 100644 index 0000000000000000000000000000000000000000..ab0c6c5ad38d171d470931aa9b2bdedf6cd17668 --- /dev/null +++ b/colossalai/auto_parallel/checkpoint/operation.py @@ -0,0 +1,184 @@ +import math +from abc import ABC +from typing import Any, Iterable, List + +from torch.utils._pytree import tree_map + + +class Chain: + + def __init__(self, + ftime: List[float], + btime: List[float], + x: List[int], + xbar: List[int], + ftmp: List[int], + btmp: List[int], + check_consistency: bool = True): + """The chain is a basic linearized structure for solving the dynamic programming problem for activation checkpoint. + See paper https://hal.inria.fr/hal-02352969 for details. + + Args: + ftime (List[float]): The forward time of each node. + btime (List[float]): The backward time of each node. + x (List[int]): The forward memory of each node (if save_output). Same as `a` in the paper. + xbar (List[int]): The forward memory of each node (if save_all). Same as `a_bar` in the paper. + ftmp (List[int]): The temporary forward memory of each node. + btmp (List[int]): The temporary backward memory of each node, can be used to control memory budget. + check_consistency (bool, optional): Check the lengths consistency for the `Chain`. Defaults to True. + """ + self.ftime = ftime + self.btime = btime + self.x = x + self.xbar = xbar + self.ftmp = ftmp + self.btmp = btmp + if check_consistency and not self.check_lengths(): + raise AttributeError("In Chain, input lists do not have consistent lengths") + + def check_lengths(self): + return ((len(self.ftime) == len(self)) and (len(self.btime) == len(self) + 1) and (len(self.x) == len(self) + 1) + and (len(self.ftmp) == len(self)) and (len(self.btmp) == len(self) + 1) + and (len(self.xbar) == len(self) + 1)) + + def __repr__(self): + chain_list = [] + for i in range(len(self)): + chain_list.append((self.ftime[i], self.btime[i], self.x[i], self.xbar[i], self.ftmp[i], self.btmp[i])) + i = len(self) + chain_list.append((None, self.btime[i], self.x[i], self.xbar[i], None, self.btmp[i])) + return chain_list.__repr__() + + def __len__(self): + return len(self.ftime) + + def discretize_all(self, unit: int): + """Discretize the chain into a list of chains according to unit size.""" + discretizer = lambda val: math.ceil(val / unit) + self.x = tree_map(discretizer, self.x) + self.xbar = tree_map(discretizer, self.xbar) + self.ftmp = tree_map(discretizer, self.ftmp) + self.btmp = tree_map(discretizer, self.btmp) + + +class Operation(ABC): + name = "Op" + + def __repr__(self) -> str: + return f"{self.name}_{self.index}" + + def shift(self, value): + if type(self.index) is tuple: + self.index = tuple(x + value for x in self.index) + else: + self.index += value + + +class Forward(Operation): + name = "F" + + def __init__(self, index): + self.index = index + + def cost(self, chain: Chain): + if chain is not None: + return chain.ftime[self.index] + else: + return 1 + + +class ForwardEnable(Forward): + name = "Fe" + + +class ForwardNograd(Forward): + name = "Fn" + + +class ForwardCheck(Forward): + name = "CF" + + +class Forwards(Operation): + + def __init__(self, start, end): + self.index = (start, end) + + def __repr__(self): + return "F_{i}->{j}".format(i=self.index[0], j=self.index[1]) + + def cost(self, chain: Chain): + if chain is not None: + return sum(chain.ftime[self.index[0]:self.index[1] + 1]) + else: + return (self.index[1] - self.index[0] + 1) + + +def isForward(op): + return type(op) is Forward or type(op) is Forwards + + +class Backward(Operation): + name = "B" + + def __init__(self, index): + self.index = index + + def cost(self, chain: Chain): + if chain is not None: + return chain.btime[self.index] + else: + return 1 + + +class Loss(Operation): + + def __init__(self): + pass + + def __repr__(self): + return "L" + + def cost(self, chain): + return 0 + + +class MemoryAccess(Operation): + name = "MA" + + def __init__(self, index): + self.index = index + + def cost(self, chain: Chain): + return 0 + + +class WriteMemory(MemoryAccess): + name = "WM" + + +class ReadMemory(MemoryAccess): + name = "RM" + + +class DiscardMemory(MemoryAccess): + name = "DM" + + +class Sequence(list): + + def __init__(self): + super().__init__() + + def __repr__(self): + return repr(self.list_operations()) + + def list_operations(self): + op_list = [] + for x in self: + if isinstance(x, Operation): + op_list.append(x) + else: + assert isinstance(x, Sequence) + op_list += x.list_operations() + return op_list diff --git a/colossalai/auto_parallel/meta_profiler/__init__.py b/colossalai/auto_parallel/meta_profiler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3741d8e5a8adcf20cb963dda760731911b719c98 --- /dev/null +++ b/colossalai/auto_parallel/meta_profiler/__init__.py @@ -0,0 +1,3 @@ +from .meta_registry import * +from .registry import meta_register +from .shard_metainfo import * diff --git a/colossalai/auto_parallel/meta_profiler/constants.py b/colossalai/auto_parallel/meta_profiler/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..35b8c13ee8fff717df39a96c60fa101eb0b2a781 --- /dev/null +++ b/colossalai/auto_parallel/meta_profiler/constants.py @@ -0,0 +1,15 @@ +import operator + +import torch +import torch.nn as nn + +from ..tensor_shard.constants import * + +# list of inplace module +INPLACE_MODULE = [nn.ReLU] + +# list of inplace operations +INPLACE_OPS = [torch.flatten] + +# list of operations that do not save forward activations +NO_SAVE_ACTIVATION = [torch.add, torch.sub, operator.add, operator.sub] diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/__init__.py b/colossalai/auto_parallel/meta_profiler/meta_registry/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4d8b656e17e176670cc98755fb4ae93a9832573d --- /dev/null +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/__init__.py @@ -0,0 +1,10 @@ +from .activation import * +from .binary_elementwise_ops import * +from .conv import * +from .embedding import * +from .linear import * +from .non_spmd import * +from .norm import * +from .pooling import * +from .tensor import * +from .where import * diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..0f2e9e44f91cedfdb888171dd373d4d6163f579f --- /dev/null +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py @@ -0,0 +1,85 @@ +from typing import Callable, List, Tuple + +import torch + +from colossalai._analyzer._subclasses.flop_tensor import ewise_flop_counter as elementwise_flop_counter +from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem + +from ..registry import meta_register + +__all__ = ["elementwise_meta_info"] + + +def elementwise_meta_info(temp_mem_scale: float = 0, buffer_mem_scale: float = 0) -> Callable: + """This is a function to create the meta information generator for elementwise operations + + Args: + temp_mem_scale (float, optional): temp memory scaling factor for backward. Defaults to 0. + buffer_mem_scale (float, optional): buffer memory scaling factor for forward. Defaults to 0. + + Returns: + Callable: meta information generator + """ + + def meta_func(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: + input_tensor = next( + filter( + lambda x: + (x.type == OperationDataType.ARG or x.type == OperationDataType.PARAM) and x.name != 'softmax_dim', + args)).data + output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data + is_inplace = 1 if kwargs.get('inplace', False) else 0 + + flop_counter = elementwise_flop_counter(1, 0) + # calculate compute cost + fwd_compute_cost = flop_counter([input_tensor], [output_tensor]) + bwd_compute_cost = flop_counter([output_tensor], [input_tensor]) + + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, + bwd=bwd_compute_cost, + total=fwd_compute_cost + bwd_compute_cost) + + # calculate memory cost + # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward + # NOTE: if in_place is True, we will not create a new tensor in forward + fwd_memory_cost = MemoryCost(activation=activation_size(input_tensor) * (2 - is_inplace), + parameter=0, + temp=0, + buffer=activation_size(input_tensor) * buffer_mem_scale) + + # temp_mem_scale is for situation like softmax backward + # the buffer will be removed during backward phase + bwd_memory_cost = MemoryCost( + activation=activation_size(input_tensor) - activation_size(input_tensor) * buffer_mem_scale, + parameter=0, + temp=activation_size(input_tensor) * temp_mem_scale + activation_size(input_tensor) * buffer_mem_scale, + buffer=0) + + # total cost is the sum of forward and backward cost + total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, + parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter, + temp=fwd_memory_cost.temp + bwd_memory_cost.temp, + buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer) + + memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) + + # store fwd_in, fwd_buffer, fwd_out + fwd_in = [] + fwd_buffer = [torch.zeros_like(output_tensor, device='meta')] + fwd_out = [torch.zeros_like(output_tensor, device='meta')] + + return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out + + return meta_func + + +# register meta information +# (0, 0) +meta_register.register([torch.nn.ReLU, torch.nn.functional.relu, torch.tanh])(elementwise_meta_info(0, 0)) + +# (1, 0) +meta_register.register([torch.nn.Softmax, torch.nn.functional.softmax])(elementwise_meta_info(1, 0)) + +# (0, 0.25) for dropout, the buffer is in bool type so that the buffer memory cost is 0.25 times of input tensor +meta_register.register([torch.nn.Dropout, torch.nn.functional.dropout])(elementwise_meta_info(0, 0.25)) diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..e451748512b9abebcc4f63ad854be3f129ee52bd --- /dev/null +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py @@ -0,0 +1,66 @@ +from typing import List, Tuple + +import torch + +from colossalai._analyzer._subclasses.flop_tensor import flop_mapping +from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem + +from ..constants import BCAST_FUNC_OP, NO_SAVE_ACTIVATION +from ..registry import meta_register + +__all__ = ['binary_elementwise_meta_info'] + + +@meta_register.register(BCAST_FUNC_OP) +def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: + """Meta information generator for binary elementwise operations + NOTE: Some of the binary elementwise operations will discard the input activation after computation, as they + don't need those tensors for back propagation, for example, if there are two tensors being sent for `torch.add`, + they will be discarded right after add operation is done. We create a simple API in `ShardMetaInfo` class to identify + this behavior, it is critical for better memory estimation. + + Returns: + Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs + """ + + input_op_data = [arg for arg in args if arg.type != OperationDataType.OUTPUT] + output_op_data = next(filter(lambda arg: arg.type == OperationDataType.OUTPUT, args)) + + # construct forward args for flop mapping + fwd_in_args = [opdata.data for opdata in input_op_data] + fwd_out_args = [output_op_data.data] + + # calculate cost + + # calculate compute cost + # NOTE: we set bwd_compute_cost two times of fwd_compute_cost in this case + fwd_compute_cost = flop_mapping[torch.ops.aten.add.Tensor](fwd_in_args, fwd_out_args) + bwd_compute_cost = fwd_compute_cost * 2 + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) + + # calculate memory cost + param_mem_cost = activation_size([arg.data for arg in input_op_data if arg.type == OperationDataType.PARAM]) + fwd_mem_cost = MemoryCost( + activation=activation_size(output_op_data.data), + parameter=param_mem_cost, + ) + bwd_mem_cost = MemoryCost( + activation=activation_size(fwd_in_args), + parameter=param_mem_cost, + ) + + # total cost + total_mem_cost = MemoryCost( + activation=fwd_mem_cost.activation + bwd_mem_cost.activation, + parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter, + ) + + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + + # store fwd_in, fwd_buffer, fwd_out + fwd_in = [] + fwd_buffer = [] + fwd_out = [torch.zeros_like(output_op_data.data, device='meta')] + + return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py b/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..4336bf68363c8a708b877c8d8116c44986a85592 --- /dev/null +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py @@ -0,0 +1,137 @@ +from typing import Callable, Dict, List, Tuple, Union + +import torch + +from colossalai._analyzer._subclasses.flop_tensor import flop_mapping +from colossalai._analyzer.fx.node_util import compute_size_in_bytes +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + MemoryCost, + OperationData, + OperationDataType, + ShardingStrategy, + StrategiesVector, + TrainCycleItem, +) +from colossalai.tensor.sharding_spec import ShardingSpec + +from ..registry import meta_register + +__all__ = ['convnd_meta_info'] + + +@meta_register.register(torch.nn.Conv1d) +@meta_register.register(torch.nn.Conv2d) +@meta_register.register(torch.nn.Conv3d) +@meta_register.register(torch.nn.functional.conv1d) +@meta_register.register(torch.nn.functional.conv2d) +@meta_register.register(torch.nn.functional.conv3d) +def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: + """torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d meta info generator + The atens graph of torch.nn.Convnd with bias is + graph(): + %input_2 : [#users=2] = placeholder[target=placeholder](default=) + %convolution_default : [#users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%input_2, None, None, [None, None, None], [None, None, None], [None, None, None], None, [None, None, None], None), kwargs = {}) + %zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%convolution_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None}) + %detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {}) + %convolution_backward_default : [#users=3] = call_function[target=torch.ops.aten.convolution_backward.default](args = (%zeros_like_default, %detach_default, None, [None], [None, None, None], [None, None, None], [None, None, None], None, [None, None, None], None, [None, None, None]), kwargs = {}) + %detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {}) + %detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {}) + %detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {}) + %detach_default_4 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_3,), kwargs = {}) + %detach_default_5 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {}) + %detach_default_6 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_5,), kwargs = {}) + + The atens graph of torch.nn.Convnd without bias is + graph(): + %input_2 : [#users=2] = placeholder[target=placeholder](default=) + %convolution_default : [#users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%input_2, None, None, [None, None], [None, None], [None, None], None, [None, None], None), kwargs = {}) + %zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%convolution_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None}) + %detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {}) + %convolution_backward_default : [#users=2] = call_function[target=torch.ops.aten.convolution_backward.default](args = (%zeros_like_default, %detach_default, None, [None], [None, None], [None, None], [None, None], None, [None, None], None, [None, None, None]), kwargs = {}) + %detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {}) + %detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {}) + %detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {}) + %detach_default_4 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_3,), kwargs = {}) + + Returns: + Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs + """ + + has_bias: bool = False + input_tensor = args[0].data + output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data + if len(args) == 4: + weight_tensors = [args[1].data, args[3].data] + else: + weight_tensors = [args[1].data] + + # check if conv has bias + if len(weight_tensors) > 1: + has_bias = True + # bias tensor's shape only has one dimension + if len(weight_tensors[0].shape) == 1: + bias_tensor, weight_tensor = weight_tensors + else: + weight_tensor, bias_tensor = weight_tensors + + else: + weight_tensor = weight_tensors[0] + + # construct input args for forward + fwd_args = [None] * 9 + + # weight and input + fwd_args[0] = input_tensor + fwd_args[1] = weight_tensor + fwd_args[2] = bias_tensor if has_bias else None + + # transpose indicator should be set to False + fwd_args[6] = False + + # construct input args for backward + bwd_args = [None] * 11 + + # weight and input + bwd_args[0] = output_tensor + bwd_args[1] = input_tensor + bwd_args[2] = weight_tensor + bwd_args[-1] = [True, True, True] if has_bias else [True, True, False] + + # calculate cost + # the fwd op with compute cost is convolution.default + # the bwd op with compute cost is convolution_backward.default + + # calculate compute cost + fwd_compute_cost = flop_mapping[torch.ops.aten.convolution.default](fwd_args, (output_tensor,)) + bwd_compute_cost = flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor, bias_tensor)) if has_bias else \ + flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor)) + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) + + # calculate memory cost + # TODO: use profiler to check conv temp memory + # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward + fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]) + if has_bias else compute_size_in_bytes(weight_tensor), + temp=0, + buffer=0) + + bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]) + if has_bias else compute_size_in_bytes([input_tensor, weight_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]) + if has_bias else compute_size_in_bytes(weight_tensor), + temp=0, + buffer=0) + + # total cost is the sum of forward and backward cost + total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, + parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter) + + memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) + + # store fwd_in, fwd_buffer, fwd_out + fwd_in = [torch.zeros_like(input_tensor, device='meta')] + fwd_buffer = [] + fwd_out = [torch.zeros_like(output_tensor, device='meta')] + + return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py b/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..d5d80f5b3700b6b644c0b630496bd907c0b5aac2 --- /dev/null +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py @@ -0,0 +1,52 @@ +from typing import List, Tuple + +import torch + +from colossalai._analyzer._subclasses.flop_tensor import flop_mapping +from colossalai._analyzer.fx.node_util import compute_size_in_bytes +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem + +from ..registry import meta_register + +__all__ = ["embedding_meta_info"] + + +@meta_register.register(torch.nn.Embedding) +def embedding_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: + """torch.nn.Embedding metainfo generator + + Returns: + Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs + """ + input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data + weight_tensor = next(filter(lambda x: x.type == OperationDataType.PARAM, args)).data + output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data + + # compute cost + fwd_compute_cost = flop_mapping[torch.ops.aten.embedding.default]([weight_tensor, input_tensor], [output_tensor]) + bwd_compute_cost = flop_mapping[torch.ops.aten.embedding_dense_backward.default]([output_tensor, weight_tensor], + [weight_tensor]) + + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) + + # memory cost + # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward + # NOTE: during the backward phase of torch.nn.Embedding, it seems when the input is large enough, it will + # have a temp memory which is kind of weird and we don't know the reason yet, so currently we just assume + # that there will be no temp memory, as the temp memory is significantly smaller than the gradient memory + fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]), + parameter=0, + temp=0, + buffer=0) + bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([weight_tensor]), parameter=0, temp=0, buffer=0) + + total_memory_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation) + + memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_memory_cost) + + # store fwd_in, fwd_buffer, fwd_out + fwd_in = [torch.zeros_like(input_tensor)] + fwd_buffer = [] + fwd_out = [torch.zeros_like(output_tensor)] + + return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py b/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..7697fc6c383d8154acfe76dba7d8baec225930ac --- /dev/null +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py @@ -0,0 +1,406 @@ +from functools import reduce +from typing import Callable, Dict, List, Tuple, Union + +import torch + +from colossalai._analyzer._subclasses.flop_tensor import flop_mapping +from colossalai._analyzer.fx.node_util import compute_size_in_bytes +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + MemoryCost, + OperationData, + OperationDataType, + ShardingStrategy, + StrategiesVector, + TrainCycleItem, +) +from colossalai.tensor.sharding_spec import ShardingSpec + +from ..registry import meta_register + +__all__ = ['linear_meta_info', 'matmul_meta_info'] + + +@meta_register.register(torch.nn.functional.linear) +@meta_register.register(torch.nn.Linear) +def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: + """torch.nn.Linear & torch.nn.functional.linear meta info generator + NOTE: currently we separate the bias part from the biased linear ops, we will consider the memory consumption in add metainfo generator, + but we will hold the bias mechanism in the linear metainfo generator for future use. + + graph(): + %input_2 : [#users=2] = placeholder[target=placeholder](default=) + %addmm_default : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (None, %input_2, None), kwargs = {}) + %zeros_like_default : [#users=3] = call_function[target=torch.ops.aten.zeros_like.default](args = (%addmm_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None}) + %detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {}) + %mm_default : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%zeros_like_default, None), kwargs = {}) + %t_default : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%zeros_like_default,), kwargs = {}) + %mm_default_1 : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%t_default, %detach_default), kwargs = {}) + %t_default_1 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%mm_default_1,), kwargs = {}) + %sum_dim_int_list : [#users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%zeros_like_default, [None], None), kwargs = {}) + %view_default : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%sum_dim_int_list, [None]), kwargs = {}) + %detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%view_default,), kwargs = {}) + %detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {}) + %detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%mm_default,), kwargs = {}) + %detach_default_4 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_3,), kwargs = {}) + %t_default_2 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%t_default_1,), kwargs = {}) + %detach_default_5 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%t_default_2,), kwargs = {}) + %detach_default_6 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_5,), kwargs = {}) + + The one without bias is + graph(): + %input_2 : [#users=2] = placeholder[target=placeholder](default=) + %mm_default : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%input_2, None), kwargs = {}) + %zeros_like_default : [#users=2] = call_function[target=torch.ops.aten.zeros_like.default](args = (%mm_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None}) + %detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {}) + %t_default : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%zeros_like_default,), kwargs = {}) + %mm_default_1 : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%t_default, %detach_default), kwargs = {}) + %t_default_1 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%mm_default_1,), kwargs = {}) + %mm_default_2 : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%zeros_like_default, None), kwargs = {}) + %detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%mm_default_2,), kwargs = {}) + %detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {}) + %t_default_2 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%t_default_1,), kwargs = {}) + %detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%t_default_2,), kwargs = {}) + %detach_default_4 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_3,), kwargs = {}) + + Returns: + Tuple[TrainCycleItem, TrainCycleItem, bool]: compute cost, memory cost and forward inputs + """ + + has_bias: bool = False + + input_tensor = args[0].data + output_tensor = args[2].data + if len(args) == 4: + weight_tensors = [args[1].data, args[3].data] + else: + weight_tensors = [args[1].data] + + # process the dimension of input and output + if len(input_tensor.shape) > 2: + input_tensor: torch.Tensor + input_tensor = input_tensor.view(-1, input_tensor.shape[-1]) + + if len(output_tensor.shape) > 2: + output_tensor: torch.Tensor + output_tensor = output_tensor.view(-1, output_tensor.shape[-1]) + + if len(weight_tensors) > 1: + has_bias = True + if len(weight_tensors[0].shape) == 2: + weight_tensor, bias_tensor = weight_tensors + else: + bias_tensor, weight_tensor = weight_tensors + else: + weight_tensor = weight_tensors[0] + + if has_bias: + # calculate cost with bias + # the fwd op with compute cost is addmm + # the bwd op with compute cost is mm * 2 and sum.dim_IntList + + # calculate compute cost + fwd_compute_cost = flop_mapping[torch.ops.aten.addmm.default]( + [bias_tensor, input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,)) + bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + \ + flop_mapping[torch.ops.aten.mm.default]([torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)) + \ + flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], (bias_tensor,)) + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, + bwd=bwd_compute_cost, + total=fwd_compute_cost + bwd_compute_cost) + + # calculate memory cost + # NOTE: Linear don't have buffer and temp in forward and backward phase + # the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor and bias_tensor + # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward + fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), + temp=0, + buffer=0) + + # the backward activation cost is the size of input_tensor, weight_tensor and bias_tensor, parameter cost is 0 + bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), + temp=0, + buffer=0) + + # total cost is to sum the forward and backward cost + total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, + parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter) + + memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) + + else: + # calculate cost without bias + # the fwd op with compute cost is mm + # the bwd op with compute cost is mm * 2 + + # calculate compute cost + fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]( + [input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,)) + bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + \ + flop_mapping[torch.ops.aten.mm.default]([torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)) + + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, + bwd=bwd_compute_cost, + total=fwd_compute_cost + bwd_compute_cost) + + # calculate memory cost + # NOTE: Linear don't have buffer and temp in forward and backward phase + # the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor + # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward + fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]), + parameter=compute_size_in_bytes(weight_tensor), + temp=0, + buffer=0) + + # the backward activation cost is the size of input_tensor and weight_tensor, parameter cost is 0 + bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor]), + parameter=compute_size_in_bytes(weight_tensor), + temp=0, + buffer=0) + + # total cost is to sum the forward and backward cost + total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, + parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter) + + memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) + + # store fwd_in, fwd_buffer, fwd_out + fwd_in = [torch.zeros_like(input_tensor, device='meta')] + fwd_buffer = [] + fwd_out = [torch.zeros_like(output_tensor, device='meta')] + + return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out + + +@meta_register.register(torch.matmul) +def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: + """torch.matmul meta info generator + There are several cases for torch.matmul: + 1. Vector-vector multiplication => no temp memory, forward memory cost is 1 element (could be neglected), backward memory cost is the same + as two input vectors. + 2. Matrix-vector multiplication => if the first input is matrix, no temp memory is needed, otherwise, there is a temp memory in the backward + phase for the transpose of the matrix. The forward memory cost is the size of output tensor, backward memory cost is the size of the two inputs; if + the first input is vector, the forward memory cost is the size of the output tensor, and during the backward phase, it will allocate a temp memory + the same size as the input matrix, and allocate memory for the gradient of two inputs. + 3. Batched Matrix-vector multiplication => if the first input is the batched matrix, no temp memory, the forward memory cost is the size of + output tensor, backward memory cost is the size of the two inputs; if the second input is the batched matrix, the matmul will allocate memory for + the gradient of the batched matrix in the forward phase (as they create a new tensor without the former batches), so the forward memory cost is + the output tensor and the newly created matrix (take the same amount of memory of the input batched matrix). During the backward phase, it will + allocate a temp memory the same size as input batched matrix, and allocate a tensor for the gradient of the input vector. The gradient of the batched + matrix will be stored in the memory allocated during the forward phase. + 3. Matrix-matrix multiplication => no temp memory, forward memory is the size of output tensor, backward memory is the size of the two inputs + 4. Batched matrix-matrix multiplication => if the first input is the batched matrix, no temp memory, the forward memory cost is the size of two + inputs and backward memory cost is the size of the output tensor; if the second input is the batched matrix, during the forward phase it will allocate + memory for the output and gradient of the second input, and has a temp memory the same size as the output, during the backward phase, it + will allocate memory for the gradient of the first input and has a temp memory which is as big as output and the second input. + 5. Batched matrix-batched matrix multiplication => if the two inputs have the same batch dimensions, no temp memory, the forward memory cost is the size + of output, backward memory cost is the size of the two inputs; it the two inputs have different batch dimensions, during the forward phase it will allocate + memory of the expanded inputs (so that the batch dimensions could match) and the output, and during the backward phase, it has a temp memory of the size of + two expanded inputs, and it will allocate memory for the gradient of the two inputs and discard the expanded inputs allocated during the forward phase. + + Returns: + Tuple[TrainCycleItem, TrainCycleItem, bool]: compute cost, memory cost and forward inputs + + """ + # Get input and output tensors + input_tensors = [args[0].data, args[1].data] + output_tensors = [args[-1].data] + + # Check dimension + if all(len(tensor.shape) == 1 for tensor in input_tensors): + # Dot + fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](input_tensors, output_tensors) + bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](input_tensors[0], output_tensors) * 2 + + fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0) + bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0) + + elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 1: + # gemv case 1: matrix-vector multiplication + # & + # batched gemv case 1: batched matrix-vector multiplication + + fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default]( + [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], output_tensors) + + # combine the dimensions of output + bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]( + [output_tensors[0].reshape(-1), input_tensors[1]], + output_tensors) + \ + flop_mapping[torch.ops.aten.matmul.default]( + [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1)], + output_tensors) + + fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0) + bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0) + + elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) == 2: + # gemv case 2: vector-matrix multiplication + fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](input_tensors, output_tensors) + + bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]([output_tensors[0], input_tensors[0]], output_tensors) + \ + flop_mapping[torch.ops.aten.matmul.default]([input_tensors[1], output_tensors[0]], output_tensors) + + fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0) + bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), + parameter=0, + temp=compute_size_in_bytes(input_tensors[1]), + buffer=0) + + elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) >= 3: + # batched gemv case 2: vector-batched matrix multiplication + + fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default]( + [input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0]], + [output_tensors[0].reshape(-1)]) + + # combine the dimensions of output + bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]( + [output_tensors[0].reshape(-1), input_tensors[0]], + output_tensors + ) + \ + flop_mapping[torch.ops.aten.matmul.default]( + [input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]).transpose(0, 1), output_tensors[0].reshape(-1)], + output_tensors + ) + + fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors + [input_tensors[1]])) + bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors[0]), + parameter=0, + temp=compute_size_in_bytes(input_tensors[1]), + buffer=0) + + elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 2: + # gemm & batched gemm case 1: batched matrix-matrix multiplication + + fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]( + [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], + [output_tensors[0].reshape(-1, output_tensors[0].shape[-1])]) + + bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]( + [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1, output_tensors[0].shape[-1])], + [input_tensors[1]] + ) + \ + flop_mapping[torch.ops.aten.mm.default]( + [output_tensors[0].reshape(-1, output_tensors[0].shape[-1]), input_tensors[1].transpose(0, 1)], + [input_tensors[0].reshape(-1, input_tensors[0].shape[-1])] + ) + + fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0) + bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0) + + elif len(input_tensors[0].shape) == 2 and len(input_tensors[1].shape) >= 3: + # batched gemm case 2: matrix-batched matrix multiplication + fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([ + input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0].transpose( + 0, 1) + ], [output_tensors[0].transpose(-2, -1)]) + + bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]( + [output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]).transpose(0, 1), input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])], + [input_tensors[0]] + ) + \ + flop_mapping[torch.ops.aten.mm.default]( + [output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]), input_tensors[0]], + [input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])] + ) + + fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors) + + compute_size_in_bytes(input_tensors[1]), + temp=compute_size_in_bytes(output_tensors)) + bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors[0]), + parameter=0, + temp=compute_size_in_bytes(input_tensors[1]) + compute_size_in_bytes(output_tensors)) + + elif all(len(tensor.shape) >= 3 for tensor in input_tensors): + # Batched matrix-batched matrix multiplication + # Fetch shape of the two inputs and see if the batch dimensions are the same + _is_batch_dims_same = True + if len(input_tensors[0].shape) == len(input_tensors[1].shape): + for (shape_0, shape_1) in zip(input_tensors[0].shape[:-2], input_tensors[1].shape[:-2]): + if shape_0 != shape_1: + _is_batch_dims_same = False + break + else: + _is_batch_dims_same = False + + # retireve dimensions + input_dim_00 = input_tensors[0].shape[-2] + input_dim_01 = input_tensors[0].shape[-1] + input_dim_10 = input_tensors[1].shape[-2] + input_dim_11 = input_tensors[1].shape[-1] + output_dim_0 = output_tensors[0].shape[-2] + output_dim_1 = output_tensors[0].shape[-1] + + if _is_batch_dims_same: + # Case 1: batch dimensions are the same + + # Forward compute cost: C = A * B + fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default]([ + input_tensors[0].reshape(-1, input_dim_00, input_dim_01), input_tensors[1].reshape( + -1, input_dim_10, input_dim_11) + ], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)]) + + # Backward compute cost: dB = A^T * dC, dA = dC * B^T + bwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default]( + [input_tensors[0].transpose(-2, -1).reshape(-1, input_dim_01, input_dim_00), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)], + [input_tensors[1].reshape(-1, input_dim_11, input_dim_10)] + ) + \ + flop_mapping[torch.ops.aten.bmm.default]( + [output_tensors[0].reshape(-1, output_dim_0, output_dim_1), input_tensors[1].transpose(-2, -1).reshape(-1, input_dim_11, input_dim_10)], + [input_tensors[0].reshape(-1, input_dim_00, input_dim_01)] + ) + + fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors)) + bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors)) + + else: + # Case 2: batch dimensions are different + batch_dims = output_tensors[0].shape[:-2] + extended_input_0 = torch.rand(reduce(lambda x, y: x * y, batch_dims), + input_dim_00, + input_dim_01, + device="meta") + extended_input_1 = torch.rand(reduce(lambda x, y: x * y, batch_dims), + input_dim_10, + input_dim_11, + device="meta") + + # Forward compute cost: C = A * B + fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default]( + [extended_input_0, extended_input_1], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)]) + + # Backward compute cost: dB = A^T * dC, dA = dC * B^T + bwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default]( + [extended_input_0.transpose(-2, -1), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)], + [extended_input_1] + ) + \ + flop_mapping[torch.ops.aten.bmm.default]( + [output_tensors[0].reshape(-1, output_dim_0, output_dim_1), extended_input_1.transpose(-2, -1)], + [extended_input_0] + ) + + fwd_mem_cost = MemoryCost( + activation=compute_size_in_bytes([output_tensors[0], extended_input_0, extended_input_1])) + bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors) - + compute_size_in_bytes([extended_input_0, extended_input_1]), + temp=compute_size_in_bytes([extended_input_0, extended_input_1])) + + # compute cost + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) + + # memory cost + total_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, + parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter, + temp=fwd_mem_cost.temp + bwd_mem_cost.temp, + buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer) + + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_cost) + + # store fwd_in, fwd_buffer, fwd_out + fwd_in = input_tensors + fwd_buffer = [] + fwd_out = output_tensors + + return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py b/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py new file mode 100644 index 0000000000000000000000000000000000000000..12874810b13e252c0597e2adf124ab7875e992a3 --- /dev/null +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py @@ -0,0 +1,27 @@ +import operator +from typing import List, Tuple + +import torch + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem + +from ..registry import meta_register + +__all__ = ["non_spmd_meta_info"] + + +@meta_register.register(torch.Size) +@meta_register.register(torch.Tensor.size) +@meta_register.register(torch.finfo) +@meta_register.register(operator.le) +def non_spmd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: + """Non-SPMD node meta information generator + Those nodes will not be handled by SPMD solver, so we just return all zero meta information for it + + Returns: + Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs + """ + compute_cost = TrainCycleItem(fwd=0, bwd=0, total=0) + memory_cost = TrainCycleItem(fwd=MemoryCost(), bwd=MemoryCost(), total=MemoryCost()) + fwd_in, fwd_buffer, fwd_out = [], [], [] + return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py b/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py new file mode 100644 index 0000000000000000000000000000000000000000..b872fdc8bdcd19717e7b81d436fffd860ec88519 --- /dev/null +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py @@ -0,0 +1,158 @@ +from typing import Callable, Dict, List, Tuple, Union + +import torch + +from colossalai._analyzer._subclasses.flop_tensor import flop_mapping +from colossalai._analyzer.fx.node_util import compute_size_in_bytes +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + MemoryCost, + OperationData, + OperationDataType, + ShardingStrategy, + StrategiesVector, + TrainCycleItem, +) +from colossalai.tensor.sharding_spec import ShardingSpec + +from ..registry import meta_register + +__all__ = ['batchnormnd_meta_info', 'layernorm_meta_info'] + + +@meta_register.register(torch.nn.BatchNorm1d) +@meta_register.register(torch.nn.BatchNorm2d) +@meta_register.register(torch.nn.BatchNorm3d) +def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: + """BatchNorm1d, BatchNorm2d, BatchNorm3d, meta info generator + The aten graph of BatchNorm2d is like + + graph(): + %input_2 : [#users=2] = placeholder[target=placeholder](default=) + %cudnn_batch_norm_default : [#users=4] = call_function[target=torch.ops.aten.cudnn_batch_norm.default](args = (%input_2, None, None, None, None, None, None, None), kwargs = {}) + %zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%cudnn_batch_norm_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None}) + %detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {}) + %detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_default,), kwargs = {}) + %detach_default_2 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_default,), kwargs = {}) + %detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_default,), kwargs = {}) + %cudnn_batch_norm_backward_default : [#users=3] = call_function[target=torch.ops.aten.cudnn_batch_norm_backward.default](args = (%detach_default, %zeros_like_default, None, None, None, %detach_default_1, %detach_default_2, None, %detach_default_3), kwargs = {}) + %detach_default_4 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_backward_default,), kwargs = {}) + %detach_default_5 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_4,), kwargs = {}) + %detach_default_6 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_backward_default,), kwargs = {}) + %detach_default_7 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_6,), kwargs = {}) + %detach_default_8 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_backward_default,), kwargs = {}) + %detach_default_9 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_8,), kwargs = {}) + Returns: + Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs + """ + + input_tensor = args[0].data + output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data + weight_tensor = next(filter(lambda x: x.name == "weight", args)).data + bias_tensor = next(filter(lambda x: x.name == "bias", args)).data + mean_tensor = next(filter(lambda x: x.name == "running_mean", args)).data + var_tensor = next(filter(lambda x: x.name == "running_var", args)).data + num_batch = next(filter(lambda x: x.name == "num_batches_tracked", args)).data + + # construct fwd args + # the fwd inputs are input, weight, bias, running_mean, running_var and some other args + # indicating the status of the module + # the fwd outputs are output, saved mean, saved inv std and num batches tracked + fwd_in_args = [input_tensor, weight_tensor, bias_tensor, mean_tensor, var_tensor, True, 0.1, 1e-5] + fwd_out_args = [output_tensor, mean_tensor, var_tensor, num_batch] + + # construct bwd args + # the bwd inputs are upstream grad, input, weight, running_mean, running_var, saved mean, + # saved inv std and some other args indicating the status of the module + # the bwd outputs are input grad, weight grad and bias grad + bwd_in_args = [ + output_tensor, output_tensor, weight_tensor, mean_tensor, var_tensor, mean_tensor, var_tensor, 1e-5, num_batch + ] + bwd_out_args = [input_tensor, weight_tensor, bias_tensor] + + # calculate cost + fwd_compute_cost = flop_mapping[torch.ops.aten.cudnn_batch_norm.default](fwd_in_args, fwd_out_args) + bwd_compute_cost = flop_mapping[torch.ops.aten.cudnn_batch_norm_backward.default](bwd_in_args, bwd_out_args) + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) + + # calculate memory cost + # the fwd activation cost is output plus saved mean and saved inv std + # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward + fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes( + [input_tensor, output_tensor, mean_tensor, var_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), + temp=0, + buffer=compute_size_in_bytes([mean_tensor, var_tensor])) + + # the bwd memory cost is quite tricky here, BatchNorm will remove saved mean + # and saved inv std during backward phase + bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), + temp=compute_size_in_bytes([mean_tensor, var_tensor]), + buffer=compute_size_in_bytes([mean_tensor, var_tensor])) + + # total cost is the sum of forward and backward cost + total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, + parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter) + + memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) + + # store fwd_in, fwd_buffer, fwd_out + fwd_in = [torch.zeros_like(input_tensor, device='meta')] + fwd_buffer = [torch.zeros_like(mean_tensor, device='meta'), torch.zeros_like(var_tensor, device='meta')] + fwd_out = [torch.zeros_like(output_tensor, device='meta')] + + return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out + + +@meta_register.register(torch.nn.LayerNorm) +def layernorm_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: + """LayerNorm meta information + + Returns: + Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs + """ + # construct needed tensors + input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data + output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data + weight_tensor = next(filter(lambda x: x.name == "weight", args)).data + bias_tensor = next(filter(lambda x: x.name == "bias", args)).data + running_mean = torch.rand(input_tensor.shape[0], 1, device='meta') + running_var = torch.rand(input_tensor.shape[0], 1, device='meta') + + # construct args + fwd_in_args = [input_tensor, [input_tensor.shape[0]], weight_tensor] + fwd_out_args = [output_tensor] + bwd_in_args = [input_tensor, output_tensor, [input_tensor.shape[0]]] + bwd_out_args = [weight_tensor, bias_tensor] + + # compute cost + fwd_compute_cost = flop_mapping[torch.ops.aten.native_layer_norm.default](fwd_in_args, fwd_out_args) + bwd_compute_cost = flop_mapping[torch.ops.aten.native_layer_norm_backward.default](bwd_in_args, bwd_out_args) + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) + + # memory cost + # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward + fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes( + [input_tensor, output_tensor, weight_tensor, bias_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), + temp=0, + buffer=compute_size_in_bytes([running_mean, running_var])) + + bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), + temp=compute_size_in_bytes([running_mean, running_var]), + buffer=compute_size_in_bytes([running_mean, running_var])) + + total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, + parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter, + temp=fwd_memory_cost.temp + bwd_memory_cost.temp, + buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer) + + memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) + + # store fwd_in, fwd_buffer, fwd_out + fwd_in = [torch.zeros_like(input_tensor, device='meta')] + fwd_buffer = [torch.zeros_like(running_mean, device='meta'), torch.zeros_like(running_var, device='meta')] + fwd_out = [torch.zeros_like(output_tensor, device='meta')] + + return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py b/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py new file mode 100644 index 0000000000000000000000000000000000000000..d785dfcca9bacb46e129adda5f83486090975859 --- /dev/null +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py @@ -0,0 +1,133 @@ +from typing import List, Tuple + +import torch + +from colossalai._analyzer._subclasses.flop_tensor import flop_mapping +from colossalai._analyzer.fx.node_util import compute_size_in_bytes +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem + +from ..registry import meta_register + +__all__ = ["avgpool_meta_info", "maxpool_meta_info"] + + +@meta_register.register(torch.nn.AdaptiveAvgPool1d) +@meta_register.register(torch.nn.AdaptiveAvgPool2d) +@meta_register.register(torch.nn.AdaptiveAvgPool3d) +def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: + """Meta info for AdaptiveAvgPool + The aten graph of AdaptiveAvgPool is + graph(): + %input_2 : [#users=2] = placeholder[target=placeholder](default=) + %_adaptive_avg_pool2d_default : [#users=1] = call_function[target=torch.ops.aten._adaptive_avg_pool2d.default](args = (%input_2, [None, None]), kwargs = {}) + %zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%_adaptive_avg_pool2d_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None}) + %detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {}) + %_adaptive_avg_pool2d_backward_default : [#users=1] = call_function[target=torch.ops.aten._adaptive_avg_pool2d_backward.default](args = (%zeros_like_default, %detach_default), kwargs = {}) + %detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%_adaptive_avg_pool2d_backward_default,), kwargs = {}) + %detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {}) + + Returns: + Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs + """ + + input_tensor = args[0].data + output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data + is_inplace = kwargs.get("inplace", False) + + # construct forward args for flop mapping + fwd_in_args = [input_tensor] + fwd_out_args = [output_tensor] + + # construct backward args for flop mapping + bwd_in_args = [output_tensor] + bwd_out_args = [input_tensor] + + # calculate cost + # the fwd op with compute cost is _adaptive_avg_pool2d.default + # the bwd op with compute cost is _adaptive_avg_pool2d_backward.default + + # calculate compute cost + fwd_compute_cost = flop_mapping[torch.ops.aten._adaptive_avg_pool2d.default](fwd_in_args, fwd_out_args) + bwd_compute_cost = flop_mapping[torch.ops.aten._adaptive_avg_pool2d_backward.default](bwd_in_args, bwd_out_args) + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) + + # calculate memory cost + fwd_mem_cost = MemoryCost() if is_inplace else MemoryCost(activation=compute_size_in_bytes(output_tensor)) + bwd_mem_cost = MemoryCost() if is_inplace else MemoryCost(activation=compute_size_in_bytes(input_tensor)) + + # total cost + total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation) + + mem_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + + # store fwd_in, fwd_buffer, fwd_out + fwd_in = [] + fwd_buffer = [] + fwd_out = [torch.zeros_like(output_tensor, device='meta')] + + return compute_cost, mem_cost, fwd_in, fwd_buffer, fwd_out + + +@meta_register.register(torch.nn.MaxPool1d) +@meta_register.register(torch.nn.MaxPool2d) +@meta_register.register(torch.nn.MaxPool3d) +def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: + """Meta info for MaxPool + The aten graph of MaxPool is + graph(): + %input_2 : [#users=2] = placeholder[target=placeholder](default=) + %max_pool2d_with_indices_default : [#users=2] = call_function[target=torch.ops.aten.max_pool2d_with_indices.default](args = (%input_2, [None, None], [None, None]), kwargs = {}) + %zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%max_pool2d_with_indices_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None}) + %detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {}) + %detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%max_pool2d_with_indices_default,), kwargs = {}) + %max_pool2d_with_indices_backward_default : [#users=1] = call_function[target=torch.ops.aten.max_pool2d_with_indices_backward.default](args = (%zeros_like_default, %detach_default, [None, None], [None, None], [None, None], [None, None], None, %detach_default_1), kwargs = {}) + %detach_default_2 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%max_pool2d_with_indices_backward_default,), kwargs = {}) + %detach_default_3 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_2,), kwargs = {}) + + Returns: + Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs + """ + + input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data + output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data + + # construct forward args for flop mapping + fwd_in_args = [input_tensor] + fwd_out_args = [output_tensor] + + # construct backward args for flop mapping + bwd_in_args = [output_tensor] + bwd_out_args = [input_tensor] + + # construct index matrix + index_matrix = torch.zeros_like(output_tensor, device="meta", dtype=torch.int64) + + # calculate cost + # the fwd op with compute cost is max_pool2d_with_indices.default + # the bwd op with compute cost is max_pool2d_with_indices_backward.default + + # calculate compute cost + fwd_compute_cost = flop_mapping[torch.ops.aten.max_pool2d_with_indices.default](fwd_in_args, fwd_out_args) + bwd_compute_cost = flop_mapping[torch.ops.aten.max_pool2d_with_indices_backward.default](bwd_in_args, bwd_out_args) + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) + + # calculate memory cost + # NOTE: the index matrix will be discarded in backward phase + # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward + fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor, index_matrix])) + + # temp memory for backward is the index matrix to be discarded + bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensor) - compute_size_in_bytes(index_matrix), + temp=compute_size_in_bytes(index_matrix)) + + # total cost + total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, temp=bwd_mem_cost.temp) + + mem_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + + # store fwd_in, fwd_buffer, fwd_out + fwd_in = [torch.zeros_like(input_tensor, device='meta')] + fwd_buffer = [torch.zeros_like(index_matrix, device='meta')] + fwd_out = [torch.zeros_like(output_tensor, device='meta')] + + return compute_cost, mem_cost, fwd_in, fwd_buffer, fwd_out diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/tensor.py b/colossalai/auto_parallel/meta_profiler/meta_registry/tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..97fe3c6196f591af7bbfcbdcf59ff3afd114175f --- /dev/null +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/tensor.py @@ -0,0 +1,79 @@ +from typing import Callable, List, Tuple + +import torch + +from colossalai._analyzer._subclasses.flop_tensor import flop_mapping +from colossalai._analyzer.fx.node_util import compute_size_in_bytes +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem + +from ..registry import meta_register + +__all__ = ["tensor_related_metainfo"] + + +def tensor_related_metainfo(bwd_mem_out_factor: float = 1, bwd_mem_tmp_factor: float = 0) -> Callable: + """torch.Tensor related metainfo generator template + + Args: + bwd_mem_out_factor (float, optional): backward activation memory cost factor. Defaults to 1. + bwd_mem_tmp_factor (float, optional): backward temp memory cost factor. Defaults to 0. + + Returns: + Callable: torch.Tensor related metainfo generator + """ + + def meta_func(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: + """torch.Tensor related metainfo generator + + Returns: + Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs + """ + outputs = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data + + # compute costs are all zero + compute_cost = TrainCycleItem(fwd=0, bwd=0, total=0) + + # memory costs + # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward + fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(outputs) * 2, parameter=0, temp=0, buffer=0) + + bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(outputs) * bwd_mem_out_factor, + parameter=0, + temp=compute_size_in_bytes(outputs) * bwd_mem_tmp_factor, + buffer=0) + + total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, + parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter, + temp=fwd_mem_cost.temp + bwd_mem_cost.temp, + buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer) + + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + + # store fwd_in, fwd_buffer, fwd_out + fwd_in = [] + fwd_buffer = [] + if isinstance(outputs, tuple) or isinstance(outputs, list) or isinstance(outputs, dict): + # tuple of tensors + fwd_out = [torch.zeros_like(tensor) for tensor in outputs] + else: + # enaged_tensors is a single tensor + fwd_out = [torch.zeros_like(outputs)] + + return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out + + return meta_func + + +# register torch.Tensor related metainfo +# (0, 0) +meta_register.register([torch.tensor, torch.Tensor.to, torch.Tensor.unsqueeze, torch.unsqueeze, + torch.arange])(tensor_related_metainfo(0, 0)) + +# (1, 0) +meta_register.register([ + torch.Tensor.flatten, torch.flatten, torch.Tensor.transpose, torch.transpose, torch.Tensor.permute, torch.permute, + torch.Tensor.split, torch.split, torch.Tensor.view +])(tensor_related_metainfo(1, 0)) + +# (1, 1) +meta_register.register([torch.Tensor.type, torch.Tensor.contiguous])(tensor_related_metainfo(1, 1)) diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/where.py b/colossalai/auto_parallel/meta_profiler/meta_registry/where.py new file mode 100644 index 0000000000000000000000000000000000000000..5cba1b5b6e2b16521ed2a0df2fbab98b19492c53 --- /dev/null +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/where.py @@ -0,0 +1,60 @@ +from typing import List, Tuple + +import torch + +from colossalai._analyzer._subclasses.flop_tensor import flop_mapping +from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem + +from ..registry import meta_register + +__all__ = ["where_meta_info"] + + +@meta_register.register(torch.where) +def where_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: + """torch.where meta information generator + + Returns: + Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs + """ + + condition_tensor, x_tensor, y_tensor, output_tensor = [arg.data for arg in args] + + # compute cost + fwd_compute_cost = 0 + + # if we need to broadcast the condition tensor, during backward we need to do a reduce_sum + bwd_compute_cost = 0 + if x_tensor.shape != output_tensor.shape: + bwd_compute_cost += flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], [x_tensor]) + if y_tensor.shape != output_tensor.shape: + bwd_compute_cost += flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], [y_tensor]) + + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) + + # memory cost + # during the forward phase, torch.where will allocate memory for output tensor and condition tensor + # during the backward phase, torch.where will allocate temp memory which is 3 times as output tensor, then generate + # gradient matrix for input x and input y, remove the temp memory and condition tensor generated in forward phase + # NOTE: currently in SPMD solver we always believe that there will be a new input tensor created in forward + fwd_mem_cost = MemoryCost(activation=activation_size([condition_tensor, x_tensor, y_tensor, output_tensor])) + bwd_mem_cost = MemoryCost(activation=activation_size([x_tensor, y_tensor]) - activation_size([condition_tensor]), + parameter=0, + temp=activation_size([output_tensor]) * 3 + activation_size([condition_tensor]) - + activation_size([x_tensor, y_tensor]), + buffer=0) + + total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, + parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter, + temp=fwd_mem_cost.temp + bwd_mem_cost.temp, + buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer) + + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + + # store fwd_in, fwd_buffer, fwd_out + fwd_in = [condition_tensor] + fwd_buffer = [] + fwd_out = [output_tensor] + + return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out diff --git a/colossalai/auto_parallel/meta_profiler/registry.py b/colossalai/auto_parallel/meta_profiler/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..46350c4dd406691c344eb92a933636d6b029b8bd --- /dev/null +++ b/colossalai/auto_parallel/meta_profiler/registry.py @@ -0,0 +1,32 @@ +__all__ = ['Registry'] + + +class Registry: + + def __init__(self, name): + self.name = name + self.store = {} + + def register(self, source): + + def wrapper(func): + if isinstance(source, (list, tuple)): + # support register a list of items for this func + for element in source: + self.store[element] = func + else: + self.store[source] = func + return func + + return wrapper + + def get(self, source): + assert source in self.store, f'{source} not found in the {self.name} registry' + target = self.store[source] + return target + + def has(self, source): + return source in self.store + + +meta_register = Registry('meta') diff --git a/colossalai/auto_parallel/meta_profiler/shard_metainfo.py b/colossalai/auto_parallel/meta_profiler/shard_metainfo.py new file mode 100644 index 0000000000000000000000000000000000000000..0eee908b48b73d9d1cd5e0e35fbf50b8d844e3a6 --- /dev/null +++ b/colossalai/auto_parallel/meta_profiler/shard_metainfo.py @@ -0,0 +1,131 @@ +from typing import Callable, List + +import torch + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + MemoryCost, + OperationData, + OperationDataType, + ShardingStrategy, + StrategiesVector, + TrainCycleItem, +) +from colossalai.tensor.sharding_spec import ShardingSpec + +from .constants import INPLACE_MODULE, INPLACE_OPS, NO_SAVE_ACTIVATION +from .registry import meta_register + +__all__ = ['ShardMetaInfo'] + + +class ShardMetaInfo: + """ShardMetaInfo class + This class is used to store meta info based on sharding strategy and the given + target function. + """ + + def __init__(self, strategy: ShardingStrategy = None, target: Callable = None) -> None: + # compute cost of forward and backward computation + self.compute_cost: TrainCycleItem + + # compute memory cost of forward and backward phase + self.memory_cost: TrainCycleItem + + # list of input tensors + self.fwd_in: List[torch.Tensor] + + # list of buffer tensors + self.fwd_buffer: List[torch.Tensor] + + # list of output tensors + self.fwd_out: List[torch.Tensor] + + # sharding strategy + self._strategy = strategy + + # target function + self._target = target + + # compute shard_metainfo if possible + if self._strategy is not None and self._target is not None: + self.compute_shard_metainfo() + + @property + def strategy(self) -> ShardingStrategy: + return self._strategy + + @property + def target(self) -> Callable: + return self._target + + @strategy.setter + def strategy(self, strategy: ShardingStrategy) -> None: + self._strategy = strategy + if self._strategy is not None and self._target is not None: + self.compute_shard_metainfo() + + @target.setter + def target(self, target: Callable) -> None: + self._target = target + if self._strategy is not None and self._target is not None: + self.compute_shard_metainfo() + + def compute_sharded_opdata(self, operation_data: OperationData, sharding_spec: ShardingSpec): + """ + Compute sharded opdata based on the given data and sharding spec. + """ + + if isinstance(sharding_spec, ShardingSpec): + op_data = OperationData(name=operation_data.name, + data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"), + type=operation_data.type, + logical_shape=operation_data.logical_shape) + elif isinstance(sharding_spec, (list, tuple)): + data = operation_data.data + assert isinstance(data, (list, tuple)), f"Data Should be list or tuple, but got {type(data)}." + assert len(data) == len(sharding_spec), f"Length of data and sharding spec should be the same." + sharded_data = [] + for d, s in zip(data, sharding_spec): + sharded_data.append(torch.zeros(s.get_sharded_shape_per_device(), device="meta")) + op_data = OperationData(name=operation_data.name, data=sharded_data, type=operation_data.type) + else: + raise ValueError(f"Sharding spec should be ShardingSpec or list, but got {type(sharding_spec)}.") + + return op_data + + def compute_shard_metainfo(self): + """ + Compute meta info based on sharding strategy and the given target function. + """ + assert meta_register.has(self._target.__class__) or meta_register.has(self._target), \ + f"Meta info for {self._target} is not registered." + if meta_register.has(self._target.__class__): + # module + meta_func = meta_register.get(self._target.__class__) + + # check whether the target in the list that we don't need to save activation + save_fwd_in = self._target.__class__ not in NO_SAVE_ACTIVATION + else: + # function + meta_func = meta_register.get(self._target) + + # check whether the target in the list that we don't need to save activation + save_fwd_in = self._target.__class__ not in NO_SAVE_ACTIVATION + + # construct args for meta_func + args = [self.compute_sharded_opdata(k, v) for k, v in self._strategy.sharding_specs.items()] + + # construct kwargs + if self.target in INPLACE_MODULE: + kwargs = {'inplace': self.target.inplace} + elif self.target in INPLACE_OPS: + kwargs = {'inplace': True} + else: + kwargs = {'inplace': False} + + # compute metainfo with meta_func + self.compute_cost, self.memory_cost, self.fwd_in, self.fwd_buffer, self.fwd_out = meta_func(*args, **kwargs) + + # process corner case for NO_SAVE_ACTIVATION + if not save_fwd_in: + self.fwd_in = [] diff --git a/colossalai/auto_parallel/offload/__init__.py b/colossalai/auto_parallel/offload/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/colossalai/auto_parallel/offload/amp_optimizer.py b/colossalai/auto_parallel/offload/amp_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..a79e5006e7d2ac264f00b5a597e7b869f6f580eb --- /dev/null +++ b/colossalai/auto_parallel/offload/amp_optimizer.py @@ -0,0 +1,177 @@ +from typing import Dict, Tuple +from enum import Enum +import torch +from torch.optim import Optimizer + +from colossalai.logging import get_dist_logger +from colossalai.nn.optimizer import ColossalaiOptimizer +from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler +from colossalai.utils import get_current_device + +from .base_offload_module import BaseOffloadModule +from .region_manager import RegionManager +from .region import Region + + +class OptimState(Enum): + SCALED = 0 + UNSCALED = 1 + +class AMPOptimizer(ColossalaiOptimizer): + + """ + A wrapper for Optimizer. + Code reference: https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/optimizer/zero_optimizer.py + + Args: + optimizer (Optimizer): An Optimizer instance. + module (BaseOffloadModule): A ``BaseOffloadModule`` instance. + initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**16. + growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2. + backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5. + growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000. + hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2. + min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1. + max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32. + norm_type (float, optional): norm_type used for `clip_grad_norm`. + """ + + def __init__(self, + optimizer: Optimizer, + module: BaseOffloadModule, + initial_scale: float = 2**16, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + min_scale: float = 1, + max_scale: float = 2**32, + clipping_norm: float = 0.0, + norm_type: float = 2.0): + + super().__init__(optimizer) + + self.module = module + self.optim_state = OptimState.UNSCALED + self.clipping_flag = clipping_norm > 0.0 + self.max_norm = clipping_norm + + self.region_manager: RegionManager = self.module.region_manager + self.param_to_range: Dict[torch.nn.Parameter, Tuple[int, int]] = dict() + self.param_to_region: Dict[torch.nn.Parameter, Region] = dict() + + self.fp32_to_fp16_params: Dict[torch.Tensor, torch.nn.Parameter] = dict() + + if self.clipping_flag: + assert norm_type == 2.0, "AMPOptimizer only supports L2 norm now" + + self.__init__optimizer() + + # Grad scaler + self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale) + self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device()) + self._logger = get_dist_logger() + + def _set_grad_ptr(self): + for group in self.param_groups: + for fake_param in group['params']: + region = self.param_to_region[fake_param] + begin, end = self.param_to_range[fake_param] + + fake_param.data = region.cpu_grad[begin:end] + fake_param.grad = fake_param.data + fake_param.data = region.fp32_data[begin:end] + + def _update_fp16_params(self): + none_tensor = torch.empty([0]) + for group in self.param_groups: + for fake_param in group['params']: + assert fake_param.grad is None + fake_param.data = none_tensor + self.param_to_region[fake_param].cpu_grad = None + + def _check_overflow(self): + # clear previous overflow record + self._found_overflow.fill_(self.module.overflow_counter.item()) + return self._found_overflow.item() > 0 + + def _get_combined_scale(self): + loss_scale = 1 + + if self.optim_state == OptimState.SCALED: + loss_scale = self.loss_scale + self.optim_state = OptimState.UNSCALED + + combined_scale = loss_scale + + if combined_scale == 1: + return -1 + else: + return combined_scale + + @property + def loss_scale(self): + return self.grad_scaler.scale.item() + + def zero_grad(self, *args, **kwargs): + self.module.overflow_counter = torch.cuda.IntTensor([0]) + return self.optim.zero_grad(set_to_none=True) + + def step(self, *args, **kwargs): + # Copy gradients from model params to main params. + self._set_grad_ptr() + + found_inf = self._check_overflow() + if found_inf: + self.optim_state = OptimState.UNSCALED # no need to unscale grad + self.grad_scaler.update(found_inf) # update gradient scaler + self._logger.info(f'Found overflow. Skip step') + self.zero_grad() # reset all gradients + self._update_fp16_params() + return + + # get combined scale. combined scale = loss scale * clipping norm + # so that gradient = gradient / combined scale + combined_scale = self._get_combined_scale() + self.grad_scaler.update(found_inf) + + ret = self.optim.step(div_scale=combined_scale, *args, **kwargs) + self.zero_grad() + self._update_fp16_params() + return ret + + def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0): + raise NotImplementedError + + def backward(self, loss: torch.Tensor): + loss = self.loss_scale * loss + self.optim_state = OptimState.SCALED + self.module.backward(loss) + + def __init__optimizer(self): + + for group in self.optim.param_groups: + fake_params_list = list() + + for param in group['params']: + region = self.region_manager.get_region(param) + fake_param = torch.nn.Parameter(torch.empty([0])) + self.param_to_range[fake_param] = region.param_to_range[param] + self.param_to_region[fake_param] = region + fake_params_list.append(fake_param) + + # Reset existing state dict key to the new main param. + if param in self.optim.state: + self.optim.state[fake_param] = self.optim.state.pop(param) + + group['params'] = fake_params_list + + # Leverage state_dict() and load_state_dict() to + # recast preexisting per-param state tensors + self.optim.load_state_dict(self.optim.state_dict()) \ No newline at end of file diff --git a/colossalai/auto_parallel/offload/base_offload_module.py b/colossalai/auto_parallel/offload/base_offload_module.py new file mode 100644 index 0000000000000000000000000000000000000000..d0c328e134ff5696ea2f8c17d5fc3468e4c891a2 --- /dev/null +++ b/colossalai/auto_parallel/offload/base_offload_module.py @@ -0,0 +1,107 @@ +from functools import partial +from typing import Optional, Set + +import torch +import torch.nn as nn + +from colossalai.nn.parallel.data_parallel import _cast_float +from colossalai.zero.legacy.gemini.tensor_utils import free_storage + +from .region_manager import RegionManager +from .util import GlobalRuntimeInfo + + +class BaseOffloadModule: + """ + BaseOffloadModule: A model wrapper for parameter offloading. + + Args: + model (nn.Module): model to apply offloading. + region_manager (RegionManager): a ``RegionManager`` instance. + is_sync (bool): synchronous mode or not. + """ + + def __init__(self, model: nn.Module, region_manager: RegionManager, is_sync=True): + + self.model = model + self.region_manager = region_manager + self.grad_hook_list = [] + self.overflow_counter = torch.cuda.IntTensor([0]) + + self.grad_offload_stream = torch.cuda.current_stream() if is_sync else GlobalRuntimeInfo.d2h_stream + + self._cast_buffers() + + def register_grad_hook(self): + for p in self.model.parameters(): + if p.requires_grad: + self.grad_hook_list.append(p.register_hook(partial(self.grad_handle, p))) + + def remove_grad_hook(self): + for hook in self.grad_hook_list: + hook.remove() + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def _pre_forward(self): + self.register_grad_hook() + for region in self.region_manager.region_list: + region.cpu_grad = None + + def forward(self, *args, **kwargs): + args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half) + self.model.zero_grad(set_to_none=True) + self._pre_forward() + outputs = self.model(*args, **kwargs) + return outputs + + def backward(self, loss): + loss.backward() + self._post_backward() + + def _post_backward(self): + torch.cuda.synchronize() + self.remove_grad_hook() + + for p in self.model.parameters(): + p.grad = None + + GlobalRuntimeInfo().fwd_prefetch_event_map.clear() + GlobalRuntimeInfo().bwd_prefetch_event_map.clear() + + def grad_handle(self, p, grad): + empty_grad = torch.empty_like(grad) + free_storage(empty_grad) + with torch._C.DisableTorchFunction(): + region = self.region_manager.get_region(p) + region.copy_grad_to_region_slice(p, grad) + if region.can_release: + self.overflow_counter += region.has_inf_or_nan + master_stream = torch.cuda.current_stream() + with torch.cuda.stream(self.grad_offload_stream): + GlobalRuntimeInfo().d2h_stream.wait_stream(master_stream) + region.move_grad_to_cpu() + return empty_grad + + def _cast_buffers(self): + for buffer in self.model.buffers(): + buffer.data = buffer.cuda() + + def parameters(self, recurse: bool = True): + return self.model.parameters(recurse) + + def named_parameters(self, prefix: str = '', recurse: bool = True): + return self.model.named_parameters(prefix, recurse) + + def named_buffers(self, prefix: str = '', recurse: bool = True): + return self.model.named_buffers(prefix, recurse) + + def named_children(self): + return self.model.named_children() + + def named_modules(self, + memo: Optional[Set[torch.nn.Module]] = None, + prefix: str = '', + remove_duplicate: bool = True): + return self.model.named_modules(memo, prefix, remove_duplicate) diff --git a/colossalai/auto_parallel/offload/mem_optimize.py b/colossalai/auto_parallel/offload/mem_optimize.py new file mode 100644 index 0000000000000000000000000000000000000000..d56166dea982288bdea160e1347c8ca3f67ed297 --- /dev/null +++ b/colossalai/auto_parallel/offload/mem_optimize.py @@ -0,0 +1,52 @@ +from typing import Dict + +import torch +import torch.fx +from torch.fx import GraphModule +from torch.utils._pytree import tree_map + +from colossalai.fx import ColoTracer, is_compatible_with_meta +from colossalai.fx.passes.meta_info_prop import MetaInfoProp + +from .base_offload_module import BaseOffloadModule +from .region_manager import RegionManager +from .runtime import runtime_asyn_offload_apply_pass, runtime_syn_offload_apply_pass +from .util import GlobalRuntimeInfo, compute_act_peak_mem, compute_max_param_mem, compute_total_param_mem + + +def memory_optimize(model: torch.nn.Module, + inps: Dict[str, torch.Tensor], + memory_budget: float = -1.0, + solver_name: str = 'asyn'): + + model = model.cpu().half() + tracer = ColoTracer() + assert is_compatible_with_meta() + wrap_fn = lambda x: x.to("meta") if isinstance(x, torch.Tensor) else x + meta_args = tree_map(wrap_fn, inps) + graph = tracer.trace(model, meta_args=meta_args) + gm = GraphModule(model, graph, model.__class__.__name__) + interp = MetaInfoProp(gm) + interp.propagate(*meta_args.values()) + + region_manager = RegionManager(graph, solver_name=solver_name, memory_budget=memory_budget) + region_manager._build_regions() + GlobalRuntimeInfo().region_list = region_manager.region_list + + act_peak_mem = compute_act_peak_mem(region_manager.region_list) / 1024**2 + max_param_mem = compute_max_param_mem(region_manager.region_list) / 1024**2 + total_param_mem = compute_total_param_mem(region_manager.region_list) / 1024**2 + print( + f"act_peak_mem={act_peak_mem:.3f} MB | max_param_mem={max_param_mem:.3f} MB | total_param_mem={total_param_mem:.3f}" + ) + + if solver_name == 'syn': + gm = runtime_syn_offload_apply_pass(gm, region_manager.region_list) + elif solver_name == 'asyn': + gm = runtime_asyn_offload_apply_pass(gm, region_manager.region_list) + else: + raise TypeError(f"Unknown solver name {solver_name}!") + + gm.recompile() + optimized_model = BaseOffloadModule(gm, region_manager, solver_name == 'syn') + return optimized_model diff --git a/colossalai/auto_parallel/offload/region.py b/colossalai/auto_parallel/offload/region.py new file mode 100644 index 0000000000000000000000000000000000000000..819ffbd96eb19098f168519ca6e3e0036fa3a638 --- /dev/null +++ b/colossalai/auto_parallel/offload/region.py @@ -0,0 +1,145 @@ +from typing import Dict, List, Tuple + +import torch +from torch.fx import Node + +from colossalai.zero.legacy.gemini.tensor_utils import alloc_storage, free_storage + + +class Region: + """ + Region: A container owning a piece of contiguous nodes in the DNN computing graph. + + Args: + r_id (int): the index of the region in the computing graph. + """ + + def __init__(self, r_id: int = 0) -> None: + self.r_id: int = r_id + self.fp16_params: List[torch.nn.Parameter] = [] + self.param_size: int = 0 + self.shared_rid: int = self.r_id + + self.param_num: int = 0 + self.grad_num: int = 0 + self.fp16_data = None + self.fp32_data = None + self.cpu_grad = None + self.temp_fp32_data = None + self.param_to_range: Dict[torch.nn.Parameter, Tuple[int, int]] = dict() + + self.need_offload: bool = False + self.is_syn: bool = False + self.nodes: List[Node] = [] + self.fwd_prefetch_region = None + self.bwd_prefetch_region = None + + self.in_mem_pool_flag: bool = False + + @property + def can_release(self) -> bool: + """ + Check if the region can be released. + """ + return self.grad_num == self.param_num + + @property + def has_inf_or_nan(self) -> bool: + """ + Check if the grad of the region has inf or nan values on CUDA. + """ + return torch.isinf(self.fp16_data).any() | torch.isnan(self.fp16_data).any() + + def init_param_data(self, pre_alloc_tensor: torch.Tensor = None): + """ + Map the parameters in the region to a contiguous memory space. + """ + + self.fp16_data = torch.zeros(self.param_num, dtype=torch.half, device='cuda') + offset = 0 + for param in self.fp16_params: + param.data = param.data.cuda() + p_num = param.data.numel() + self.fp16_data[offset:offset + p_num].copy_(param.data.flatten()) + param.data = self.fp16_data[offset:offset + p_num].view(param.data.shape) + self.param_to_range[param] = (offset, offset + p_num) + offset += p_num + + self.fp32_data = self.fp16_data.float().cpu().pin_memory() + free_storage(self.fp16_data) + if self.in_mem_pool_flag and pre_alloc_tensor is not None: + self.fp16_data = pre_alloc_tensor + + def move_param_to_cuda(self): + """ + Move parameters from CPU to GPU. + It first moves float32 parameters to GPU and + then transforms float32 parameters to half-precision on the GPU. + The reason is that the performance of precision conversion on the CPU + is much slower than the data transfer overhead. + """ + + self.temp_fp32_data.copy_(self.fp32_data, non_blocking=True) + self.temp_fp32_data.record_stream(torch.cuda.current_stream()) + if not self.in_mem_pool_flag: + alloc_storage(self.fp16_data) + self.fp16_data[:self.param_num].copy_(self.temp_fp32_data) + self.fp16_data.record_stream(torch.cuda.current_stream()) + + self.__update_params_ptr() + + def move_grad_to_cpu(self): + """ + Move gradients from GPU to CPU. + """ + + self.cpu_grad = torch.empty(self.param_num, dtype=torch.half, pin_memory=True) + self.cpu_grad.copy_(self.fp16_data[:self.param_num], non_blocking=True) + self.fp16_data.record_stream(torch.cuda.current_stream()) + if not self.in_mem_pool_flag: + self.free_cuda_data() + + self.grad_num = 0 + + def free_cuda_data(self): + free_storage(self.fp16_data) + + # torch.cuda.empty_cache() + + def copy_grad_to_region_slice(self, param: torch.nn.Parameter, data_slice: torch.Tensor) -> None: + """ + Copy data slice to the memory space indexed by the input tensor in the region. + + Args: + param (torch.nn.Parameter): the param used to retrieve meta information + data_slice (torch.Tensor): the tensor to be copied to the region + """ + + begin, end = self.param_to_range[param] + self.fp16_data[begin:end].copy_(data_slice.data.flatten()) + param.data = self.fp16_data[begin:end].view(param.data.shape) + + self.grad_num += data_slice.numel() + + def split(self, cut_node_idx: int, cut_param_idx: int): + """ + Split the region into two and return the latter. + """ + new_reg = Region(r_id=self.r_id + 1) + new_reg.nodes = self.nodes[cut_node_idx:] + new_reg.fp16_params = self.fp16_params[cut_param_idx:] + for p in new_reg.fp16_params: + new_reg.param_size += p.data.numel() * p.data.element_size() + new_reg.param_num += p.data.numel() + + self.nodes = self.nodes[:cut_node_idx] + self.fp16_params = self.fp16_params[:cut_param_idx] + self.param_size -= new_reg.param_size + self.param_num -= new_reg.param_num + + return new_reg + + def __update_params_ptr(self) -> None: + for param in self.fp16_params: + begin, end = self.param_to_range[param] + param.data = self.fp16_data[begin:end].view(param.data.shape) diff --git a/colossalai/auto_parallel/offload/region_manager.py b/colossalai/auto_parallel/offload/region_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..30bfaf00d4939afadc3c2aaaec3f27ce70db4a20 --- /dev/null +++ b/colossalai/auto_parallel/offload/region_manager.py @@ -0,0 +1,526 @@ +from typing import List, Any, Dict, Tuple +import torch +from torch.fx import Graph, Node + +from .solver import SolverFactory +from .training_simulator import TrainingSimulator +from .region import Region +from .util import NodeInfo + + +class RegionManager: + """ + RegionManager is used to construct and manage the offload plan for the model execution. + + Args: + graph (Graph): a Graph object used for analysis and strategy generation. + solver_name (str): a solver name which specifies the preferences for plan searching. + memory_budget (float): the given memory budget. + cnode (List[str], optional): Common node List, should be the subset of input. + """ + + def __init__(self, + graph: Graph, + solver_name: str = 'asyn', + memory_budget: float = -1.0, + cnode: List[str] = None): + + self.graph = graph + assert graph.owning_module is not None, 'The given graph is not associated with a owning_module' + self.root_module = self.graph.owning_module + self.nodes = list(graph.nodes) + self.cnode = cnode + self.only_param_ops = [] + self.param_region_map: Dict[torch.nn.Parameter, Region] = dict() + self.shared_region_pairs: List[Tuple[Region, Region]] = list() + self.region_list: List[Region] = list() + self.rid_in_pool: List[int] = list() + self.mem_block_size: int = 0 + self.memory_budget = memory_budget + + self.solver_name = solver_name + self.require_pool: bool = solver_name == 'asyn' + + self.reg_to_block: Dict[int, int] = dict() + + def _build_regions(self): + """ + 1. Pre-processing, mainly contains linearized computing graph and + merge smaller regions into larger ones. + 2. Construct a solver to search for an efficient offload strategy. + 3. Post-processing, mainly contains early region placement if using asynchronous mode, + and initialize region data. + """ + + self._pre_process() + + solver_cls = SolverFactory.create(self.solver_name) + solver = solver_cls(self.region_list, self.memory_budget) + solver._call_solver() + + self._post_process(solver.best_ts) + + def _pre_process(self): + + init_region_list = self._linearize_graph() + + if len(self.shared_region_pairs) > 1: + raise NotImplementedError( + 'The current version only considers at most one pair of parameter sharing.') + + elif len(self.shared_region_pairs) == 1: + shared_regs = self.shared_region_pairs[0] + assert shared_regs[0].shared_rid == shared_regs[1].r_id \ + and shared_regs[1].shared_rid == shared_regs[0].r_id + fst_id = shared_regs[0].r_id + lst_id = shared_regs[1].r_id + regs_left_out = init_region_list[:fst_id + 1] + regs_right_out = init_region_list[lst_id:] + hold_regs = init_region_list[fst_id + 1:lst_id] + else: + regs_left_out = [] + regs_right_out = [] + hold_regs = init_region_list + + self.mem_block_size = self._search_block_size(hold_regs) + hold_regs = self._merge_small_regions(hold_regs) + + if self.require_pool: + for reg in hold_regs: + reg.in_mem_pool_flag = True + self.rid_in_pool.append(reg.r_id) + + self.region_list.extend(regs_left_out) + self.region_list.extend(hold_regs) + + for reg in regs_right_out: + reg.r_id = self.region_list[-1].r_id + 1 + self.region_list[reg.shared_rid].shared_rid = reg.r_id + self.region_list.append(reg) + + self._process_shared_region() + + self.max_param_num = max([reg.param_num for reg in self.region_list]) + self.memory_budget -= self.max_param_num * torch.tensor([], dtype=torch.float32).element_size() + + def _post_process(self, ts: TrainingSimulator = None): + if self.require_pool: + self._early_region_placement(ts) + self._init_region_data() + + def _early_region_placement(self, ts: TrainingSimulator): + """ + Implemented the early region placement strategy to avoid GPU memory fragmentation. + It maps all region data into a contiguous memory space and + reuses the same memory space for regions that do not coexist. + + Args: + ts (TrainingSimulator): the best training simulator, which records region execution flow. + + Raises: + NotImplementedError: due to the naive implementation, + it may not find a suitable region placement strategy for the given execution flow. + """ + + reg_flow = torch.cat( + [ts.fwd_reg_flow, ts.bwd_reg_flow], dim=0) + mem_block_num = torch.max( + torch.sum(reg_flow[:, self.rid_in_pool], dim=1)) + coexist_matrix = torch.logical_or( + ts.fwd_reg_flow, ts.bwd_reg_flow) + + block_to_regs = {} + for block_idx in range(mem_block_num): + block_to_regs[block_idx] = [] + for reg in self.region_list: + if reg.r_id in self.rid_in_pool: + cur_reg_appears = coexist_matrix[:, reg.r_id] + cur_reg_coexists = torch.sum( + coexist_matrix[cur_reg_appears], dim=0).bool() + for block_idx in range(mem_block_num): + if not any(cur_reg_coexists[block_to_regs[block_idx]]): + block_to_regs[block_idx].append(reg.r_id) + self.reg_to_block[reg.r_id] = block_idx + break + + if reg.r_id not in self.reg_to_block: + raise NotImplementedError( + f'can not find a block from the memory pool to store parameters of the region') + self.memory_pool = torch.chunk(torch.zeros(int( + mem_block_num * self.mem_block_size / 2), dtype=torch.half, device='cuda'), chunks=int(mem_block_num)) + + def _merge_small_regions(self, orig_reg_list: List[Region]) -> List[Region]: + """ + Merge smaller regions into larger ones for better bandwidth utilization and easier management. + It is inspired by Gemini. + + Args: + orig_reg_list (List[Region]): original region list. + + Returns: + List[Region]: region list after merging. + """ + + r_id = orig_reg_list[0].r_id + region = Region(r_id=r_id) + region_list = [region] + + for orig_reg in orig_reg_list: + if region_list[-1].param_size + orig_reg.param_size > self.mem_block_size: + r_id += 1 + region = Region(r_id=r_id) + region_list.append(region) + region.param_size += orig_reg.param_size + region.param_num += orig_reg.param_num + region.nodes.extend(orig_reg.nodes) + region.fp16_params.extend(orig_reg.fp16_params) + self.__update_param_region_map(orig_reg.fp16_params, region) + + return region_list + + def _search_block_size(self, + region_list: List[Region], + search_interval_byte: int = 1024, + search_range_byte: int = 128 * 1024 ** 2) -> int: + """ + Search for a suitable memory block size. + + Args: + region_list (List[Region]): region list. + search_interval_byte (int): searching interval in byte. + search_range_byte (int): searching range in byte. + + Returns: + int: the best memory block size. + """ + + def _get_wasted_mem(size_list: List[int], blk_size: int): + """ + Get wasted byte for a certain block size. + """ + acc_wasted = 0 + left = 0 + for s in size_list: + if left + s > blk_size: + acc_wasted += blk_size - left + left = s + left += s + acc_wasted += blk_size - left + return acc_wasted + + param_size_list = [ + region.param_size for region in region_list if region.r_id == region.shared_rid] + + start_size = max(param_size_list) + min_mem_waste = float('+inf') + best_block_size = start_size + + for block_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte): + temp_waste = 0 + temp_waste += _get_wasted_mem(param_size_list, block_size) + if temp_waste < min_mem_waste: + min_mem_waste = temp_waste + best_block_size = block_size + + return best_block_size + + def _init_region_data(self): + """ + Initialize region data, which maps the parameters in the region to a contiguous memory space. + """ + + self.temp_fp32_data = torch.zeros(self.max_param_num, device='cuda', dtype=torch.float32) + + for region in self.region_list: + pre_alloc_tensor = None + if self.require_pool and region.r_id in self.rid_in_pool: + block_idx = self.reg_to_block[region.r_id] + pre_alloc_tensor = self.memory_pool[block_idx] + + if region.r_id <= region.shared_rid: + region.init_param_data(pre_alloc_tensor) + else: + shared_region = self.region_list[region.shared_rid] + region.fp16_data = shared_region.fp16_data + region.fp32_data = shared_region.fp32_data + region.param_to_range = shared_region.param_to_range + region.temp_fp32_data = self.temp_fp32_data[:region.param_num].detach( + ) + + torch.cuda.empty_cache() + + def _process_shared_region(self): + """ + Special processing for the shared region, which uses GPT2 and Bert case as a priori knowledge. + """ + + if len(self.shared_region_pairs): + assert len(self.shared_region_pairs) <= 1 + former_reg, latter_reg = self.shared_region_pairs[0] + assert latter_reg.param_num >= former_reg.param_num + embedding_node = former_reg.nodes[-1] + assert embedding_node.op == 'call_module' and isinstance( + self.root_module.get_submodule(embedding_node.target), torch.nn.Embedding) + if latter_reg.param_num > former_reg.param_num: + for idx, n in enumerate(latter_reg.nodes): + if (n.op == 'call_module' and isinstance(self.root_module.get_submodule(n.target), + torch.nn.Linear)) or \ + (n.op == 'call_function' and n.target is torch.nn.functional.linear): + cut_node_idx = idx + 1 + break + assert len(latter_reg.fp16_params) == 2 + new_reg = latter_reg.split(cut_node_idx, 1) + for p in new_reg.fp16_params: + self.param_region_map[p] = new_reg + self.region_list.insert(new_reg.r_id, new_reg) + for reg in self.region_list[new_reg.r_id + 1:]: + reg.r_id += 1 + latter_reg.shared_rid = former_reg.r_id + former_reg.shared_rid = latter_reg.r_id + + def _linearize_graph(self) -> List[Region]: + """Linearizing the graph + + Args: + graph (Graph): The computing graph to be optimized. + + Returns: + List[Region]: each region contains the actual 'node' in linearized manner. + + Remarks: + Do merge the inplace ops and shape-consistency ops into the previous node. + """ + + # List of target name that could be seen as common node + common_ops = ["getattr", "getitem", "size"] + + def _is_cop(target: Any) -> bool: + """Check if an op could be seen as common node + + Args: + target (Any): node target + + Returns: + bool + """ + + if isinstance(target, str): + return target in common_ops + else: + return target.__name__ in common_ops + + def _is_act(data: Any) -> bool: + """Check if an op could be seen as parameter computation start + + Args: + data (Any): meta_data + + Returns: + bool + """ + + label = False + if isinstance(data, torch.Tensor): + return True + elif isinstance(data, (tuple, list)): + for d in data: + label = label or _is_act(d) + return label + + def _maybe_param_comp_start() -> bool: + """Check if an op could be seen as parameter computation start + + Args: + n (Node): node + + Returns: + bool + """ + + label = False + if n.op == "get_attr": + label = True + elif n.op == "call_module": + target = n.target + submod = self.root_module.get_submodule(target) + if ( + len(list(submod.named_parameters(recurse=False))) != 0 + or len(list(submod.named_buffers(recurse=False))) != 0 + ): + label = True + + return label and not sum([v for _, v in param_op_deps.items()]) + + def _is_param_comp_end() -> bool: + """Check if an op could be seen as parameter computation end + + Args: + n (Node): node + + Returns: + bool + """ + + def _is_inplace(n: Node): + """Get the inplace argument from ``torch.fx.Node`` + """ + inplace = False + if n.op == "call_function": + inplace = n.kwargs.get("inplace", False) + elif n.op == "call_module": + inplace = getattr(n.graph.owning_module.get_submodule( + n.target), "inplace", False) + return inplace + + label = False + + if n.op == "call_module": + target = n.target + submod = self.root_module.get_submodule(target) + if ( + len(list(submod.named_parameters(recurse=False))) != 0 + or len(list(submod.named_buffers(recurse=False))) != 0 + ): + label = True + + elif n.op == "call_function": + label = any(map(lambda x: x.name in self.only_param_ops, n.all_input_nodes)) and any( + map(lambda x: x.name not in self.only_param_ops and not _is_cop(n.target), n.all_input_nodes)) + + return label and not sum([v for _, v in param_op_deps.items()]) and not any(map(_is_inplace, n.users)) + + def _exception_node_handling(): + # TODO meta info prop bug + if n.name.__contains__("transpose") and n.meta['fwd_out'][0].dim() <= 2: + n.meta['fwd_out'] = [] + + # make sure that item in cnode is valid + if self.cnode: + for name in self.cnode: + try: + assert next(node for node in self.graph.nodes if node.name == name).op == "placeholder", \ + f"Common node {name} is not an input of the model." + except StopIteration: + raise ValueError(f"Common node name {name} not in graph.") + else: + self.cnode = [] + + node_id = 0 + region_id = 0 + + param_op_deps = {} + + deps = {} + region_list = [] + region = Region(r_id=region_id) + + act_n = None + + for n in self.graph.nodes: + if n.op != "placeholder" and n.op != "output": + for n_par in n.all_input_nodes: + if n_par.op != "placeholder" and n_par.name not in self.cnode: + deps[n_par] -= 1 + if n_par.op != "placeholder" and n_par.name in self.only_param_ops: + param_op_deps[n_par] -= 1 + + if act_n in region.nodes and _maybe_param_comp_start(): + ns = [] + border_n_idx = region.nodes.index(act_n) + if border_n_idx < len(region.nodes): + ns = region.nodes[border_n_idx + 1:] + region.nodes = region.nodes[:border_n_idx + 1] + region_list.append(region) + region_id += 1 + region = Region(r_id=region_id) + region.nodes = ns + + _exception_node_handling() + region.nodes.append(n) + self._set_node_and_region_info(node_id, n, region) + node_id += 1 + + # if the node could free all dependencies in graph + # we could begin a new region + if _is_param_comp_end(): + region_list.append(region) + region_id += 1 + region = Region(r_id=region_id) + + # propagate common node attr if possible + if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.cnode + ]) or _is_cop(n.target): + self.cnode.append(n.name) + else: + deps[n] = len( + [user for user in n.users if user.op != "output"]) + + # propagate param node attr if possible + if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.only_param_ops + ]) or n.op == "get_attr": + self.only_param_ops.append(n.name) + param_op_deps[n] = len( + [user for user in n.users if user.op != "output"]) + + # record last activation node + if _is_act(n._meta_data): + act_n = n + + if len(region.nodes): + region_list.append(region) + + return region_list + + def _set_node_and_region_info(self, node_id: int, cur_n: Node, cur_reg: Region): + + cur_n.node_info = NodeInfo(node_id) + + if cur_n.op == 'call_module': + target = cur_n.target + submod = self.root_module.get_submodule(target) + for p in list(submod.parameters(recurse=False)): + + if p in self.param_region_map: + cur_reg.shared_rid = self.param_region_map[p].r_id + self.param_region_map[p].shared_rid = cur_reg.r_id + self.shared_region_pairs.append( + (self.param_region_map[p], cur_reg)) + else: + self.param_region_map[p] = cur_reg + + cur_reg.fp16_params.append(p) + cur_reg.param_num += p.data.numel() + cur_reg.param_size += p.data.numel() * p.data.element_size() + + elif cur_n.op == "get_attr": + attr_itr = self.root_module + atoms = cur_n.target.split(".") + for atom in atoms: + attr_itr = getattr(attr_itr, atom) + + if isinstance(attr_itr, torch.nn.Parameter): + + if attr_itr in self.param_region_map: + cur_reg.shared_rid = self.param_region_map[attr_itr].r_id + self.param_region_map[attr_itr].shared_rid = cur_reg.r_id + self.shared_region_pairs.append( + (self.param_region_map[attr_itr], cur_reg)) + else: + self.param_region_map[attr_itr] = cur_reg + + cur_reg.fp16_params.append(attr_itr) + cur_reg.param_num += attr_itr.data.numel() + cur_reg.param_size += attr_itr.data.numel() * attr_itr.data.element_size() + + def get_region(self, param: torch.nn.Parameter) -> Region: + """ + Return the region owning the parameter. + + Args: + param (torch.nn.Parameter): a torch parameter object + """ + return self.param_region_map[param] + + def __update_param_region_map(self, params: List[torch.nn.Parameter], region: Region): + for p in params: + self.param_region_map[p] = region diff --git a/colossalai/auto_parallel/offload/runtime.py b/colossalai/auto_parallel/offload/runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..764ac608826b2860f3294ec9a5883745dfa10f57 --- /dev/null +++ b/colossalai/auto_parallel/offload/runtime.py @@ -0,0 +1,256 @@ +from typing import List + +import torch +from torch.fx.node import Node + +from .region import Region +from .util import GlobalRuntimeInfo, requires_upload_p_in_fwd + + +class SynPreFwdPostBwdOP(torch.autograd.Function): + """ + A customized prefetch and offload operation. + + Args: + input_: input tensor. + fwd_info: information dict, which contains region indices + that need to be uploaded or freed during forward pass. + bwd_info: information dict, which contains region indices + that need to be uploaded during backward pass. + """ + + @staticmethod + def forward(ctx, input_, fwd_info, bwd_info): + ctx.bwd_info = bwd_info + d2h_rid = fwd_info.get('d2h_rid', None) + if d2h_rid is not None: + free_region = GlobalRuntimeInfo().region_list[d2h_rid] + assert isinstance(free_region, Region) + free_region.free_cuda_data() + + h2d_rid = fwd_info.get('h2d_rid', None) + if h2d_rid is not None: + h2d_region = GlobalRuntimeInfo().region_list[h2d_rid] + assert isinstance(h2d_region, Region) + h2d_region.move_param_to_cuda() + + return input_ + + @staticmethod + def backward(ctx, grad_output): + + h2d_rid = ctx.bwd_info.get('h2d_rid', None) + if h2d_rid is not None: + pref_region = GlobalRuntimeInfo().region_list[h2d_rid] + assert isinstance(pref_region, Region) + pref_region.move_param_to_cuda() + + return grad_output, None, None + + +class AsynPreFwdPostBwdOP(torch.autograd.Function): + """ + A customized prefetch and offload operation. + + Args: + input_: input tensor. + fwd_info: information dict, which contains region indices + that need to be prefetched, waited, or freed during forward pass. + bwd_info: information dict, which contains region indices + that need to be prefetched or waited during backward pass. + """ + + @staticmethod + def forward(ctx, input_, fwd_info, bwd_info): + ctx.bwd_info = bwd_info + + sync_rid = fwd_info.get('sync_rid', None) + if sync_rid is not None: + prefetch_event = GlobalRuntimeInfo().fwd_prefetch_event_map.get(sync_rid, None) + if prefetch_event: + prefetch_event.wait() + + h2d_rid = fwd_info.get('h2d_rid', None) + if h2d_rid is not None: + pref_region = GlobalRuntimeInfo().region_list[h2d_rid] + assert isinstance(pref_region, Region) + master_stream = torch.cuda.current_stream() + with torch.cuda.stream(GlobalRuntimeInfo().h2d_stream): + GlobalRuntimeInfo().h2d_stream.wait_stream(master_stream) + pref_region.move_param_to_cuda() + + prefetch_event = torch.cuda.Event() + prefetch_event.record(GlobalRuntimeInfo().h2d_stream) + GlobalRuntimeInfo().fwd_prefetch_event_map[h2d_rid] = prefetch_event + + return input_ + + @staticmethod + def backward(ctx, grad_output): + + sync_rid = ctx.bwd_info.get('sync_rid', None) + if sync_rid is not None: + wait_region = GlobalRuntimeInfo().region_list[sync_rid] + assert isinstance(wait_region, Region) + prefetch_event = GlobalRuntimeInfo().bwd_prefetch_event_map.get(sync_rid, None) + if prefetch_event: + prefetch_event.wait() + else: + wait_region.move_param_to_cuda() + + h2d_rid = ctx.bwd_info.get('h2d_rid', None) + if h2d_rid is not None: + pref_region = GlobalRuntimeInfo().region_list[h2d_rid] + assert isinstance(pref_region, Region) + master_stream = torch.cuda.current_stream() + with torch.cuda.stream(GlobalRuntimeInfo().h2d_stream): + GlobalRuntimeInfo().h2d_stream.wait_stream(master_stream) + pref_region.move_param_to_cuda() + + prefetch_event = torch.cuda.Event() + prefetch_event.record(GlobalRuntimeInfo().h2d_stream) + GlobalRuntimeInfo().bwd_prefetch_event_map[h2d_rid] = prefetch_event + return grad_output, None, None + + +def convert_fwd_upload_bwd_offload_to_action(tensor, fwd_info, bwd_info): + ''' + Convert Upload and Offload operation into runtime action. + + Argument: + tensor(torch.Tensor): input tensor. + fwd_info(dict): information dict, which contains region indices + that need to be uploaded, or freed during forward pass. + bwd_info(dict): information dict, which contains region indices + that need to be uploaded during backward pass. + ''' + with torch._C.DisableTorchFunction(): + ret = SynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info) + return ret + + +def convert_fwd_prefetch_bwd_offload_to_action(tensor, fwd_info, bwd_info): + ''' + Convert Prefetch and Offload operation into runtime action. + + Argument: + tensor(torch.Tensor): input tensor. + fwd_info(dict): information dict, which contains region indices + that need to be prefetched, waited, or freed during forward pass. + bwd_info(dict): information dict, which contains region indices + that need to be prefetched or waited during backward pass. + ''' + with torch._C.DisableTorchFunction(): + ret = AsynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info) + return ret + + +def replace_node_users(orig_node: Node, inserted_node: Node, rep_user_nodes: List[Node] = None): + user_list = list(orig_node.users.keys()) + if rep_user_nodes is not None: + user_list = rep_user_nodes + for user in user_list: + if user == inserted_node: + continue + new_args = list(user.args) + new_kwargs = dict(user.kwargs) + # the origin node may be a positional argument or key word argument of user node + if orig_node in new_args: + # substitute the origin node with offload_apply_node + new_args[new_args.index(orig_node)] = inserted_node + user.args = tuple(new_args) + elif str(orig_node) in new_kwargs: + # substitute the origin node with offload_apply_node + new_kwargs[str(orig_node)] = inserted_node + user.kwargs = new_kwargs + + +def runtime_syn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[Region]): + """ + This pass is used to add the synchronous upload and offload spec apply node to the origin graph. + """ + mod_graph = gm.graph + last_inp_node = tuple(mod_graph.nodes)[0] + + for r_idx, region in enumerate(region_list): + # forward upload + fwd_info = {} + if requires_upload_p_in_fwd(region_list[region.shared_rid]): + fwd_info['h2d_rid'] = region.r_id + + # forward offload + if r_idx > 0 and region_list[r_idx - 1].need_offload: + fwd_info['d2h_rid'] = r_idx - 1 + + bwd_info = {} + # backward upload + if r_idx > 0 and region_list[r_idx - 1].need_offload: + bwd_info['h2d_rid'] = region_list[r_idx - 1].r_id + + if fwd_info or bwd_info: + with mod_graph.inserting_after(last_inp_node): + new_node = mod_graph.create_node('call_function', + convert_fwd_upload_bwd_offload_to_action, + args=(last_inp_node, fwd_info, bwd_info)) + replace_node_users(last_inp_node, new_node) + + last_inp_node = region.nodes[-1] + + return gm + + +def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[Region]): + """ + This pass is used to add the asynchronous prefetch and offload spec apply node to the origin graph. + """ + mod_graph = gm.graph + + # upload parameters of the first region + last_inp_node = tuple(mod_graph.nodes)[0] + first_region_with_p = [region for region in region_list if region.param_size][0] + fwd_info = {"h2d_rid": first_region_with_p.r_id} + with mod_graph.inserting_after(last_inp_node): + upload_apply_node = mod_graph.create_node('call_function', + convert_fwd_upload_bwd_offload_to_action, + args=(last_inp_node, fwd_info, {})) + replace_node_users(last_inp_node, upload_apply_node) + last_inp_node = upload_apply_node + + for r_idx, region in enumerate(region_list): + # forward prefetch + fwd_info = {} + if region.param_size: + fwd_info['sync_rid'] = region.r_id + fwd_prefetch_region = region.fwd_prefetch_region + if fwd_prefetch_region and requires_upload_p_in_fwd(region_list[fwd_prefetch_region.shared_rid]): + fwd_info['h2d_rid'] = fwd_prefetch_region.r_id + + # forward offload + if r_idx > 0 and region_list[r_idx - 1].need_offload: + fwd_info['d2h_rid'] = r_idx - 1 + + bwd_info = {} + # backward prefetch + if r_idx > 0 and region_list[r_idx - 1].need_offload: + bwd_info['sync_rid'] = r_idx - 1 + if r_idx > 0 and region_list[r_idx - 1].bwd_prefetch_region: + bwd_info['h2d_rid'] = region_list[r_idx - 1].bwd_prefetch_region.r_id + + if fwd_info or bwd_info: + with mod_graph.inserting_after(last_inp_node): + new_node = mod_graph.create_node('call_function', + convert_fwd_prefetch_bwd_offload_to_action, + args=(last_inp_node, fwd_info, bwd_info)) + replace_node_users(last_inp_node, new_node) + + last_inp_node = region.nodes[-1] + + if region.bwd_prefetch_region: + bwd_info = {'h2d_rid': region.bwd_prefetch_region.r_id} + with mod_graph.inserting_after(last_inp_node): + new_node = mod_graph.create_node('call_function', + convert_fwd_prefetch_bwd_offload_to_action, + args=(last_inp_node, {}, bwd_info)) + replace_node_users(last_inp_node, new_node) + # gm.graph.print_tabular() + return gm diff --git a/colossalai/auto_parallel/offload/solver.py b/colossalai/auto_parallel/offload/solver.py new file mode 100644 index 0000000000000000000000000000000000000000..161f7ff868981d913047d7073bb144694ab68591 --- /dev/null +++ b/colossalai/auto_parallel/offload/solver.py @@ -0,0 +1,523 @@ +import time +from typing import List, Dict, Type +from abc import ABC, abstractmethod + +NOT_NVML = False +try: + from pynvml import * +except: + NOT_NVML = True + +import torch +from torch.fx.node import Node +from colossalai.utils.cuda import get_current_device + +from .training_simulator import TrainingSimulator, SynTrainingSimulator, AsynTrainingSimulator +from .region import Region +from .util import NodeInfo, NvDevicePower + + +def benchmark_func(func, number=1, repeat=1, warmup=3): + """ + benchmark data transfer cost. + """ + + for i in range(warmup): + func() + + costs = [] + + for i in range(repeat): + torch.cuda.synchronize() + begin = time.time() + for i in range(number): + func() + torch.cuda.synchronize() + costs.append((time.time() - begin) / number) + + return sum(costs) / len(costs) + + +class Solver(ABC): + """ + The parameter offload solver. + + Args: + region_list (List[Region]): represents the linearized DNN computing graph. + memory_budget (float): the given memory budget. + error_factor (float): the error factor. + It is used to reduce the memory budget. Due to some errors in the estimation of peak memory and execution time. + """ + + def __init__(self, + region_list: List[Region], + memory_budget: float = -1.0, + error_factor: float = 0.95) -> None: + + self.region_list = region_list + + self.error_factor: float = error_factor + if memory_budget > 0: + self.memory_budget = memory_budget * self.error_factor + else: + self.memory_budget = torch.cuda.get_device_properties( + get_current_device()).total_memory * self.error_factor + + self.link_to_bandwidth: Dict[str, Dict[float, float]] = self._profile_bandwidth() + self.comp_power: float = self._extract_computing_power() + + @abstractmethod + def _call_solver(self): + raise NotImplementedError + + @abstractmethod + def _try_to_offload(self, *args): + raise NotImplementedError + + @abstractmethod + def _eval_one_choice(self, *args): + raise NotImplementedError + + def _compute_offload_profit(self, total_mem_saving: float, peak_mem_saving: float, extra_cost: float): + """ + Compute the profits of the offload strategies, + which packages the memory savings information for subsequent comparisons. + + Args: + total_mem_saving (float): the total memory saving of the offload strategy. + peak_mem_saving (float): the peak memory saving of the offload strategy. + extra_cost (float): extra data transfer cost. + + Returns: + tuple: profit information, the first term represents memory savings per unit of time. + """ + + if extra_cost == 0: + # means data transfer overhead can be completely overlapped + return (float('inf'), total_mem_saving, peak_mem_saving) + return (total_mem_saving / extra_cost, total_mem_saving, peak_mem_saving) + + def _compare_profit(self, profit_a: tuple, profit_b: tuple) -> bool: + """ + Compare the profits of the two offload strategies using the dictionary order algorithm. + + Args: + profit_a (tuple): the profit of a offload strategy. + profit_b (tuple): the profit of another offload strategy. + + Returns: + bool: whether profit_a is greater than profit_b. + """ + + for val1, val2 in zip(profit_a, profit_b): + if val1 != val2: + return val1 > val2 + return False + + def _update_state(self, best_ts: TrainingSimulator): + """ + Update the solver state. + """ + + self.best_ts = best_ts + self._update_node_mem_info(best_ts.fwd_node_mem, best_ts.bwd_node_mem) + + def _update_node_mem_info(self, + fwd_mem_info: Dict[Node, float], + bwd_mem_info: Dict[Node, float]): + """ + Update the runtime memory information of the node. + + Args: + fwd_mem_info (Dict[Node, float]): the runtime memory of each node in forward pass. + bwd_mem_info (Dict[Node, float]): the runtime memory of each node in backward pass. + """ + + for node, mem in fwd_mem_info.items(): + assert hasattr(node, 'node_info') and isinstance( + node.node_info, NodeInfo) + node.node_info.runtime_fwd_mem = mem + for node, mem in bwd_mem_info.items(): + assert hasattr(node, 'node_info') and isinstance( + node.node_info, NodeInfo) + node.node_info.runtime_bwd_mem = mem + + def _extract_computing_power(self): + """ + return the FP16 computing performance of the current NVIDIA GPU. + + Raises: + TypeError: Unknown NVIDIA GPU device. + """ + + nvmlInit() + handle = nvmlDeviceGetHandleByIndex(0) + device_name = nvmlDeviceGetName(handle) + units = 1e12 + + if device_name.__contains__("RTX 3080"): + return NvDevicePower.RTX3080_FP16 * units + elif device_name.__contains__("RTX 3090"): + return NvDevicePower.RTX3090_FP16 * units + elif device_name.__contains__('V100'): + return NvDevicePower.V100_FP16 * units + elif device_name.__contains__("A100"): + return NvDevicePower.A100_FP16 * units + else: + raise TypeError(f'Unknown NVIDIA GPU device name {device_name}') + + def _profile_bandwidth(self): + """ + Profile the bidirectional communication bandwidth between CPU and GPU + using data volumes ranging from 1KB to 1GB. + """ + + print('profiling bandwidth ......') + link_to_bandwidth = {} + links = ['h2d', 'd2h'] + + for link in links: + t_size = 1024 + size_to_bandwidth = {} + + # from 1KB to 1GB + for i in range(21): + if link == 'h2d': + src_tensor = torch.ones( + int(t_size), dtype=torch.int8, pin_memory=True) + dst_tensor = torch.ones( + (int(t_size)), dtype=torch.int8, device='cuda') + elif link == 'd2h': + src_tensor = torch.ones( + int(t_size), dtype=torch.int8, device='cuda') + dst_tensor = torch.ones( + (int(t_size)), dtype=torch.int8, pin_memory=True) + + def func(): + dst_tensor.copy_(src_tensor) + + size_to_bandwidth[t_size] = t_size / benchmark_func(func, number=5, repeat=3) + print(f'size: {t_size / 1024 ** 2:.3f} MB, ' + f'{src_tensor.device.type}-to-{dst_tensor.device.type} ' + f'bandwidth: {size_to_bandwidth[t_size] / 1024 ** 3:.3f} GB/s') + + t_size *= 2 + + link_to_bandwidth[link] = size_to_bandwidth + return link_to_bandwidth + + +class SynGreedySolver(Solver): + + def __init__(self, + region_list: List[Region], + memory_budget: float = -1.0) -> None: + super().__init__(region_list, memory_budget) + + self.best_ts: SynTrainingSimulator = None + self._init_state() + + def _init_state(self): + """ + Initialize the solver state when without offloading. + """ + + ts = SynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth) + ts.execute() + self._update_state(ts) + + def _call_solver(self): + """ + Call the solver to search an efficient parameter offloading strategy for the linearized graph. + The solver adopts greedy algorithm. + + Raises: + NotImplementedError: Unable to find a solution for the given memory budget. + """ + + print("search offloading strategy ......") + while self.best_ts.peak_mem > self.memory_budget: + offload_region = None + best_ts = None + max_profit = (0,) + + # search which region should be offloaded, + # the last region does not need to be offloaded. + for region in self.region_list[:-1]: + if region.param_size and not region.need_offload: + temp_ts, profit = self._try_to_offload(region) + if self._compare_profit(profit, max_profit): + offload_region = region + max_profit = profit + best_ts = temp_ts + + if offload_region is not None and best_ts is not None: + offload_region.need_offload = True + offload_region.is_syn = True + self._update_state(best_ts) + else: + raise NotImplementedError( + f"can't find the offload strategy met the memory budget {self.memory_budget / 1024 ** 2} MB, " + f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!") + + def _call_solver_l2l(self): + """ + The layer-wise offload strategy. + """ + + for region in self.region_list[:-1]: + region.need_offload = True + region.is_syn = True + + def _try_to_offload(self, offload_region: Region): + + # record previous information + orig_need_offload = offload_region.need_offload + assert not orig_need_offload + offload_region.need_offload = True + + ts, profit = self._eval_one_choice(offload_region) + + # restore previous information + offload_region.need_offload = orig_need_offload + return ts, profit + + def _eval_one_choice(self, offload_region: Region): + """ + Evaluate the profit of a strategy choice. + + Args: + offload_region (Region): the offload region of current choice. + + Returns: + SynTrainingSimulator: the training simulator corresponding to the current strategy. + tuple: contains memory saving and cost information of the current strategy. + """ + + ts = SynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth) + ts.execute() + + extra_comm_cost = 2.0 * \ + ts._get_communication_overhead('h2d', offload_region.param_size) + # the shared region needs to be moved twice + if offload_region.r_id < offload_region.shared_rid: + extra_comm_cost *= 2.0 + profit = self._compute_offload_profit( + ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost) + + return ts, profit + + +class AsynGreedySolver(Solver): + + def __init__(self, + region_list: List[Region], + memory_budget: float = -1.0, + search_window_size: int = 3): + super().__init__(region_list, memory_budget) + + self.search_window_size = search_window_size + # Records the prefetch execution location of the offloaded region + self.region_to_region_map = {} + self.best_ts: AsynTrainingSimulator = None + + self._init_state() + + def _init_state(self): + """ + Initialize the solver state when without offloading. + """ + + ts = AsynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth) + ts.execute() + self._update_state(ts) + print("init peak memory", self.best_ts.peak_mem / 1024 ** 2, "MB") + + def _call_solver(self): + """ + Call the solver to search an efficient parameter offloading strategy for the linearized graph. + The solver adopts greedy algorithm. + + Raises: + NotImplementedError: Unable to find a solution for the given memory budget. + """ + + print("search for offloading strategy ......") + # Records the prefetch execution location of the offloaded region + region_to_region_map = {} + while self.best_ts.peak_mem > self.memory_budget: + region_to_offload = None + max_offload_profit = (0,) + best_offl_ts = None + + # search which region should be offloaded, + # the last region does not need to be offloaded + for region in self.region_list[:-1]: + if region.param_size and not region.need_offload: + max_prefetch_profit = (0,) + best_pref_ts = None + + # search when to prefetch the region offloaded + for host_region in self.region_list[region.r_id + 1:region.r_id + 1 + self.search_window_size]: + if host_region.bwd_prefetch_region is not None: + continue + + temp_ts, profit = self._try_to_offload( + host_region, region) + + if self._compare_profit(profit, max_prefetch_profit): + region_to_region_map[region.r_id] = host_region + max_prefetch_profit = profit + best_pref_ts = temp_ts + if profit[0] == float('inf'): + break + + if self._compare_profit(max_prefetch_profit, max_offload_profit): + region_to_offload = region + max_offload_profit = max_prefetch_profit + best_offl_ts = best_pref_ts + + if (region_to_offload is not None) and (best_offl_ts is not None): + region_to_offload.need_offload = True + if region_to_region_map[region_to_offload.r_id] == region_to_offload: + region_to_offload.is_syn = True + else: + region_to_region_map[region_to_offload.r_id].bwd_prefetch_region = region_to_offload + self.region_to_region_map[region_to_offload.r_id] = region_to_region_map[region_to_offload.r_id] + + self._update_state(best_offl_ts) + + elif self.region_to_region_map.__len__() > 0: + self._repair_strategy() + else: + raise NotImplementedError( + f"can't find the offload strategy met the memory budget {self.memory_budget / 1024 ** 2} MB, " + f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!") + + region_to_region_map.clear() + + def _try_to_offload(self, host_region: Region, offload_region: Region): + """ + Attempts to offload the region and prefetch it in backward pass. + """ + + # record previous information + orig_prefetch = host_region.bwd_prefetch_region + orig_is_syn = offload_region.is_syn + orig_need_offload = offload_region.need_offload + + if host_region == offload_region: + offload_region.is_syn = True + else: + host_region.bwd_prefetch_region = offload_region + offload_region.need_offload = True + + ts, profit = self._eval_one_choice() + + # restore previous information + host_region.bwd_prefetch_region = orig_prefetch + offload_region.is_syn = orig_is_syn + offload_region.need_offload = orig_need_offload + + return ts, profit + + def _try_convert_to_syn_upload(self, host_region: Region, offload_region: Region): + """ + Attempts to convert asynchronous prefetch into synchronous upload operations. + """ + + # record previous information + orig_prefetch = host_region.bwd_prefetch_region + orig_is_syn = offload_region.is_syn + assert orig_prefetch is not None and not orig_is_syn + + host_region.bwd_prefetch_region = None + offload_region.is_syn = True + + ts, profit = self._eval_one_choice() + + # restore previous information + host_region.bwd_prefetch_region = orig_prefetch + offload_region.is_syn = orig_is_syn + + return ts, profit + + def _repair_strategy(self): + """ + Repair offload strategy. + It attempts to convert asynchronous prefetch into synchronous upload operations and selects the best one. + The repair process does not end until peak memory is reduced or there is no asynchronous prefetch operation. + """ + print("repair strategy ......") + + peak_mem_saving = 0 + while len(self.region_to_region_map) and peak_mem_saving <= 0: + + max_profit = (0,) + best_ts = None + undo_host_region = None + undo_offload_region = None + + for offload_region_id, host_region in self.region_to_region_map.items(): + offload_region = self.region_list[offload_region_id] + assert host_region.bwd_prefetch_region == offload_region + assert offload_region.need_offload + assert not offload_region.is_syn + + ts, profit = self._try_convert_to_syn_upload(host_region, + offload_region) + + if self._compare_profit(profit, max_profit): + undo_host_region = host_region + undo_offload_region = offload_region + max_profit = profit + best_ts = ts + + if best_ts is None: + raise NotImplementedError('repair error!') + + assert not undo_offload_region.is_syn + undo_offload_region.is_syn = True + undo_host_region.bwd_prefetch_region = None + + peak_mem_saving = self.best_ts.peak_mem - best_ts.peak_mem + + self._update_state(best_ts) + self.region_to_region_map.pop(undo_offload_region.r_id) + + return best_ts + + def _eval_one_choice(self): + """ + Evaluate the profit of a strategy choice. + + Returns: + AsynTrainingSimulator: the training simulator corresponding to the current strategy. + tuple: contains memory saving and cost information of the current strategy. + """ + + ts = AsynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth) + ts.execute() + + extra_comm_cost = max(ts.iter_end_time - self.best_ts.iter_end_time, 0) + profit = self._compute_offload_profit( + ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost) + + return ts, profit + + +class SolverFactory: + solvers: Dict[str, Type[Solver]] = { + 'syn': SynGreedySolver, + 'asyn': AsynGreedySolver + } + + @staticmethod + def create(solver_name: str) -> Type[Solver]: + if solver_name not in SolverFactory.solvers: + raise TypeError(f"Unknown parameter offload policy {solver_name}") + return SolverFactory.solvers[solver_name] + + @staticmethod + def get_solver_names(): + return tuple(SolverFactory.solvers.keys()) diff --git a/colossalai/auto_parallel/offload/training_simulator.py b/colossalai/auto_parallel/offload/training_simulator.py new file mode 100644 index 0000000000000000000000000000000000000000..de58023ec2d6a4b4247ad838f4e9c2e9a56da692 --- /dev/null +++ b/colossalai/auto_parallel/offload/training_simulator.py @@ -0,0 +1,458 @@ +import bisect +from typing import List, Dict +from collections import OrderedDict +from abc import ABC, abstractmethod + +from torch.fx.node import Node + +from .region import Region +from .util import * + + +@dataclass +class ExecutionPeriod: + start_time: float = 0 + end_time: float = 0 + + +class TrainingSimulator(ABC): + """ + The Training Simulator is used to simulate the training process. + It records computation, communication, and runtime memory during forward and backward passes. + + Args: + region_list (List[Region]): represents the linearized DNN computing graph. + comp_power (float): the NVIDIA GPU FP16 computing power. + link_to_bw (Dict[str, Dict[float, float]]): communication links and the corresponding bandwidth. + """ + + def __init__(self, + region_list: List[Region], + comp_power: float, + link_to_bw: Dict[str, Dict[float, float]]) -> None: + self.region_list = region_list + self.region_num = len(region_list) + + self.runtime_mem: int = 0 + self.peak_mem: int = 0 + self.total_mem_saving: int = 0 + + self.fwd_node_mem: Dict[Node, float] = {} + self.bwd_node_mem: Dict[Node, float] = {} + + # Node dependencies in backward pass + self.bwd_node_deps: Dict[Node, int] = {} + + self.comp_power: float = comp_power + self.link_to_bandwidth: Dict[str, Dict[float, float]] = link_to_bw + + @abstractmethod + def execute(self): + raise NotImplementedError + + @abstractmethod + def _eval_fwd_mem_per_region(self, region: Region): + raise NotImplementedError + + @abstractmethod + def _eval_bwd_mem_per_region(self, region: Region): + raise NotImplementedError + + def _get_bandwidth(self, link: str, comm_volumn: float) -> float: + """ + Get the data transfer bandwidth. + + Args: + link (str): the data transfer link. + comm_volumn (float): the amount of data transferred. + + Returns: + float: the data transfer bandwidth. + """ + + assert len(self.link_to_bandwidth) + if link not in self.link_to_bandwidth: + raise TypeError(f"Unknown data transfer link {link}") + + # size_list = sorted(list(map(float, self.link_to_bandwidth[link].keys()))) + size_list = sorted(self.link_to_bandwidth[link].keys()) + d_idx = bisect.bisect_left(size_list, comm_volumn) + return self.link_to_bandwidth[link][size_list[d_idx]] + + def _get_communication_overhead(self, link: str, comm_volumn: float) -> float: + return comm_volumn / self._get_bandwidth(link, comm_volumn) + + def _get_computing_overhead(self, flop: float) -> float: + return flop / self.comp_power + + +class SynTrainingSimulator(TrainingSimulator): + + def __init__(self, + region_list: List[Region], + comp_power: float, + link_to_bw: Dict[str, Dict[float, float]]) -> None: + super().__init__(region_list, comp_power, link_to_bw) + + def execute(self): + """ + Simulate synchronous training process. + """ + + for reg in self.region_list: + self._eval_fwd_mem_per_region(reg) + + for reg in self.region_list.__reversed__(): + self._eval_bwd_mem_per_region(reg) + + def _eval_fwd_mem_per_region(self, region: Region): + """ + Evaluate the runtime and peak memory when the forward execution reaches the current region. + """ + + # upload parameters of the current region + if requires_upload_p_in_fwd(self.region_list[region.shared_rid]): + self.runtime_mem += region.param_size + + for node in region.nodes: + self.runtime_mem += calculate_fwd_tmp(node) + \ + calculate_fwd_out(node) + self.fwd_node_mem[node] = self.runtime_mem + self.peak_mem = max(self.runtime_mem, self.peak_mem) + self.total_mem_saving += node.node_info.runtime_fwd_mem - self.runtime_mem + + if region.need_offload: + self.runtime_mem -= region.param_size + + def _eval_bwd_mem_per_region(self, region: Region): + """ + Evaluate the runtime and peak memory when the backward execution reaches the current region. + """ + + # upload parameters of the current region + if region.need_offload: + self.runtime_mem += region.param_size + + # add the gradient of the parameter + if region.r_id < region.shared_rid: + # gradient accumulation is required for shared parameters + self.runtime_mem += 2.0 * region.param_size + else: + self.runtime_mem += region.param_size + + for node in region.nodes.__reversed__(): + + self.runtime_mem -= calculate_fwd_out(node) + self.runtime_mem += node.meta['bwd_mem_tmp'] + \ + node.meta['bwd_mem_out'] + self.peak_mem = max(self.runtime_mem, self.peak_mem) + + # The memory savings of a node may be negative due to parameter prefetch. + self.total_mem_saving += node.node_info.runtime_bwd_mem - self.runtime_mem + self.bwd_node_mem[node] = self.runtime_mem + + self.runtime_mem -= (node.meta['bwd_mem_tmp'] + + calculate_fwd_tmp(node)) + + # free bwd_mem_out + self.bwd_node_deps[node] = len(node.all_input_nodes) + for user_node in node.users: + if user_node in self.bwd_node_deps: + self.bwd_node_deps[user_node] -= 1 + if self.bwd_node_deps[user_node] <= 0: + self.runtime_mem -= user_node.meta['bwd_mem_out'] + + if self.runtime_mem < 0: + raise ValueError(f"region id: {region.r_id}, node name: {node.name}, " + f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---" + f"runtime memory computed less than 0, which is miscalculated!") + + # release parameter and offload gradient in region + if region.r_id == region.shared_rid: + self.runtime_mem -= 2.0 * region.param_size + elif region.r_id < region.shared_rid: + self.runtime_mem -= 3.0 * region.param_size + elif self.region_list[region.shared_rid].need_offload: + self.runtime_mem -= region.param_size + + +class AsynTrainingSimulator(TrainingSimulator): + + def __init__(self, + region_list: List[Region], + comp_power: float, + link_to_bw: Dict[str, Dict[float, float]]) -> None: + super().__init__(region_list, comp_power, link_to_bw) + + self.iter_end_time: int = 0 + # the last computation execution period + self.last_comp: ExecutionPeriod = ExecutionPeriod( + start_time=0, end_time=0) + # the last parameter prefetch execution period + self.last_h2d: ExecutionPeriod = ExecutionPeriod( + start_time=0, end_time=0) + # the last gradient offload execution period + self.last_d2h: ExecutionPeriod = ExecutionPeriod( + start_time=0, end_time=0) + # the forward computation execution period of the region + self.fwd_reg_to_comp: OrderedDict[int, ExecutionPeriod] = OrderedDict() + # the forward parameter prefetch execution period of the region + self.fwd_reg_to_pref: OrderedDict[int, ExecutionPeriod] = OrderedDict() + # the backward computation execution period of the region + self.bwd_reg_to_comp: OrderedDict[int, ExecutionPeriod] = OrderedDict() + # the backward parameter prefetch execution period of the region + self.bwd_reg_to_pref: OrderedDict[int, ExecutionPeriod] = OrderedDict() + # the gradient offload execution period of the region + # which is divided into those that are waiting and those that have been released + self.bwd_reg_to_offl_waiting: OrderedDict[int, + ExecutionPeriod] = OrderedDict() + self.bwd_reg_to_offl_freed: OrderedDict[int, + ExecutionPeriod] = OrderedDict() + # the region buffer, which records regions that are offloaded but not released + self.reg_buffer_to_free: List[int] = [] + + # node dependencies in backward pass + self.bwd_node_deps: Dict[Node, int] = {} + + # the region execution flow, + # where fwd_reg_flow[i,j] denotes whether the parameters of j-th region are in the GPU + # when the execution reaches the i-th region. + self.fwd_reg_flow = torch.zeros( + (self.region_num, self.region_num)).bool() + self.bwd_reg_flow = torch.zeros( + (self.region_num, self.region_num)).bool() + + def execute(self): + """ + Simulate asynchronous training process. + In forward pass, parameter prefetching is advanced by one region. + In backward pass, parameter prefetching is executed at the specified location, + and gradient offloading is urgent. + """ + + for reg in self.region_list: + if reg.param_size and reg.r_id < self.region_num - 1: + for nr in self.region_list[reg.r_id + 1:]: + if nr.param_size and requires_upload_p_in_fwd(self.region_list[nr.shared_rid]): + reg.fwd_prefetch_region = nr + break + self._eval_fwd_cost_per_region(reg) + self._eval_fwd_mem_per_region(reg) + + for reg in self.region_list.__reversed__(): + self._eval_bwd_cost_per_region(reg) + self._eval_bwd_mem_per_region(reg) + + # release remaining grads + for reg_id, offl_exec in self.bwd_reg_to_offl_waiting.items(): + self.bwd_reg_to_offl_freed[reg_id] = offl_exec + self.runtime_mem -= self.region_list[reg_id].param_size + self.bwd_reg_to_offl_waiting.clear() + + self.iter_end_time = max( + self.last_comp.end_time, self.last_d2h.end_time) + + def _insert_h2d_exec(self, region: Region, is_fwd: bool = True): + """ + Insert parameter prefetch execution period of the current region to the end of the h2d stream + """ + + pref_start_time = max(self.last_h2d.end_time, self.last_comp.end_time) + pref_end_time = pref_start_time + \ + 2.0 * self._get_communication_overhead('h2d', region.param_size) + pref_ep = ExecutionPeriod( + start_time=pref_start_time, end_time=pref_end_time) + if is_fwd: + self.fwd_reg_to_pref[region.r_id] = pref_ep + else: + self.bwd_reg_to_pref[region.r_id] = pref_ep + self.last_h2d = pref_ep + + def _insert_comp_exec(self, region: Region, is_fwd: bool = True): + """ + Insert computation execution period of the current region to the end of the computing stream + """ + + if is_fwd: + reg_to_comp = self.fwd_reg_to_comp + reg_to_pref = self.fwd_reg_to_pref + flop_key = 'fwd_flop' + else: + reg_to_comp = self.bwd_reg_to_comp + reg_to_pref = self.bwd_reg_to_pref + flop_key = 'bwd_flop' + comp_start_time = max(self.last_comp.end_time, reg_to_pref.get( + region.r_id, ExecutionPeriod(0, 0)).end_time) + comp_end_time = comp_start_time + \ + sum([self._get_computing_overhead(node.meta.get(flop_key, 0)) + for node in region.nodes]) + comp_ep = ExecutionPeriod( + start_time=comp_start_time, end_time=comp_end_time) + reg_to_comp[region.r_id] = comp_ep + self.last_comp = comp_ep + + def _insert_d2h_exec(self, region: Region): + """ + Insert gradient offload execution period of the current region to the end of the d2h stream + """ + + offl_start_time = max(self.last_d2h.end_time, self.last_comp.end_time) + offl_end_time = offl_start_time + \ + self._get_communication_overhead('d2h', region.param_size) + offl_ep = ExecutionPeriod( + start_time=offl_start_time, end_time=offl_end_time) + self.bwd_reg_to_offl_waiting[region.r_id] = offl_ep + self.last_d2h = offl_ep + + def _eval_fwd_cost_per_region(self, region: Region): + """ + Evaluate computation and communication execution period of the region in forward pass. + """ + + # upload parameters of the first region + if region.r_id == 0: + self._insert_h2d_exec(region) + + # prefetch parameters of the next region + fwd_prefetch_region = region.fwd_prefetch_region + if fwd_prefetch_region and requires_upload_p_in_fwd(self.region_list[fwd_prefetch_region.shared_rid]): + self._insert_h2d_exec(fwd_prefetch_region) + + # execute computation + self._insert_comp_exec(region) + + def _eval_fwd_mem_per_region(self, region: Region): + """ + Evaluate the runtime and peak memory when the forward execution reaches the current region. + """ + + # upload parameters of the current region + if region.r_id <= 0: + self.runtime_mem += region.param_size + self.fwd_reg_flow[region.r_id, region.r_id] = True + else: + self.fwd_reg_flow[region.r_id] = self.fwd_reg_flow[region.r_id - 1] + self.fwd_reg_flow[region.r_id, + self.reg_buffer_to_free] = False + self.reg_buffer_to_free.clear() + + # prefetch parameters of the next region + fwd_prefetch_region = region.fwd_prefetch_region + if fwd_prefetch_region and requires_upload_p_in_fwd(self.region_list[fwd_prefetch_region.shared_rid]): + self.runtime_mem += fwd_prefetch_region.param_size + self.fwd_reg_flow[region.r_id, + fwd_prefetch_region.r_id] = True + + for node in region.nodes: + self.runtime_mem += calculate_fwd_tmp(node) + \ + calculate_fwd_out(node) + self.peak_mem = max(self.runtime_mem, self.peak_mem) + + self.total_mem_saving += node.node_info.runtime_fwd_mem - self.runtime_mem + self.fwd_node_mem[node] = self.runtime_mem + + if region.need_offload: + self.runtime_mem -= region.param_size + + assert len( + self.reg_buffer_to_free) <= 1, f'{len(self.reg_buffer_to_free)}' + self.reg_buffer_to_free.append(region.r_id) + + def _eval_bwd_cost_per_region(self, region: Region): + """ + Evaluate computation and communication execution period of the region in backward pass. + """ + + # upload parameters of the current region + if region.is_syn: + assert region.need_offload + self._insert_h2d_exec(region, is_fwd=False) + + # prefetch parameters of the region choiced, which is parallel to computation + if region.bwd_prefetch_region is not None: + self._insert_h2d_exec(region.bwd_prefetch_region, is_fwd=False) + + # execute computation + self._insert_comp_exec(region, is_fwd=False) + + # offload gradient + if requires_offload_g_in_bwd(region): + self._insert_d2h_exec(region) + + assert len(self.reg_buffer_to_free) == 0 + for reg_id, offl_exec in self.bwd_reg_to_offl_waiting.items(): + if offl_exec.end_time >= self.last_comp.start_time: + break + self.reg_buffer_to_free.append(reg_id) + self.bwd_reg_to_offl_freed[reg_id] = offl_exec + + for reg_id in self.reg_buffer_to_free: + self.bwd_reg_to_offl_waiting.pop(reg_id) + + def _eval_bwd_mem_per_region(self, region: Region): + """ + Evaluate the runtime and peak memory when the backward execution reaches the current region. + """ + + if region.r_id + 1 < self.region_num: + self.bwd_reg_flow[region.r_id] = self.bwd_reg_flow[region.r_id + 1] + else: + self.bwd_reg_flow[region.r_id] = self.fwd_reg_flow[-1] + self.bwd_reg_flow[region.r_id, + self.reg_buffer_to_free] = False + + # free gradients in the buffer + while len(self.reg_buffer_to_free): + reg_id = self.reg_buffer_to_free.pop(0) + self.runtime_mem -= self.region_list[reg_id].param_size + + # upload parameters of the current region + if region.is_syn: + self.runtime_mem += region.param_size + self.bwd_reg_flow[region.r_id, region.r_id] = True + + # prefetch parameters of the region choiced + bwd_prefetch_region = region.bwd_prefetch_region + if bwd_prefetch_region: + self.runtime_mem += bwd_prefetch_region.param_size + self.bwd_reg_flow[region.r_id, + bwd_prefetch_region.r_id] = True + + # add the gradient of the parameter + if region.r_id < region.shared_rid: + # gradient accumulation is required for shared parameters + self.runtime_mem += 2.0 * region.param_size + else: + self.runtime_mem += region.param_size + + for node in region.nodes.__reversed__(): + + self.runtime_mem -= calculate_fwd_out(node) + self.runtime_mem += node.meta['bwd_mem_tmp'] + \ + node.meta['bwd_mem_out'] + self.peak_mem = max(self.runtime_mem, self.peak_mem) + + # The memory savings of a node may be negative due to parameter prefetch. + self.total_mem_saving += node.node_info.runtime_bwd_mem - self.runtime_mem + + self.bwd_node_mem[node] = self.runtime_mem + + self.runtime_mem -= (node.meta['bwd_mem_tmp'] + + calculate_fwd_tmp(node)) + + # free bwd_mem_out + self.bwd_node_deps[node] = len(node.all_input_nodes) + for user_node in node.users: + if user_node in self.bwd_node_deps: + self.bwd_node_deps[user_node] -= 1 + if self.bwd_node_deps[user_node] <= 0: + self.runtime_mem -= user_node.meta['bwd_mem_out'] + + if self.runtime_mem < 0: + raise ValueError(f"region id: {region.r_id}, node name: {node.name}, " + f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---" + f"runtime memory computed less than 0, which is miscalculated!") + + # release parameters of the region + if requires_release_p_in_bwd(self.region_list[region.shared_rid]): + self.runtime_mem -= region.param_size diff --git a/colossalai/auto_parallel/offload/util.py b/colossalai/auto_parallel/offload/util.py new file mode 100644 index 0000000000000000000000000000000000000000..6b010512cc9c99b6dff4acf2e7bd97d12d8146c0 --- /dev/null +++ b/colossalai/auto_parallel/offload/util.py @@ -0,0 +1,99 @@ +from dataclasses import dataclass +from typing import List + +import torch + +from colossalai.context.singleton_meta import SingletonMeta +from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp + +from .region import Region + + +@dataclass +class NodeInfo: + node_id: int = 0 + runtime_fwd_mem: float = 0 + runtime_bwd_mem: float = 0 + + +class NvDevicePower: + """ + NVIDIA GPU computing performance (TFLOPs). + """ + + RTX3080_FP16 = 70 + RTX3080_FP32 = 34.1 + + RTX3090_FP16 = 71 + RTX3090_FP32 = 35.7 + + V100_FP16 = 31.4 + V100_FP32 = 15.7 + + A100_FP16 = 78 + A100_FP32 = 19.5 + + +class GlobalRuntimeInfo(metaclass=SingletonMeta): + + def __init__(self): + self.h2d_stream = torch.cuda.Stream() + self.d2h_stream = torch.cuda.Stream() + self.fwd_prefetch_event_map = {} + self.bwd_prefetch_event_map = {} + self.region_list = [] + + +def compute_act_peak_mem(region_list: List[Region]) -> float: + act_peak_mem = 0 + runtime_mem = 0 + # forward + for region in region_list: + for node in region.nodes: + runtime_mem = runtime_mem + \ + calculate_fwd_tmp(node) + calculate_fwd_out(node) + act_peak_mem = max(runtime_mem, act_peak_mem) + # backward + bwd_deps = {} + for region in region_list.__reversed__(): + for node in region.nodes.__reversed__(): + runtime_mem -= calculate_fwd_out(node) + runtime_mem = runtime_mem + \ + node.meta['bwd_mem_tmp'] + node.meta['bwd_mem_out'] + + act_peak_mem = max(runtime_mem, act_peak_mem) + + runtime_mem = runtime_mem - \ + node.meta['bwd_mem_tmp'] - calculate_fwd_tmp(node) + + # free bwd_mem_out + bwd_deps[node] = len(node.all_input_nodes) + for user_node in node.users: + if user_node in bwd_deps: + bwd_deps[user_node] -= 1 + if bwd_deps[user_node] <= 0: + runtime_mem -= user_node.meta['bwd_mem_out'] + + return act_peak_mem + + +def compute_max_param_mem(region_list: List[Region]) -> float: + return max(region.param_size for region in region_list) + + +def compute_total_param_mem(region_list: List[Region]) -> float: + return sum(region.param_size for region in region_list if region.r_id <= region.shared_rid) + + +def requires_upload_p_in_fwd(shared_reg: Region): + return (shared_reg.r_id >= shared_reg.shared_rid) or (shared_reg.r_id < shared_reg.shared_rid + and shared_reg.need_offload) + + +def requires_release_p_in_bwd(shared_reg: Region): + return (shared_reg.r_id >= shared_reg.shared_rid) or (shared_reg.r_id < shared_reg.shared_rid + and shared_reg.need_offload) + + +def requires_offload_g_in_bwd(region: Region): + return region.param_size and (region.r_id <= region.shared_rid) diff --git a/colossalai/auto_parallel/passes/__init__.py b/colossalai/auto_parallel/passes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/colossalai/auto_parallel/passes/comm_metainfo_pass.py b/colossalai/auto_parallel/passes/comm_metainfo_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..ffda58e0689f1b0e7fa962dd328eee7453e559ec --- /dev/null +++ b/colossalai/auto_parallel/passes/comm_metainfo_pass.py @@ -0,0 +1,113 @@ +from typing import Dict + +import torch +from torch.fx import GraphModule +from torch.fx.node import Node + +from colossalai.auto_parallel.meta_profiler import ShardMetaInfo +from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply, runtime_comm_spec_apply +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem +from colossalai.tensor.comm_spec import CommSpec +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec + +shape_consistency_manager = ShapeConsistencyManager() + + +def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec, + target_sharding_spec: ShardingSpec) -> ShardMetaInfo: + # get comm_action_sequence and total_cost from shape_consistency_manager + _, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency( + origin_sharding_spec, target_sharding_spec) + + meta_info = ShardMetaInfo() + # NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel + # get mem cost for ShardMetaInfo + mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence) + # extract user that has _meta_data and extract element length + input_node = next(n for n in node._input_nodes if hasattr(n, '_meta_data')) + element_length = input_node._meta_data.element_size() + + mem_cost.fwd.activation *= element_length + mem_cost.fwd.temp *= element_length + mem_cost.bwd.activation *= element_length + mem_cost.bwd.temp *= element_length + mem_cost.total.activation *= element_length + + meta_info.memory_cost = mem_cost + + # get computation cost for ShardMetaInfo + meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length, + total_cost['backward'] * element_length, + total_cost['total'] * element_length) + + # get tensor shape for ShardMetaInfo + origin_sharding_spec: ShardingSpec + target_sharding_spec: ShardingSpec + input_shape = origin_sharding_spec.get_sharded_shape_per_device() + output_shape = target_sharding_spec.get_sharded_shape_per_device() + + meta_info.fwd_in = [torch.rand(input_shape, device='meta')] + meta_info.fwd_buffer = [] + meta_info.fwd_out = [torch.rand(output_shape, device='meta')] + + return meta_info + + +def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) -> ShardMetaInfo: + """ + This method is used to construct `MetaInto` for shape consistency node + """ + + # extract node index and user node index + args = node.args + node_index, user_node_index = args[3], args[4] + origin_sharding_spec, target_sharding_spec = origin_spec_dict[node_index], sharding_spec_dict[node_index][ + user_node_index] + + return _construct_shard_meta_info(node, origin_sharding_spec, target_sharding_spec) + + +def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> ShardMetaInfo: + # extract node_index and op_data_name + node_index, op_data_name = node.args[2], node.args[3] + + comm_action = comm_actions_dict[node_index][op_data_name] + if isinstance(comm_action.comm_spec, CommSpec): + # this case is for all_reduce, there will be no memory cost + meta_info = ShardMetaInfo() + meta_info.memory_cost = TrainCycleItem(MemoryCost(), MemoryCost(), MemoryCost) + output_node = next(n for n in node.users if hasattr(n, '_meta_data')) + element_length = output_node._meta_data.element_size() + + total_cost = comm_action.comm_spec.get_comm_cost() + meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length, + total_cost['backward'] * element_length, + total_cost['total'] * element_length) + + input_shape = output_shape = comm_action.comm_spec.sharding_spec.get_sharded_shape_per_device() + meta_info.fwd_in = [torch.rand(input_shape, device='meta')] + meta_info.fwd_buffer = [] + meta_info.fwd_out = [torch.rand(output_shape, device='meta')] + else: + # this case will be handled by shape consistency manager + origin_sharding_spec, target_sharding_spec = comm_action.comm_spec['src_spec'], comm_action.comm_spec[ + 'tgt_spec'] + meta_info = _construct_shard_meta_info(node, origin_sharding_spec, target_sharding_spec) + + return meta_info + + +def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, origin_spec_dict: Dict, + comm_actions_dict: Dict) -> GraphModule: + """ + The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph. + """ + for node in gm.graph.nodes: + if node.target == runtime_apply: + setattr(node, 'best_strategy_info', _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict)) + elif node.target == runtime_comm_spec_apply: + setattr(node, 'best_strategy_info', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict)) + else: + pass + return gm diff --git a/colossalai/auto_parallel/passes/constants.py b/colossalai/auto_parallel/passes/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..485a87492f4ce5fca3f2401108e285449f931788 --- /dev/null +++ b/colossalai/auto_parallel/passes/constants.py @@ -0,0 +1,13 @@ +import torch + +OUTPUT_SAVED_OPS = [torch.nn.functional.relu, torch.nn.functional.softmax, torch.flatten] + +OUTPUT_SAVED_MOD = [ + torch.nn.ReLU, + torch.nn.Softmax, +] + +# SHAPE_ARGUMENT_OPS contains node with (input, *shape) style args. +# This list could be extended if any other method has the same +# argument style as view and reshape. +SHAPE_ARGUMENT_OPS = [torch.Tensor.view, torch.Tensor.reshape, torch.reshape] diff --git a/colossalai/auto_parallel/passes/meta_info_prop.py b/colossalai/auto_parallel/passes/meta_info_prop.py new file mode 100644 index 0000000000000000000000000000000000000000..bc0960483980b9242af531d4b309a9c74f735c16 --- /dev/null +++ b/colossalai/auto_parallel/passes/meta_info_prop.py @@ -0,0 +1,165 @@ +import uuid +from dataclasses import asdict +from typing import List + +import torch +import torch.fx +from torch.fx import GraphModule +from torch.fx.node import Node + +from colossalai.auto_parallel.meta_profiler import ShardMetaInfo +from colossalai.auto_parallel.passes.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS +from colossalai.fx._compatibility import compatibility +from colossalai.fx.profiler import GraphInfo + + +def _normalize_tuple(x): + if not isinstance(x, tuple): + return (x,) + return x + + +@compatibility(is_backward_compatible=False) +class MetaInfoProp: + + def __init__(self, module: GraphModule) -> None: + self.module = module + self.func_dict = { + 'placeholder': self.placeholder_handler, + 'get_attr': self.get_attr_handler, + 'output': self.output_handler, + 'call_function': self.node_handler, + 'call_module': self.node_handler, + 'call_method': self.node_handler, + } + + def _set_data_ptr(self, x): + """ + Set uuid to tensor + """ + if isinstance(x, torch.Tensor): + if not x.data_ptr(): + data_ptr = uuid.uuid4() + x.data_ptr = lambda: data_ptr + + def _is_inplace(self, node: Node): + """ + Check if the node is inplace operation. + """ + if node.op == 'call_module': + return node.graph.owning_module.get_submodule(node.target).__class__ in OUTPUT_SAVED_MOD + elif node.op == "call_function": + return node.target in OUTPUT_SAVED_OPS + return False + + def run(self) -> GraphModule: + """ + Run the meta information propagation pass on the module. + """ + for node in self.module.graph.nodes: + node: Node + self.func_dict[node.op](node) + + @compatibility(is_backward_compatible=False) + def placeholder_handler(self, node: Node) -> None: + """ + Handle the placeholder node. + """ + graph_info = GraphInfo() + out = _normalize_tuple(getattr(node, '_meta_data', None)) + graph_info.fwd_out = list(out) if out[0] is not None else [] + node.meta = {**asdict(graph_info)} + + @compatibility(is_backward_compatible=False) + def get_attr_handler(self, node: Node) -> None: + """ + Handle the get_attr node. + """ + graph_info = GraphInfo() + node.meta = {**asdict(graph_info)} + + @compatibility(is_backward_compatible=False) + def output_handler(self, node: Node) -> None: + """ + Handle the output node. + """ + graph_info = GraphInfo() + output_tensors = [] + for par in node._input_nodes: + if par.meta: + output_tensors += par.meta["fwd_out"] + graph_info.fwd_in = output_tensors + node.meta = {**asdict(graph_info)} + + @compatibility(is_backward_compatible=False) + def node_handler(self, node: Node) -> None: + """ + Handle other kind of nodes + """ + assert hasattr(node, 'best_strategy_info'), f"Cannot find best_strategy_info in node {node}, {node.op}" + graph_info = GraphInfo() + meta_info = node.best_strategy_info + meta_info: ShardMetaInfo + + # set data_ptr for input_tensor in ShardMetaInfo class + input_tensors: List[torch.Tensor] = meta_info.fwd_in + buffer_tensors: List[torch.Tensor] = meta_info.fwd_buffer + output_tensors: List[torch.Tensor] = meta_info.fwd_out + + if self._is_inplace(node): + # inplace operation will not create new tensor, and it only has one parent node + # TODO: Verify this observation + # set data_ptr for input_tensor, buffer_tensor and output_tensor of current node + parent_node = list(node._input_nodes.keys())[0] + parent_tensor = parent_node.meta.get("fwd_out")[0] + parent_tensor: torch.Tensor + for tensor in input_tensors: + tensor.data_ptr = parent_tensor.data_ptr + for tensor in buffer_tensors: + tensor.data_ptr = parent_tensor.data_ptr + for tensor in output_tensors: + tensor.data_ptr = parent_tensor.data_ptr + + else: + for par in node._input_nodes: + # set data_ptr for the input_tensor of current node from the output_tensor of its parent node + for tensor in par.meta.get("fwd_out", []): + tensor: torch.Tensor + target_input_tensor = next( + (x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None) + if target_input_tensor is not None: + target_input_tensor.data_ptr = tensor.data_ptr + + # set data_ptr for tensor in input_tensor that is not set + for tensor in input_tensors: + if not tensor.data_ptr(): + self._set_data_ptr(tensor) + + # set data_ptr for buffer_tensor + for tensor in buffer_tensors: + self._set_data_ptr(tensor) + + # set data_ptr for output_tensor + for tensor in output_tensors: + self._set_data_ptr(tensor) + + # attach them to graph_info + graph_info.fwd_in = input_tensors + graph_info.fwd_tmp = buffer_tensors + graph_info.fwd_out = output_tensors + + # fetch other memory informations + memory_cost = meta_info.memory_cost + graph_info.fwd_mem_tmp = memory_cost.fwd.temp + graph_info.fwd_mem_out = memory_cost.fwd.activation + graph_info.bwd_mem_tmp = memory_cost.bwd.temp + graph_info.bwd_mem_out = memory_cost.bwd.activation + + # fetch flop information + # here we use fwd_time and bwd_time to deal with the case that + # communication cost is a float + compute_cost = meta_info.compute_cost + graph_info.fwd_time = compute_cost.fwd + graph_info.bwd_time = compute_cost.bwd + + node.meta = {**asdict(graph_info)} diff --git a/colossalai/auto_parallel/passes/runtime_apply_pass.py b/colossalai/auto_parallel/passes/runtime_apply_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..a473bb6e973de453f47c7bc16c2e83b8e7fe86df --- /dev/null +++ b/colossalai/auto_parallel/passes/runtime_apply_pass.py @@ -0,0 +1,256 @@ +from copy import deepcopy +from typing import Dict, List + +import torch +from torch.fx.node import Node + +from colossalai._analyzer.fx.node_util import MetaInfo +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommAction, + CommType, + OperationData, + OperationDataType, + TrainCycleItem, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.comm_spec import CommSpec +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec + +shape_consistency_manager = ShapeConsistencyManager() + + +def runtime_apply(node: Node, origin_dict: Dict, input_dict: Dict, node_index: int, user_node_index: int): + """ + This method will be invoked during runtime to do the shape consistency, which make sure the activations is converted into + the user node expected form. + """ + origin_sharding_spec = origin_dict[node_index] + target_sharding_spec = input_dict[node_index][user_node_index] + return shape_consistency_manager.apply_for_autoparallel_runtime(node, origin_sharding_spec, target_sharding_spec) + + +def runtime_apply_for_iterable_object(node: Node, origin_dict: Dict, input_dict: Dict, node_index: int, + user_node_index: int): + """ + This method will be invoked during runtime to do the shape consistency, which makes sure the activations in type of tuple or list + is converted into the user node expected form. + """ + rst = [] + for index, (origin_sharding_spec, + target_sharding_spec) in enumerate(zip(origin_dict[node_index], + input_dict[node_index][user_node_index])): + rst.append( + shape_consistency_manager.apply_for_autoparallel_runtime(node[index], origin_sharding_spec, + target_sharding_spec)) + rst = type(node)(rst) + return rst + + +def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_index: int, op_data_name: str): + """ + This method will be invoked during runtime to apply the comm action following the instruction of comm spec. + """ + comm_action = comm_actions_dict[node_index][op_data_name] + if isinstance(comm_action.comm_spec, CommSpec): + rst = comm_action.comm_spec.covert_spec_to_action(tensor) + else: + origin_sharding_spec = comm_action.comm_spec['src_spec'] + tgt_sharding_spec = comm_action.comm_spec['tgt_spec'] + rst = shape_consistency_manager.apply_for_autoparallel_runtime(tensor, origin_sharding_spec, tgt_sharding_spec) + return rst + + +def _preprocess_graph(nodes: List[Node]): + """ + This method is used to extract all the placeholders with sharding information, + and mapping the nodes into the index of the origin graph. + """ + # mapping the node into the origin graph index + node_to_index_dict = {} + index = 0 + for node in nodes: + if node.target == 'sharding_spec_convert_dict': + input_dict_node = node + continue + if node.target == 'origin_node_sharding_spec_dict': + origin_dict_node = node + continue + if node.target == 'comm_actions_dict': + comm_actions_dict_node = node + continue + if not hasattr(node, 'best_strategy'): + continue + node_to_index_dict[node] = index + index += 1 + + return input_dict_node, origin_dict_node, comm_actions_dict_node, node_to_index_dict + + +def _shape_consistency_apply(gm: torch.fx.GraphModule): + """ + This pass is used to add the shape consistency node to the origin graph. + """ + mod_graph = gm.graph + nodes = tuple(mod_graph.nodes) + + input_dict_node, origin_dict_node, _, node_to_index_dict = _preprocess_graph(nodes) + + for node in nodes: + if not hasattr(node, 'best_strategy') or node.op == 'output': + continue + + for user_node_index, user_node in enumerate(node.strategies_vector.successor_nodes): + if isinstance(node.sharding_spec, (list, tuple)): + assert isinstance( + node.target_sharding_specs, + (list, + tuple)), 'target sharding specs should be tuple or list when node.sharding_spec is tuple or list' + total_difference = 0 + for sharding_spec, target_sharding_spec in zip(node.sharding_spec, + node.target_sharding_specs[user_node_index]): + total_difference += sharding_spec.sharding_sequence_difference(target_sharding_spec) + if total_difference == 0: + continue + with mod_graph.inserting_before(user_node): + shape_consistency_node = mod_graph.create_node('call_function', + runtime_apply_for_iterable_object, + args=(node, origin_dict_node, input_dict_node, + node_to_index_dict[node], user_node_index)) + + else: + assert isinstance(node.sharding_spec, + ShardingSpec), 'node.sharding_spec should be type of ShardingSpec, tuple or list.' + if node.sharding_spec.sharding_sequence_difference(node.target_sharding_specs[user_node_index]) == 0: + continue + with mod_graph.inserting_before(user_node): + shape_consistency_node = mod_graph.create_node('call_function', + runtime_apply, + args=(node, origin_dict_node, input_dict_node, + node_to_index_dict[node], user_node_index)) + if hasattr(user_node.meta['info'], 'activation_checkpoint'): + MetaInfo(shape_consistency_node, + mod_dir=user_node.meta['info'].mod_dir, + activation_checkpoint=tuple(user_node.meta['info'].activation_checkpoint)) + new_args = list(user_node.args) + new_kwargs = dict(user_node.kwargs) + # the origin node may be a positional argument or key word argument of user node + if node in new_args: + # substitute the origin node with shape_consistency_node + origin_index_args = new_args.index(node) + new_args[origin_index_args] = shape_consistency_node + user_node.args = tuple(new_args) + elif str(node) in new_kwargs: + # substitute the origin node with shape_consistency_node + new_kwargs[str(node)] = shape_consistency_node + user_node.kwargs = new_kwargs + + return gm + + +def _comm_spec_apply(gm: torch.fx.GraphModule): + """ + This pass is used to add the comm spec apply node to the origin graph. + """ + mod_graph = gm.graph + nodes = tuple(mod_graph.nodes) + + _, _, comm_actions_dict_node, node_to_index_dict = _preprocess_graph(nodes) + + for node in nodes: + if not hasattr(node, 'best_strategy') or node.op == 'output': + continue + + comm_actions = node.best_strategy.communication_actions + for op_data, comm_action in comm_actions.items(): + + if comm_action.comm_type == CommType.HOOK: + continue + if comm_action.comm_type == CommType.BEFORE: + if op_data.type == OperationDataType.OUTPUT: + comm_object = node + elif comm_action.key_for_kwarg is not None: + comm_object = node.kwargs[comm_action.key_for_kwarg] + else: + comm_object = node.args[comm_action.arg_index] + with mod_graph.inserting_before(node): + comm_spec_apply_node = mod_graph.create_node('call_function', + runtime_comm_spec_apply, + args=(comm_object, comm_actions_dict_node, + node_to_index_dict[node], op_data.name)) + # the origin node may be a positional argument or key word argument of user node + if comm_action.key_for_kwarg is not None: + # substitute the origin node with comm_spec_apply_node + new_kwargs = dict(node.kwargs) + new_kwargs[comm_action.key_for_kwarg] = comm_spec_apply_node + node.kwargs = new_kwargs + else: + # substitute the origin node with comm_spec_apply_node + new_args = list(node.args) + new_args[comm_action.arg_index] = comm_spec_apply_node + node.args = tuple(new_args) + + elif comm_action.comm_type == CommType.AFTER: + with mod_graph.inserting_after(node): + comm_spec_apply_node = mod_graph.create_node('call_function', + runtime_comm_spec_apply, + args=(node, comm_actions_dict_node, + node_to_index_dict[node], op_data.name)) + user_list = list(node.users.keys()) + for user in user_list: + if user == comm_spec_apply_node: + continue + new_args = list(user.args) + new_kwargs = dict(user.kwargs) + # the origin node may be a positional argument or key word argument of user node + if node in new_args: + # substitute the origin node with comm_spec_apply_node + new_args[new_args.index(node)] = comm_spec_apply_node + user.args = tuple(new_args) + elif str(node) in new_kwargs: + # substitute the origin node with comm_spec_apply_node + new_kwargs[str(node)] = comm_spec_apply_node + user.kwargs = new_kwargs + if hasattr(node.meta['info'], 'activation_checkpoint'): + MetaInfo(comm_spec_apply_node, + mod_dir=node.meta['info'].mod_dir, + activation_checkpoint=tuple(node.meta['info'].activation_checkpoint)) + + return gm + + +def _act_annotataion_pass(gm: torch.fx.GraphModule): + """ + This pass is used to add the act annotation to the new inserted nodes. + """ + mod_graph = gm.graph + nodes = tuple(mod_graph.nodes) + + for node in nodes: + if not hasattr(node.meta, 'activation_checkpoint'): + from .runtime_preparation_pass import size_processing + + user_act_annotation = -1 + input_act_annotation = -1 + for user_node in node.users.keys(): + if 'activation_checkpoint' in user_node.meta: + user_act_annotation = user_node.meta['activation_checkpoint'] + break + for input_node in node._input_nodes.keys(): + if 'activation_checkpoint' in input_node.meta: + input_act_annotation = input_node.meta['activation_checkpoint'] + break + if user_act_annotation == input_act_annotation and user_act_annotation != -1: + node.meta['activation_checkpoint'] = user_act_annotation + + return gm + + +def runtime_apply_pass(gm: torch.fx.GraphModule): + """ + The method manages all the passes acting on the distributed training runtime. + """ + gm = _shape_consistency_apply(gm) + gm = _comm_spec_apply(gm) + + return gm diff --git a/colossalai/auto_parallel/passes/runtime_preparation_pass.py b/colossalai/auto_parallel/passes/runtime_preparation_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..08af846b221db60b3950a0cf285238f616b17711 --- /dev/null +++ b/colossalai/auto_parallel/passes/runtime_preparation_pass.py @@ -0,0 +1,507 @@ +import operator +from copy import deepcopy +from typing import Dict, List, Union + +import torch +from torch.fx import symbolic_trace +from torch.fx.node import Node + +from colossalai._analyzer.fx.node_util import MetaInfo +from colossalai.auto_parallel.tensor_shard.constants import RESHAPE_FUNC_OP +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommAction, + CommType, + OperationDataType, + ShardingStrategy, +) +from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.comm_spec import _all_reduce +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec + +from .constants import SHAPE_ARGUMENT_OPS + +shape_consistency_manager = ShapeConsistencyManager() + + +def size_processing(size: Union[int, torch.Size], + dim_partition_dict: Dict[int, List[int]], + device_mesh_info: Dict[int, int], + target_dim: int = None, + node_name: str = None): + """ + This method will be invoked during runtime to convert size node value depending on distributed information. + """ + if target_dim is not None: + assert isinstance(size, int) + if target_dim in dim_partition_dict: + total_shard_size = 1 + for shard_dim in dim_partition_dict[target_dim]: + total_shard_size *= device_mesh_info[shard_dim] + size = size * total_shard_size + + else: + size = list(size) + for dim, dim_size in enumerate(size): + if dim in dim_partition_dict: + total_shard_size = 1 + for shard_dim in dim_partition_dict[dim]: + total_shard_size *= device_mesh_info[shard_dim] + size[dim] = dim_size * total_shard_size + size = torch.Size(size) + + return size + + +def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], + strategies_constructor: StrategiesConstructor): + """ + This method is used to stick the solution strategy to the nodes and add the information + required in runtime into graph as placeholder nodes. + """ + mod_graph = gm.graph + + nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] + no_strategy_nodes = strategies_constructor.no_strategy_nodes + + # the dict to get origin sharding spec of node + origin_node_sharding_spec_dict = {} + for node_index, (node, strategy_index) in enumerate(zip(nodes, solution)): + strategies_vector = node.strategies_vector + # stick the solution strategy to the corresponding node + setattr(node, 'best_strategy', strategies_vector[strategy_index]) + setattr(node, 'sharding_spec', strategies_vector[strategy_index].get_sharding_spec_by_name(str(node))) + origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].get_sharding_spec_by_name( + str(node)) + + # attach the corresponding metainfo if node has the attribute `strategies_info` + if hasattr(node, 'strategies_info'): + setattr(node, 'best_strategy_info', node.strategies_info[strategy_index]) + + # the dict to get input sharding specs of user node + sharding_spec_convert_dict = {} + # the dict to record comm actions of nodes + comm_actions_dict = {} + for index, node in enumerate(nodes): + target_sharding_specs = [] + for user_node in node.strategies_vector.successor_nodes: + if user_node in no_strategy_nodes: + target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(str(node.name)) + else: + target_sharding_spec = user_node.best_strategy.get_sharding_spec_by_name(str(node.name)) + target_sharding_specs.append(target_sharding_spec) + sharding_spec_convert_dict[index] = target_sharding_specs + setattr(node, 'target_sharding_specs', target_sharding_specs) + + # the get_attr node strategy is kind of pending strategy, which means we will change it + # to the same strategy of the user node. + if node.op == 'get_attr': + assert len(target_sharding_specs) == 1, f'sharing weight is not supported in current version.' + target_node = node.strategies_vector.successor_nodes[0] + node_name = str(node) + if target_node.op == 'call_function' and target_node.target in RESHAPE_FUNC_OP: + node_name = str(target_node) + target_node = target_node.strategies_vector.successor_nodes[0] + user_strategy = target_node.best_strategy + op_data_in_user = user_strategy.get_op_data_by_name(node_name) + origin_pending_strategy = node.best_strategy + origin_op_data = origin_pending_strategy.get_op_data_by_name(str(node)) + + new_communication_actions = {} + if op_data_in_user in user_strategy.communication_actions: + new_communication_action = user_strategy.communication_actions.pop(op_data_in_user) + new_communication_action.arg_index = 0 + new_communication_actions[origin_op_data] = new_communication_action + node.best_strategy.communication_actions = new_communication_actions + + comm_action_dict = {} + for op_data, comm_action in node.best_strategy.communication_actions.items(): + comm_action_dict[op_data.name] = comm_action + comm_actions_dict[index] = comm_action_dict + + # add above dicts into graph + for node in nodes: + if node.op != 'placeholder': + with mod_graph.inserting_before(node): + input_specs_node = mod_graph.create_node('placeholder', target='sharding_spec_convert_dict') + origin_specs_node = mod_graph.create_node('placeholder', target='origin_node_sharding_spec_dict') + comm_actions_dict_node = mod_graph.create_node('placeholder', target='comm_actions_dict') + break + return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict + + +def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): + """ + In the auto parallel system, tensors may get shard on different devices, so the size of tensors + need to be converted to the size of original tensor and managed by the users, such as torch.view, + torch.reshape, etc. These nodes have enough information like input sharding_spec and + output sharding_spec to decide how to convert the size value. + """ + mod_graph = gm.graph + nodes = tuple(mod_graph.nodes) + node_pairs = {} + + # DeviceMesh information instructs the scaling of the size value + device_mesh_info = {} + for dim, dim_size in enumerate(device_mesh.mesh_shape): + device_mesh_info[dim] = dim_size + + def _extract_target_dim(node): + ''' + A helper function to extract the target dimension from size node. + There are two usages of torch.Tensor.size: + 1. tensor.size() + 2. tensor.size(dim) + + If a target_dim is assigned, then the output will be in type of int, instead of torch.Size. + Otherwise, the output will be in type of torch.Size and this function will return None. + ''' + target_dim = None + if len(node.args) > 1: + target_dim = node.args[1] + if target_dim < 0: + target_dim += node.args[0]._meta_data.dim() + return target_dim + + def _post_processing(node, size_processing_node): + ''' + This function is used to process the dependency between the size node and its users after + inserting the size_process_node. + ''' + # store original node and processing node pair in node_pairs dictioanry + # It will be used to replace the original node with processing node in slice object + node_pairs[node] = size_processing_node + size_processing_node._meta_data = node._meta_data + + if hasattr(node.meta['info'], 'activation_checkpoint'): + MetaInfo(size_processing_node, + mod_dir=node.meta['info'].mod_dir, + activation_checkpoint=tuple(node.meta['info'].activation_checkpoint)) + + user_list = list(node.users.keys()) + for user in user_list: + if user == size_processing_node: + continue + new_args = list(user.args) + new_kwargs = dict(user.kwargs) + # the origin node may be a positional argument or key word argument of user node + if node in new_args: + # substitute the origin node with size_processing_node + new_args[new_args.index(node)] = size_processing_node + user.args = tuple(new_args) + elif str(node) in new_kwargs: + # substitute the origin node with size_processing_node + new_kwargs[str(node)] = size_processing_node + user.kwargs = new_kwargs + + def _update_slice_object_args(slice_object): + ''' + This function is used to update the slice object argument list. + If the slice object contains the Node argument, then the size node will be replaced with + ''' + if isinstance(slice_object, slice): + start = slice_object.start + stop = slice_object.stop + step = slice_object.step + if start in node_pairs: + start = node_pairs[start] + if stop in node_pairs: + stop = node_pairs[stop] + if step in node_pairs: + step = node_pairs[step] + return slice(start, stop, step) + elif isinstance(slice_object, int): + if slice_object in node_pairs: + return node_pairs[slice_object] + else: + return slice_object + else: + raise RuntimeError(f"Unsupported slice object type: {type(slice_object)}") + + for node in nodes: + + if node.op == 'call_method' and node.target == 'size': + # extract useful information from size node + # dim_partition_dict will instruct the size value on which + # dimension should be enlarged. + sharding_spec = node.args[0].sharding_spec + dim_partition_dict = sharding_spec.dim_partition_dict + + target_dim = _extract_target_dim(node) + + # insert size_processing node + with mod_graph.inserting_after(node): + size_processing_node = mod_graph.create_node('call_function', + size_processing, + args=(node, dim_partition_dict, device_mesh_info, + target_dim, node.name)) + _post_processing(node, size_processing_node) + + if node.op == 'call_function' and node.target == operator.getitem: + + getitem_index = node.args[1] + # slice object is quite special in torch.fx graph, + # On one side, we treat slice object same as type of int, + # so we do not create a node for slice object. On the other side, + # slice object could take fx.Node as its argument. And the user + # relationship cannot be tracked in fx graph. + # Therefore, I record the node_pairs in this pass, and use the it + # to replace the original node argument inside the slice object if + # it has been processed in above pass. + + # There are three main usages of operator.getitem: + # getitem(input, int) + # getitem(input, slice) + # getitem(input, Tuple[slice]) + # In this pass, we need process the last two cases because + # node arguments may potentially appear in these cases. + if isinstance(getitem_index, slice): + new_slice_item = _update_slice_object_args(getitem_index) + new_args = (node.args[0], new_slice_item) + node.args = new_args + + elif isinstance(getitem_index, (tuple, list)): + if not isinstance(getitem_index[0], slice): + continue + new_slice_items = [] + + for slice_item in getitem_index: + if slice_item is None: + new_slice_items.append(None) + continue + new_slice_item = _update_slice_object_args(slice_item) + new_slice_items.append(new_slice_item) + + new_args = (node.args[0], tuple(new_slice_items)) + node.args = new_args + + return gm + + +def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): + """ + This pass will process node args to adapt the distributed tensor layout. + """ + mod_graph = gm.graph + nodes = tuple(mod_graph.nodes) + + def _extract_info_from_sharding_spec(sharding_spec): + ''' + This function is used to extract the dim_partition_dict and device_mesh from + sharding spec instance or a list of sharding spec. + ''' + if isinstance(sharding_spec, ShardingSpec): + dim_partition_dict = sharding_spec.dim_partition_dict + device_mesh = sharding_spec.device_mesh + return dim_partition_dict, device_mesh + if sharding_spec is None: + return None, None + assert isinstance(sharding_spec, + (tuple, list)), 'sharding_spec should be type of ShardingSpec, tuple, list or None' + + device_mesh = sharding_spec[0].device_mesh + dim_partition_dict = [] + for element in sharding_spec: + dim_partition_dict.append(_extract_info_from_sharding_spec(element)) + return dim_partition_dict, sharding_spec + + def _process_node_arguments(node): + new_args = [] + for arg in node.args: + # There are two args style: + # 1. (input, *shape) + # 2. (input, shape) + # We will extract the elements from shape and add them into the new_args + # Finally, the args style of new_args will be unified to (input, *shape) + if isinstance(arg, Node): + if isinstance(arg._meta_data, (tuple, list)): + new_args.extend(arg._meta_data) + elif isinstance(arg._meta_data, int): + new_args.append(arg._meta_data) + else: + new_args.append(arg) + else: + assert isinstance(arg, + (int, tuple, list)), 'The argument in view node should be either type of Node or int.' + if isinstance(arg, (tuple, list)): + new_args.extend(arg) + else: + new_args.append(arg) + return new_args + + def _scale_args_adapt_sharding_spec(dim_partition_dict, device_mesh, node): + new_args = _process_node_arguments(node) + if node.op == 'call_method': + args_to_process = list(new_args[1:]) + else: + args_to_process = list(new_args) + for dim, shard_dims in dim_partition_dict.items(): + total_shard_size = 1 + for shard_dim in shard_dims: + total_shard_size *= device_mesh.shape[shard_dim] + + # we will skip the dim with -1 value + if args_to_process[dim] == -1: + continue + else: + # TODO: add assertion here to make sure the dim size is divisible by total_shard_size + args_to_process[dim] //= total_shard_size + + args_to_process = tuple(args_to_process) + + if node.op == 'call_method': + new_args = (new_args[0],) + args_to_process + else: + new_args = args_to_process + + node.args = new_args + + def _filter_node_with_shape_args(node): + if node.op == 'call_method': + target = getattr(node.args[0]._meta_data.__class__, node.target) + elif node.op == 'call_function': + target = node.target + else: + target = None + + if target in SHAPE_ARGUMENT_OPS: + return True + return False + + for node in nodes: + # skip the placeholder node added in _solution_annotation pass + if not hasattr(node, 'sharding_spec'): + continue + + output_dim_partition_dict, device_mesh = _extract_info_from_sharding_spec(node.sharding_spec) + if _filter_node_with_shape_args(node): + _scale_args_adapt_sharding_spec(output_dim_partition_dict, device_mesh, node) + + return gm + + +def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, overlap=False): + """ + Apply the sharding action to the module parameters and buffers following the + instructions of solver solution. + """ + mod_graph = gm.graph + nodes = tuple(mod_graph.nodes) + # This stream is created for overlaping the communication and computation. + reduction_stream = torch.cuda.Stream() + + def _add_hook_for_grad_communication(node, param, name=None): + + comm_actions = node.best_strategy.communication_actions + + def _filter_param_to_hook(node, op_data, comm_action, name): + + if node.op == 'call_module' and op_data.type == OperationDataType.PARAM and op_data.name == name and comm_action.comm_type == CommType.HOOK: + return True + if node.op == 'get_attr' and isinstance( + node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK: + return True + return False + + for operation_data, comm_action in comm_actions.items(): + comm_spec_to_use = comm_action.comm_spec + # register hook to the parameters + if _filter_param_to_hook(node, operation_data, comm_action, name=name): + + def wrapper(param, comm_spec, stream, overlap): + + def hook_fn(grad): + if overlap: + with torch.cuda.stream(stream): + _all_reduce(grad, comm_spec, async_op=True) + else: + _all_reduce(grad, comm_spec, async_op=False) + + param.register_hook(hook_fn) + + wrapper(param, comm_spec_to_use, reduction_stream, overlap=overlap) + + def _shard_param(param, target_sharding_spec): + # apply the sharding spec of parameters + if target_sharding_spec.dim_partition_dict != {}: + origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {}) + setattr(param, 'sharding_spec', origin_sharding_spec) + # TODO: build a ColoParameter class to manager the distributed parameters + # we could use .data here, because all the operations just happen before the real training + # loop, so we don't need to track these operations in the autograd graph. + param = torch.nn.Parameter( + shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec, + target_sharding_spec).detach().clone()) + return param + + for node in nodes: + if node.op == 'call_module': + target_module = node.graph.owning_module.get_submodule(node.target) + # TODO: we need to do more actions to take care of the shared parameters. + if hasattr(target_module, 'processed') and target_module.processed: + continue + setattr(target_module, 'processed', True) + for name, param in target_module.named_parameters(): + target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name) + param = _shard_param(param, target_sharding_spec) + + setattr(target_module, name, param) + _add_hook_for_grad_communication(node, param, name) + + sharded_buffer_dict = {} + # apply the sharding spec of buffers + for name, buffer in target_module.named_buffers(): + origin_sharding_spec = ShardingSpec(device_mesh, buffer.shape, {}) + setattr(buffer, 'sharding_spec', origin_sharding_spec) + target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name) + buffer_sharded = shape_consistency_manager.apply(buffer, target_sharding_spec) + sharded_buffer_dict[name] = buffer_sharded + + for name, buffer_sharded in sharded_buffer_dict.items(): + setattr(target_module, name, buffer_sharded.detach().clone()) + + if node.op == 'get_attr': + root = node.graph.owning_module + atoms = node.target.split(".") + attr_len = len(atoms) + if attr_len == 1: + target_module = root + target = getattr(root, atoms[0]) + else: + target_module = root + for atom in atoms[:-1]: + target_module = getattr(target_module, atom) + target = getattr(target_module, atoms[-1]) + + target_sharding_spec = node.sharding_spec + target = _shard_param(target, target_sharding_spec) + + assert hasattr(target_module, atoms[-1]) + setattr(target_module, atoms[-1], target) + _add_hook_for_grad_communication(node, target) + + return gm + + +def implicit_comm_action_apply(gm: torch.fx.GraphModule): + """ + replace the origin kernel into kernel with implicit communication inside. + """ + pass + + +def runtime_preparation_pass(gm: torch.fx.GraphModule, + solution: List[int], + device_mesh: DeviceMesh, + strategies_constructor: StrategiesConstructor, + overlap=False): + gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = solution_annotatation_pass( + gm, solution, strategies_constructor) + gm = size_value_converting_pass(gm, device_mesh) + gm = node_args_converting_pass(gm, device_mesh) + # TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed. + # gm = implicit_comm_action_apply(gm) + gm = module_params_sharding_pass(gm, device_mesh, overlap=overlap) + + return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict diff --git a/colossalai/auto_parallel/pipeline_shard/__init__.py b/colossalai/auto_parallel/pipeline_shard/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/colossalai/auto_parallel/tensor_shard/__init__.py b/colossalai/auto_parallel/tensor_shard/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/colossalai/auto_parallel/tensor_shard/constants.py b/colossalai/auto_parallel/tensor_shard/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..99c1249340602daee1a1314f102bc600eae6667d --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/constants.py @@ -0,0 +1,91 @@ +import operator + +import torch + +__all__ = [ + 'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'RESHAPE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP', + 'LINEAR_MODULE_OP', 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP', 'NON_PARAM_FUNC_OP', 'BCAST_FUNC_OP', + 'EMBEDDING_MODULE_OP', 'LAYERNORM_MODULE_OP', 'ELEMENTWISE_METHOD_OP', 'RESHAPE_METHOD_OP', 'INFINITY_COST' +] + +ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU] +ELEMENTWISE_FUNC_OP = [ + torch.abs, + torch.cos, + torch.exp, + operator.neg, + torch.multiply, + torch.nn.functional.relu, + torch.nn.functional.dropout, + # softmax should not be here + torch.nn.functional.softmax +] +ELEMENTWISE_METHOD_OP = [ + torch.Tensor.to, + torch.Tensor.type, + # TODO: contiguous maybe need some extra processes. + torch.Tensor.contiguous +] +RESHAPE_FUNC_OP = [ + torch.flatten, + torch.reshape, + torch.transpose, + torch.split, + torch.permute, + operator.getitem, +] +RESHAPE_METHOD_OP = [ + torch.Tensor.view, + torch.Tensor.unsqueeze, + torch.Tensor.split, + torch.Tensor.permute, + torch.Tensor.transpose, +] +BCAST_FUNC_OP = [ + torch.add, torch.sub, torch.mul, torch.div, torch.floor_divide, torch.true_divide, operator.add, operator.sub, + operator.mul, operator.floordiv, operator.truediv, torch.matmul, operator.pow, torch.pow +] +CONV_MODULE_OP = [ + torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, + torch.nn.ConvTranspose3d +] +CONV_FUNC_OP = [ + torch.conv1d, torch.conv2d, torch.conv3d, torch.conv_transpose1d, torch.conv_transpose2d, torch.conv_transpose3d +] +EMBEDDING_MODULE_OP = [torch.nn.modules.sparse.Embedding] +LINEAR_MODULE_OP = [torch.nn.Linear] +LINEAR_FUNC_OP = [torch.nn.functional.linear, torch.matmul, torch.bmm] +BATCHNORM_MODULE_OP = [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.SyncBatchNorm] +LAYERNORM_MODULE_OP = [torch.nn.LayerNorm] +POOL_MODULE_OP = [torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d, torch.nn.AdaptiveAvgPool2d] +NON_PARAM_FUNC_OP = [ + torch.flatten, + torch.reshape, + torch.abs, + torch.cos, + torch.exp, + operator.neg, + torch.multiply, + torch.nn.functional.relu, + torch.nn.functional.dropout, + torch.flatten, + torch.where, + operator.pow, + torch.pow, + torch.tanh, + torch.add, + torch.sub, + torch.mul, + torch.div, + torch.floor_divide, + torch.true_divide, + operator.add, + operator.sub, + operator.mul, + operator.floordiv, + operator.truediv, + # softmax should not be here + torch.nn.functional.softmax +] + +INFINITY_COST = 1e13 diff --git a/colossalai/auto_parallel/tensor_shard/initialize.py b/colossalai/auto_parallel/tensor_shard/initialize.py new file mode 100644 index 0000000000000000000000000000000000000000..b406ca6fb7e0fd28a9a6d3e98365b093f73f7171 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/initialize.py @@ -0,0 +1,356 @@ +from typing import Dict, List, Tuple + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.fx import GraphModule +from torch.fx.graph import Graph + +from colossalai._analyzer.fx.codegen import ActivationCheckpointCodeGen +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes import shape_prop_pass +from colossalai._analyzer.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass +from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass +from colossalai.auto_parallel.tensor_shard.options import DataloaderOption, ShardOption, SolverOptions, SolverPerference +from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction +from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor +from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.sharding_spec import ShardingSpec + + +class ModuleWrapper(nn.Module): + ''' + This class is used to wrap the original module, and add the sharding_spec_dict, origin_spec_dict, comm_actions_dict + into the forward function. + ''' + + def __init__(self, module: ColoGraphModule, sharding_spec_dict: Dict[int, List[ShardingSpec]], + origin_spec_dict: Dict[int, ShardingSpec], comm_actions_dict: Dict[int, Dict[str, CommAction]]): + ''' + Args: + module: the original module + sharding_spec_dict: The sharding_spec_dict is used to record the target sharding specs of each tensor required in user node. + origin_spec_dict: The origin_spec_dict is used to record the original sharding spec of each tensor. + comm_actions_dict: The comm_actions_dict is used to record the communication actions of each tensor. + ''' + super(ModuleWrapper, self).__init__() + self.module = module + self.sharding_spec_dict = sharding_spec_dict + self.origin_spec_dict = origin_spec_dict + self.comm_actions_dict = comm_actions_dict + + def forward(self, *args, **kwargs): + return self.module(*args, + sharding_spec_convert_dict=self.sharding_spec_dict, + origin_node_sharding_spec_dict=self.origin_spec_dict, + comm_actions_dict=self.comm_actions_dict, + **kwargs) + + +def extract_meta_args_from_dataloader(data_loader: torch.utils.data.DataLoader, data_process_func: callable): + ''' + This method is used to extract the meta_args from the dataloader under the instruction of the data_process_func. + ''' + # TODO: implement this function + pass + + +def extract_alpha_beta_for_device_mesh(alpha_beta_dict: Dict[Tuple[int], Tuple[float]], logical_mesh_shape: Tuple[int]): + ''' + This method is used to extract the mesh_alpha and mesh_beta for the given logical_mesh_shape + from the alpha_beta_dict. These two values will be used to estimate the communication cost. + ''' + # TODO: implement this function + pass + + +def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh, solver_preference: str, dataloader_option: str, + shard_option: str): + ''' + This method is used to build the strategy_constructor for the given graph. + After this method, each node in the graph will have a strategies_vector which + is constructed by the related node handler. + ''' + if solver_preference == 'standard': + solver_preference = SolverPerference.STANDARD + elif solver_preference == 'tp': + solver_preference = SolverPerference.TP + elif solver_preference == 'dp': + solver_preference = SolverPerference.DP + else: + raise ValueError(f'Invalid solver_preference: {solver_preference}') + + if dataloader_option == 'replicated': + dataloader_option = DataloaderOption.REPLICATED + elif dataloader_option == 'distributed': + dataloader_option = DataloaderOption.DISTRIBUTED + else: + raise ValueError(f'Invalid dataloader_option: {dataloader_option}') + + if shard_option == 'standard': + shard_option = ShardOption.STANDARD + elif shard_option == 'shard': + shard_option = ShardOption.SHARD + elif shard_option == 'shard_last_axis': + shard_option = ShardOption.SHARD_LAST_AXIS + elif shard_option == 'full_shard': + shard_option = ShardOption.FULL_SHARD + else: + raise ValueError(f'Invalid shard_option: {shard_option}') + + solver_options = SolverOptions(solver_perference=solver_preference, + dataloader_option=dataloader_option, + shard_option=shard_option) + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + strategies_constructor.build_strategies_and_cost() + + return strategies_constructor + + +def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstructor, memory_budget: float = -1.0): + ''' + This method is used to solve the best solution for the given graph. + The solution is a list of integers, each integer represents the best strategy index of the corresponding node. + ''' + # temporarily we use all nodes as liveness list, we count the backward memory cost together with + # forward memory cost into the node memory cost, and no activation checkpoint is used in this phase. + # graph_analyser = GraphAnalyser(gm) + # liveness_list = graph_analyser.liveness_analysis() + cost_graph = CostGraph(strategy_constructor.leaf_strategies) + cost_graph.simplify_graph() + solver = Solver(gm.graph, strategy_constructor, cost_graph, memory_budget=memory_budget) + ret = solver.call_solver_serialized_args() + solution = list(ret[0]) + + return solution + + +def transform_to_sharded_model(gm: ColoGraphModule, + meta_args: Dict, + solution: List[int], + device_mesh: DeviceMesh, + strategies_constructor: StrategiesConstructor, + overlap: bool = False): + ''' + This method is used to transform the original graph to the sharded graph. + The model parameters will be sharded according to the solution and the grad hooks + will be added to the sharded graph using the runtime_preparation_pass. + The communication node will be added into the graph using the runtime_apply_pass. + ''' + gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, + solution, + device_mesh, + strategies_constructor, + overlap=overlap) + gm = runtime_apply_pass(gm) + shape_prop_pass(gm, *meta_args.values(), sharding_spec_dict, origin_spec_dict, comm_actions_dict) + gm.recompile() + sharding_spec_dicts = (sharding_spec_dict, origin_spec_dict, comm_actions_dict) + + return gm, sharding_spec_dicts + + +def initialize_device_mesh(world_size: int = -1, + physical_devices: List[int] = None, + alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None, + logical_mesh_shape: Tuple[int] = None, + logical_mesh_id: torch.Tensor = None): + ''' + This method is used to initialize the device mesh. + + Args: + world_size: the size of device mesh. If the world_size is -1, + the world size will be set to the number of GPUs in the current machine. + physical_devices: the physical devices used to initialize the device mesh. + alpha_beta_dict(optional): the alpha_beta_dict contains the alpha and beta values + for each devices. if the alpha_beta_dict is None, the alpha_beta_dict will be + generated by profile_alpha_beta function. + logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical + mesh shape. + logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id. + ''' + # if world_size is not set, use the world size from torch.distributed + if world_size == -1: + world_size = dist.get_world_size() + + if physical_devices is None: + physical_devices = [i for i in range(world_size)] + physical_mesh = torch.tensor(physical_devices) + + if alpha_beta_dict is None: + # if alpha_beta_dict is not given, use a series of executions to profile alpha and beta values for each device + ab_profiler = AlphaBetaProfiler(physical_devices) + alpha_beta_dict = ab_profiler.alpha_beta_dict + else: + ab_profiler = AlphaBetaProfiler(physical_devices, alpha_beta_dict=alpha_beta_dict) + + if logical_mesh_shape is None and logical_mesh_id is None: + # search for the best logical mesh shape + logical_mesh_id = ab_profiler.search_best_logical_mesh() + logical_mesh_id = torch.Tensor(logical_mesh_id).to(torch.int) + logical_mesh_shape = logical_mesh_id.shape + + # extract alpha and beta values for the chosen logical mesh shape + mesh_alpha, mesh_beta = ab_profiler.extract_alpha_beta_for_device_mesh() + + elif logical_mesh_shape is not None and logical_mesh_id is None: + logical_mesh_id = physical_mesh.reshape(logical_mesh_shape) + + # extract alpha and beta values for the chosen logical mesh shape + mesh_alpha, mesh_beta = extract_alpha_beta_for_device_mesh(alpha_beta_dict, logical_mesh_id) + + device_mesh = DeviceMesh(physical_mesh_id=physical_mesh, + logical_mesh_id=logical_mesh_id, + mesh_alpha=mesh_alpha, + mesh_beta=mesh_beta, + init_process_group=True) + return device_mesh + + +def initialize_model(model: nn.Module, + meta_args: Dict[str, torch.Tensor], + device_mesh: DeviceMesh, + memory_budget: float = -1.0, + overlap: bool = False, + solver_preference: str = 'standard', + dataloader_option: str = 'replicated', + shard_option: str = 'standard', + save_solver_solution: bool = False, + load_solver_solution: bool = False, + solution_path: str = None, + return_solution: bool = False): + ''' + This method is used to initialize the sharded model which could be used as normal pytorch model. + + Args: + model: the model to be sharded. + meta_args: the meta_args is used to specify the input shapes of the model. + device_mesh: the device mesh to execute the model. + memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0, + the memory budget will be infinity. + overlap(optional): the overlap is used to specify whether to overlap gradient communication and + backward computing. + solver_preference(optional): the solver_preference is used to specify which parallelism algorithm + has higher priority. The valid solver_preference could be 'standard', 'tp', or 'dp'. + dataloader_option(optional): the dataloader_option is used to specify which kind of data_loader will + be used. The valid dataloader_option could be 'replicated' or 'distributed'. + shard_option(optional): the shard_option is used to specify how many axes will be used to shard the + model. The valid shard_option could be 'standard', 'shard', 'shard_last_axis', or 'full_shard'. + save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved + to the solution_path. + load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded + from the solution_path. + solution_path(optional): the path to save or load the solution. + return_solution(optional): if the return_solution is True, the solution will be returned. The returned + solution will be used to debug or help to analyze the sharding result. Therefore, we will not just + return a series of integers, but return the best strategies. + ''' + tracer = ColoTracer(trace_act_ckpt=True, bias_addition_split=True) + + graph = tracer.trace(root=model, meta_args=meta_args) + graph.set_codegen(ActivationCheckpointCodeGen()) + gm = ColoGraphModule(model, graph, model.__class__.__name__) + + shape_prop_pass(gm, *meta_args.values()) + gm.recompile() + + strategies_constructor = build_strategy_constructor(graph, + device_mesh, + solver_preference=solver_preference, + dataloader_option=dataloader_option, + shard_option=shard_option) + if load_solver_solution: + solution = torch.load(solution_path) + else: + solution = solve_solution(gm, strategies_constructor, memory_budget) + if save_solver_solution: + torch.save(solution, solution_path) + + gm, sharding_spec_dicts = transform_to_sharded_model(gm, meta_args, solution, device_mesh, strategies_constructor, + overlap) + + model_to_return = ModuleWrapper(gm, *sharding_spec_dicts) + + if return_solution: + solution_to_return = [] + nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] + for index, node in enumerate(nodes): + solution_to_return.append(f'{node.name} {node.strategies_vector[solution[index]].name}') + return model_to_return, solution_to_return + else: + return model_to_return + + +def autoparallelize(model: nn.Module, + meta_args: Dict[str, torch.Tensor] = None, + data_loader: torch.utils.data.DataLoader = None, + data_process_func: callable = None, + alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None, + logical_mesh_shape: Tuple[int] = None, + logical_mesh_id: torch.Tensor = None, + solver_preference: str = 'standard', + dataloader_option: str = 'replicated', + shard_option: str = 'standard', + save_solver_solution: bool = False, + load_solver_solution: bool = False, + solver_solution_path: str = None, + return_solution: bool = False, + memory_budget: float = -1.0): + ''' + This method is used to initialize the device mesh, extract the meta_args, and + use them to create a sharded model. + + Args: + model: the model to be sharded. + meta_args(optional): the meta_args is used to specify the input shapes of the model. + If the meta_args is None, the meta_args will be extracted from the data_loader. + data_loader(optional): the data_loader to be used in normal training loop. + data_process_func(optional): the data_process_func is used to process the data from the data_loader. + alpha_beta_dict(optional): the alpha_beta_dict contains the alpha and beta values + for each devices. if the alpha_beta_dict is None, the alpha_beta_dict will be + generated by profile_alpha_beta function. + logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical + mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be + generated by search_best_logical_mesh_shape function. + logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id. + solver_preference(optional): the solver_preference is used to specify which parallelism algorithm + has higher priority. The valid solver_preference could be 'standard', 'tp', or 'dp'. + dataloader_option(optional): the dataloader_option is used to specify which kind of data_loader will + be used. The valid dataloader_option could be 'replicated' or 'distributed'. + shard_option(optional): the shard_option is used to specify how many axes will be used to shard the + model. The valid shard_option could be 'standard', 'shard', 'shard_last_axis', or 'full_shard'. + save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved + to the solution_path. + load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded + from the solution_path. + solver_solution_path(optional): the path to save or load the solution. + return_solution(optional): if the return_solution is True, the solution will be returned. + memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0, + the memory budget will be infinity. + ''' + device_mesh = initialize_device_mesh(alpha_beta_dict=alpha_beta_dict, + logical_mesh_shape=logical_mesh_shape, + logical_mesh_id=logical_mesh_id) + if meta_args is None: + meta_args = extract_meta_args_from_dataloader(data_loader, data_process_func) + + rst_to_unpack = initialize_model(model, + meta_args, + device_mesh, + solver_preference=solver_preference, + dataloader_option=dataloader_option, + shard_option=shard_option, + save_solver_solution=save_solver_solution, + load_solver_solution=load_solver_solution, + solution_path=solver_solution_path, + return_solution=return_solution, + memory_budget=memory_budget) + + if return_solution: + model, solution = rst_to_unpack + return model, solution + else: + model = rst_to_unpack + return model diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9903ca54e52cb70559cce2c68169c84ca08bef9c --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py @@ -0,0 +1,35 @@ +from .addmm_handler import ADDMMFunctionHandler +from .batch_norm_handler import BatchNormModuleHandler +from .binary_elementwise_handler import BinaryElementwiseHandler +from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler +from .conv_handler import ConvFunctionHandler, ConvModuleHandler +from .default_reshape_handler import DefaultReshapeHandler +from .embedding_handler import EmbeddingFunctionHandler, EmbeddingModuleHandler +from .getattr_handler import GetattrHandler +from .getitem_handler import GetItemHandler +from .layer_norm_handler import LayerNormModuleHandler +from .linear_handler import LinearFunctionHandler, LinearModuleHandler +from .matmul_handler import MatMulHandler +from .normal_pooling_handler import NormPoolingHandler +from .output_handler import OutputHandler +from .permute_handler import PermuteHandler +from .placeholder_handler import PlaceholderHandler +from .registry import operator_registry +from .softmax_handler import SoftmaxHandler +from .split_handler import SplitHandler +from .sum_handler import SumHandler +from .tensor_constructor_handler import TensorConstructorHandler +from .transpose_handler import TransposeHandler +from .unary_elementwise_handler import UnaryElementwiseHandler +from .view_handler import ViewHandler +from .where_handler import WhereHandler + +__all__ = [ + 'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler', + 'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler', + 'UnaryElementwiseHandler', 'DefaultReshapeHandler', 'PlaceholderHandler', 'OutputHandler', 'WhereHandler', + 'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler', + 'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler', + 'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler', 'TransposeHandler', + 'SplitHandler' +] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/addmm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/addmm_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..da0d199c5e05b37340cb4ddcfee0b52a9102fadf --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/addmm_handler.py @@ -0,0 +1,91 @@ +from typing import Dict, List, Union + +import torch + +from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager + +from ..sharding_strategy import CommAction, CommType, OperationData, OperationDataType, ShardingStrategy +from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape +from .node_handler import NodeHandler +from .registry import operator_registry +from .strategy import LinearProjectionStrategyGenerator, StrategyGenerator + +__all__ = ['ADDMMFunctionHandler'] + + +@operator_registry.register(torch.addmm) +@operator_registry.register(torch.Tensor.addmm) +class ADDMMFunctionHandler(NodeHandler): + """ + This is a NodeHandler class which deals with the batched matrix multiplication operation in PyTorch. + Such operations including `torch.bmm` and `torch.Tensor.bmm` require the tensor to be 3D, thus, there is + no logical-physical shape conversion in this handler. + """ + + def _infer_op_data_type(self, tensor: torch.Tensor) -> OperationDataType: + if isinstance(tensor, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + return data_type + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + + # input operand + input_data = self.node.args[1]._meta_data + physical_input_operand = OperationData(name=str(self.node.args[1]), + type=self._infer_op_data_type(input_data), + data=input_data) + + # other operand + other_data = self.node.args[2]._meta_data + physical_other_operand = OperationData(name=str(self.node.args[2]), + type=self._infer_op_data_type(other_data), + data=other_data) + # bias physical shape + bias_logical_shape = self.node._meta_data.shape + bias_data = self.node.args[0]._meta_data + physical_bias_operand = OperationData(name=str(self.node.args[0]), + type=self._infer_op_data_type(bias_data), + data=bias_data, + logical_shape=bias_logical_shape) + + # output + physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + + mapping = { + "input": physical_input_operand, + "other": physical_other_operand, + "output": physical_output, + 'bias': physical_bias_operand + } + + return mapping + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append( + LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='addmm')) + return generators + + def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]: + # convert bias from its logical sharding spec to its physical sharding spec + op_data_mapping = self.get_operation_data_mapping() + + bias_op_data = op_data_mapping['bias'] + bias_physical_shape = bias_op_data.data.shape + bias_logical_shape = bias_op_data.logical_shape + bias_sharding_spec = strategy.get_sharding_spec_by_name(bias_op_data.name) + bias_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape( + bias_sharding_spec, bias_logical_shape, bias_physical_shape) + strategy.sharding_specs[bias_op_data] = bias_sharding_spec + + if len(removed_dims) > 0: + comm_action = comm_actions_for_oprands(node=self.node, + removed_dims=removed_dims, + op_data=bias_op_data, + sharding_spec=bias_sharding_spec) + strategy.communication_actions[bias_op_data] = comm_action + + return strategy diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..cb1bb36b78796db3d5656213518376f8f365dce0 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py @@ -0,0 +1,69 @@ +from typing import Dict, List + +import torch + +from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector +from .node_handler import MetaInfoModuleHandler, ModuleHandler +from .registry import operator_registry +from .strategy import BatchNormStrategyGenerator, StrategyGenerator + +__all__ = ['BatchNormModuleHandler'] + + +@operator_registry.register(torch.nn.BatchNorm1d) +@operator_registry.register(torch.nn.BatchNorm2d) +@operator_registry.register(torch.nn.BatchNorm3d) +class BatchNormModuleHandler(MetaInfoModuleHandler): + """ + A BatchNormModuleHandler which deals with the sharding strategies for nn.BatchNormXd module. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(BatchNormStrategyGenerator(op_data_mapping, self.device_mesh)) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + physical_input_operand = OperationData(name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.node.args[0]._meta_data) + physical_other_operand = OperationData(name="weight", + type=OperationDataType.PARAM, + data=self.named_parameters['weight'], + logical_shape=self.named_parameters['weight'].shape) + physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + + physical_running_mean_operand = OperationData(name="running_mean", + type=OperationDataType.BUFFER, + data=self.named_buffers['running_mean'], + logical_shape=self.named_buffers['running_mean'].shape) + + physical_running_var_operand = OperationData(name="running_var", + type=OperationDataType.BUFFER, + data=self.named_buffers['running_var'], + logical_shape=self.named_buffers['running_var'].shape) + + physical_num_batches_tracked_operand = OperationData( + name="num_batches_tracked", + type=OperationDataType.BUFFER, + data=self.named_buffers['num_batches_tracked'], + logical_shape=self.named_buffers['num_batches_tracked'].shape) + + mapping = { + "input": physical_input_operand, + "other": physical_other_operand, + "output": physical_output, + "running_mean": physical_running_mean_operand, + "running_var": physical_running_var_operand, + "num_batches_tracked": physical_num_batches_tracked_operand + } + + if self.named_parameters['bias'] is not None: + physical_bias_operand = OperationData(name="bias", + type=OperationDataType.PARAM, + data=self.named_parameters['bias']) + mapping['bias'] = physical_bias_operand + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..db8f0b54ddeeb1c5250951f0c9e8bfef364eb16d --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py @@ -0,0 +1,113 @@ +from typing import Dict, List, Union + +import torch +from torch.fx.node import Node + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, ShardingStrategy +from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager + +from ..constants import BCAST_FUNC_OP +from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape +from .node_handler import MetaInfoNodeHandler, NodeHandler +from .registry import operator_registry +from .strategy import BinaryElementwiseStrategyGenerator, StrategyGenerator + +__all__ = ['BinaryElementwiseHandler'] + + +@operator_registry.register(BCAST_FUNC_OP) +class BinaryElementwiseHandler(MetaInfoNodeHandler): + """ + An BinaryBcastOpHandler is a node handler which deals with operations which have two + operands and broadcasting occurs such as torch.add. + """ + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + bcast_shape = self.node._meta_data.shape + + def _get_op_data_type(tensor): + if isinstance(tensor, torch.nn.parameter.Parameter): + return OperationDataType.PARAM + else: + return OperationDataType.ARG + + def _get_arg_value(idx): + non_tensor = False + if isinstance(self.node.args[idx], Node): + meta_data = self.node.args[idx]._meta_data + # The meta_data of node type argument could also possibly be a non-tensor object. + if not isinstance(meta_data, torch.Tensor): + assert isinstance(meta_data, (int, float)) + meta_data = torch.Tensor([meta_data]).to('meta') + non_tensor = True + + else: + # this is in fact a real data like int 1 + # but we can deem it as meta data + # as it won't affect the strategy generation + assert isinstance(self.node.args[idx], (int, float)) + meta_data = torch.Tensor([self.node.args[idx]]).to('meta') + non_tensor = True + + return meta_data, non_tensor + + input_meta_data, non_tensor_input = _get_arg_value(0) + other_meta_data, non_tensor_other = _get_arg_value(1) + output_meta_data = self.node._meta_data + # we need record op_data with non-tensor data in this list, + # and filter the non-tensor op_data in post_process. + self.non_tensor_list = [] + # assert False + input_op_data = OperationData(name=str(self.node.args[0]), + type=_get_op_data_type(input_meta_data), + data=input_meta_data, + logical_shape=bcast_shape) + other_op_data = OperationData(name=str(self.node.args[1]), + type=_get_op_data_type(other_meta_data), + data=other_meta_data, + logical_shape=bcast_shape) + output_op_data = OperationData(name=str(self.node), + type=OperationDataType.OUTPUT, + data=output_meta_data, + logical_shape=bcast_shape) + if non_tensor_input: + self.non_tensor_list.append(input_op_data) + if non_tensor_other: + self.non_tensor_list.append(other_op_data) + + mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data} + return mapping + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(BinaryElementwiseStrategyGenerator(op_data_mapping, self.device_mesh)) + return generators + + def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]: + # convert bias from its logical sharding spec to its physical sharding spec + op_data_mapping = self.get_operation_data_mapping() + + for op_name, op_data in op_data_mapping.items(): + if op_data in self.non_tensor_list: + # remove the sharding spec if the op_data is not a tensor, e.g. torch.pow(tensor, 2) + strategy.sharding_specs.pop(op_data) + + else: + # convert the logical sharding spec to physical sharding spec if broadcast + # e.g. torch.rand(4, 4) + torch.rand(4) + physical_shape = op_data.data.shape + logical_shape = op_data.logical_shape + sharding_spec = strategy.get_sharding_spec_by_name(op_data.name) + sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape( + sharding_spec, logical_shape, physical_shape) + + strategy.sharding_specs[op_data] = sharding_spec + if len(removed_dims) > 0: + comm_action = comm_actions_for_oprands(node=self.node, + removed_dims=removed_dims, + op_data=op_data, + sharding_spec=sharding_spec) + strategy.communication_actions[op_data] = comm_action + + return strategy diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..da2b733c9f7afda075c10dc3dd17a0d4f42fbc01 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py @@ -0,0 +1,110 @@ +from typing import Dict, List, Union + +import torch + +from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager + +from ..sharding_strategy import CommAction, CommType, OperationData, OperationDataType, ShardingStrategy +from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape +from .node_handler import NodeHandler +from .registry import operator_registry +from .strategy import BatchedMatMulStrategyGenerator, StrategyGenerator + +__all__ = ['BMMFunctionHandler', 'AddBMMFunctionHandler'] + + +def _get_data_mapping_for_bmm_op(node, input_idx, other_idx, bias_idx=None): + """ + This function is a helper function which extracts the common logic for both `bmm` and `addbmm` + node handler to reduce code redundancy. + """ + # input operand + physical_input_operand = OperationData(name=str(node.args[input_idx]), + type=OperationDataType.ARG, + data=node.args[input_idx]._meta_data) + + # other operand + physical_other_operand = OperationData(name=str(node.args[other_idx]), + type=OperationDataType.ARG, + data=node.args[other_idx]._meta_data) + + # output + physical_output = OperationData(name=str(node), type=OperationDataType.OUTPUT, data=node._meta_data) + mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} + + if bias_idx is not None: + # bias physical shape + bias_logical_shape = node._meta_data.shape + physical_bias_operand = OperationData(name=str(node.args[bias_idx]), + type=OperationDataType.ARG, + data=node.args[bias_idx]._meta_data, + logical_shape=bias_logical_shape) + mapping['bias'] = physical_bias_operand + return mapping + + +@operator_registry.register(torch.bmm) +@operator_registry.register(torch.Tensor.bmm) +class BMMFunctionHandler(NodeHandler): + """ + This is a NodeHandler class which deals with the batched matrix multiplication operation in PyTorch. + Such operations including `torch.bmm` and `torch.Tensor.bmm` require the tensor to be 3D, thus, there is + no logical-physical shape conversion in this handler. + """ + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + mapping = _get_data_mapping_for_bmm_op(node=self.node, input_idx=0, other_idx=1) + return mapping + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh)) + return generators + + +@operator_registry.register(torch.addbmm) +@operator_registry.register(torch.Tensor.addbmm) +class AddBMMFunctionHandler(NodeHandler): + """ + This is a NodeHandler class which deals with the addition + batched matrix multiplication operation in PyTorch. + Such operations including `torch.addbmm` and `torch.Tensor.addbmm` require the two matmul tensor to be 3D. However, due to the + addition, logical-physical shape conversion is required for the bias term. + + As the addbmm operation will reduce the batch dimension, the bias is maximum 2D. + """ + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + mapping = _get_data_mapping_for_bmm_op(node=self.node, input_idx=1, other_idx=2, bias_idx=0) + return mapping + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generator = BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh) + # addbmm will shrink the first batch dim + generator.squeeze_batch_dim = True + generators.append(generator) + return generators + + def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]: + # convert bias from its logical sharding spec to its physical sharding spec + op_data_mapping = self.get_operation_data_mapping() + + if 'bias' in op_data_mapping: + bias_op_data = op_data_mapping['bias'] + bias_physical_shape = bias_op_data.data.shape + bias_logical_shape = bias_op_data.logical_shape + bias_sharding_spec = strategy.get_sharding_spec_by_name(bias_op_data.name) + bias_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape( + bias_sharding_spec, bias_logical_shape, bias_physical_shape) + strategy.sharding_specs[bias_op_data] = bias_sharding_spec + + if len(removed_dims) > 0: + comm_action = comm_actions_for_oprands(node=self.node, + removed_dims=removed_dims, + op_data=bias_op_data, + sharding_spec=bias_sharding_spec) + strategy.communication_actions[bias_op_data] = comm_action + + return strategy diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..272b1c85630a8ab15145e701740a44e20d5103b8 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py @@ -0,0 +1,120 @@ +from typing import Dict, List + +import torch +import torch.nn.functional as F + +from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector +from ..utils import transpose_partition_dim +from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler, ModuleHandler, NodeHandler +from .registry import operator_registry +from .strategy import ConvStrategyGenerator, StrategyGenerator + +__all__ = ['ConvModuleHandler', 'ConvFunctionHandler'] + + +@operator_registry.register(torch.nn.Conv1d) +@operator_registry.register(torch.nn.Conv2d) +@operator_registry.register(torch.nn.Conv3d) +class ConvModuleHandler(MetaInfoModuleHandler): + """ + A ConvModuleHandler which deals with the sharding strategies for nn.Convxd module. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(ConvStrategyGenerator(op_data_mapping, self.device_mesh)) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + physical_input_operand = OperationData(name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.node.args[0]._meta_data) + logical_shape_for_weight = list(self.named_parameters["weight"].shape) + logical_shape_for_weight[0], logical_shape_for_weight[1] = logical_shape_for_weight[ + 1], logical_shape_for_weight[0] + physical_other_operand = OperationData(name="weight", + type=OperationDataType.PARAM, + data=self.named_parameters['weight'], + logical_shape=torch.Size(logical_shape_for_weight)) + physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + + mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} + + if "bias" in self.named_parameters: + physical_bias_operand = OperationData(name="bias", + type=OperationDataType.PARAM, + data=self.named_parameters['bias']) + mapping['bias'] = physical_bias_operand + return mapping + + def post_process(self, strategy: ShardingStrategy): + """ + Convert the sharding spec of the weight parameter back to its original shape. + """ + for op_data, sharding_spec in strategy.input_sharding_specs.items(): + if op_data.name == "weight": + transpose_partition_dim(sharding_spec, 0, 1) + return strategy + + +@operator_registry.register(F.conv1d) +@operator_registry.register(F.conv2d) +@operator_registry.register(F.conv3d) +class ConvFunctionHandler(MetaInfoNodeHandler): + """ + A ConvFunctionHandler which deals with the sharding strategies for nn.functional.ConvXd functions. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(ConvStrategyGenerator(op_data_mapping, self.device_mesh)) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + physical_input_operand = OperationData(name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.node.args[0]._meta_data) + + # check if the other operand is a parameter + if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + + logical_shape_for_weight = list(self.node.args[1]._meta_data.shape) + logical_shape_for_weight[0], logical_shape_for_weight[1] = logical_shape_for_weight[ + 1], logical_shape_for_weight[0] + physical_other_operand = OperationData(name=str(self.node.args[1]), + type=data_type, + data=self.node.args[1]._meta_data, + logical_shape=torch.Size(logical_shape_for_weight)) + physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + + mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} + + if "bias" in self.node.kwargs and self.node.kwargs['bias'] is not None: + # check if the other operand is a parameter + if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + physical_bias_operand = OperationData(name=str(self.node.kwargs["bias"]), + type=data_type, + data=self.node.kwargs["bias"]._meta_data) + mapping['bias'] = physical_bias_operand + return mapping + + def post_process(self, strategy: ShardingStrategy): + """ + Convert the sharding spec of the weight parameter back to its original shape. + """ + for op_data, sharding_spec in strategy.input_sharding_specs.items(): + if op_data.name == str(self.node.args[1]): + transpose_partition_dim(sharding_spec, 0, 1) + return strategy diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/default_reshape_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/default_reshape_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..0c5b9f39e1fba75b44308d57569c3b8c0b5087c0 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/default_reshape_handler.py @@ -0,0 +1,71 @@ +from typing import Dict, List + +import torch + +from ..sharding_strategy import OperationData, OperationDataType +from .node_handler import MetaInfoNodeHandler, NodeHandler +from .registry import operator_registry +from .strategy import DefaultReshapeGenerator, StrategyGenerator + +__all__ = ['DefaultReshapeHandler'] + + +@operator_registry.register(torch.flatten) +@operator_registry.register(torch.Tensor.unsqueeze) +@operator_registry.register(torch.nn.AdaptiveAvgPool2d) +class DefaultReshapeHandler(MetaInfoNodeHandler): + """ + A DefaultReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(DefaultReshapeGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) + return generators + + def infer_logical_shape(self, data): + """ + This function is used to infer logical shape for operands. + + Notes: This function is only used for the operands whose data are not only in type of tensor, + such as tuple of tensor. + """ + if isinstance(data, torch.Tensor): + return data.shape + else: + assert isinstance(data, tuple), "input_data should be a tuple of tensor or a tensor." + logical_shape = [] + for tensor in data: + assert isinstance(tensor, torch.Tensor), "input_data should be a tuple of tensor or a tensor." + logical_shape.append(tensor.shape) + logical_shape = tuple(logical_shape) + return logical_shape + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + + # check if the input operand is a parameter + if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + + input_data = self.node.args[0]._meta_data + input_logical_shape = self.infer_logical_shape(input_data) + physical_input_operand = OperationData(name=str(self.node.args[0]), + type=data_type, + data=input_data, + logical_shape=input_logical_shape) + + output_data = self.node._meta_data + output_logical_shape = self.infer_logical_shape(output_data) + physical_output = OperationData(name=str(self.node), + type=OperationDataType.OUTPUT, + data=output_data, + logical_shape=output_logical_shape) + + mapping = {"input": physical_input_operand, "output": physical_output} + + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/embedding_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/embedding_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..e154105b672de30b5675ca56147fb6b68205a469 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/embedding_handler.py @@ -0,0 +1,230 @@ +from typing import Dict, List, Union + +import torch +import torch.nn.functional as F + +from colossalai.auto_parallel.tensor_shard.utils import update_partition_dim +from colossalai.logging import get_dist_logger +from colossalai.tensor.sharding_spec import ShardingNotDivisibleError + +from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy +from .node_handler import ModuleHandler, NodeHandler +from .registry import operator_registry +from .strategy import EmbeddingStrategyGenerator, StrategyGenerator + +__all__ = ['EmbeddingModuleHandler', 'EmbeddingFunctionHandler'] + + +def _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy: ShardingStrategy, input_name: str, + output_name: str) -> List[ShardingStrategy]: + """ + This function converts the logical sharding spec to the physical sharding spec for both the input and output + of the embedding operation. + + Args: + strategy (ShardingStrategy): the logical strategy generated by the strategy generator. + input_name (str): the name of the OperationData object for the input. + output_name (str): the name of the OperationData object for the output. + """ + # the result will be a list of strategies + sharding_strategies = [] + + # get operation data + input_op_data = strategy.get_op_data_by_name(input_name) + output_op_data = strategy.get_op_data_by_name(output_name) + input_sharding_spec = strategy.get_sharding_spec_by_name(input_op_data.name) + output_sharding_spec = strategy.get_sharding_spec_by_name(output_op_data.name) + + # recover the last logical dimension to physical dimension + last_logical_output_dims = len(output_op_data.logical_shape) - 1 + last_physical_output_dims = output_op_data.data.dim() - 1 + + # get logger for debug message + logger = get_dist_logger() + + # For the input of the embedding operation, it can be multi-dimensional. The sharding spec is only generated for + # logical 1D non-matrix dimension, the logical non-matrix dimension can belong to the 0th to Nth dimension of the + # physical input shape. Thus, we enumerate to get all possible cases. + if input_sharding_spec.dim_partition_dict: + # if bool(input_sharding_spec.dim_partition_dict), it means that the + # the generated sharding strategy does shard the non-matrix dimension, + # in this case, we need to do enumeration + num_input_dims = input_op_data.data.dim() + for i in range(num_input_dims): + strategy_copy = strategy.clone() + input_sharding_spec = strategy_copy.get_sharding_spec_by_name(input_op_data.name) + output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name) + try: + # replace the 0th dimension in the logical sharding with ith dimension in the physical sharding + update_partition_dim(sharding_spec=input_sharding_spec, + dim_mapping={0: i}, + physical_shape=input_op_data.data.shape, + inplace=True) + + if last_logical_output_dims in output_sharding_spec.dim_partition_dict: + dim_mapping = {0: i, last_logical_output_dims: last_physical_output_dims} + else: + dim_mapping = {0: i} + + update_partition_dim(sharding_spec=output_sharding_spec, + dim_mapping=dim_mapping, + physical_shape=output_op_data.data.shape, + inplace=True) + + strategy_copy.name = f'{strategy.name}_{i}' + sharding_strategies.append(strategy_copy) + + except ShardingNotDivisibleError as e: + logger.debug( + f'Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}' + ) + else: + # the generated sharding strategy does not shard the non-matrix dimension, + # in this case, we don't need to do enumeration + # but instead, we still need to convert the logical shape to physical shape + strategy_copy = strategy.clone() + input_sharding_spec = strategy_copy.get_sharding_spec_by_name(input_op_data.name) + output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name) + + # after updating, the logical shape will be replaced by the physical shape + update_partition_dim(sharding_spec=input_sharding_spec, + dim_mapping={}, + physical_shape=input_op_data.data.shape, + inplace=True) + + if last_logical_output_dims in output_sharding_spec.dim_partition_dict: + dim_mapping = {last_logical_output_dims: last_physical_output_dims} + else: + dim_mapping = {} + + update_partition_dim(sharding_spec=output_sharding_spec, + dim_mapping=dim_mapping, + physical_shape=output_op_data.data.shape, + inplace=True) + sharding_strategies.append(strategy_copy) + + return sharding_strategies + + +@operator_registry.register(torch.nn.Embedding) +class EmbeddingModuleHandler(ModuleHandler): + """ + A EmbeddingModuleHandler which deals with the sharding strategies for nn.Embedding module. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(EmbeddingStrategyGenerator(op_data_mapping, self.device_mesh)) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # In nn.Embedding operation, all the dimensions of input will be treated as the batch dimension, + # and then the sharding spec will be generated based on the logical 1D tensor. + # After that, the logical sharding info will be enumerated among all the physical dimensions. + # Finally, the input will be transformed back to its original shape in self.post_process + input_meta_data = self.node.args[0]._meta_data + input_logical_shape = input_meta_data.view(-1).shape + physical_input_operand = OperationData(name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=input_meta_data, + logical_shape=input_logical_shape) + + physical_other_operand = OperationData(name="weight", + type=OperationDataType.PARAM, + data=self.named_parameters['weight']) + + # Same as input, in nn.Embedding operation, all the dimensions of output will be treated as + # (batch dimension, embedding dimension), and then the sharding spec will be generated based + # on the logical 2D tensor. + # After that, the logical sharding info of batch dimension will be enumerated among all the physical dimensions. + # Finally, the output will be transformed back to its original shape in self.post_process + output_meta_data = self.node._meta_data + output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape + physical_output = OperationData(name=str(self.node), + type=OperationDataType.OUTPUT, + data=output_meta_data, + logical_shape=output_logical_shape) + + mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} + + return mapping + + def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]: + """ + Convert the sharding spec from the logical shape to the physical shape. + """ + # create multiple sharding strategies for the inputs + # as input can be multi-dimensinal and the partition dim is only 2D, + # we need to map the partition at logical dim 0 to one of the first few dimensions of the input and output + strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy=strategy, + input_name=str( + self.node.args[0]), + output_name=str(self.node)) + return strategies + + +@operator_registry.register(F.embedding) +class EmbeddingFunctionHandler(NodeHandler): + """ + A EmbeddingFunctionHandler which deals with the sharding strategies for F.embedding. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(EmbeddingStrategyGenerator(op_data_mapping, self.device_mesh)) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # In F.embedding operation, all the dimensions of input will be treated as the batch dimension, + # and then the sharding spec will be generated based on the logical 1D tensor. + # After that, the logical sharding info will be enumerated among all the physical dimensions. + # Finally, the input will be transformed back to its original shape in self.post_process + input_meta_data = self.node.args[0]._meta_data + input_logical_shape = input_meta_data.view(-1).shape + physical_input_operand = OperationData(name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.node.args[0]._meta_data, + logical_shape=input_logical_shape) + + # check if the other operand is a parameter + if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + + physical_other_operand = OperationData(name=str(self.node.args[1]), + type=data_type, + data=self.node.args[1]._meta_data) + + # Same as input, in F.embedding operation, all the dimensions of output will be treated as + # (batch dimension, embedding dimension), and then the sharding spec will be generated based + # on the logical 2D tensor. + # After that, the logical sharding info of batch dimension will be enumerated among all the physical dimensions. + # Finally, the output will be transformed back to its original shape in self.post_process + output_meta_data = self.node._meta_data + output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape + physical_output = OperationData( + name=str(self.node), + type=OperationDataType.OUTPUT, + data=self.node._meta_data, + logical_shape=output_logical_shape, + ) + + mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} + + return mapping + + def post_process(self, strategy: ShardingStrategy): + """ + Convert the sharding spec from the logical shape to the physical shape. + """ + # create multiple sharding strategies for the inputs + # as input can be multi-dimensinal and the partition dim is only 2D, + # we need to map the partition at logical dim 0 to one of the first few dimensions of the input and output + strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy=strategy, + input_name=str( + self.node.args[0]), + output_name=str(self.node)) + return strategies diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/getattr_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/getattr_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..53addb873d1d1a014352058f8ec127f6bf7c4d91 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/getattr_handler.py @@ -0,0 +1,34 @@ +from typing import Dict, List + +from ..sharding_strategy import OperationData, OperationDataType +from .node_handler import NodeHandler +from .strategy import GetattrGenerator, StrategyGenerator + +__all__ = ['GetattrHandler'] + + +class GetattrHandler(NodeHandler): + """ + A GetattrHandler which deals with the sharding strategies for Getattr Node. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(GetattrGenerator(op_data_mapping, self.device_mesh)) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + + # There are only two possible types for get_attr node: + # 1. torch.Tensor(torch.nn.Parameters or torch.nn.Buffers) + # 2. torch.nn.Module + # temporarily, we just support first case in Tracer, so we don't have to worry about + # issue related to the node._meta_data type. + physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + + mapping = {"output": physical_output} + + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/getitem_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/getitem_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..3466e9dd9940e748da4bc8abb3488aacf98cd8ff --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/getitem_handler.py @@ -0,0 +1,41 @@ +import operator +from typing import Dict, List + +import torch + +from ..sharding_strategy import OperationData, OperationDataType +from .node_handler import NodeHandler +from .registry import operator_registry +from .strategy import StrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator + +__all__ = ['GetItemHandler'] + + +@operator_registry.register(operator.getitem) +class GetItemHandler(NodeHandler): + """ + A GetItemHandler which deals with the sharding strategies for operator.getitem. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + if isinstance(op_data_mapping["input"].data, torch.Tensor): + generators.append(TensorStrategyGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) + else: + generators.append(TensorTupleStrategyGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) + + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + physical_input_operand = OperationData(name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.node.args[0]._meta_data) + physical_other_operand = OperationData(name="index", type=OperationDataType.ARG, data=self.node.args[1]) + physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + + mapping = {"input": physical_input_operand, "index": physical_other_operand, "output": physical_output} + + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/layer_norm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/layer_norm_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..452381169b74d093188e0f8d7775037f8bf5019c --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/layer_norm_handler.py @@ -0,0 +1,44 @@ +from typing import Dict, List + +import torch + +from ..sharding_strategy import OperationData, OperationDataType +from .node_handler import MetaInfoModuleHandler, ModuleHandler +from .registry import operator_registry +from .strategy import LayerNormGenerator, StrategyGenerator + +__all__ = ['LayerNormModuleHandler'] + + +@operator_registry.register(torch.nn.LayerNorm) +class LayerNormModuleHandler(MetaInfoModuleHandler): + """ + A LayerNormModuleHandler which deals with the sharding strategies for nn.LayerNorm module. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(LayerNormGenerator(op_data_mapping, self.device_mesh)) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + physical_input_operand = OperationData(name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.node.args[0]._meta_data) + physical_other_operand = OperationData(name="weight", + type=OperationDataType.PARAM, + data=self.named_parameters['weight'], + logical_shape=self.named_parameters['weight'].shape) + physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + + mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} + + if self.named_parameters['bias'] is not None: + physical_bias_operand = OperationData(name="bias", + type=OperationDataType.PARAM, + data=self.named_parameters['bias']) + mapping['bias'] = physical_bias_operand + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..59091dab519f4e4458461b84e444b9a034f4df98 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py @@ -0,0 +1,275 @@ +from typing import Dict, List, Union + +import torch +import torch.nn.functional as F + +from colossalai.auto_parallel.tensor_shard.utils import ( + check_sharding_spec_validity, + transpose_partition_dim, + update_partition_dim, +) +from colossalai.logging import get_dist_logger +from colossalai.tensor.sharding_spec import ShardingNotDivisibleError + +from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector +from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler, ModuleHandler, NodeHandler +from .registry import operator_registry +from .strategy import LinearProjectionStrategyGenerator, StrategyGenerator + +__all__ = ['LinearModuleHandler', 'LinearFunctionHandler'] + + +def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStrategy, + weight_name: str) -> ShardingStrategy: + """ + This function is a helper function used by both module node handler and function node handler. This function will + convert the sharding spec for the transposed weight to the correct partititon spec. + + Args: + strategy (ShardingStrategy): the strategy generated by the strategy generator. + weight_name (str): the name of the OperationData object for the weight. + """ + # switch the dimensions of the transposed weight + sharding_spec = strategy.get_sharding_spec_by_name(weight_name) + op_data = strategy.get_op_data_by_name(weight_name) + assert op_data.logical_shape[0] == op_data.data.shape[1] and \ + op_data.logical_shape[1] == op_data.data.shape[0], \ + "Expected the logical shape of the linear operator's weight is equal to transposed physical shape" + dim_size = len(op_data.logical_shape) + transpose_partition_dim(sharding_spec, 0, dim_size - 1) + return strategy + + +def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: ShardingStrategy, input_name: str, + output_name: str) -> List[ShardingStrategy]: + """ + This function converts the logical sharding spec to the physical sharding spec for both the input and output of the linear operation. The input and output + should have the same sharding spec. + + Args: + strategy (ShardingStrategy): the logical strategy generated by the strategy generator. + input_name (str): the name of the OperationData object for the input. + output_name (str): the name of the OperationData object for the output. + + + """ + # the result will be a list of strategies + sharding_strategies = [] + + # get operation data + input_op_data = strategy.get_op_data_by_name(input_name) + output_op_data = strategy.get_op_data_by_name(output_name) + input_sharding_spec = strategy.get_sharding_spec_by_name(input_op_data.name) + output_sharding_spec = strategy.get_sharding_spec_by_name(output_op_data.name) + + # recover the last logical dimension to physical dimension + last_logical_input_dims = len(input_op_data.logical_shape) - 1 + last_logical_output_dims = len(output_op_data.logical_shape) - 1 + last_physical_input_dims = input_op_data.data.dim() - 1 + last_physical_output_dims = output_op_data.data.dim() - 1 + + if last_logical_input_dims in input_sharding_spec.dim_partition_dict: + input_last_dim_mapping = {last_logical_input_dims: last_physical_input_dims} + else: + input_last_dim_mapping = {} + + if last_logical_output_dims in output_sharding_spec.dim_partition_dict: + output_last_dim_mapping = {last_logical_output_dims: last_physical_output_dims} + else: + output_last_dim_mapping = {} + + # get logger for debug message + logger = get_dist_logger() + + # for the input of the linear operation, it can be multi-dimensional. The sharding spec generated is only + # 2D, where the first dimension is non-matrix dimension and the last dimension is the matrix dimension. + # the logical non-matrix dimension can belong to the 0th to (N-1)th dimension of the physical input shape. + # Thus, we enumerate to get all possible cases. + if 0 in input_sharding_spec.dim_partition_dict: + # if 0 is in the dim_partition_dict, it means that the + # the generated sharding strategy does shard the non-matrix dimension, + # in this case, we need to do enumeration + num_input_dims = input_op_data.data.dim() + for i in range(num_input_dims - 1): + strategy_copy = strategy.clone() + input_sharding_spec = strategy_copy.get_sharding_spec_by_name(input_op_data.name) + output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name) + try: + # replace the 0th dimension in the logical sharding with ith dimension in the physical sharding + input_dim_mapping = {0: i} + input_dim_mapping.update(input_last_dim_mapping) + + update_partition_dim(sharding_spec=input_sharding_spec, + dim_mapping=input_dim_mapping, + physical_shape=input_op_data.data.shape, + inplace=True) + output_dim_mapping = {0: i} + output_dim_mapping.update(output_last_dim_mapping) + + update_partition_dim(sharding_spec=output_sharding_spec, + dim_mapping=output_dim_mapping, + physical_shape=output_op_data.data.shape, + inplace=True) + strategy_copy.name = f'{strategy.name}_{i}' + sharding_strategies.append(strategy_copy) + except ShardingNotDivisibleError as e: + logger.debug( + f'Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}' + ) + else: + # the generated sharding strategy does not shard the non-matrix dimension, + # in this case, we don't need to do enumeration + # but instead, we still need to convert the logical shape to physical shape + strategy_copy = strategy.clone() + input_sharding_spec = strategy_copy.get_sharding_spec_by_name(input_op_data.name) + output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name) + + # after updating, the logical shape will be replaced by the physical shape + input_dim_mapping = {} + input_dim_mapping.update(input_last_dim_mapping) + update_partition_dim(sharding_spec=input_sharding_spec, + dim_mapping=input_dim_mapping, + physical_shape=input_op_data.data.shape, + inplace=True) + + output_dim_mapping = {} + output_dim_mapping.update(output_last_dim_mapping) + update_partition_dim(sharding_spec=output_sharding_spec, + dim_mapping=output_dim_mapping, + physical_shape=output_op_data.data.shape, + inplace=True) + sharding_strategies.append(strategy_copy) + return sharding_strategies + + +@operator_registry.register(torch.nn.Linear) +class LinearModuleHandler(MetaInfoModuleHandler): + """ + A LinearModuleHandler which deals with the sharding strategies for nn.Linear module. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append( + LinearProjectionStrategyGenerator(op_data_mapping, + self.device_mesh, + linear_projection_type='linear', + solver_perference=self.solver_perference)) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + input_meta_data = self.node.args[0]._meta_data + input_logical_shape = input_meta_data.view(-1, input_meta_data.shape[-1]).shape + physical_input_operand = OperationData(name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=input_meta_data, + logical_shape=input_logical_shape) + physical_other_operand = OperationData(name="weight", + type=OperationDataType.PARAM, + data=self.named_parameters['weight'], + logical_shape=self.named_parameters['weight'].shape[::-1]) + output_meta_data = self.node._meta_data + output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape + physical_output = OperationData(name=str(self.node), + type=OperationDataType.OUTPUT, + data=output_meta_data, + logical_shape=output_logical_shape) + + mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} + + if 'bias' in self.named_parameters is not None: + physical_bias_operand = OperationData(name="bias", + type=OperationDataType.PARAM, + data=self.named_parameters['bias']) + mapping['bias'] = physical_bias_operand + return mapping + + def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]: + """ + Convert the sharding spec from the logical shape to the physical shape. In this function, two tasks are completed: + 1. the sharding spec is updated for the transposed weight + 2. the input and output sharding specs are updated to physical shape. + """ + # switch the dimensions of the transposed weight + strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy, weight_name='weight') + + # create multiple sharding strategies for the inputs + # as input can be multi-dimensinal and the partition dim is only 2D, + # we need to map the partition at dim 0 to one of the first few dimensions of the input + strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy=strategy, + input_name=str(self.node.args[0]), + output_name=str(self.node)) + return strategies + + +@operator_registry.register(F.linear) +class LinearFunctionHandler(MetaInfoNodeHandler): + """ + A LinearFunctionHandler which deals with the sharding strategies for F.Linear. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append( + LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear')) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + input_meta_data = self.node.args[0]._meta_data + input_logical_shape = input_meta_data.view(-1, input_meta_data.shape[-1]).shape + physical_input_operand = OperationData(name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.node.args[0]._meta_data, + logical_shape=input_logical_shape) + + # check if the other operand is a parameter + if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + + physical_other_operand = OperationData(name=str(self.node.args[1]), + type=data_type, + data=self.node.args[1]._meta_data, + logical_shape=self.node.args[1]._meta_data.shape[::-1]) + output_meta_data = self.node._meta_data + output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape + physical_output = OperationData( + name=str(self.node), + type=OperationDataType.OUTPUT, + data=self.node._meta_data, + logical_shape=output_logical_shape, + ) + + mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} + + if 'bias' in self.node.kwargs and self.node.kwargs['bias'] is not None: + # check if the other operand is a parameter + if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + physical_bias_operand = OperationData(name=str(self.node.kwargs["bias"]), + type=data_type, + data=self.node.kwargs["bias"]._meta_data) + mapping['bias'] = physical_bias_operand + + return mapping + + def post_process(self, strategy: ShardingStrategy): + # switch the dimensions of the transposed weight + strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy, + weight_name=str(self.node.args[1])) + # create multiple sharding strategies for the inputs + # as input can be multi-dimensinal and the partition dim is only 2D, + # we need to map the partition at dim 0 to one of the first few dimensions of the input + strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy=strategy, + input_name=str(self.node.args[0]), + output_name=str(self.node)) + return strategies diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..f3c9d0cbf8267e2321415ae29887e308a9af35b2 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py @@ -0,0 +1,488 @@ +import operator +from abc import ABC, abstractmethod +from copy import deepcopy +from enum import Enum +from functools import reduce +from typing import Dict, List, Union + +import torch + +from colossalai.auto_parallel.tensor_shard.utils.broadcast import ( + BroadcastType, + get_broadcast_dim_info, + get_broadcast_shape, +) +from colossalai.tensor.sharding_spec import ShardingSpecException + +from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy +from ..utils import recover_sharding_spec_for_broadcast_shape +from .node_handler import MetaInfoNodeHandler, NodeHandler +from .registry import operator_registry +from .strategy import ( + BatchedMatMulStrategyGenerator, + DotProductStrategyGenerator, + LinearProjectionStrategyGenerator, + MatVecStrategyGenerator, + StrategyGenerator, +) + + +class MatMulType(Enum): + """ + The MatMulType is categorized into 4 types based on the reference of torch.matmul + in https://pytorch.org/docs/stable/generated/torch.matmul.html. + + DOT: dot product, both tensors are 1D, these two tensors need to have the same number of elements + MM: matrix-matrix product, both tensors are 2D or the 1st tensor is 1D and the 2nd tensor is 2D + MV: matrix-vector product: the 1st tensor is 2D and the 2nd tensor is 1D + BMM: batched matrix-matrix multiplication, one tensor is at least 1D and the other is at least 3D + """ + DOT = 0 + MM = 1 + MV = 2 + BMM = 3 + + +def get_matmul_type(input_dim: int, other_dim: int): + """ + Determine which type of matmul operation should be executed for the given tensor dimensions. + + Args: + input_dim (int): the number of dimensions for the input tenosr + other_dim (int): the number of dimensions for the other tenosr + """ + if input_dim == 1 and other_dim == 1: + matmul_type = MatMulType.DOT + elif input_dim in [1, 2] and other_dim == 2: + matmul_type = MatMulType.MM + elif input_dim == 2 and other_dim == 1: + matmul_type = MatMulType.MV + elif input_dim >= 1 and other_dim >= 1 and (input_dim > 2 or other_dim > 2): + matmul_type = MatMulType.BMM + else: + raise ValueError( + f"The input and other tensors are of {input_dim} and {other_dim} which cannot used to execute matmul operation" + ) + return matmul_type + + +class BmmTransform(ABC): + """ + BmmTransform is an abstraction of the shape conversion between logical and physical operation data + during the strategy generation. + """ + + @abstractmethod + def apply(self, shape_mapping: Dict[str, List[int]]): + pass + + @abstractmethod + def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy): + pass + + +class Padder(BmmTransform): + """ + Add padding to the matrix dimensions for batched matrix multiplication. + """ + + def __init__(self) -> None: + # keep the padding dim, op_name -> padded_dim + self.padded_dim_mapping = {} + + def apply(self, shape_mapping: Dict[str, List[int]]): + mapping_copy = deepcopy(shape_mapping) + input_shape = mapping_copy['input'] + other_shape = mapping_copy['other'] + + if len(input_shape) == 1: + # if the input is a 1D tensor, 1 is prepended to its shape + # and it will be removed afterwards + input_shape.insert(0, 1) + self.padded_dim_mapping['input'] = -2 + self.padded_dim_mapping['output'] = -2 + elif len(other_shape) == 1: + # if the other is a 1D tensor, 1 is appended to its shape + # and it will be removed afterwards + other_shape = other_shape.append(1) + self.padded_dim_mapping['other'] = -1 + self.padded_dim_mapping['output'] = -1 + return mapping_copy + + def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy): + input_op_data = op_data_mapping['input'] + other_op_data = op_data_mapping['other'] + + def _remove_padded_dim(key, strategy): + op_data = op_data_mapping[key] + sharding_spec = strategy.get_sharding_spec_by_name(op_data.name) + tensor_shape = list(sharding_spec.entire_shape) + dim_partition_list = [None] * len(tensor_shape) + + # padded dim is a negative number as the padded dim must be a matrix dim + padded_dim = self.padded_dim_mapping[key] + + # compute the new dim partition + for tensor_dim, mesh_dims in sharding_spec.dim_partition_dict.items(): + dim_partition_list[tensor_dim] = mesh_dims + dim_partition_list.pop(padded_dim) + unpadded_dim_partition_list = {k: v for k, v in enumerate(dim_partition_list) if v is not None} + + # compute unpadded tensor shape + tensor_shape.pop(padded_dim) + + assert tensor_shape == list(op_data.data.shape), f'{tensor_shape} vs {list(op_data.data.shape)}' + + # update sharding spec + sharding_spec.__init__(sharding_spec.device_mesh, tensor_shape, unpadded_dim_partition_list) + + # enumerate all sharding strategies + strategies = [] + try: + strategy_copy = strategy.clone() + + # only one of input and other will be padded + if 'input' in self.padded_dim_mapping: + _remove_padded_dim('input', strategy_copy) + _remove_padded_dim('output', strategy_copy) + elif 'other' in self.padded_dim_mapping: + _remove_padded_dim('other', strategy_copy) + _remove_padded_dim('output', strategy_copy) + + strategies.append(strategy_copy) + except ShardingSpecException as e: + pass + return strategies + + +class Broadcaster(BmmTransform): + """ + Broadcast the non-matrix dimensions for batched matrix multiplication. + """ + + def __init__(self) -> None: + self.broadcast_dim_info = {} + + def apply(self, shape_mapping: Dict[str, List[int]]): + mapping_copy = shape_mapping.copy() + + # get shapes + input_shape = mapping_copy['input'] + other_shape = mapping_copy['other'] + + # sanity check + assert len(input_shape) > 1 and len(other_shape) > 1 + + # broadcast the batch dim and record + bcast_non_matrix_dims = get_broadcast_shape(input_shape[:-2], other_shape[:-2]) + + # store the broadcast dim info + input_broadcast_dim_info = get_broadcast_dim_info(bcast_non_matrix_dims, input_shape[:-2]) + other_broadcast_dim_info = get_broadcast_dim_info(bcast_non_matrix_dims, other_shape[:-2]) + self.broadcast_dim_info['input'] = input_broadcast_dim_info + self.broadcast_dim_info['other'] = other_broadcast_dim_info + + # create the full logical shape + input_shape = bcast_non_matrix_dims + input_shape[-2:] + other_shape = bcast_non_matrix_dims + other_shape[-2:] + assert len(input_shape) == len(other_shape) + + mapping_copy['input'] = input_shape + mapping_copy['other'] = other_shape + + return mapping_copy + + def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy): + # remove sharding on the broadcast dim + def _remove_sharding_on_broadcast_dim(key, strategy): + op_data = op_data_mapping[key] + sharding_spec = strategy.get_sharding_spec_by_name(op_data.name) + tensor_shape = list(sharding_spec.entire_shape) + + for dim_idx, broadcast_type in self.broadcast_dim_info[key].items(): + if broadcast_type == BroadcastType.MULTIPLE: + # if the dim is originally 1 and multiplied during broadcast + # we set its sharding to R + # e.g. [1, 2, 4] x [4, 4, 8] -> [4, 2, 8] + # the dim 0 of [1, 2, 4] is multiplied to 4 + tensor_shape[dim_idx] = 1 + elif broadcast_type == BroadcastType.PADDDING: + # if the dim is padded + # we remove its sharding + tensor_shape[dim_idx] = None + + tensor_shape_before_broadcast = [dim for dim in tensor_shape if dim is not None] + + physical_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape( + logical_sharding_spec=sharding_spec, + logical_shape=sharding_spec.entire_shape, + physical_shape=tensor_shape_before_broadcast) + strategy.sharding_specs[op_data] = physical_sharding_spec + + # enumerate all sharding strategies + strategies = [] + try: + strategy_copy = strategy.clone() + _remove_sharding_on_broadcast_dim('input', strategy_copy) + _remove_sharding_on_broadcast_dim('other', strategy_copy) + strategies.append(strategy_copy) + except ShardingSpecException as e: + pass + return strategies + + +class Viewer(BmmTransform): + """ + Change the shape of the tensor from N-D to 3D + """ + + def __init__(self) -> None: + self.batch_dims_before_view = None + + def apply(self, shape_mapping: Dict[str, List[int]]): + mapping_copy = shape_mapping.copy() + self.batch_dims_before_view = list(mapping_copy['input'][:-2]) + + # get shapes + input_shape = shape_mapping['input'] + other_shape = shape_mapping['other'] + + # view to 3d tensor + assert len(input_shape) >= 3 and len(other_shape) >= 3 + input_shape = [reduce(operator.mul, input_shape[:-2])] + input_shape[-2:] + other_shape = [reduce(operator.mul, other_shape[:-2])] + other_shape[-2:] + output_shape = input_shape[:2] + other_shape[2:] + mapping_copy['input'] = input_shape + mapping_copy['other'] = other_shape + mapping_copy['output'] = output_shape + return mapping_copy + + def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy): + # get operation data + def _update_sharding_spec(key, strategy, physical_batch_dim): + """ + Map the logical batch dim to the physical batch dim + """ + op_data = op_data_mapping[key] + sharding_spec = strategy.get_sharding_spec_by_name(op_data.name) + dim_partition_dict = sharding_spec.dim_partition_dict + entire_shape = sharding_spec.entire_shape + + # upddate the dimension index for the matrix dimensions + if 2 in dim_partition_dict: + dim_partition_dict[len(self.batch_dims_before_view) + 1] = dim_partition_dict.pop(2) + if 1 in dim_partition_dict: + dim_partition_dict[len(self.batch_dims_before_view)] = dim_partition_dict.pop(1) + + # map the logical batch dim to phyiscal batch dim + if 0 in dim_partition_dict: + batch_dim_shard = dim_partition_dict.pop(0) + dim_partition_dict[physical_batch_dim] = batch_dim_shard + + # the new shape will be the batch dims + the last 2 matrix dims + shape_before_view = self.batch_dims_before_view + list(entire_shape[-2:]) + sharding_spec.__init__(sharding_spec.device_mesh, shape_before_view, dim_partition_dict) + + num_batch_dim_before_view = len(self.batch_dims_before_view) + + # enumerate all sharding strategies + strategies = [] + for i in range(num_batch_dim_before_view): + # create a new strategy + strategy_copy = strategy.clone() + try: + _update_sharding_spec('input', strategy_copy, i) + _update_sharding_spec('other', strategy_copy, i) + _update_sharding_spec('output', strategy_copy, i) + strategies.append(strategy_copy) + except ShardingSpecException as e: + continue + return strategies + + +def _get_bmm_logical_shape(input_shape, other_shape, transforms): + """ + Compute the logical shapes for BMM operation. BMM has a general representation + [b, i, k] = [b, i, j] x [b, j, k] + + The dimension b is called non-matrix (batch) dimension and the remaining dimensions are called matrix dimensions + The logical shape for the bmm operands will undergo three stages + 1. append/prepend the 1 to the 1D tensor if there is any + 2. broadcast the non-matrix dimensions + 3. reshape to 3 dimensions + + """ + shape_mapping = {'input': input_shape, 'other': other_shape} + + for transform in transforms: + shape_mapping = transform.apply(shape_mapping) + + input_shape = shape_mapping.get('input', None) + other_shape = shape_mapping.get('other', None) + output_shape = shape_mapping.get('output', None) + + return input_shape, other_shape, output_shape + + +@operator_registry.register(torch.matmul) +@operator_registry.register(torch.Tensor.matmul) +class MatMulHandler(MetaInfoNodeHandler): + """ + The MatMulHandler is a node handler which handles the sharding strategy generation for the matmul operation. + According to https://pytorch.org/docs/stable/generated/torch.matmul.html, the operations will vary depending on + the operands. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + # check which type of operation this matmul will call + self.input_meta_data = self.node.args[0]._meta_data + self.other_meta_data = self.node.args[1]._meta_data + self.output_meta_data = self.node._meta_data + + input_dim = self.input_meta_data.dim() + other_dim = self.other_meta_data.dim() + self.matmul_type = get_matmul_type(input_dim, other_dim) + + if self.matmul_type == MatMulType.BMM: + # bmm operation can possibly involve padding, broadcasting and view + # these transforms will be used to create logical shape and + # recover physical sharding spec + self.transforms = [Padder(), Broadcaster(), Viewer()] + else: + self.transforms = None + + def get_strategy_generator(self) -> List[StrategyGenerator]: + generators = [] + op_data_mapping = self.get_operation_data_mapping() + if self.matmul_type == MatMulType.BMM: + generators.append(BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh)) + elif self.matmul_type == MatMulType.DOT: + generators.append(DotProductStrategyGenerator(op_data_mapping, self.device_mesh)) + elif self.matmul_type == MatMulType.MV: + generators.append(MatVecStrategyGenerator(op_data_mapping, self.device_mesh)) + elif self.matmul_type == MatMulType.MM: + generators.append( + LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear')) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + logical_shape_func = { + MatMulType.DOT: self._get_logical_shape_for_dot, + MatMulType.MM: self._get_logical_shape_for_mm, + MatMulType.MV: self._get_logical_shape_for_mv, + MatMulType.BMM: self._get_logical_shape_for_bmm + } + logical_shapes = logical_shape_func[self.matmul_type]() + op_data_mapping = self._get_op_data_mapping(*logical_shapes) + return op_data_mapping + + def _get_op_data_mapping(self, input_logical_shape, other_logical_shape, output_logical_shape): + # convert list to torch.Size + if input_logical_shape: + input_logical_shape = torch.Size(input_logical_shape) + + if other_logical_shape: + other_logical_shape = torch.Size(other_logical_shape) + + if output_logical_shape: + output_logical_shape = torch.Size(output_logical_shape) + + # create op data + input_op_data = OperationData(name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.input_meta_data, + logical_shape=input_logical_shape) + other_op_data = OperationData(name=str(self.node.args[1]), + type=OperationDataType.ARG, + data=self.other_meta_data, + logical_shape=other_logical_shape) + output_op_data = OperationData(name=str(self.node), + type=OperationDataType.OUTPUT, + data=self.output_meta_data, + logical_shape=output_logical_shape) + + mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data} + return mapping + + def _get_logical_shape_for_dot(self): + """ + The operands for the dot operation have the same logical shape as the physical shape + """ + return None, None, None + + def _get_logical_shape_for_mm(self): + """ + We need to handle the input tensor for a matrix-matrix multiplcation as the input + tensor can be a 1D or 2D tensor. If it is a 1D tensor, 1 will be prepended to its shape + (e.g. [4] -> [1, 4]). + """ + if self.input_meta_data.dim() == 1: + input_logical_shape = [1] + list(self.input_meta_data.shape) + input_logical_shape = torch.Size(input_logical_shape) + else: + input_logical_shape = None + return input_logical_shape, None, None + + def _get_logical_shape_for_mv(self): + """ + No broadcasting or dim insertion occurs for matrix-vector operation. + """ + return None, None, None + + def _get_logical_shape_for_bmm(self): + input_physical_shape = list(self.input_meta_data.shape) + other_physical_shape = list(self.other_meta_data.shape) + return _get_bmm_logical_shape(input_physical_shape, other_physical_shape, self.transforms) + + def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]: + if self.matmul_type in [MatMulType.DOT, MatMulType.MV]: + return strategy + elif self.matmul_type == MatMulType.MM: + if self.input_meta_data.dim() == 1: + # if a 1 is prepended to the input shape (this occurs when input is a 1D tensor) + # we need to remove that dim + input_sharding_spec = strategy.get_sharding_spec_by_name(str(self.node.args[0])) + input_physical_shape = self.node.args[0]._meta_data.shape + dim_partition_dict = input_sharding_spec.dim_partition_dict + + # remove the partitioning in the dim 0 + if 0 in dim_partition_dict: + dim_partition_dict.pop(0, None) + + # move the partitioning in dim 1 to dim 0 + if -1 in dim_partition_dict: + shard = dim_partition_dict.pop(-1) + dim_partition_dict[0] = shard + if 1 in dim_partition_dict: + shard = dim_partition_dict.pop(1) + dim_partition_dict[0] = shard + + # re-init the sharding spec + input_sharding_spec.__init__(input_sharding_spec.device_mesh, + entire_shape=input_physical_shape, + dim_partition_dict=dim_partition_dict) + return strategy + else: + return strategy + elif self.matmul_type == MatMulType.BMM: + op_data_mapping = self.get_operation_data_mapping() + + strategies = [strategy] + # recover the physical sharding spec + for transform in self.transforms[::-1]: + recovered_stragies = [] + for strategy_ in strategies: + output = transform.recover(op_data_mapping, strategy_) + if isinstance(output, ShardingStrategy): + recovered_stragies.append(output) + elif isinstance(output, (list, tuple)): + recovered_stragies.extend(output) + else: + raise TypeError( + f"Found unexpected output type {type(output)} from the recover method of BmmTransform") + strategies = recovered_stragies + for index, strategies in enumerate(strategies): + strategies.name = f"{strategies.name}_{index}" + return strategies diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..ab391ebfaf80960ef49a4e9c4761c76f82567d25 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py @@ -0,0 +1,338 @@ +from abc import ABC, abstractmethod +from typing import Dict, List, Tuple, Union + +import torch +from torch.fx.node import Node + +from colossalai.auto_parallel.meta_profiler.shard_metainfo import ShardMetaInfo, meta_register +from colossalai.auto_parallel.tensor_shard.options import ShardOption, SolverPerference +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + OperationData, + OperationDataType, + ShardingSpec, + ShardingStrategy, + StrategiesVector, + TrainCycleItem, +) +from colossalai.auto_parallel.tensor_shard.utils import check_sharding_spec_validity +from colossalai.device.device_mesh import DeviceMesh +from colossalai.logging import get_dist_logger +from colossalai.tensor.shape_consistency import ShapeConsistencyManager + +from .strategy import StrategyGenerator + + +class NodeHandler(ABC): + ''' + The NodeHandler is an abstract class used to generate every possible strategies for an operator node. + + Args: + node (Node): the input node in node argument list. + device_mesh (DeviceMesh): A logical view of a physical mesh. + strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector. + ''' + + def __init__(self, + node: Node, + device_mesh: DeviceMesh, + strategies_vector: StrategiesVector, + shard_option: ShardOption = ShardOption.STANDARD, + solver_perference: SolverPerference = SolverPerference.STANDARD) -> None: + self.node = node + self.predecessor_node = list(node._input_nodes.keys()) + self.successor_node = list(node.users.keys()) + self.device_mesh = device_mesh + self.strategies_vector = strategies_vector + self.shard_option = shard_option + self.solver_perference = solver_perference + + def update_resharding_cost(self, strategy: ShardingStrategy) -> None: + """ + Compute the resharding costs and save the costs in the ShardingStrategy object. + """ + # TODO: test this function when other handlers are ready + resharding_costs = {} + shape_consistency_manager = ShapeConsistencyManager() + + for node in self.predecessor_node: + node_name = str(node) + # get the current sharding spec generated by this node handler + + # we will not compute the resharding costs for the node not counted in the strategy. + # And the node with tuple or list output need to be handled below. + node_in_strategy = [op_data.name for op_data in strategy.sharding_specs.keys()] + if str(node) not in node_in_strategy: + continue + + op_data = strategy.get_op_data_by_name(node_name) + current_sharding_spec = strategy.sharding_specs[op_data] + # get the sharding specs for this node generated + # in its own node handler + assert hasattr(node, 'strategies_vector'), \ + f'The predecessor node {node_name} has no strategy vector to compute the resharding cost.' + prev_strategy_vector = node.strategies_vector + prev_sharding_specs = [ + prev_strategy.get_sharding_spec_by_name(node_name) for prev_strategy in prev_strategy_vector + ] + + # create data structrure to store costs + if node not in resharding_costs: + resharding_costs[node] = [] + + def _compute_resharding_cost( + prev_sharding_spec: Union[ShardingSpec, + List[ShardingSpec]], current_sharding_spec: Union[ShardingSpec, + List[ShardingSpec]], + data: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]) -> TrainCycleItem: + """ + This is a helper function to compute the resharding cost for a specific strategy of a node. + """ + if prev_sharding_spec is None: + return TrainCycleItem(fwd=0, bwd=0, total=0) + elif isinstance(prev_sharding_spec, ShardingSpec): + if isinstance(data, torch.Tensor): + dtype = data.dtype + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + _, _, consistency_cost = shape_consistency_manager.shape_consistency( + prev_sharding_spec, current_sharding_spec) + + resharding_cost = TrainCycleItem(fwd=consistency_cost["forward"] * size_per_elem_bytes, + bwd=consistency_cost["backward"] * size_per_elem_bytes, + total=consistency_cost["total"] * size_per_elem_bytes) + return resharding_cost + else: + # This raise is used to check if we have missed any type of data. + # It could be merged into Parameter branch, which means we won't handle + # non-tensor arguments. + raise ValueError(f'Unsupported data type {type(data)}') + else: + assert isinstance(prev_sharding_spec, (tuple, list)), \ + f'prev_sharding_spec should be in type of ShardingSpec, List[ShardingSpec], \ + or Tuple[ShardingSpec], but got {type(prev_sharding_spec)}' + + fwd_cost = 0 + bwd_cost = 0 + total_cost = 0 + for index, (prev_sharding_spec_item, + current_sharding_spec_item) in enumerate(zip(prev_sharding_spec, + current_sharding_spec)): + item_cost = _compute_resharding_cost(prev_sharding_spec_item, current_sharding_spec_item, + data[index]) + fwd_cost += item_cost.fwd + bwd_cost += item_cost.bwd + total_cost += item_cost.total + resharding_cost = TrainCycleItem(fwd=fwd_cost, bwd=bwd_cost, total=total_cost) + return resharding_cost + + # for each sharding spec generated by the predecessor's node handler + # compute the resharding cost to switch to the sharding spec generated + # by the current node handler + for prev_sharding_spec in prev_sharding_specs: + resharding_cost = _compute_resharding_cost(prev_sharding_spec, current_sharding_spec, op_data.data) + resharding_costs[node].append(resharding_cost) + strategy.resharding_costs = resharding_costs + return strategy + + def get_target_function(self) -> callable: + """ + This function is used to get the target function for the node handler. + The target function is used to analyze the costs of strategies. + """ + if self.node.op in ('placeholder', 'get_attr', 'output'): + return None + + if self.node.op == 'call_module': + target = self.node.graph.owning_module.get_submodule(self.node.target) + elif self.node.op == 'call_function': + target = self.node.target + elif self.node.op == 'call_method': + target = getattr(self.node.args[0]._meta_data.__class__, self.node.target) + else: + raise ValueError(f'Unsupported node type: {self.node.op}') + + return target + + def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector: + """ + Register different sharding strategies for the current node. + """ + strategy_generators = self.get_strategy_generator() + for generator in strategy_generators: + strategies = generator.generate() + + # postprocess a strategy + # postprocess can produce one strategy or multiple strategies + post_processed_strategies_map = map(self.post_process, strategies) + post_processed_strategies = [] + + for strategy in post_processed_strategies_map: + if isinstance(strategy, (list, tuple)): + post_processed_strategies.extend(strategy) + else: + post_processed_strategies.append(strategy) + + # compute the resharding costs based on the previous node + # strategies if specified + if compute_resharding_cost: + updated_strategies = map(self.update_resharding_cost, post_processed_strategies) + post_processed_strategies = list(updated_strategies) + + self.strategies_vector.extend(post_processed_strategies) + + # validating the correctness of the sharding strategy + for strategy in self.strategies_vector: + for op_data, sharding_spec in strategy.sharding_specs.items(): + if op_data.data is not None and isinstance(op_data.data, torch.Tensor): + check_sharding_spec_validity(sharding_spec, op_data.data) + + remove_strategy_list = [] + for strategy in self.strategies_vector: + shard_axis_list = [] + last_axis = len(self.device_mesh.mesh_shape) - 1 + for op_data, sharding_spec in strategy.sharding_specs.items(): + if op_data.data is not None and isinstance(op_data.data, torch.Tensor): + for dim, shard_axes in sharding_spec.dim_partition_dict.items(): + for shard_axis in shard_axes: + if shard_axis not in shard_axis_list: + shard_axis_list.append(shard_axis) + + shard_level = len(shard_axis_list) + using_last_axis = last_axis in shard_axis_list or -1 in shard_axis_list + if self.shard_option == ShardOption.SHARD and shard_level == 0: + remove_strategy_list.append(strategy) + if self.shard_option == ShardOption.FULL_SHARD and shard_level <= 1: + remove_strategy_list.append(strategy) + if self.shard_option == ShardOption.SHARD_LAST_AXIS: + if shard_level != 1 or using_last_axis == False: + remove_strategy_list.append(strategy) + + for strategy in remove_strategy_list: + self.strategies_vector.remove(strategy) + + return self.strategies_vector + + def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]: + # tranform the strategy generated + # e.g. to process the sharding strategy for the transposed weights + return strategy + + @abstractmethod + def get_strategy_generator(self) -> List[StrategyGenerator]: + """ + Define which generators should be used by this NodeHandler object. + """ + pass + + @abstractmethod + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + """ + Returns the mapping between the logical operation data to its physical data. + A logical operation data is a data associated with an operation, which can be input and output. It is + defined by the strategy generator, for example, a matrix multiplication operation has two operands "input" + and "other" and one result "output". For a nn.Linear module, the physical operand for "input" is + the module input, the physical operand for "other" is the module weight, and the physical result for "output" + is the module output. + Note that the operand name is specified by the StrategyGenerator object. + + For example: + + # for a linear layer + mapping = { + "input": Operand(name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data), + "other": Operand(name="weight", type=OperationDataType.PARAM, data=self.named_parameters['weight']), + "bias": Operand(name="bias", type=OperationDataType.PARAM, data=self.named_parameters['bias']), + "output": Operand(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data), + } + """ + pass + + +class MetaInfoNodeHandler(NodeHandler): + """ + This is a base class to handle the nodes patched in the meta profiler. + + Note: this class will be integrated into the NodeHandler class in the future, after + all the functions are patched. + """ + + def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector: + """ + This method is inherited from NodeHandler. It will register the strategies first, + and rewrite the memory_cost and compute_cost of the strategy using the ShardMetaInfo class. + """ + super().register_strategy(compute_resharding_cost=compute_resharding_cost) + target = self.get_target_function() + # Currently we haven't patched all the torch functions and modules, so if the target + # is not patched, we will use the default cost model to compute the cost. + # TODO: patch all torch functions and modules to make it clean + if meta_register.has(target.__class__) or meta_register.has(target): + strategies_info = [] + for strategy in self.strategies_vector: + metainfo = ShardMetaInfo(strategy, target) + strategy.compute_cost = metainfo.compute_cost + strategy.memory_cost = metainfo.memory_cost + strategies_info.append(metainfo) + + # attach metainfos to the handler + setattr(self, "strategies_info", strategies_info) + + else: + logger = get_dist_logger() + logger.warning(f'The target function {target} is not patched yet, ') + + return self.strategies_vector + + +class ModuleHandler(NodeHandler): + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + # set attributes to access module parameters for convenience + assert self.node.graph.owning_module is not None, \ + f'The graph is not associated with a module, please make sure it can be used to instantiate a GraphModule object.' + module = self.node.graph.owning_module.get_submodule(self.node.target) + named_parameters = list(module.named_parameters(recurse=False)) + named_buffers = list(module.named_buffers(recurse=False)) + # convert named parameters from list to dict + named_parameters = {k: v for k, v in named_parameters} + named_buffers = {k: v for k, v in named_buffers} + self.module = module + self.named_parameters = named_parameters + self.named_buffers = named_buffers + + +class MetaInfoModuleHandler(ModuleHandler): + """ + This is a base class to handle the module patched in the meta profiler. + + Note: this class will be integrated into the ModuleHandler class in the future, after + all the modules are patched. + """ + + def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector: + """ + This method is inherited from NodeHandler. It will register the strategies first, + and rewrite the memory_cost and compute_cost of the strategy using the ShardMetaInfo class. + """ + super().register_strategy(compute_resharding_cost=compute_resharding_cost) + target = self.get_target_function() + # Currently we haven't patched all the torch functions and modules, so if the target + # is not patched, we will use the default cost model to compute the cost. + # TODO: patch all torch functions and modules to make it clean + if meta_register.has(target.__class__) or meta_register.has(target): + strategies_info = [] + for strategy in self.strategies_vector: + metainfo = ShardMetaInfo(strategy, target) + strategy.compute_cost = metainfo.compute_cost + strategy.memory_cost = metainfo.memory_cost + strategies_info.append(metainfo) + + # attach metainfos to the handler + setattr(self, "strategies_info", strategies_info) + + else: + logger = get_dist_logger() + logger.warning(f'The target function {target} is not patched yet') + + return self.strategies_vector diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..4e71ccba95a7e6457309a455986400dc49893d18 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py @@ -0,0 +1,41 @@ +from typing import Dict, List + +import torch + +from ..sharding_strategy import OperationData, OperationDataType +from .node_handler import MetaInfoModuleHandler, ModuleHandler +from .registry import operator_registry +from .strategy import NormalPoolStrategyGenerator, StrategyGenerator + +__all__ = ['NormPoolingHandler'] + + +@operator_registry.register(torch.nn.MaxPool1d) +@operator_registry.register(torch.nn.MaxPool2d) +@operator_registry.register(torch.nn.MaxPool1d) +@operator_registry.register(torch.nn.AvgPool1d) +@operator_registry.register(torch.nn.AvgPool2d) +@operator_registry.register(torch.nn.AvgPool3d) +class NormPoolingHandler(MetaInfoModuleHandler): + """ + A NormPoolingHandler which deals with the sharding strategies for nn.MaxPoolxd module. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(NormalPoolStrategyGenerator(op_data_mapping, self.device_mesh)) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + physical_input_operand = OperationData(name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.node.args[0]._meta_data) + physical_weight_operand = OperationData(name="kernel", type=OperationDataType.ARG, data=self.module.kernel_size) + physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + + mapping = {"input": physical_input_operand, "other": physical_weight_operand, "output": physical_output} + + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..ed120a8c3d6df9b5d10f44f2b86be1c3cf283c10 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py @@ -0,0 +1,52 @@ +from typing import Dict, List + +import torch + +from colossalai.device.device_mesh import DeviceMesh + +from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector +from .node_handler import NodeHandler +from .strategy import OutputGenerator, StrategyGenerator + +__all__ = ['OutputHandler'] + + +class OutputHandler(NodeHandler): + """ + A OutputHandler which deals with the sharding strategies for Output Node. + """ + + def __init__(self, node: torch.fx.node.Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector, + output_option: str) -> None: + super().__init__(node, device_mesh, strategies_vector) + self.output_option = output_option + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(OutputGenerator(op_data_mapping, self.device_mesh, self.predecessor_node, self.output_option)) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + mapping = {} + output_meta_data = [] + for index, input_node in enumerate(self.predecessor_node): + input_meta_data = input_node._meta_data + physical_inputs = OperationData(name=str(input_node), type=OperationDataType.ARG, data=input_meta_data) + name_key = f'input_{index}' + mapping[name_key] = physical_inputs + output_meta_data.append(input_meta_data) + + assert len(output_meta_data) > 0, f'Output node {self.node} has no input node.' + if len(output_meta_data) == 1: + output_meta_data = output_meta_data[0] + else: + output_meta_data = tuple(output_meta_data) + + self.node._meta_data = output_meta_data + physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + + mapping["output"] = physical_output + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/permute_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/permute_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..91e4a5105a08ff7d28cebba41f4962daa951259c --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/permute_handler.py @@ -0,0 +1,75 @@ +from typing import Dict, List + +import torch + +from ..sharding_strategy import OperationData, OperationDataType +from .node_handler import NodeHandler +from .registry import operator_registry +from .strategy import PermuteGenerator, StrategyGenerator + +__all__ = ['PermuteHandler'] + + +@operator_registry.register(torch.Tensor.permute) +@operator_registry.register(torch.permute) +class PermuteHandler(NodeHandler): + """ + A PermuteHandler which deals with the sharding strategies for torch.permute or torch.transpose. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(PermuteGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # check if the input operand is a parameter + if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + + input_data = self.node.args[0]._meta_data + physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data) + + permute_dims = [] + if self.node.op == 'call_method': + # torch.Tensor.permute (input, *dims) + for arg in self.node.args: + if isinstance(arg, torch.fx.Node): + if isinstance(arg._meta_data, int): + permute_dims.append(arg._meta_data) + else: + assert isinstance(arg, int), 'The argument in permute node should be either type of Node or int.' + permute_dims.append(arg) + else: + # torch.permute (input, dims) + for arg in self.node.args: + if isinstance(arg, torch.fx.Node): + if isinstance(arg._meta_data, (tuple, list)): + permute_dims.extend(arg._meta_data) + else: + assert isinstance( + arg, + (tuple, list)), 'The argument in permute node should be type of Node, Tuple[int] or List[int].' + permute_dims.extend(arg) + + num_dims = self.node._meta_data.dim() + for i in range(num_dims): + # recover negative value to positive + if permute_dims[i] < 0: + permute_dims[i] += num_dims + + physical_shape_operand = OperationData(name='permute_dims', type=OperationDataType.ARG, data=list(permute_dims)) + + output_data = self.node._meta_data + physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) + + mapping = { + "input": physical_input_operand, + "permute_dims": physical_shape_operand, + "output": physical_output_operand + } + + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..e4f40fc935a404dd8625c82fbb4dc7511c9fc839 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py @@ -0,0 +1,38 @@ +from typing import Dict, List + +from torch.fx.node import Node + +from colossalai.device.device_mesh import DeviceMesh + +from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector +from .node_handler import NodeHandler +from .strategy import PlaceholderGenerator, StrategyGenerator + +__all__ = ['PlaceholderHandler'] + + +class PlaceholderHandler(NodeHandler): + """ + A PlaceholderHandler which deals with the sharding strategies for Placeholder Node. + """ + + def __init__(self, node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector, + placeholder_option: str) -> None: + super().__init__(node, device_mesh, strategies_vector) + self.placeholder_option = placeholder_option + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append( + PlaceholderGenerator(op_data_mapping, self.device_mesh, placeholder_option=self.placeholder_option)) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + + mapping = {"output": physical_output} + + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/registry.py b/colossalai/auto_parallel/tensor_shard/node_handler/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..8e06cec4f463a8600b2abe1a7f6713ec2ffb2931 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/registry.py @@ -0,0 +1,30 @@ +class Registry: + # TODO: refactor the registry classes used in colossalai.registry, colossalai.fx and here + + def __init__(self, name): + self.name = name + self.store = {} + + def register(self, source): + + def wrapper(func): + if isinstance(source, (list, tuple)): + # support register a list of items for this func + for element in source: + self.store[element] = func + else: + self.store[source] = func + return func + + return wrapper + + def get(self, source): + assert source in self.store, f'{source} not found in the {self.name} registry' + target = self.store[source] + return target + + def has(self, source): + return source in self.store + + +operator_registry = Registry('operator') diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..743a1f90eaafa869b3a62882648cbde53f9e3166 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py @@ -0,0 +1,55 @@ +from typing import Dict, List + +import torch + +from ..sharding_strategy import OperationData, OperationDataType +from .node_handler import NodeHandler +from .registry import operator_registry +from .strategy import SoftmaxGenerator, StrategyGenerator + +__all__ = ['SoftmaxHandler'] + + +@operator_registry.register(torch.nn.Softmax) +@operator_registry.register(torch.nn.functional.softmax) +class SoftmaxHandler(NodeHandler): + """ + A SoftmaxHandler which deals with the sharding strategies for + torch.nn.Softmax or torch.nn.functional.softmax. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(SoftmaxGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # check if the input operand is a parameter + if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + + input_data = self.node.args[0]._meta_data + physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data) + + softmax_dim = self.node.kwargs['dim'] + + num_dims = self.node.args[0]._meta_data.dim() + # recover negative value to positive + if softmax_dim < 0: + softmax_dim += num_dims + + physical_dim_operand = OperationData(name='softmax_dim', type=OperationDataType.ARG, data=softmax_dim) + + output_data = self.node._meta_data + physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) + + mapping = { + "input": physical_input_operand, + "softmax_dim": physical_dim_operand, + "output": physical_output_operand + } + + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/split_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/split_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..653d158b7c36ee1ff27791add2edad4093ce8675 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/split_handler.py @@ -0,0 +1,62 @@ +from typing import Dict, List + +import torch + +from ..sharding_strategy import OperationData, OperationDataType +from .node_handler import NodeHandler +from .registry import operator_registry +from .strategy import SplitGenerator, StrategyGenerator + +__all__ = ['SplitHandler'] + + +@operator_registry.register(torch.Tensor.split) +@operator_registry.register(torch.split) +class SplitHandler(NodeHandler): + """ + A SplitHandler which deals with the sharding strategies for torch.permute or torch.split. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(SplitGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # check if the input operand is a parameter + if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + + input_data = self.node.args[0]._meta_data + physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data) + split_size = self.node.args[1] + if len(self.node.args) == 3: + # (input, split_size, split_dim) + split_dim = self.node.args[2] + else: + if self.node.kwargs: + split_dim = self.node.kwargs['dim'] + else: + split_dim = 0 + + num_dims = self.node.args[0]._meta_data.dim() + # recover negative value to positive + if split_dim < 0: + split_dim += num_dims + + split_info = (split_size, split_dim) + physical_shape_operand = OperationData(name='split_info', type=OperationDataType.ARG, data=split_info) + + output_data = self.node._meta_data + physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) + + mapping = { + "input": physical_input_operand, + "split_info": physical_shape_operand, + "output": physical_output_operand + } + + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..db1f31521c86ef1842e93d9bbdbc58953e11934d --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py @@ -0,0 +1,39 @@ +from .batch_norm_generator import BatchNormStrategyGenerator +from .binary_elementwise_generator import BinaryElementwiseStrategyGenerator +from .conv_strategy_generator import ConvStrategyGenerator +from .embedding_generator import EmbeddingStrategyGenerator +from .getattr_generator import GetattrGenerator +from .getitem_generator import GetItemStrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator +from .layer_norm_generator import LayerNormGenerator +from .matmul_strategy_generator import ( + BatchedMatMulStrategyGenerator, + DotProductStrategyGenerator, + LinearProjectionStrategyGenerator, + MatVecStrategyGenerator, +) +from .normal_pooling_generator import NormalPoolStrategyGenerator +from .output_generator import OutputGenerator +from .placeholder_generator import PlaceholderGenerator +from .reshape_generator import ( + DefaultReshapeGenerator, + PermuteGenerator, + SplitGenerator, + TransposeGenerator, + ViewGenerator, +) +from .softmax_generator import SoftmaxGenerator +from .strategy_generator import StrategyGenerator +from .sum_generator import SumGenerator +from .tensor_constructor_generator import TensorConstructorGenerator +from .unary_elementwise_generator import UnaryElementwiseGenerator +from .where_generator import WhereGenerator + +__all__ = [ + 'StrategyGenerator', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator', 'LinearProjectionStrategyGenerator', + 'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator', 'UnaryElementwiseGenerator', + 'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator', + 'LayerNormGenerator', 'PlaceholderGenerator', 'OutputGenerator', 'WhereGenerator', 'NormalPoolStrategyGenerator', + 'BinaryElementwiseStrategyGenerator', 'GetattrGenerator', 'TensorConstructorGenerator', + 'EmbeddingStrategyGenerator', 'SumGenerator', 'SoftmaxGenerator', 'ViewGenerator', 'PermuteGenerator', + 'TransposeGenerator', 'SplitGenerator', 'DefaultReshapeGenerator' +] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..1f3812429fc274064163f5859d0fedb04f8115fb --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py @@ -0,0 +1,350 @@ +import copy +import operator +from functools import reduce +from typing import List + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommType, + MemoryCost, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception +from colossalai.tensor.shape_consistency import CollectiveCommPattern + +from .strategy_generator import StrategyGenerator + +__all__ = ['BatchNormStrategyGenerator'] + + +class BatchNormStrategyGenerator(StrategyGenerator): + """ + A StrategyGenerator which deals with the sharding strategies of batch normalization. + + To keep the math consistency, there are two way to do BatchNorm if the input + shards on batch dimension: + 1. We gather the input partitions through batch dimension, then do the normal BatchNorm. + 2. We do the SyncBatchNorm on the each input partition seperately, the SyncBN op will help + us to keep the computing correctness. + In this generator, both methods will be considered. + """ + + def validate(self) -> bool: + ''' + In sanity check, we need make sure the input data having correct dimension size. + For BatchNorm1d, the dim of input data should be 3([N, C, L]). + For BatchNorm2d, the dim of input data should be 4([N, C, H, W]). + For BatchNorm3d, the dim of input data should be 5([N, C, H, W, D]). + ''' + input_op_data = self.op_data['input'] + assert input_op_data.data.dim() in ( + 3, 4, 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].' + + def update_compute_cost(self, strategy: ShardingStrategy): + ''' + Compute the computation cost per device with this specific strategy. + + Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size. + ''' + # TODO: a constant coefficient need to be added. + # 1D: (L) * N * Cin + # 2D: (H * W) * N * Cin + # 3D: (H * W * D) * N * Cin + sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device() + if self.has_bias: + # bias add is an element wise operation, so the cost is equal to product of output shape. + bias_compute_cost = reduce(operator.mul, sharded_output_shape) + input_product = reduce(operator.mul, sharded_input_shape, 1) + forward_compute_cost = input_product + backward_activation_compute_cost = input_product + backward_weight_compute_cost = input_product + backward_compute_cost = backward_weight_compute_cost + backward_activation_compute_cost + if self.has_bias: + forward_compute_cost += bias_compute_cost + backward_compute_cost += bias_compute_cost + total_compute_cost = forward_compute_cost + backward_compute_cost + compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy): + forward_size_mapping = { + 'input': self._compute_size_in_bytes(strategy, "input"), + 'other': self._compute_size_in_bytes(strategy, "other"), + 'output': self._compute_size_in_bytes(strategy, "output"), + 'running_mean': self._compute_size_in_bytes(strategy, "running_mean"), + 'running_var': self._compute_size_in_bytes(strategy, "running_var"), + } + + if self.has_bias: + bias_size = self._compute_size_in_bytes(strategy, "bias") + forward_size_mapping['bias'] = bias_size + + backward_size_mapping = copy.deepcopy(forward_size_mapping) + backward_size_mapping.pop("output") + # compute fwd cost incurred + # fwd_cost = input + other + bias + output + fwd_activation_cost = sum( + [v for k, v in forward_size_mapping.items() if not self.is_param(k) and not self.is_buffer(k)]) + fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)]) + fwd_buffer_cost = sum([v for k, v in forward_size_mapping.items() if self.is_buffer(k)]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost, buffer=fwd_buffer_cost) + + # compute bwd cost incurred + # bwd_cost = input_grad + other_grad + bias_grad + bwd_activation_cost = sum( + [v for k, v in backward_size_mapping.items() if not self.is_param(k) and not self.is_buffer(k)]) + bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)]) + bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, + parameter=fwd_parameter_cost + bwd_parameter_cost, + buffer=fwd_buffer_cost) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + @ignore_sharding_exception + def split_input_channel(self, mesh_dim_0): + name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}' + dim_partition_dict_mapping = { + "input": { + 1: [mesh_dim_0] + }, + "other": { + 0: [mesh_dim_0] + }, + "output": { + 1: [mesh_dim_0] + }, + "running_mean": { + 0: [mesh_dim_0] + }, + "running_var": { + 0: [mesh_dim_0] + }, + "num_batches_tracked": {}, + } + if self.has_bias: + dim_partition_dict_mapping["bias"] = {0: [mesh_dim_0]} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + communication_action_mapping = {} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1): + name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}' + dim_partition_dict_mapping = { + "input": { + 1: [mesh_dim_0, mesh_dim_1] + }, + "other": { + 0: [mesh_dim_0, mesh_dim_1] + }, + "output": { + 1: [mesh_dim_0, mesh_dim_1] + }, + "running_mean": { + 0: [mesh_dim_0, mesh_dim_1] + }, + "running_var": { + 0: [mesh_dim_0, mesh_dim_1] + }, + "num_batches_tracked": {}, + } + if self.has_bias: + dim_partition_dict_mapping["bias"] = {0: [mesh_dim_0, mesh_dim_1]} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + communication_action_mapping = {} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def non_split(self): + name = f'RR = RR x R' + dim_partition_dict_mapping = { + "input": {}, + "other": {}, + "output": {}, + "running_mean": {}, + "running_var": {}, + "num_batches_tracked": {}, + } + if self.has_bias: + dim_partition_dict_mapping["bias"] = {} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + communication_action_mapping = {} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_input_batch(self, mesh_dim_0): + name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN' + dim_partition_dict_mapping = { + "input": { + 0: [mesh_dim_0] + }, + "other": {}, + "output": { + 0: [mesh_dim_0] + }, + "running_mean": {}, + "running_var": {}, + "num_batches_tracked": {}, + } + if self.has_bias: + dim_partition_dict_mapping["bias"] = {} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + # For SyncBN case, we don't need to do communication for weight and bias. + # TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation + # to SyncBN operation instead of inserting a communication node. + output_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping["output"], + communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.IMPLICIT) + + # TODO: Temporary solution has no communication cost, + # above action should be added after the SyncBN replace pass completed. + communication_action_mapping = {} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1): + name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN' + dim_partition_dict_mapping = { + "input": { + 0: [mesh_dim_0, mesh_dim_1] + }, + "other": {}, + "output": { + 0: [mesh_dim_0, mesh_dim_1] + }, + "running_mean": {}, + "running_var": {}, + "num_batches_tracked": {}, + } + if self.has_bias: + dim_partition_dict_mapping["bias"] = {} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + # For SyncBN case, we don't need to do communication for gradients of weight and bias. + # TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation + # to SyncBN operation instead of inserting a communication node. + output_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping["output"], + communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.IMPLICIT) + + # TODO: Temporary solution has no communication cost, + # above action should be added after the SyncBN replace pass completed. + communication_action_mapping = {} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_input_both_dim(self, mesh_dim_0, mesh_dim_1): + name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN' + dim_partition_dict_mapping = { + "input": { + 0: [mesh_dim_0], + 1: [mesh_dim_1], + }, + "other": { + 0: [mesh_dim_1], + }, + "output": { + 0: [mesh_dim_0], + 1: [mesh_dim_1], + }, + "running_mean": { + 0: [mesh_dim_1], + }, + "running_var": { + 0: [mesh_dim_1], + }, + "num_batches_tracked": {}, + } + if self.has_bias: + dim_partition_dict_mapping["bias"] = { + 0: [mesh_dim_1], + } + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + # For SyncBN case, we don't need to do communication for gradients of weight and bias. + # TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation + # to SyncBN operation instead of inserting a communication node. + output_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping["output"], + communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, + logical_process_axis=[mesh_dim_0], + comm_type=CommType.IMPLICIT) + + # TODO: Temporary solution has no communication cost, + # above action should be added after the SyncBN replace pass completed. + communication_action_mapping = {} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def collate_strategies(self) -> List[ShardingStrategy]: + ''' + Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector. + ''' + + strategy_list = [] + # RS = RS x S + strategy_list.append(self.split_input_channel(0)) + strategy_list.append(self.split_input_channel(1)) + + # RR = RR x R + strategy_list.append(self.non_split()) + + # RS01 = RS01 x S01 + strategy_list.append(self.split_input_channel_1d(0, 1)) + + # The strategies with SYNC_BN are temporarily commented, + # because it requires some additional passes to keep runtime + # computation correctness. + + # TODO: The strategies below should be uncommented after runtime + # passes ready. + # SR = SR x R WITH SYNC_BN + strategy_list.append(self.split_input_batch(0)) + strategy_list.append(self.split_input_batch(1)) + + # SS = SS x S WITH SYNC_BN + strategy_list.append(self.split_input_both_dim(0, 1)) + strategy_list.append(self.split_input_both_dim(1, 0)) + + # S01R = S01R x R WITH SYNC_BN + strategy_list.append(self.split_input_batch_1d(0, 1)) + + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/binary_elementwise_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/binary_elementwise_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..fd7f811c8972412eaec88bb1dcfc639cdf1fe630 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/binary_elementwise_generator.py @@ -0,0 +1,111 @@ +import operator +from functools import reduce +from typing import List + +import torch + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem +from colossalai.auto_parallel.tensor_shard.utils import ( + enumerate_all_possible_1d_sharding, + enumerate_all_possible_2d_sharding, + ignore_sharding_exception, +) +from colossalai.tensor.sharding_spec import ShardingSpecException + +from .strategy_generator import StrategyGenerator + +__all__ = ['BinaryElementwiseStrategyGenerator'] + + +class BinaryElementwiseStrategyGenerator(StrategyGenerator): + """ + An BinaryElementwiseStrategyGenerator is a node handler which deals with elementwise operations + which have two operands and broadcasting occurs such as torch.add. + + The logical shape for this operation will be `input other`. + """ + + def validate(self) -> bool: + assert len(self.op_data) == 3, \ + f'BinaryElementwiseStrategyGenerator only accepts three operation data (input, other and output), but got {len(self.op_data)}' + for name, op_data in self.op_data.items(): + if not isinstance(op_data.data, (torch.Tensor, int, float)): + raise TypeError(f'The operation data {name} is not a torch.Tensor/int/float.') + + def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: + shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + + # since elementwise ops are not compute-intensive, + # we approximate the backward compute cost + # to be twice the fwd compute cost + fwd_compute_cost = reduce(operator.mul, shape) + bwd_compute_cost = fwd_compute_cost * 2 + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, + bwd=bwd_compute_cost, + total=fwd_compute_cost + bwd_compute_cost) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: + # all input, output and outputs have the same shape + shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + + # compute fwd memory cost in bytes + # as the elementwise ops are not memory-intensive + # we approximate the fwd memroy cost to be the output + # and the backward memory cost to be grad of input and other + input_bytes = self._compute_size_in_bytes(strategy, 'input') + other_bytes = self._compute_size_in_bytes(strategy, 'other') + output_bytes = self._compute_size_in_bytes(strategy, 'output') + fwd_memory_cost = MemoryCost(activation=output_bytes) + bwd_memory_cost = MemoryCost(activation=input_bytes + other_bytes) + total_memory_cost = MemoryCost(activation=input_bytes + other_bytes + output_bytes) + memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_memory_cost) + strategy.memory_cost = memory_cost + + @ignore_sharding_exception + def enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1): + # we check for the output logical shape to get the number of dimensions + dim_partition_list = [] + dim_size = len(self.op_data['output'].logical_shape) + + # enumerate all the 2D sharding cases + sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size) + dim_partition_list.extend(sharding_list_2d) + + # enumerate all the 1D sharding cases + sharding_list_1d_on_dim_0 = enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size) + dim_partition_list.extend(sharding_list_1d_on_dim_0) + sharding_list_1d_on_dim_1 = enumerate_all_possible_1d_sharding(mesh_dim_1, dim_size) + dim_partition_list.extend(sharding_list_1d_on_dim_1) + + # add empty dict for fully replicated case + dim_partition_list.append({}) + + # sharding strategy bookkeeping + strategy_list = [] + + # convert these dim partition dict to sharding strategy + for dim_partition_dict in dim_partition_list: + dim_partition_dict_mapping = dict(input=dim_partition_dict, + other=dim_partition_dict, + output=dim_partition_dict) + + try: + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + communication_action_mapping = {} + + # get name + sharding_seq = sharding_spec_mapping['input'].sharding_sequence + name = f'{sharding_seq} = {sharding_seq} {sharding_seq}' + sharding_strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategy_list.append(sharding_strategy) + except ShardingSpecException: + continue + return strategy_list + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = self.enumerate_all_possible_output(0, 1) + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..c2154b3104d3d52e994a2add25ddc796792e1c66 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py @@ -0,0 +1,584 @@ +import copy +import operator +import warnings +from functools import reduce +from typing import List + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommAction, + CommType, + MemoryCost, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception +from colossalai.tensor.shape_consistency import CollectiveCommPattern + +from .strategy_generator import StrategyGenerator + + +class ConvStrategyGenerator(StrategyGenerator): + """ + ConvStrategyGenerator is a generic class to generate strategies. + The operation data is defined as `output = input x other + bias`. + """ + + def validate(self) -> bool: + ''' + In sanity check, we need make sure the input data having correct dimension size. + For Conv1d, the dim of input data should be 3([N, C, L]). + For Conv2d, the dim of input data should be 4([N, C, H, W]). + For Conv3d, the dim of input data should be 5([N, C, H, W, D]). + ''' + input_op_data = self.op_data['input'] + assert input_op_data.data.dim() in ( + 3, 4, 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].' + + def update_compute_cost(self, strategy: ShardingStrategy): + ''' + Compute the computation cost per device with this specific strategy. + + Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size. + ''' + # TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size. + # 1D: (L) * N * Cout * Cin * kernel + # 2D: (H * W) * N * Cout * Cin * kernel + # 3D: (H * W * D) * N * Cout * Cin * kernel + sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device() + sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device() + if self.has_bias: + # bias add is an element wise operation, so the cost is equal to product of output shape. + bias_compute_cost = reduce(operator.mul, sharded_output_shape) + + output_size = sharded_output_shape[2:] + output_size_product = reduce(operator.mul, output_size) + input_size = sharded_input_shape[2:] + input_size_product = reduce(operator.mul, input_size, 1) + kernel_size = sharded_other_shape[2:] + kernel_size_product = reduce(operator.mul, kernel_size, 1) + batch_size = sharded_input_shape[0] + channel_in = sharded_input_shape[1] + channel_out = sharded_other_shape[1] + + forward_compute_cost = output_size_product * batch_size * channel_in * channel_out * kernel_size_product + + backward_activation_cost = input_size_product * batch_size * channel_in * channel_out * kernel_size_product + backward_weight_cost = output_size_product * batch_size * channel_in * channel_out * kernel_size_product + backward_compute_cost = backward_weight_cost + backward_activation_cost + if self.has_bias: + forward_compute_cost += bias_compute_cost + backward_compute_cost += bias_compute_cost + total_compute_cost = forward_compute_cost + backward_compute_cost + + compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy): + forward_size_mapping = { + 'input': self._compute_size_in_bytes(strategy, "input"), + 'other': self._compute_size_in_bytes(strategy, "other"), + 'output': self._compute_size_in_bytes(strategy, "output") + } + + if self.has_bias: + bias_size = self._compute_size_in_bytes(strategy, "bias") + forward_size_mapping['bias'] = bias_size + + backward_size_mapping = copy.deepcopy(forward_size_mapping) + backward_size_mapping.pop("output") + # compute fwd cost incurred + # fwd_cost = input + other + bias + output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)]) + fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost) + + # compute bwd cost incurred + # bwd_cost = input_grad + other_grad + bias_grad + bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)]) + bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)]) + bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, + parameter=fwd_parameter_cost + bwd_parameter_cost) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + @ignore_sharding_exception + def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1): + name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' + + dim_partition_dict_mapping = { + "input": { + 0: [mesh_dim_0] + }, + "other": { + 1: [mesh_dim_1] + }, + "output": { + 0: [mesh_dim_0], + 1: [mesh_dim_1] + }, + } + if self.has_bias: + dim_partition_dict_mapping["bias"] = {0: [mesh_dim_1]} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + input_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping["input"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_1, + comm_type=CommType.BEFORE, + arg_index=0) + communication_action_mapping = {"input": input_comm_action} + + if self.is_param("other"): + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + + else: + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + arg_index=1) + + communication_action_mapping["other"] = other_comm_action + + if self.has_bias: + if self.is_param('bias'): + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + else: + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + key_for_kwarg='bias') + communication_action_mapping["bias"] = bias_comm_action + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_input_batch(self, mesh_dim_0): + name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR' + + dim_partition_dict_mapping = { + "input": { + 0: [mesh_dim_0] + }, + "other": {}, + "output": { + 0: [mesh_dim_0], + }, + } + if self.has_bias: + dim_partition_dict_mapping["bias"] = {} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + communication_action_mapping = {} + if self.is_param("other"): + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + + else: + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + arg_index=1) + + communication_action_mapping["other"] = other_comm_action + + if self.has_bias: + if self.is_param('bias'): + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + else: + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + key_for_kwarg='bias') + communication_action_mapping["bias"] = bias_comm_action + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1): + name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' + + dim_partition_dict_mapping = { + "input": { + 0: [mesh_dim_0], + 1: [mesh_dim_1], + }, + "other": { + 0: [mesh_dim_1] + }, + "output": { + 0: [mesh_dim_0], + }, + } + if self.has_bias: + dim_partition_dict_mapping["bias"] = {} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + output_comm_action = self.get_communication_action( + sharding_spec_mapping["output"], + communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, + logical_process_axis=mesh_dim_1, + comm_type=CommType.AFTER) + + communication_action_mapping = {"output": output_comm_action} + + if self.is_param("other"): + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + + else: + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + arg_index=1) + communication_action_mapping["other"] = other_comm_action + if self.has_bias: + if self.is_param("bias"): + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + else: + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + key_for_kwarg='bias') + communication_action_mapping["bias"] = bias_comm_action + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1): + name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' + + dim_partition_dict_mapping = { + "input": { + 1: [mesh_dim_0], + }, + "other": { + 0: [mesh_dim_0], + 1: [mesh_dim_1], + }, + "output": { + 1: [mesh_dim_1], + }, + } + + if self.has_bias: + dim_partition_dict_mapping["bias"] = { + 0: [mesh_dim_1], + } + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + output_comm_action = self.get_communication_action( + sharding_spec_mapping["output"], + communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.AFTER) + input_comm_action = self.get_communication_action( + sharding_spec_mapping["input"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_1, + comm_type=CommType.BEFORE, + arg_index=0) + + communication_action_mapping = {"output": output_comm_action, "input": input_comm_action} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_input_in_channel_weight_in_channel(self, mesh_dim_0): + name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R' + + dim_partition_dict_mapping = { + "input": { + 1: [mesh_dim_0], + }, + "other": { + 0: [mesh_dim_0], + }, + "output": {}, + } + + if self.has_bias: + dim_partition_dict_mapping["bias"] = {} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + output_comm_action = self.get_communication_action( + sharding_spec_mapping["output"], + communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.AFTER) + + communication_action_mapping = {"output": output_comm_action} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_weight_out_channel(self, mesh_dim_0): + name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}' + + dim_partition_dict_mapping = { + "input": {}, + "other": { + 1: [mesh_dim_0], + }, + "output": { + 1: [mesh_dim_0], + }, + } + + if self.has_bias: + dim_partition_dict_mapping["bias"] = { + 0: [mesh_dim_0], + } + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + input_comm_action = self.get_communication_action( + sharding_spec_mapping["input"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + arg_index=0) + + communication_action_mapping = {"input": input_comm_action} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def non_split(self): + name = f'RR = RR x RR' + + dim_partition_dict_mapping = { + "input": {}, + "other": {}, + "output": {}, + } + + if self.has_bias: + dim_partition_dict_mapping["bias"] = {} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping={}) + + @ignore_sharding_exception + def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1): + name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR' + + dim_partition_dict_mapping = { + "input": { + 0: [mesh_dim_0, mesh_dim_1], + }, + "other": {}, + "output": { + 0: [mesh_dim_0, mesh_dim_1], + }, + } + + if self.has_bias: + dim_partition_dict_mapping["bias"] = {} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + communication_action_mapping = {} + if self.is_param("other"): + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.HOOK) + else: + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.BEFORE, + arg_index=1) + + communication_action_mapping["other"] = other_comm_action + + if self.has_bias: + if self.is_param("bias"): + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.HOOK) + else: + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.BEFORE, + key_for_kwarg='bias') + communication_action_mapping["bias"] = bias_comm_action + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1): + name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R' + dim_partition_dict_mapping = { + "input": { + 1: [mesh_dim_0, mesh_dim_1], + }, + "other": { + 0: [mesh_dim_0, mesh_dim_1], + }, + "output": {}, + } + + if self.has_bias: + dim_partition_dict_mapping["bias"] = {} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + output_comm_action = self.get_communication_action( + sharding_spec_mapping["output"], + communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.AFTER) + + communication_action_mapping = {"output": output_comm_action} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_1d_parallel_on_out_channel(self, mesh_dim_0, mesh_dim_1): + name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}' + dim_partition_dict_mapping = { + "input": {}, + "other": { + 1: [mesh_dim_0, mesh_dim_1], + }, + "output": { + 1: [mesh_dim_0, mesh_dim_1], + }, + } + + if self.has_bias: + dim_partition_dict_mapping["bias"] = { + 0: [mesh_dim_0, mesh_dim_1], + } + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + input_comm_action = self.get_communication_action( + sharding_spec_mapping["input"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.BEFORE, + arg_index=0) + + communication_action_mapping = {"input": input_comm_action} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def collate_strategies(self) -> List[ShardingStrategy]: + strategies = [] + # SS = SR x RS + strategies.append(self.split_input_batch_weight_out_channel(0, 1)) + strategies.append(self.split_input_batch_weight_out_channel(1, 0)) + + # SR = SR x RR + strategies.append(self.split_input_batch(0)) + strategies.append(self.split_input_batch(1)) + + # SR = SS x SR + strategies.append(self.split_input_both_dim_weight_in_channel(0, 1)) + strategies.append(self.split_input_both_dim_weight_in_channel(1, 0)) + + # RS = RS x SS + strategies.append(self.split_input_in_channel_weight_both_channel(0, 1)) + strategies.append(self.split_input_in_channel_weight_both_channel(1, 0)) + + # RR = RS x SR + strategies.append(self.split_input_in_channel_weight_in_channel(0)) + strategies.append(self.split_input_in_channel_weight_in_channel(1)) + + # RS = RR x RS + strategies.append(self.split_weight_out_channel(0)) + strategies.append(self.split_weight_out_channel(1)) + + # RR= RR x RR + strategies.append(self.non_split()) + + # S01R = S01R x RR + strategies.append(self.split_1d_parallel_on_input_batch(0, 1)) + + # RR = RS01 x S01R + strategies.append(self.split_1d_parallel_on_in_channel(0, 1)) + + # RS01 = RR x RS01 + strategies.append(self.split_1d_parallel_on_out_channel(0, 1)) + + return strategies diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/embedding_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/embedding_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..82a04ab52e739ae3db29efde2a66f30ff24cb8d0 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/embedding_generator.py @@ -0,0 +1,310 @@ +import copy +import operator +import warnings +from functools import reduce +from typing import List + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommAction, + CommType, + MemoryCost, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception +from colossalai.tensor.shape_consistency import CollectiveCommPattern + +from .strategy_generator import StrategyGenerator + + +class EmbeddingStrategyGenerator(StrategyGenerator): + """ + EmbeddingStrategyGenerator is a generic class to generate strategies for nn.Embedding or F.embedding. + The operation data is defined as `output = input x other`. + """ + + def validate(self) -> bool: + return super().validate() + + def update_compute_cost(self, strategy: ShardingStrategy): + ''' + Compute the computation cost per device with this specific strategy. + + Note: The computation cost for the embedding handler is estimated as dense computing now. + It may not be accurate. + ''' + # TODO: estimate the embedding computation cost as sparse operation + sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device() + sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device() + + input_size_product = reduce(operator.mul, sharded_input_shape) + other_size_product = reduce(operator.mul, sharded_other_shape) + output_size_product = reduce(operator.mul, sharded_output_shape) + + forward_compute_cost = input_size_product * other_size_product + + backward_activation_cost = other_size_product * output_size_product / sharded_output_shape[-1] + backward_weight_cost = input_size_product * other_size_product + backward_compute_cost = backward_weight_cost + backward_activation_cost + + total_compute_cost = forward_compute_cost + backward_compute_cost + + compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy): + forward_size_mapping = { + 'input': self._compute_size_in_bytes(strategy, "input"), + 'other': self._compute_size_in_bytes(strategy, "other"), + 'output': self._compute_size_in_bytes(strategy, "output") + } + + backward_size_mapping = copy.deepcopy(forward_size_mapping) + backward_size_mapping.pop("output") + # compute fwd cost incurred + # fwd_cost = input + other + output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)]) + fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost) + + # compute bwd cost incurred + # bwd_cost = input_grad + other_grad + bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)]) + bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)]) + bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, + parameter=fwd_parameter_cost + bwd_parameter_cost) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + @ignore_sharding_exception + def non_split(self): + name = f'RR = R x RR' + + dim_partition_dict_mapping = { + "input": {}, + "other": {}, + "output": {}, + } + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping={}) + + @ignore_sharding_exception + def split_input(self, mesh_dim_0): + name = f'S{mesh_dim_0}R = S{mesh_dim_0} x RR' + + dim_partition_dict_mapping = { + "input": { + 0: [mesh_dim_0] + }, + "other": {}, + "output": { + 0: [mesh_dim_0], + }, + } + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + communication_action_mapping = {} + if self.is_param("other"): + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + + else: + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + arg_index=1) + + communication_action_mapping["other"] = other_comm_action + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_input_and_embedding_dim(self, mesh_dim_0, mesh_dim_1): + name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0} x RS{mesh_dim_1}' + + dim_partition_dict_mapping = { + "input": { + 0: [mesh_dim_0], + }, + "other": { + 1: [mesh_dim_1], + }, + "output": { + 0: [mesh_dim_0], + 1: [mesh_dim_1], + }, + } + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + input_comm_action = self.get_communication_action( + sharding_spec_mapping["input"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_1, + comm_type=CommType.BEFORE, + arg_index=0) + communication_action_mapping = {"input": input_comm_action} + + if self.is_param("other"): + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + + else: + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + arg_index=1) + + communication_action_mapping["other"] = other_comm_action + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_1d_parallel_on_input(self, mesh_dim_0, mesh_dim_1): + name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1} x RR' + + dim_partition_dict_mapping = { + "input": { + 0: [mesh_dim_0, mesh_dim_1] + }, + "other": {}, + "output": { + 0: [mesh_dim_0, mesh_dim_1], + }, + } + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + communication_action_mapping = {} + + if self.is_param("other"): + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.HOOK) + + else: + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.BEFORE, + arg_index=1) + + communication_action_mapping["other"] = other_comm_action + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_embedding_dim(self, mesh_dim_0): + name = f'RS{mesh_dim_0} = R x RS{mesh_dim_0}' + + dim_partition_dict_mapping = { + "input": {}, + "other": { + 1: [mesh_dim_0], + }, + "output": { + 1: [mesh_dim_0], + }, + } + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + input_comm_action = self.get_communication_action( + sharding_spec_mapping["input"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + arg_index=0) + + communication_action_mapping = {"input": input_comm_action} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_1d_parallel_on_embedding_dim(self, mesh_dim_0, mesh_dim_1): + name = f'RS{mesh_dim_0}{mesh_dim_1} = R x RS{mesh_dim_0}{mesh_dim_1}' + + dim_partition_dict_mapping = { + "input": {}, + "other": { + 1: [mesh_dim_0, mesh_dim_1], + }, + "output": { + 1: [mesh_dim_0, mesh_dim_1], + }, + } + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + input_comm_action = self.get_communication_action( + sharding_spec_mapping["input"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.BEFORE, + arg_index=0) + + communication_action_mapping = {"input": input_comm_action} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def collate_strategies(self) -> List[ShardingStrategy]: + strategies = [] + + # RR= R x RR + strategies.append(self.non_split()) + + # SR = S x RR + strategies.append(self.split_input(0)) + strategies.append(self.split_input(1)) + + # SS = S x RS + strategies.append(self.split_input_and_embedding_dim(0, 1)) + strategies.append(self.split_input_and_embedding_dim(1, 0)) + + # S01R = S01 x RR + strategies.append(self.split_1d_parallel_on_input(0, 1)) + + # RS = R x RS + strategies.append(self.split_embedding_dim(0)) + strategies.append(self.split_embedding_dim(1)) + + # RS01 = R x RS01 + strategies.append(self.split_1d_parallel_on_embedding_dim(0, 1)) + + return strategies diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..bbeb9a639c835869634417e5ceff1a2cad082339 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py @@ -0,0 +1,89 @@ +from typing import List + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem +from colossalai.auto_parallel.tensor_shard.utils import ( + enumerate_all_possible_1d_sharding, + enumerate_all_possible_2d_sharding, + ignore_sharding_exception, +) +from colossalai.tensor.sharding_spec import ShardingSpecException + +from .strategy_generator import StrategyGenerator + +__all__ = ['GetattrGenerator'] + + +class GetattrGenerator(StrategyGenerator): + """ + PlaceholderGenerator is a generic class to generate strategies for placeholder node. + """ + + def validate(self) -> bool: + return super().validate() + + def update_compute_cost(self, strategy: ShardingStrategy): + compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy): + ''' + Compute the memory cost per device with this specific strategy. + ''' + forward_size_mapping = {'output': self._compute_size_in_bytes(strategy, "output")} + + # compute fwd cost incurred + # fwd_cost = output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items()]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=0) + + bwd_mem_cost = MemoryCost(activation=0, parameter=0) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=0) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + @ignore_sharding_exception + def enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1): + # we check for the output logical shape to get the number of dimensions + dim_partition_list = [] + dim_size = len(self.op_data['output'].logical_shape) + + # enumerate all the 2D sharding cases + sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size) + dim_partition_list.extend(sharding_list_2d) + + # enumerate all the 1D sharding cases + sharding_list_1d_on_dim_0 = enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size) + dim_partition_list.extend(sharding_list_1d_on_dim_0) + sharding_list_1d_on_dim_1 = enumerate_all_possible_1d_sharding(mesh_dim_1, dim_size) + dim_partition_list.extend(sharding_list_1d_on_dim_1) + + # add empty dict for fully replicated case + dim_partition_list.append({}) + + # sharding strategy bookkeeping + strategy_list = [] + + # convert these dim partition dict to sharding strategy + for dim_partition_dict in dim_partition_list: + dim_partition_dict_mapping = dict(output=dim_partition_dict) + + try: + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + communication_action_mapping = {} + + # get name + name = f"get_attr {sharding_spec_mapping['output'].sharding_sequence}" + sharding_strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategy_list.append(sharding_strategy) + except ShardingSpecException: + continue + + return strategy_list + + def collate_strategies(self) -> List[ShardingStrategy]: + return self.enumerate_all_possible_output(0, 1) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..0aeb2e0d4079ea1b302d580554ed7ca24ab7096d --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py @@ -0,0 +1,171 @@ +import copy +from typing import List + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommType, + MemoryCost, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.logging import get_dist_logger +from colossalai.tensor.shape_consistency import CollectiveCommPattern +from colossalai.tensor.sharding_spec import ShardingSpecException + +from .strategy_generator import FollowingStrategyGenerator + +__all__ = ['GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator'] + + +class GetItemStrategyGenerator(FollowingStrategyGenerator): + """ + GetItemStrategyGenerator is a generic class to generate strategies for operator.getitem. + The operation data is defined as `output = input[other]`. + + There are mainly three use cases: + 1. args_0._meta_data: torch.Tensor, args_1._meta_data: int + 2. args_0._meta_data: torch.Tensor, args_1._meta_data: slice + 3. args_0._meta_data: Tuple[torch.Tensor], args_1._meta_data: int + """ + + def validate(self) -> bool: + return super().validate() + + def update_compute_cost(self, strategy: ShardingStrategy): + compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy): + ''' + Compute the memory cost per device with this specific strategy. + ''' + forward_size_mapping = { + 'input': self._compute_size_in_bytes(strategy, "input"), + 'output': self._compute_size_in_bytes(strategy, "output") + } + + backward_size_mapping = copy.deepcopy(forward_size_mapping) + backward_size_mapping.pop("output") + # compute fwd cost incurred + # fwd_cost = input + output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)]) + fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost) + + # compute bwd cost incurred + # bwd_cost = input_grad + bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)]) + bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)]) + bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, + parameter=fwd_parameter_cost + bwd_parameter_cost) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + +class TensorStrategyGenerator(GetItemStrategyGenerator): + ''' + Deal with case 1 and 2. + ''' + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + getitem_index = self.op_data['index'].data + for index, strategy in enumerate(self.predecessor_node.strategies_vector): + try: + logger = get_dist_logger() + dim_partition_dict_mapping = {} + communication_action_mapping = {} + dim_partition_dict_for_input = copy.deepcopy( + strategy.output_sharding_specs[self.op_data["input"]].dim_partition_dict) + + int_index = False + if isinstance(getitem_index, int): + int_index = True + getitem_dims = [ + 0, + ] + shift_length = 1 + elif isinstance(getitem_index, slice): + getitem_dims = [ + 0, + ] + else: + getitem_dims = [i for i in range(len(getitem_index))] + if isinstance(getitem_index[0], int): + int_index = True + shift_length = len(getitem_index) + + gather_dims = [] + for dim in getitem_dims: + if dim in dim_partition_dict_for_input: + gather_dims.append(dim) + + for dim in gather_dims: + dim_partition_dict_for_input.pop(dim) + dim_partition_dict_for_output = copy.deepcopy(dim_partition_dict_for_input) + + if int_index: + shift_dim_partition_dict_for_output = {} + for dim, mesh_dim_list in dim_partition_dict_for_output.items(): + shift_dim_partition_dict_for_output[dim - shift_length] = mesh_dim_list + dim_partition_dict_for_output = shift_dim_partition_dict_for_output + + dim_partition_dict_mapping = { + "input": dim_partition_dict_for_input, + "output": dim_partition_dict_for_output, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}_{index}' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + except ShardingSpecException as e: + logger.debug(e) + continue + strategy_list.append(strategy) + + for strategy in strategy_list: + self.update_communication_cost(strategy) + self.update_compute_cost(strategy) + self.update_memory_cost(strategy) + + return strategy_list + + +class TensorTupleStrategyGenerator(GetItemStrategyGenerator): + ''' + Deal with case 3. + ''' + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + index = self.op_data["index"].data + + for strategy_index, strategy in enumerate(self.predecessor_node.strategies_vector): + # the sharding spec for input in this case is a tuple of ShardingSpec. + sharding_spec_for_input = strategy.output_sharding_specs[self.op_data["input"]] + dim_partition_dict_for_output = sharding_spec_for_input[index].dim_partition_dict + dim_partition_dict_mapping = {} + communication_action_mapping = {} + dim_partition_dict_mapping = { + "output": dim_partition_dict_for_output, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + sharding_spec_mapping["input"] = sharding_spec_for_input + input_sharding_info = f"get the {index} element from (" + for sharding_spec in sharding_spec_for_input: + input_sharding_info += f'{sharding_spec.sharding_sequence}, ' + input_sharding_info += ")" + name = f'{sharding_spec_mapping["output"].sharding_sequence} = {input_sharding_info}_{strategy_index}' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + strategy_list.append(strategy) + + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..fbb6070f7e82c9a41848c626c6271d1a7b9d73ee --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py @@ -0,0 +1,195 @@ +import copy +import operator +from functools import reduce +from typing import List + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommType, + MemoryCost, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.auto_parallel.tensor_shard.utils import ( + enumerate_all_possible_1d_sharding, + enumerate_all_possible_2d_sharding, + ignore_sharding_exception, +) +from colossalai.tensor.shape_consistency import CollectiveCommPattern + +from .strategy_generator import StrategyGenerator + +__all__ = ['LayerNormGenerator'] + + +class LayerNormGenerator(StrategyGenerator): + """ + LayerNormGenerator is a generic class to generate strategies for LayerNorm operation. + The operation data is defined as `output = input x other + bias`. + """ + + def validate(self) -> bool: + return super().validate() + + def update_compute_cost(self, strategy: ShardingStrategy): + ''' + Compute the computation cost per device with this specific strategy. + + Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size. + ''' + # TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size. + # TODO: a constant coefficient need to be added. + + sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + sharded_weight_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device() + if self.has_bias: + # bias add is an element wise operation, so the cost is equal to product of output shape. + bias_compute_cost = reduce(operator.mul, sharded_weight_shape) + # in LayerNorm context, batch dimensions mean all the dimensions do not join the normalization. + input_batch_shape = sharded_input_shape[:-len(sharded_weight_shape)] + input_batch_product = reduce(operator.mul, input_batch_shape, 1) + norm_kernel_product = reduce(operator.mul, sharded_weight_shape, 1) + forward_compute_cost = input_batch_product * norm_kernel_product + backward_activation_compute_cost = input_batch_product * norm_kernel_product + # To compute gradient of on norm kernel element requires input_batch_product times computation, so + # the total cost is input_batch_product * norm_kernel_product + backward_weight_compute_cost = input_batch_product * norm_kernel_product + backward_compute_cost = backward_activation_compute_cost + backward_weight_compute_cost + if self.has_bias: + forward_compute_cost += bias_compute_cost + backward_compute_cost += bias_compute_cost + total_compute_cost = forward_compute_cost + backward_compute_cost + compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy): + ''' + Compute the memory cost per device with this specific strategy. + ''' + forward_size_mapping = { + 'input': self._compute_size_in_bytes(strategy, "input"), + 'other': self._compute_size_in_bytes(strategy, "other"), + 'output': self._compute_size_in_bytes(strategy, "output") + } + + if self.has_bias: + bias_size = self._compute_size_in_bytes(strategy, "bias") + forward_size_mapping['bias'] = bias_size + + backward_size_mapping = copy.deepcopy(forward_size_mapping) + backward_size_mapping.pop("output") + # compute fwd cost incurred + # fwd_cost = input + other + bias + output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)]) + fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost) + + # compute bwd cost incurred + # bwd_cost = input_grad + other_grad + bias_grad + bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)]) + bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)]) + bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, + parameter=fwd_parameter_cost + bwd_parameter_cost) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + @ignore_sharding_exception + def _generate_strategy_with_dim_partition(self, dim_partition): + dim_partition_dict_mapping = { + "input": dim_partition, + "other": {}, + "output": dim_partition, + } + if self.has_bias: + dim_partition_dict_mapping["bias"] = {} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence} x {sharding_spec_mapping["other"].sharding_sequence}' + total_mesh_dim_list = [] + for mesh_dim_list in dim_partition.values(): + total_mesh_dim_list.extend(mesh_dim_list) + # if there is only one sharding dimension, we should use the value instead of list as logical_process_axis. + if len(total_mesh_dim_list) == 1: + total_mesh_dim_list = total_mesh_dim_list[0] + communication_action_mapping = {} + + other_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=total_mesh_dim_list, + comm_type=CommType.HOOK) + communication_action_mapping["other"] = other_comm_action + + if self.has_bias: + bias_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=total_mesh_dim_list, + comm_type=CommType.HOOK) + communication_action_mapping["bias"] = bias_comm_action + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + return strategy + + def split_input_batch_single_mesh_dim(self, mesh_dim_0, batch_dimension_length): + strategy_list = [] + dim_partition_list = enumerate_all_possible_1d_sharding(mesh_dim_0, batch_dimension_length) + for dim_partition in dim_partition_list: + strategy = self._generate_strategy_with_dim_partition(dim_partition) + strategy_list.append(strategy) + return strategy_list + + def split_input_batch_both_mesh_dim(self, mesh_dim_0, mesh_dim_1, batch_dimension_length): + strategy_list = [] + dim_partition_list = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, batch_dimension_length) + for dim_partition in dim_partition_list: + strategy = self._generate_strategy_with_dim_partition(dim_partition) + strategy_list.append(strategy) + return strategy_list + + @ignore_sharding_exception + def non_split(self): + name = f'RR = RR x R' + dim_partition_dict_mapping = { + "input": {}, + "other": {}, + "output": {}, + } + if self.has_bias: + dim_partition_dict_mapping["bias"] = {} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + communication_action_mapping = {} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def collate_strategies(self) -> List[ShardingStrategy]: + ''' + Generate every possible strategies for a LayerNorm node, and record all strategies into the strategies_vector. + ''' + strategy_list = [] + input_data_dim = len(self.op_data["input"].logical_shape) + weight_data_dim = len(self.op_data["other"].logical_shape) + # in LayerNorm context, batch dimensions mean all the dimensions do not join the normalization. + batch_dimension_length = input_data_dim - weight_data_dim + + # SR = SR x R with single mesh dim on batch dimensions + strategy_list.extend(self.split_input_batch_single_mesh_dim(0, batch_dimension_length)) + strategy_list.extend(self.split_input_batch_single_mesh_dim(1, batch_dimension_length)) + + # SR = SR x R with both mesh dims on batch dimensions + strategy_list.extend(self.split_input_batch_both_mesh_dim(0, 1, batch_dimension_length)) + + # RR = RR x R + strategy_list.append(self.non_split()) + + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..1ce5a08f2d6b70d20f10476309034ab1a26b75d1 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py @@ -0,0 +1,1020 @@ +import operator +from ast import arg +from functools import reduce +from typing import List + +from colossalai.auto_parallel.tensor_shard.options import SolverPerference +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommType, + MemoryCost, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception +from colossalai.tensor.shape_consistency import CollectiveCommPattern + +from .strategy_generator import StrategyGenerator + + +class MatMulStrategyGenerator(StrategyGenerator): + """ + MatMulStrategyGenerator is a generic class to cover all matrix multiplication cases. + The operation data is defined as `output = input x other + bias`. + """ + + def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: + size_mapping = { + 'input': self._compute_size_in_bytes(strategy, "input"), + 'other': self._compute_size_in_bytes(strategy, "other"), + 'output': self._compute_size_in_bytes(strategy, "output") + } + + if self.has_bias: + bias_size = self._compute_size_in_bytes(strategy, "bias") + size_mapping['bias'] = bias_size + + # compute fwd cost incurred + # fwd_cost = input + other + bias + output + fwd_activation_cost = sum([v for k, v in size_mapping.items() if not self.is_param(k)]) + fwd_parameter_cost = sum([v for k, v in size_mapping.items() if self.is_param(k)]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost) + + # compute bwd cost incurred + # bwd_cost = input_grad + bias_grad + bwd_activation_cost = sum([v for k, v in size_mapping.items() if k in ['input', 'other', 'bias']]) + bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=0) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, + parameter=fwd_parameter_cost + 0) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + +class DotProductStrategyGenerator(MatMulStrategyGenerator): + + def validate(self) -> bool: + input_op_data = self.op_data['input'] + other_op_data = self.op_data['other'] + assert input_op_data.data.dim() == 1 and other_op_data.data.dim() == 1 + + def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: + sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + fwd_compute_cost = sharded_input_shape[0] + bwd_compute_cost = fwd_compute_cost * 2 + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, + bwd=bwd_compute_cost, + total=fwd_compute_cost + bwd_compute_cost) + return compute_cost + + @ignore_sharding_exception + def no_split(self): + name = f'R = R dot R' + dim_partition_dict = {"input": {}, "other": {}, "output": {}, 'bias': {}} + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) + communication_action_mapping = {} + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_one_dim(self, mesh_dim): + name = f'R = S{mesh_dim} dot S{mesh_dim}' + + # get sharding spec + dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "output": {}, "bias": {0: [mesh_dim]}} + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) + + # get communication action + output_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['output'], + communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, + logical_process_axis=mesh_dim, + comm_type=CommType.AFTER) + communication_action_mapping = {"output": output_comm_action} + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + + # do not split dimensions for dot product + # R = R dot R + strategy_list.append(self.no_split()) + + # split two tensors in the same dimensions + # S = S dot S + strategy_list.append(self.split_one_dim(0)) + strategy_list.append(self.split_one_dim(1)) + + return strategy_list + + +class MatVecStrategyGenerator(MatMulStrategyGenerator): + + def validate(self) -> bool: + input_op_data = self.op_data['input'] + other_op_data = self.op_data['other'] + assert input_op_data.data.dim() == 2 and other_op_data.data.dim() == 1 + + def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: + sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + fwd_compute_cost = sharded_input_shape[0] + bwd_compute_cost = fwd_compute_cost * 2 + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, + bwd=bwd_compute_cost, + total=fwd_compute_cost + bwd_compute_cost) + return compute_cost + + @ignore_sharding_exception + def no_split(self): + name = "R = R x R" + dim_partition_dict = {"input": {}, "other": {}, "output": {}} + + if self.has_bias: + dim_partition_dict['bias'] = {} + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping={}) + + @ignore_sharding_exception + def split_input_batch(self, mesh_dim): + name = f'S{mesh_dim}R = S{mesh_dim}R x R' + + # get sharding spec + dim_partition_dict = { + "input": { + 0: [mesh_dim] + }, + "other": {}, + "output": { + 0: [mesh_dim] + }, + } + + if self.has_bias: + dim_partition_dict['bias'] = {} + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) + + # get communication action + communication_action_mapping = {} + if self.is_param('other'): + other_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['other'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim, + comm_type=CommType.HOOK) + else: + other_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['other'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim, + comm_type=CommType.BEFORE, + arg_index=1) + communication_action_mapping['other'] = other_comm_action + + if self.has_bias: + if self.is_param('bias'): + bias_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['bias'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim, + comm_type=CommType.HOOK) + else: + bias_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['bias'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim, + comm_type=CommType.BEFORE, + arg_index=2) + communication_action_mapping['bias'] = bias_comm_action + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + + # no split + strategy_list.append(self.no_split()) + + # split the batch dim for the first tensor only + strategy_list.append(self.split_input_batch(0)) + strategy_list.append(self.split_input_batch(1)) + + return strategy_list + + +class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): + + def __init__(self, + operation_data_mapping, + device_mesh, + linear_projection_type='linear', + solver_perference=SolverPerference.STANDARD): + super().__init__(operation_data_mapping, device_mesh) + self.linear_projection_type = linear_projection_type + self.solver_perference = solver_perference + + def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: + # C = AB + # C: [M, N], A: [M, P], B: [P, N] + # fwd cost = MNP (only count mul) + # bwd: 2 x fwd_cost + sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device() + dim_m_val = reduce(operator.mul, sharded_input_shape[:-1]) + dim_n_val = sharded_other_shape[-1] + dim_p_val = sharded_other_shape[0] + + fwd_compute_cost = dim_m_val * dim_n_val * dim_p_val + bwd_compute_cost = fwd_compute_cost * 2 + compute_cost = TrainCycleItem(fwd=bwd_compute_cost, + bwd=bwd_compute_cost, + total=fwd_compute_cost + bwd_compute_cost) + strategy.compute_cost = compute_cost + + def dp_strategies(self) -> List[ShardingStrategy]: + strategies = [] + + # S01R = S01R x RR + strategies.append(self.split_lhs_1st_dim_1d(0, 1)) + + return strategies + + def tp_strategies(self) -> List[ShardingStrategy]: + strategies = [] + + # RR = RS01 x S01R + strategies.append(self.split_lhs_2nd_dim_1d(0, 1)) + + # RS01 = RR x RS01 + strategies.append(self.split_rhs_2nd_dim_1d(0, 1)) + + # RS = RS x SS + strategies.append(self.split_rhs_space_both_contract(0, 1)) + strategies.append(self.split_rhs_space_both_contract(1, 0)) + + # RR= RS x SR + strategies.append(self.recompute_split_both_contract(0)) + strategies.append(self.recompute_split_both_contract(1)) + + # RS = RR x RS + strategies.append(self.split_rhs_space_only(0)) + strategies.append(self.split_rhs_space_only(1)) + + return strategies + + def mix_strategies(self) -> List[ShardingStrategy]: + strategies = [] + + # SS = SR x RS + strategies.append(self.split_lhs_space_rhs_space(0, 1)) + strategies.append(self.split_lhs_space_rhs_space(1, 0)) + + # SR = SS x SR + strategies.append(self.split_lhs_space_both_contract(0, 1)) + strategies.append(self.split_lhs_space_both_contract(1, 0)) + + # RR = RR x RR + strategies.append(self.non_split()) + + return strategies + + def collate_strategies(self) -> List[ShardingStrategy]: + strategies = [] + + if self.solver_perference == SolverPerference.STANDARD: + strategies.extend(self.dp_strategies()) + strategies.extend(self.tp_strategies()) + strategies.extend(self.mix_strategies()) + elif self.solver_perference == SolverPerference.DP: + strategies.extend(self.dp_strategies()) + elif self.solver_perference == SolverPerference.TP: + strategies.extend(self.tp_strategies()) + + return strategies + + @ignore_sharding_exception + def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1): + # handle case SS = SR x RS + name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' + dim_partition_dict_mapping = { + "input": { + 0: [mesh_dim_0] + }, + "other": { + -1: [mesh_dim_1] + }, + "output": { + 0: [mesh_dim_0], + -1: [mesh_dim_1] + }, + } + + # linear bias only has one dimension, but addmm bias has same dimensions + # as the output logically. + if self.linear_projection_type == 'linear': + dim_partition_dict_mapping['bias'] = {-1: [mesh_dim_1]} + elif self.linear_projection_type == 'addmm': + dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0], -1: [mesh_dim_1]} + else: + raise ('Unsupported linear projection type') + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + communication_action_mapping = {} + input_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping["input"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_1, + comm_type=CommType.BEFORE, + arg_index=0) + + if self.is_param('other'): + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + else: + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + arg_index=1) + + communication_action_mapping['input'] = input_comm_action + communication_action_mapping['other'] = other_comm_action + + # we only add allreduce comm action for linear bias, because + # allreduce comm action for addmm bias will be considered in post processing + if self.has_bias and self.linear_projection_type == 'linear': + if self.is_param('bias'): + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + else: + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + key_for_kwarg='bias') + communication_action_mapping['bias'] = bias_comm_action + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): + # handle the case SR = SS x SR + name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' + + # get sharding spec mapping + dim_partition_dict_mapping = { + "input": { + 0: [mesh_dim_0], + -1: [mesh_dim_1] + }, + "other": { + 0: [mesh_dim_1] + }, + "bias": {}, + "output": { + 0: [mesh_dim_0] + }, + } + + # linear bias only has one dimension, but addmm bias has same dimensions + # as the output logically. + if self.linear_projection_type == 'linear': + dim_partition_dict_mapping['bias'] = {} + elif self.linear_projection_type == 'addmm': + dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0]} + else: + raise ('Unsupported linear projection type') + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # get communication action mapping + communication_action_mapping = {} + + output_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping["output"], + communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, + logical_process_axis=mesh_dim_1, + comm_type=CommType.AFTER) + + if self.is_param('other'): + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + else: + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + arg_index=1) + + communication_action_mapping['other'] = other_comm_action + communication_action_mapping['output'] = output_comm_action + + # we only add allreduce comm action for linear bias, because + # allreduce comm action for addmm bias will be considered in post processing + if self.has_bias and self.linear_projection_type == 'linear': + if self.is_param('bias'): + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + else: + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + key_for_kwarg='bias') + communication_action_mapping['bias'] = bias_comm_action + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): + name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' + + # get sharding specs + dim_partition_dict_mapping = { + "input": { + -1: [mesh_dim_0] + }, + "other": { + 0: [mesh_dim_0], + -1: [mesh_dim_1] + }, + "bias": { + -1: [mesh_dim_1] + }, + "output": { + -1: [mesh_dim_1] + }, + } + + # We don't have to do anything special for bias here, because + # the bias is already the same sharding spec as the output. + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # get communication actions + communication_action_mapping = {} + output_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['output'], + communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.AFTER) + input_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['input'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_1, + comm_type=CommType.BEFORE, + arg_index=0) + communication_action_mapping["input"] = input_comm_action + communication_action_mapping['output'] = output_comm_action + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def recompute_split_both_contract(self, mesh_dim): + name = f'RR = RS{mesh_dim} x S{mesh_dim}R' + + # get sharding spec + dim_partition_dict_mapping = { + "input": { + -1: [mesh_dim] + }, + "other": { + 0: [mesh_dim] + }, + "bias": {}, + "output": {}, + } + # We don't have to do anything special for bias here, because + # the bias is already the same sharding spec as the output. + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # get communication action + communication_action_mapping = {} + output_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['output'], + communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, + logical_process_axis=mesh_dim, + comm_type=CommType.AFTER) + + communication_action_mapping['output'] = output_comm_action + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_rhs_space_only(self, mesh_dim): + name = f'RS{mesh_dim} = RR x RS{mesh_dim}' + + # get sharding spec + dim_partition_dict_mapping = { + "input": {}, + "other": { + -1: [mesh_dim] + }, + "bias": { + -1: [mesh_dim] + }, + "output": { + -1: [mesh_dim] + }, + } + # We don't have to do anything special for bias here, because + # the bias is already the same sharding spec as the output. + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # get communication actions + communication_action_mapping = {} + input_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['input'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim, + comm_type=CommType.BEFORE, + arg_index=0) + + communication_action_mapping['input'] = input_comm_action + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1): + name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR' + # get sharding spec + dim_partition_dict_mapping = { + "input": { + 0: [mesh_dim_0, mesh_dim_1] + }, + "other": {}, + "bias": {}, + "output": { + 0: [mesh_dim_0, mesh_dim_1] + }, + } + + # linear bias only has one dimension, but addmm bias has same dimensions + # as the output logically. + if self.linear_projection_type == 'linear': + dim_partition_dict_mapping['bias'] = {} + elif self.linear_projection_type == 'addmm': + dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0, mesh_dim_1]} + else: + raise ('Unsupported linear projection type') + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # get communication action + communication_action_mapping = {} + if self.is_param('other'): + other_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['other'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.HOOK) + else: + other_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['other'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.BEFORE, + arg_index=1) + communication_action_mapping['other'] = other_comm_action + + # we only add allreduce comm action for linear bias, because + # allreduce comm action for addmm bias will be considered in post processing + if self.has_bias and self.linear_projection_type == 'linear': + if self.is_param('bias'): + bias_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['bias'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.HOOK) + else: + bias_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['bias'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.BEFORE, + key_for_kwarg='bias') + communication_action_mapping['bias'] = bias_comm_action + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1): + name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R' + + # get sharding spec + dim_partition_dict_mapping = { + "input": { + -1: [mesh_dim_0, mesh_dim_1] + }, + "other": { + 0: [mesh_dim_0, mesh_dim_1] + }, + "bias": {}, + "output": {}, + } + + # We don't have to do anything special for bias here, because + # the bias is already the same sharding spec as the output. + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # get communication action + communication_action_mapping = {} + output_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['output'], + communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.AFTER) + communication_action_mapping['output'] = output_comm_action + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1): + name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}' + + # get sharding spec + dim_partition_dict_mapping = { + "input": {}, + "other": { + -1: [mesh_dim_0, mesh_dim_1] + }, + "bias": { + -1: [mesh_dim_0, mesh_dim_1] + }, + "output": { + -1: [mesh_dim_0, mesh_dim_1] + }, + } + + # We don't have to do anything special for bias here, because + # the bias is already the same sharding spec as the output. + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # get communication action + communication_action_mapping = {} + input_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['input'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.BEFORE, + arg_index=0) + communication_action_mapping['input'] = input_comm_action + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def non_split(self): + name = f'RR = RR x RR' + + # get sharding spec + dim_partition_dict_mapping = { + "input": {}, + "other": {}, + "bias": {}, + "output": {}, + } + + # We don't have to do anything special for bias here, because + # the bias is already the same sharding spec as the output. + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # get communication action + communication_action_mapping = {} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def validate(self) -> bool: + assert "input" in self.op_data + assert "other" in self.op_data + + # make sure the other has 2 dim + input_data = self.op_data['input'] + other_data = self.op_data['other'] + assert input_data.data.dim() > 0 and other_data.data.dim() == 2 + assert other_data.logical_shape[0] == input_data.logical_shape[-1] + + if self.has_bias: + bias_data = self.op_data['bias'] + assert bias_data.logical_shape[-1] == other_data.logical_shape[-1] + + +class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): + """ + Generate sharding strategies for the batched matrix multiplication. + + A batched matrix multiplication can be viewed as + [b, i, k] x [b, k, j] -> [b, i, j] + + The bias term is considered to have a 2D logical shape. + + Note: This class will be used to generate strategies for torch.bmm + and torch.addbmm. However, the result of torch.addbmm is not correct, + some extra runtime apply actions are required to keep numerical correctness. + """ + + # TODO: torch.addbmm correctness issue need to be fixed. + def __init__(self, *args, **kwargs): + self.squeeze_batch_dim = False + super().__init__(*args, **kwargs) + + def _pop_batch_dim_sharding_for_output(self, dim_partition_dict): + # remove partition dict for dim 0 + dim_partition_dict['output'].pop(0, None) + + # decrease the remaining dim index by 1 + temp_dim_partition = {} + keys = list(dim_partition_dict['output'].keys()) + for key in keys: + val = dim_partition_dict['output'].pop(key) + temp_dim_partition[key - 1] = val + dim_partition_dict['output'].update(temp_dim_partition) + + def validate(self) -> bool: + input_op_data = self.op_data['input'] + other_op_data = self.op_data['other'] + assert len(input_op_data.logical_shape) == 3 or len(other_op_data.logical_shape) == 3 + + if 'bias' in self.op_data: + bias_op_data = self.op_data['bias'] + assert bias_op_data.data.dim() < 3 and len(bias_op_data.logical_shape) == 2 + + def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: + fwd_compute_cost = self.op_data['input'].data.shape[-1] * reduce(operator.mul, + self.op_data['output'].data.shape) + bwd_compute_cost = fwd_compute_cost * 2 + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, + bwd=bwd_compute_cost, + total=fwd_compute_cost + bwd_compute_cost) + strategy.compute_cost = compute_cost + + @ignore_sharding_exception + def split_one_batch_dim(self, mesh_dim): + name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}' + + # get sharding_spec + dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "bias": {}, "output": {0: [mesh_dim]}} + if self.squeeze_batch_dim: + self._pop_batch_dim_sharding_for_output(dim_partition_dict) + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) + + # get communication actions + communication_action_mapping = {} + if self.has_bias: + bias_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['bias'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim, + comm_type=CommType.BEFORE, + arg_index=0) + communication_action_mapping['bias'] = bias_comm_action + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_two_batch_dim(self, mesh_dim_0, mesh_dim_1): + name = f'Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}' + dim_partition_dict = { + "input": { + 0: [mesh_dim_0, mesh_dim_1] + }, + "other": { + 0: [mesh_dim_0, mesh_dim_1] + }, + "bias": {}, + "output": { + 0: [mesh_dim_0, mesh_dim_1] + } + } + if self.squeeze_batch_dim: + self._pop_batch_dim_sharding_for_output(dim_partition_dict) + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) + + # get communication actions + communication_action_mapping = {} + if self.has_bias: + bias_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['bias'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.BEFORE, + arg_index=0) + communication_action_mapping['bias'] = bias_comm_action + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_batch_dim_lhs_space(self, mesh_dim_0, mesh_dim_1): + name = f'Sb{mesh_dim_0}Si{mesh_dim_1} = Sb{mesh_dim_0}Si{mesh_dim_1} x Sb{mesh_dim_0}' + dim_partition_dict = { + "input": { + 0: [mesh_dim_0], + 1: [mesh_dim_1] + }, + "other": { + 0: [mesh_dim_0] + }, + "bias": { + 0: [mesh_dim_1] + }, + "output": { + 0: [mesh_dim_0], + 1: [mesh_dim_1] + } + } + if self.squeeze_batch_dim: + self._pop_batch_dim_sharding_for_output(dim_partition_dict) + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) + + # get communication actions + communication_action_mapping = {} + other_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['other'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_1, + comm_type=CommType.BEFORE, + arg_index=1) + communication_action_mapping['other'] = other_comm_action + + if self.has_bias: + bias_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['bias'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.BEFORE, + arg_index=0) + communication_action_mapping['bias'] = bias_comm_action + # for addbmm case, other is the third argument instead of second. + communication_action_mapping['other'].arg_index += 1 + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_batch_dim_rhs_space(self, mesh_dim_0, mesh_dim_1): + name = f'Sb{mesh_dim_0}Sj{mesh_dim_1} = Sb{mesh_dim_0}R x Sb{mesh_dim_0}Sj{mesh_dim_1}' + dim_partition_dict = { + "input": { + 0: [mesh_dim_0] + }, + "other": { + 0: [mesh_dim_0], + 2: [mesh_dim_1] + }, + "bias": { + 1: [mesh_dim_1] + }, + "output": { + 0: [mesh_dim_0], + 2: [mesh_dim_1] + } + } + if self.squeeze_batch_dim: + self._pop_batch_dim_sharding_for_output(dim_partition_dict) + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) + + # get communication actions + communication_action_mapping = {} + input_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['input'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_1, + comm_type=CommType.BEFORE, + arg_index=0) + communication_action_mapping['input'] = input_comm_action + + if self.has_bias: + bias_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['bias'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE) + communication_action_mapping['bias'] = bias_comm_action + # for addbmm case, other is the second argument instead of first. + communication_action_mapping['input'].arg_index += 1 + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_batch_dim_both_contract(self, mesh_dim_0, mesh_dim_1): + name = f'Sb{mesh_dim_0}R = Sb{mesh_dim_0}Sk{mesh_dim_1} x Sb{mesh_dim_0}Sk{mesh_dim_1}' + dim_partition_dict = { + "input": { + 0: [mesh_dim_0], + 2: [mesh_dim_1] + }, + "other": { + 0: [mesh_dim_0], + 1: [mesh_dim_1] + }, + "bias": {}, + "output": { + 0: [mesh_dim_0], + } + } + if self.squeeze_batch_dim: + self._pop_batch_dim_sharding_for_output(dim_partition_dict) + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) + + # get communication actions + communication_action_mapping = {} + output_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['output'], + communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, + logical_process_axis=mesh_dim_1, + comm_type=CommType.AFTER) + communication_action_mapping['output'] = output_comm_action + + if self.has_bias: + bias_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['bias'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + arg_index=0) + communication_action_mapping['bias'] = bias_comm_action + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + device_mesh_is_1d = True + if len(self.device_mesh.mesh_shape) == 2 and 1 not in self.device_mesh.mesh_shape: + device_mesh_is_1d = False + + if device_mesh_is_1d: + # split only the batch dimension + # Sb = Sb x Sb + # can be None as it is only for 1D device mesh + # only for 1D device mesh + if len(self.device_mesh.mesh_shape) == 1: + mesh_dim = 0 + else: + mesh_dim = self.device_mesh.mesh_shape.index(1) + strategy_list.append(self.split_one_batch_dim(mesh_dim)) + else: + # for 2D device mesh + # split batch dim of two inputs and the i dim of the first tensor + # SbSi = SbSi x Sb + strategy_list.append(self.split_batch_dim_lhs_space(0, 1)) + strategy_list.append(self.split_batch_dim_lhs_space(1, 0)) + + # split batch dim of two inputs and the j of the second tensor + # SbSj = Sb x SbSj + strategy_list.append(self.split_batch_dim_rhs_space(0, 1)) + strategy_list.append(self.split_batch_dim_rhs_space(1, 0)) + + # split batch dim of two inputs and the k dim of two inputs + # Sb = SbSk x SbSk, need to all-reduce by k dim + strategy_list.append(self.split_batch_dim_both_contract(0, 1)) + strategy_list.append(self.split_batch_dim_both_contract(1, 0)) + + # split two batch dim + strategy_list.append(self.split_two_batch_dim(0, 1)) + + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..9df6d2fbfa127b71eba66256fb27f204fb1da5fe --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py @@ -0,0 +1,118 @@ +import copy +import operator +from functools import reduce +from typing import List + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem +from colossalai.auto_parallel.tensor_shard.utils import ( + enumerate_all_possible_1d_sharding, + enumerate_all_possible_2d_sharding, + ignore_sharding_exception, +) + +from .strategy_generator import StrategyGenerator + + +class NormalPoolStrategyGenerator(StrategyGenerator): + """ + NormalPoolStrategyGenerator is a generic class to generate strategies for pool operation like MaxPoolxd. + The reason we call this normal pool is AvgPoolxd and MaxPoolxd are taking the kernel size element from image, + and reduce them depening on the operation type. + """ + + def validate(self) -> bool: + ''' + In sanity check, we need make sure the input data having correct dimension size. + For Pool1d, the dim of input data should be 3([N, C, L]). + For Pool2d, the dim of input data should be 4([N, C, H, W]). + For Pool3d, the dim of input data should be 5([N, C, H, W, D]). + ''' + input_op_data = self.op_data['input'] + assert input_op_data.data.dim() in ( + 3, 4, 5), f'We suppose the dim of input fed into Pool op should in range of [3, 5].' + + def update_compute_cost(self, strategy: ShardingStrategy) -> TrainCycleItem: + ''' + Compute the computation cost per device with this specific strategy. + + Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size. + ''' + # TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size. + # 1D: (Lout) * N * C * kernel + # 2D: (H * W) * N * Cout * Cin * kernel + # 3D: (H * W * D) * N * Cout * Cin * kernel + sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device() + sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + + kernel_size = self.op_data["other"].data + if isinstance(kernel_size, int): + kernel_size = [kernel_size] * (len(sharded_output_shape) - 2) + kernel_size_product = reduce(operator.mul, kernel_size) + output_size_product = reduce(operator.mul, sharded_output_shape) + input_size_product = reduce(operator.mul, sharded_input_shape) + + forward_compute_cost = output_size_product * kernel_size_product + backward_compute_cost = input_size_product * kernel_size_product + + total_compute_cost = forward_compute_cost + backward_compute_cost + + compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: + forward_size_mapping = { + 'input': self._compute_size_in_bytes(strategy, "input"), + 'output': self._compute_size_in_bytes(strategy, "output") + } + + backward_size_mapping = copy.deepcopy(forward_size_mapping) + backward_size_mapping.pop("output") + # compute fwd cost incurred + # fwd_cost = input + output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items()]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=0) + + # compute bwd cost incurred + # bwd_cost = input_grad + bwd_activation_cost = sum([v for k, v in backward_size_mapping.items()]) + bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=0) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, parameter=0) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + @ignore_sharding_exception + def _generate_strategy_with_dim_partition(self, dim_partition): + dim_partition_dict_mapping = {"input": dim_partition, "output": dim_partition} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}' + communication_action_mapping = {} + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + return strategy + + def enumerate_all_possible_batch_dimensions_dim_partition(self, mesh_dim_0, mesh_dim_1): + dim_partition_list = [] + dim_partition_list.extend(enumerate_all_possible_1d_sharding(mesh_dim_0, 2)) + dim_partition_list.extend(enumerate_all_possible_1d_sharding(mesh_dim_1, 2)) + dim_partition_list.extend(enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, 2)) + # append {} for non_split case + dim_partition_list.append({}) + + return dim_partition_list + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + + dim_partition_list = self.enumerate_all_possible_batch_dimensions_dim_partition(0, 1) + for dim_partition in dim_partition_list: + strategy = self._generate_strategy_with_dim_partition(dim_partition) + strategy_list.append(strategy) + + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..69d1642d4f808038d0eeb58547a7b1c0604c85eb --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py @@ -0,0 +1,121 @@ +from typing import Dict, List + +from torch.fx import Node + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + MemoryCost, + OperationData, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.device.device_mesh import DeviceMesh + +from .strategy_generator import OutputStrategyGenerator + +__all__ = ['OutputGenerator'] + + +class OutputGenerator(OutputStrategyGenerator): + """ + OutputGenerator is a generic class to generate strategies for Output Node. + """ + + def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, + predecessor_nodes: List[Node], output_option: str): + super().__init__(operation_data_mapping, device_mesh, predecessor_nodes) + self.output_option = output_option + + def validate(self) -> bool: + return super().validate() + + def update_compute_cost(self, strategy: ShardingStrategy): + compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy): + ''' + Compute the memory cost per device with this specific strategy. + ''' + fwd_mem_cost = MemoryCost(activation=0, parameter=0) + + bwd_mem_cost = MemoryCost(activation=0, parameter=0) + + # compute total cost + total_mem_cost = MemoryCost(activation=0, parameter=0) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + def replica_strategy(self) -> List[ShardingStrategy]: + """ + Generate replica strategy for output node. + """ + dim_partition_dict_mapping = {} + dim_partition_dict_for_output = [] + for index, _ in enumerate(self.predecessor_nodes): + mapping_name = f"input_{index}" + if isinstance(self.op_data[mapping_name].data, (tuple, list)): + dim_partition_dict_for_input = [{} for _ in range(len(self.op_data[mapping_name].data))] + else: + dim_partition_dict_for_input = {} + dim_partition_dict_mapping[mapping_name] = dim_partition_dict_for_input + dim_partition_dict_for_output.append(dim_partition_dict_for_input) + + if len(dim_partition_dict_for_output) == 1: + dim_partition_dict_for_output = dim_partition_dict_for_output[0] + else: + dim_partition_dict_for_output = tuple(dim_partition_dict_for_output) + + dim_partition_dict_mapping['output'] = dim_partition_dict_for_output + + communication_action_mapping = {} + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + name = 'Replica Output' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + return strategy + + def distributed_strategy(self, mesh_list: List[List[int]] = None) -> List[ShardingStrategy]: + """ + Generate distributed strategy for output node. + """ + # TODO: need to take care of the case when the first element of output only need to be sharded. + output_op_data = self.op_data['output'] + if isinstance(output_op_data.data, tuple): + length = len(output_op_data.data) + dim_partition_dict_mapping = { + "output": [{ + 0: mesh_list + }] * length, + } + else: + dim_partition_dict_mapping = { + "output": { + 0: mesh_list + }, + } + for index, _ in enumerate(self.predecessor_nodes): + mapping_name = f"input_{index}" + dim_partition_dict_mapping[mapping_name] = {0: mesh_list} + + communication_action_mapping = {} + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + name = 'Distributed Output' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + return strategy + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + mesh_list = [0, 1] + if self.output_option == 'replicated': + strategy_list.append(self.replica_strategy()) + elif self.output_option == 'distributed': + strategy_list.append(self.distributed_strategy(mesh_list)) + + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..779a7ced93bb503c390bd89382d087230e48d2f0 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py @@ -0,0 +1,100 @@ +from typing import Dict, List + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + MemoryCost, + OperationData, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.device.device_mesh import DeviceMesh + +from .strategy_generator import StrategyGenerator + +__all__ = ['PlaceholderGenerator'] + + +class PlaceholderGenerator(StrategyGenerator): + """ + PlaceholderGenerator is a generic class to generate strategies for placeholder node. + """ + + def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, + placeholder_option: str): + super().__init__(operation_data_mapping, device_mesh) + self.placeholder_option = placeholder_option + + def validate(self) -> bool: + return super().validate() + + def update_compute_cost(self, strategy: ShardingStrategy): + compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy): + ''' + Compute the memory cost per device with this specific strategy. + ''' + forward_size_mapping = {'output': self._compute_size_in_bytes(strategy, "output")} + + # compute fwd cost incurred + # fwd_cost = output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items()]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=0) + + bwd_mem_cost = MemoryCost(activation=0, parameter=0) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=0) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + def replica_placeholder(self) -> ShardingStrategy: + """ + Generate replica strategy for placeholder node. + """ + dim_partition_dict_mapping = { + "output": {}, + } + communication_action_mapping = {} + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + name = 'Replica Placeholder' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + return strategy + + def distributed_placeholder(self, mesh_list) -> ShardingStrategy: + """ + Generate distributed strategy for placeholder node. + """ + dim_partition_dict_mapping = { + "output": { + 0: mesh_list + }, + } + communication_action_mapping = {} + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + name = 'Distributed Placeholder' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + return strategy + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + if self.placeholder_option == 'distributed': + mesh_list = [0, 1] + distributed_strategy = self.distributed_placeholder(mesh_list) + strategy_list.append(distributed_strategy) + else: + assert self.placeholder_option == 'replicated', f'placeholder_option {self.placeholder_option} is not supported' + replicated_strategy = self.replica_placeholder() + strategy_list.append(replicated_strategy) + + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..24f75e352935f149e02c399b2c8e90c0f3ddc2f7 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py @@ -0,0 +1,366 @@ +import copy +from typing import List + +from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommAction, + CommType, + MemoryCost, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.auto_parallel.tensor_shard.utils import ( + check_keep_sharding_status, + detect_reshape_mapping, + infer_output_dim_partition_dict, +) +from colossalai.tensor.shape_consistency import CollectiveCommPattern +from colossalai.tensor.sharding_spec import ShardingSpec + +__all__ = ['ReshapeGenerator', 'ViewGenerator', 'PermuteGenerator', 'TransposeGenerator', 'SplitGenerator'] + + +class ReshapeGenerator(FollowingStrategyGenerator): + """ + ReshapeGenerator is the base class for all the reshape operation. + """ + + def validate(self) -> bool: + return super().validate() + + def update_compute_cost(self, strategy: ShardingStrategy): + compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy): + ''' + Compute the memory cost per device with this specific strategy. + ''' + forward_size_mapping = { + 'input': self._compute_size_in_bytes(strategy, "input"), + 'output': self._compute_size_in_bytes(strategy, "output") + } + + backward_size_mapping = copy.deepcopy(forward_size_mapping) + backward_size_mapping.pop("output") + # compute fwd cost incurred + # fwd_cost = input + output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)]) + fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost) + + # compute bwd cost incurred + # bwd_cost = input_grad + bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)]) + bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)]) + bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, + parameter=fwd_parameter_cost + bwd_parameter_cost) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + def collate_strategies(self) -> List[ShardingStrategy]: + return super().collate_strategies() + + +class ViewGenerator(ReshapeGenerator): + """ + ViewGenerator deals with the sharding strategies of view op. + """ + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + for index, strategy in enumerate(self.predecessor_node.strategies_vector): + dim_partition_dict_mapping = {} + communication_action_mapping = {} + input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] + + origin_shape = self.op_data['input'].data.shape + tgt_shape = self.op_data['tgt_shape'].data + + reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape) + + dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict + keep_sharding_status = check_keep_sharding_status(dim_partition_dict_for_input, reshape_mapping_dict) + + if keep_sharding_status: + dim_partition_dict_for_output = infer_output_dim_partition_dict(dim_partition_dict_for_input, + reshape_mapping_dict) + else: + dim_partition_dict_for_output = {} + + dim_partition_dict_mapping = { + "input": dim_partition_dict_for_input, + "output": dim_partition_dict_for_output, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # add index into name to pass the duplicated check + # we keep same strategies with different name for node merging, and it will not increase the searching space, + # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. + if keep_sharding_status: + name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' + else: + name = f'{sharding_spec_mapping["input"].sharding_sequence} -> FULLY REPLICATED_{index}' + + # add comm action for converting input to fully replicated + total_mesh_dim_list = [] + for mesh_dim_list in dim_partition_dict_for_input.values(): + total_mesh_dim_list.extend(mesh_dim_list) + # if there is only one sharding dimension, we should use the value instead of list as logical_process_axis. + if len(total_mesh_dim_list) == 1: + total_mesh_dim_list = total_mesh_dim_list[0] + # the total mesh dim list only has one element, so the shard dim has only one element as well. + shard_dim = list(dim_partition_dict_for_input.keys())[0] + input_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping["input"], + communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + logical_process_axis=total_mesh_dim_list, + comm_type=CommType.BEFORE, + arg_index=0) + # it will gather the input through gather_dim during forward phase. + input_comm_action.comm_spec.gather_dim = shard_dim + # it will split the input activation grad through shard_dim during backward phase. + input_comm_action.comm_spec.shard_dim = shard_dim + + elif len(total_mesh_dim_list) >= 2: + source_spec = sharding_spec_mapping["input"] + target_spec = ShardingSpec(device_mesh=self.device_mesh, + entire_shape=source_spec.entire_shape, + dim_partition_dict={}) + comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec} + input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0) + + else: + input_comm_action = None + + if input_comm_action is not None: + communication_action_mapping["input"] = input_comm_action + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategy_list.append(strategy) + + return strategy_list + + +class PermuteGenerator(ReshapeGenerator): + """ + PermuteGenerator deals with the sharding strategies of permute op. + """ + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + for index, strategy in enumerate(self.predecessor_node.strategies_vector): + dim_partition_dict_mapping = {} + communication_action_mapping = {} + input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] + + permute_dims = self.op_data['permute_dims'].data + dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict + dim_partition_dict_for_output = {} + for dim_index, permute_dim in enumerate(permute_dims): + if permute_dim in dim_partition_dict_for_input: + dim_partition_dict_for_output[dim_index] = dim_partition_dict_for_input[permute_dim] + + dim_partition_dict_mapping = { + "input": dim_partition_dict_for_input, + "output": dim_partition_dict_for_output, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # add index into name to pass the duplicated check + # we keep same strategies with different name for node merging, and it will not increase the searching space, + # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. + name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategy_list.append(strategy) + + return strategy_list + + +class TransposeGenerator(ReshapeGenerator): + """ + TransposeGenerator deals with the sharding strategies of permute op. + """ + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + for index, strategy in enumerate(self.predecessor_node.strategies_vector): + dim_partition_dict_mapping = {} + communication_action_mapping = {} + input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] + dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict + dim_partition_dict_for_output = {} + + transpose_dims = self.op_data['transpose_dims'].data + dim_0 = transpose_dims[0] + dim_1 = transpose_dims[1] + for dim, sharded_dims in dim_partition_dict_for_input.items(): + if dim == dim_0: + dim_partition_dict_for_output[dim_1] = dim_partition_dict_for_input[dim_0] + elif dim == dim_1: + dim_partition_dict_for_output[dim_0] = dim_partition_dict_for_input[dim_1] + else: + dim_partition_dict_for_output[dim] = sharded_dims + + dim_partition_dict_mapping = { + "input": dim_partition_dict_for_input, + "output": dim_partition_dict_for_output, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # add index into name to pass the duplicated check + # we keep same strategies with different name for node merging, and it will not increase the searching space, + # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. + name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategy_list.append(strategy) + + return strategy_list + + +class SplitGenerator(ReshapeGenerator): + """ + SplitGenerator deals with the sharding strategies of split op. + """ + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + for index, strategy in enumerate(self.predecessor_node.strategies_vector): + recover_dims = None + dim_partition_dict_mapping = {} + communication_action_mapping = {} + input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] + dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict) + split_size, split_dim = self.op_data['split_info'].data + + if split_dim in dim_partition_dict_for_input: + recover_dims = dim_partition_dict_for_input.pop(split_dim) + + dim_partition_dict_for_output = [ + copy.deepcopy(dim_partition_dict_for_input) for _ in range(len(self.op_data["output"].data)) + ] + assert len(dim_partition_dict_for_output) >= 2 + dim_partition_dict_mapping = { + "input": dim_partition_dict_for_input, + "output": dim_partition_dict_for_output, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + # add index into name to pass the duplicated check + # we keep same strategies with different name for node merging, and it will not increase the searching space, + # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. + name = f'{sharding_spec_mapping["input"].sharding_sequence}_{index}' + + # add comm action if the input need to be recovered to replica in the split dimension. + if recover_dims: + # if there is only one sharding dimension, we should use the value instead of list as logical_process_axis. + if len(recover_dims) == 1: + recover_dims = recover_dims[0] + input_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping["input"], + communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + logical_process_axis=recover_dims, + comm_type=CommType.BEFORE, + arg_index=0) + # it will gather the input through gather_dim during forward phase. + input_comm_action.comm_spec.gather_dim = split_dim + # it will split the input activation grad through split_dim during backward phase. + input_comm_action.comm_spec.shard_dim = split_dim + + elif len(recover_dims) >= 2: + # original sharding spec + source_spec = input_sharding_spec + # target sharding spec + target_spec = sharding_spec_mapping["input"] + comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec} + input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0) + + else: + input_comm_action = None + + if input_comm_action is not None: + communication_action_mapping["input"] = input_comm_action + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategy_list.append(strategy) + + return strategy_list + + +class DefaultReshapeGenerator(ReshapeGenerator): + """ + DefaultReshapeGenerator which deals with the sharding strategies of Reshape Op which have to recover the tensor + to Replica status. + """ + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + # For default reshape strategy, to keep the computing correctness we keep the + # sharding spec of input is fully replicated. In addition, we will keep the output + # in replica status and let the successor node choose the way to resharding the + # output node. Therefore, the different strategies of input node with same + # output sharding spec will generate same strategy for reshape function. + for index, strategy in enumerate(self.predecessor_node.strategies_vector): + dim_partition_dict_mapping = {} + communication_action_mapping = {} + input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] + dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict + dim_partition_dict_for_output = {} + if isinstance(self.op_data["output"].data, tuple): + dim_partition_dict_for_output = [{} for _ in range(len(self.op_data["output"].data))] + dim_partition_dict_mapping = { + "input": dim_partition_dict_for_input, + "output": dim_partition_dict_for_output, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + # add index into name to pass the duplicated check + # we keep same strategies with different name for node merging, and it will not increase the searching space, + # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. + name = f'{sharding_spec_mapping["input"].sharding_sequence} -> FULLY REPLICATED_{index}' + + total_mesh_dim_list = [] + for mesh_dim_list in dim_partition_dict_for_input.values(): + total_mesh_dim_list.extend(mesh_dim_list) + # if there is only one sharding dimension, we should use the value instead of list as logical_process_axis. + if len(total_mesh_dim_list) == 1: + total_mesh_dim_list = total_mesh_dim_list[0] + input_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping["input"], + communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + logical_process_axis=total_mesh_dim_list, + comm_type=CommType.BEFORE, + arg_index=0) + input_comm_action.comm_spec.gather_dim = total_mesh_dim_list + input_comm_action.comm_spec.shard_dim = total_mesh_dim_list + + elif len(total_mesh_dim_list) >= 2: + source_spec = sharding_spec_mapping["input"] + target_spec = ShardingSpec(device_mesh=self.device_mesh, + entire_shape=source_spec.entire_shape, + dim_partition_dict={}) + comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec} + input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0) + + else: + input_comm_action = None + + if input_comm_action is not None: + communication_action_mapping["input"] = input_comm_action + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategy_list.append(strategy) + + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/softmax_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/softmax_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..a1ebadd043e2c2e563fcdc611567bca3ededfa51 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/softmax_generator.py @@ -0,0 +1,104 @@ +import copy +import operator +from functools import reduce +from typing import List + +from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommAction, + CommType, + MemoryCost, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.auto_parallel.tensor_shard.utils import ( + check_keep_sharding_status, + detect_reshape_mapping, + infer_output_dim_partition_dict, +) +from colossalai.tensor.shape_consistency import CollectiveCommPattern + +__all__ = ['SoftmaxGenerator'] + + +class SoftmaxGenerator(FollowingStrategyGenerator): + """ + SoftmaxGenerator is used to generate strategies for torch.nn.Softmax or F.softmax. + """ + + def validate(self) -> bool: + return super().validate() + + def update_compute_cost(self, strategy: ShardingStrategy): + ''' + Compute the computation cost per device with this specific strategy. + ''' + sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device() + input_size_product = reduce(operator.mul, sharded_input_shape) + output_size_product = reduce(operator.mul, sharded_output_shape) + + forward_compute_cost = output_size_product * 2 + backward_compute_cost = input_size_product + total_compute_cost = forward_compute_cost + backward_compute_cost + compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy): + ''' + Compute the memory cost per device with this specific strategy. + ''' + forward_size_mapping = { + 'input': self._compute_size_in_bytes(strategy, "input"), + 'output': self._compute_size_in_bytes(strategy, "output") + } + + backward_size_mapping = copy.deepcopy(forward_size_mapping) + backward_size_mapping.pop("output") + # compute fwd cost incurred + # fwd_cost = input + output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)]) + fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost) + + # compute bwd cost incurred + # bwd_cost = input_grad + bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)]) + bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)]) + bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, + parameter=fwd_parameter_cost + bwd_parameter_cost) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + for index, strategy in enumerate(self.predecessor_node.strategies_vector): + dim_partition_dict_mapping = {} + communication_action_mapping = {} + input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] + dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict) + softmax_dim = self.op_data['softmax_dim'].data + + if softmax_dim in dim_partition_dict_for_input: + recover_dims = dim_partition_dict_for_input.pop(softmax_dim) + + dim_partition_dict_for_output = copy.deepcopy(dim_partition_dict_for_input) + dim_partition_dict_mapping = { + "input": dim_partition_dict_for_input, + "output": dim_partition_dict_for_output, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + # add index into name to pass the duplicated check + # we keep same strategies with different name for node merging, and it will not increase the searching space, + # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. + name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategy_list.append(strategy) + + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..6d68521aaea7989c085c24f32a8dc92f4b1b71fc --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py @@ -0,0 +1,298 @@ +import operator +from abc import ABC, abstractmethod +from functools import reduce +from typing import Any, Dict, List, Union + +import torch +from torch.fx import Node + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommAction, + CommType, + OperationData, + OperationDataType, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.tensor.utils import convert_dim_partition_dict + + +class StrategyGenerator(ABC): + """ + StrategyGenerator is used to generate the same group of sharding strategies. + + TODO: remove the original strategy_generator.py after refactoring + """ + + def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh): + self.op_data = operation_data_mapping + self.device_mesh = device_mesh + + # validate the whether operation data is of desired value + self.validate() + + @property + def has_bias(self): + """ + A utility method to check for the existence of bias operand for convenience. + """ + return 'bias' in self.op_data + + def is_param(self, op_data_name): + other_data = self.op_data[op_data_name] + return other_data.type == OperationDataType.PARAM + + def is_buffer(self, op_data_name): + other_data = self.op_data[op_data_name] + return other_data.type == OperationDataType.BUFFER + + def get_sharding_strategy(self, name: str, sharding_spec_mapping: Dict[str, ShardingSpec], + communication_action_mapping: Dict[str, CommSpec]): + """ + A factory method to produce a ShardingStrategy object. + + Args: + sharding_spec_mapping (Dict[str, ShardingSpec]): the mapping between the operation data name and the ShardingSpec object. + communication_action_mapping (Dict[str, CommSpec]): the mapping between the operation data name and the CommSpec object. + """ + sharding_specs = self.replace_op_name_with_op_data(sharding_spec_mapping) + communication_actions = self.replace_op_name_with_op_data(communication_action_mapping) + return ShardingStrategy(name=name, sharding_specs=sharding_specs, communication_actions=communication_actions) + + def to_sharding_spec_mapping(self, mapping: Dict[str, Dict[int, List[int]]]): + """ + A utility method to convert the the dim partition dict to a ShardingSpec object. + + Args: + mapping (Dict[str, Dict[int, List[int]]]): the key of the mapping is the operation data name and the value is a dim partition dictionary. + + Notes: + The op_data.data is commonly type of torch.Tensor, torch.nn.Parameter, so the sharding spec is easy to create from the shape of the data. + However, if the op_data.data is of other non-iterative types, such as float or int, we should return None. If the op_data.data is of some iterative types, such as + list or tuple, we should return a list of ShardingSpec objects follow the same rule as above mentioned. + """ + results = {} + for op_data_name, dim_partition_dict in mapping.items(): + if op_data_name in self.op_data: + op_data = self.op_data[op_data_name] + + def _to_sharding_spec( + data: any, logical_shape: any, + dim_partition_dict: Dict[int, List[int]]) -> Union[ShardingSpec, List[ShardingSpec], None]: + """ + This is a recursive function to convert the dim partition dict to a ShardingSpec object. + """ + if isinstance(data, torch.Tensor): + dim_size = len(logical_shape) + dim_partition_dict = convert_dim_partition_dict(dim_size, dim_partition_dict) + sharding_spec = ShardingSpec(device_mesh=self.device_mesh, + entire_shape=logical_shape, + dim_partition_dict=dim_partition_dict) + return sharding_spec + elif isinstance(data, (list, tuple)): + sharding_spec = [] + for data_element, logical_shape_element, dim_partition_dict_element in zip( + data, logical_shape, dim_partition_dict): + sharding_spec.append( + _to_sharding_spec(data_element, logical_shape_element, dim_partition_dict_element)) + return sharding_spec + else: + return None + + sharding_spec = _to_sharding_spec(op_data.data, op_data.logical_shape, dim_partition_dict) + results[op_data_name] = sharding_spec + return results + + def replace_op_name_with_op_data(self, mapping: Dict[str, Any]): + """ + Convert the key of the dictionary from the operation data name to an OperationData object. + """ + results = {} + for k, v in mapping.items(): + op_data = self.op_data[k] + results[op_data] = v + return results + + def get_communication_spec(self, sharding_spec: ShardingSpec, communication_pattern: CollectiveCommPattern, + logical_process_axis: Union[int, List[int]]): + """ + A factory method to produce a CommSpec object. + """ + return CommSpec(comm_pattern=communication_pattern, + sharding_spec=sharding_spec, + logical_process_axis=logical_process_axis) + + def get_communication_action(self, + sharding_spec: ShardingSpec, + communication_pattern: CollectiveCommPattern, + logical_process_axis: Union[int, List[int]], + comm_type: CommType, + arg_index: int = -1, + key_for_kwarg: any = None) -> CommAction: + """ + A factory method to produce a CommAction object. + """ + return CommAction(comm_spec=self.get_communication_spec(sharding_spec=sharding_spec, + communication_pattern=communication_pattern, + logical_process_axis=logical_process_axis), + comm_type=comm_type, + arg_index=arg_index, + key_for_kwarg=key_for_kwarg) + + def update_communication_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: + """ + Compute the communication cost involved in the forward and backward iteration. + """ + + comm_cost = TrainCycleItem(fwd=0, bwd=0, total=0) + + def _compute_and_add(op_data: OperationData, comm_spec: CommSpec): + num_ele_in_comm = comm_spec.get_comm_cost() + dtype = op_data.data.dtype + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + for phase, cost in num_ele_in_comm.items(): + num_ele_in_comm[phase] = num_ele_in_comm[phase] * size_per_elem_bytes + comm_cost.fwd += num_ele_in_comm['forward'] + comm_cost.bwd += num_ele_in_comm['backward'] + comm_cost.total += num_ele_in_comm['total'] + + # check if communication action exists + # if so, loop over each action and compute the cost of each action + if strategy.communication_actions is not None: + for operand, comm_action in strategy.communication_actions.items(): + if isinstance(comm_action, CommAction): + comm_spec = comm_action.comm_spec + else: + # this condition branch will be removed after all the handler updated. + comm_spec = comm_action + if isinstance(comm_spec, dict): + src_spec = comm_spec['src_spec'] + tgt_spec = comm_spec['tgt_spec'] + shape_consistency_manager = ShapeConsistencyManager() + _, comm_action_sequence, _ = shape_consistency_manager.shape_consistency(src_spec, tgt_spec) + for comm_spec_ in comm_action_sequence: + _compute_and_add(operand, comm_spec_) + else: + _compute_and_add(operand, comm_spec) + + # update the communication cost attribute in-place + strategy.communication_cost = comm_cost + return strategy + + @abstractmethod + def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: + """ + Customize this method to compute the computation flops. + """ + pass + + @abstractmethod + def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: + """ + Customize this method to compute the memory cost in bytes. + """ + pass + + def _compute_size_in_bytes(self, strategy: ShardingStrategy, key: str): + """ + Compute the size of a tensor in bytes. + + Args: + strategy (ShardingStrategy): the ShardingStrategy generated. + key (str): the name of the operation data defined by the generator. + """ + op_data = self.op_data[key] + + def _compute_size_in_bytes_helper(sharding_spec, meta_data): + sharded_shape = sharding_spec.get_sharded_shape_per_device() + if len(sharded_shape) == 0: + num_elements = 1 + else: + num_elements = reduce(operator.mul, sharded_shape) + dtype = getattr(meta_data, 'dtype') + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + return num_elements * size_per_elem_bytes + + if isinstance(op_data.data, tuple): + assert isinstance(strategy.sharding_specs[op_data], list), \ + 'sharding_spec of op_data should be a list of sharding specs if op_data.data is a tuple.' + total_bytes = 0 + for index, sharding_spec in enumerate(strategy.sharding_specs[op_data]): + meta_data = op_data.data[index] + if isinstance(meta_data, torch.Tensor): + element_bytes = _compute_size_in_bytes_helper(sharding_spec, meta_data) + else: + # if meta_data is not a tensor, we count the memroy as 0 + element_bytes = 0 + total_bytes += element_bytes + + else: + if isinstance(op_data.data, torch.Tensor): + total_bytes = _compute_size_in_bytes_helper(strategy.sharding_specs[op_data], op_data.data) + else: + # if op_data.data is not a tensor, we count the memroy as 0 + total_bytes = 0 + + return total_bytes + + def generate(self) -> List[ShardingStrategy]: + """ + Generate all possible sharding strategies for this operation. + """ + strategies = self.collate_strategies() + + # some strategies may be None as ignore_sharding_exception may return None + # when ShardingSpecException occurs. + # thus, remove those None values + strategies = [strategy for strategy in strategies if strategy] + + # update the costs + # update mete info on cost + # these update methods are all in-place, the default method will do nothing + # the cost info will only be added if the child class overrides these methods + for strategy in strategies: + self.update_communication_cost(strategy) + self.update_compute_cost(strategy) + self.update_memory_cost(strategy) + + return strategies + + @abstractmethod + def collate_strategies(self) -> List[ShardingStrategy]: + pass + + @abstractmethod + def validate(self) -> bool: + """ + Validate if the operands are of desired shape. + If True, means this generator can be used for the current operation. + """ + pass + + +class FollowingStrategyGenerator(StrategyGenerator): + """ + FollowingStrategyGenerator is used to generate the sharding strategies which depends on its predecessor node. + + TODO: remove the original strategy_generator.py after refactoring + """ + + def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, + predecessor_node: Node): + self.op_data = operation_data_mapping + self.device_mesh = device_mesh + self.predecessor_node = predecessor_node + + +class OutputStrategyGenerator(StrategyGenerator): + """ + OutputStrategyGenerator is used to generate the sharding strategies for Output Node. + """ + + def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, + predecessor_nodes: List[Node]): + super().__init__(operation_data_mapping, device_mesh) + self.predecessor_nodes = predecessor_nodes diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/sum_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/sum_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..a0fbc58d70c0feba2c78305fb14d9bcb38a82e41 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/sum_generator.py @@ -0,0 +1,113 @@ +import copy +import operator +from functools import reduce +from typing import List + +from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommAction, + CommType, + MemoryCost, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.auto_parallel.tensor_shard.utils import ( + check_keep_sharding_status, + detect_reshape_mapping, + infer_output_dim_partition_dict, +) +from colossalai.tensor.shape_consistency import CollectiveCommPattern +from colossalai.tensor.sharding_spec import ShardingSpec + +__all__ = ['SumGenerator'] + + +class SumGenerator(FollowingStrategyGenerator): + """ + SumGenerator deals with the sharding strategies of torch.sum op. + """ + + def validate(self) -> bool: + return super().validate() + + def update_compute_cost(self, strategy: ShardingStrategy): + sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device() + input_size_product = reduce(operator.mul, sharded_input_shape) + output_size_product = reduce(operator.mul, sharded_output_shape) + + compute_cost = TrainCycleItem(fwd=input_size_product, + bwd=output_size_product, + total=input_size_product + output_size_product) + + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy): + ''' + Compute the memory cost per device with this specific strategy. + ''' + forward_size_mapping = { + 'input': self._compute_size_in_bytes(strategy, "input"), + 'output': self._compute_size_in_bytes(strategy, "output") + } + + backward_size_mapping = copy.deepcopy(forward_size_mapping) + backward_size_mapping.pop("output") + # compute fwd cost incurred + # fwd_cost = input + output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)]) + fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost) + + # compute bwd cost incurred + # bwd_cost = input_grad + bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)]) + bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)]) + bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, + parameter=fwd_parameter_cost + bwd_parameter_cost) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + for index, strategy in enumerate(self.predecessor_node.strategies_vector): + dim_partition_dict_mapping = {} + communication_action_mapping = {} + input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] + dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict) + sum_dims, sum_mapping_dict = self.op_data['sum_info'].data + + # TODO: a better way to handle the distributed sum is sum all the data on chip and then do all reduce + # among all the shard groups + recover_dims = [] + dim_partition_dict_for_output = {} + for dim in dim_partition_dict_for_input: + if dim in sum_dims: + recover_dims.append(dim) + elif dim in sum_mapping_dict: + dim_partition_dict_for_output[sum_mapping_dict[dim]] = dim_partition_dict_for_input[dim] + else: + raise RuntimeError(f'dim {dim} is not in sum_mapping_dict or sum_dims') + + for dim in recover_dims: + dim_partition_dict_for_input.pop(dim) + + dim_partition_dict_mapping = { + "input": dim_partition_dict_for_input, + "output": dim_partition_dict_for_output, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + # add index into name to pass the duplicated check + # we keep same strategies with different name for node merging, and it will not increase the searching space, + # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. + name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategy_list.append(strategy) + + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/tensor_constructor_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/tensor_constructor_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..93cfc9eeea532ac4383f0821008deeccb13951d0 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/tensor_constructor_generator.py @@ -0,0 +1,67 @@ +import copy +from typing import List + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommAction, + CommType, + MemoryCost, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.tensor.shape_consistency import CollectiveCommPattern +from colossalai.tensor.sharding_spec import ShardingSpec + +from .strategy_generator import StrategyGenerator + +__all__ = ['TensorConstructorGenerator'] + + +class TensorConstructorGenerator(StrategyGenerator): + """ + TensorConstructorGenerator which deals with + the sharding strategies for tensor constructor operation, such as torch.arange. + """ + + def validate(self) -> bool: + return super().validate() + + def update_compute_cost(self, strategy: ShardingStrategy): + compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy): + ''' + Compute the memory cost per device with this specific strategy. + ''' + forward_size_mapping = {'output': self._compute_size_in_bytes(strategy, "output")} + + # compute fwd cost incurred + # fwd_cost = input + output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)]) + fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost) + + # compute bwd cost incurred + bwd_mem_cost = MemoryCost(activation=0, parameter=0) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + dim_partition_dict_mapping = { + "output": {}, + } + communication_action_mapping = {} + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + name = 'Replica Tensor Constructor' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategy_list.append(strategy) + + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..b867a30686eb97a55096895d344dcc28b51f347a --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py @@ -0,0 +1,77 @@ +import copy +from typing import List + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) + +from .strategy_generator import FollowingStrategyGenerator + +__all__ = ['UnaryElementwiseGenerator'] + + +class UnaryElementwiseGenerator(FollowingStrategyGenerator): + """ + UnaryElementwiseGenerator which deals with the sharding strategies of UnaryElementwiseOp. + """ + + def validate(self) -> bool: + return super().validate() + + def update_compute_cost(self, strategy: ShardingStrategy): + compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy): + ''' + Compute the memory cost per device with this specific strategy. + ''' + forward_size_mapping = { + 'input': self._compute_size_in_bytes(strategy, "input"), + 'output': self._compute_size_in_bytes(strategy, "output") + } + + backward_size_mapping = copy.deepcopy(forward_size_mapping) + backward_size_mapping.pop("output") + # compute fwd cost incurred + # fwd_cost = input + output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)]) + fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost) + + # compute bwd cost incurred + # bwd_cost = input_grad + bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)]) + bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)]) + bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, + parameter=fwd_parameter_cost + bwd_parameter_cost) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + # For element-wise function, we keep the sharding spec of output node same as + # the input. Therefore, the different strategies of input node with same + # output sharding spec will generate same strategy for element-wise function. + for index, strategy in enumerate(self.predecessor_node.strategies_vector): + dim_partition_dict_mapping = {} + communication_action_mapping = {} + input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] + dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict + dim_partition_dict_for_output = copy.deepcopy(dim_partition_dict_for_input) + dim_partition_dict_mapping = { + "input": dim_partition_dict_for_input, + "output": dim_partition_dict_for_output, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + # add index into name to pass the duplicated check + # we keep same strategies with different name for node merging, and it will not increase the searching space, + # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. + name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategy_list.append(strategy) + + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..fa941f2cc51dc4d817bfc8f49c54bbaf7a8a5407 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py @@ -0,0 +1,98 @@ +import copy +from typing import List + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem +from colossalai.auto_parallel.tensor_shard.utils import ( + enumerate_all_possible_1d_sharding, + enumerate_all_possible_2d_sharding, + ignore_sharding_exception, +) + +from .strategy_generator import StrategyGenerator + +__all__ = ['WhereGenerator'] + + +class WhereGenerator(StrategyGenerator): + """ + WhereGenerator is a generic class to generate strategies for Where operation. + """ + + def validate(self) -> bool: + return super().validate() + + def update_compute_cost(self, strategy: ShardingStrategy): + compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy): + ''' + Compute the memory cost per device with this specific strategy. + ''' + forward_size_mapping = { + 'condition': self._compute_size_in_bytes(strategy, "condition"), + 'x': self._compute_size_in_bytes(strategy, "x"), + 'y': self._compute_size_in_bytes(strategy, "y"), + 'output': self._compute_size_in_bytes(strategy, "output") + } + + backward_size_mapping = copy.deepcopy(forward_size_mapping) + backward_size_mapping.pop("output") + # compute fwd cost incurred + # fwd_cost = condition + x + y + output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items()]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=0) + + # compute bwd cost incurred + # bwd_cost = condition_grad + x_grad + y_grad + bwd_activation_cost = sum([v for k, v in backward_size_mapping.items()]) + bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=0) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, parameter=0) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + @ignore_sharding_exception + def _generate_strategy_with_dim_partition(self, dim_partition): + dim_partition_dict_mapping = { + "condition": dim_partition, + "x": dim_partition, + "y": dim_partition, + "output": dim_partition + } + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["condition"].sharding_sequence} x {sharding_spec_mapping["x"].sharding_sequence} x {sharding_spec_mapping["y"].sharding_sequence}' + communication_action_mapping = {} + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + return strategy + + def enumerate_all_possible_output_spec(self, mesh_dim_0, mesh_dim_1, dimension_length): + dim_partition_list = [] + dim_partition_list.extend(enumerate_all_possible_1d_sharding(mesh_dim_0, dimension_length)) + dim_partition_list.extend(enumerate_all_possible_1d_sharding(mesh_dim_1, dimension_length)) + dim_partition_list.extend(enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dimension_length)) + # append {} for non_split case + dim_partition_list.append({}) + + return dim_partition_list + + def collate_strategies(self) -> List[ShardingStrategy]: + ''' + Generate every possible strategies for a where node, and record all strategies into the strategies_vector. + ''' + strategy_list = [] + + dimension_length = len(self.op_data["output"].logical_shape) + dim_partition_list = self.enumerate_all_possible_output_spec(0, 1, dimension_length) + for dim_partition in dim_partition_list: + strategy = self._generate_strategy_with_dim_partition(dim_partition) + strategy_list.append(strategy) + + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/sum_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/sum_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..86f90694e0604f72e9564020ccab455cfdee29a0 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/sum_handler.py @@ -0,0 +1,81 @@ +from typing import Dict, List + +import torch + +from ..sharding_strategy import OperationData, OperationDataType +from .node_handler import NodeHandler +from .registry import operator_registry +from .strategy import StrategyGenerator, SumGenerator + +__all__ = ['SumHandler'] + + +@operator_registry.register(torch.Tensor.sum) +@operator_registry.register(torch.sum) +class SumHandler(NodeHandler): + """ + A SumHandler which deals with the sharding strategies for torch.sum or torch.Tensor.sum. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(SumGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # check if the input operand is a parameter + if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + + input_data = self.node.args[0]._meta_data + physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data) + + if len(self.node.args) > 1: + sum_dims = self.node.args[1] + else: + sum_dims = tuple(range(self.node.args[0]._meta_data.dim())) + + if isinstance(sum_dims, int): + sum_dims = (sum_dims,) + + # recover negative value to positive + num_dims = self.node.args[0]._meta_data.dim() + for i in range(len(sum_dims)): + if sum_dims[i] < 0: + sum_dims[i] += num_dims + + # mapping the input dims to output dims + # For examples: + # input: torch.rand(2, 3, 4, 5) + # output: torch.sum(input, (0, 2)) + # sum_mapping_dict = {1: 0, 3: 1} + # sum_mapping_dict[1] = 0 means the 0th dim of output is the 1st dim of input + # sum_mapping_dict[3] = 1 means the 1st dim of output is the 3rd dim of input + sum_mapping_dict = {} + if 'keepdim' in self.node.kwargs and self.node.kwargs['keepdim']: + for i in range(num_dims): + sum_mapping_dict.update({i: i}) + else: + output_index = 0 + for i in range(num_dims): + if i not in sum_dims: + sum_mapping_dict.update({i: output_index}) + output_index += 1 + assert output_index == self.node._meta_data.dim() + + sum_info = (sum_dims, sum_mapping_dict) + physical_shape_operand = OperationData(name='sum_info', type=OperationDataType.ARG, data=sum_info) + + output_data = self.node._meta_data + physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) + + mapping = { + "input": physical_input_operand, + "sum_info": physical_shape_operand, + "output": physical_output_operand + } + + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/tensor_constructor_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/tensor_constructor_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..855a2e7612af0cb59cae9bc8574197fad098f983 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/tensor_constructor_handler.py @@ -0,0 +1,32 @@ +from typing import Dict, List + +import torch + +from ..sharding_strategy import OperationData, OperationDataType +from .node_handler import NodeHandler +from .registry import operator_registry +from .strategy import StrategyGenerator +from .strategy.tensor_constructor_generator import TensorConstructorGenerator + +__all__ = ['TensorConstructorHandler'] + + +@operator_registry.register(torch.arange) +class TensorConstructorHandler(NodeHandler): + """ + A TensorConstructorHandler which deals with the sharding strategies for tensor constructor operations, such as torch.arange. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(TensorConstructorGenerator(op_data_mapping, self.device_mesh)) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + output_data = self.node._meta_data + physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) + + mapping = {"output": physical_output_operand} + + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/transpose_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/transpose_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..7a9d377264905a650fd991cb10e98f8c3f16f871 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/transpose_handler.py @@ -0,0 +1,64 @@ +from typing import Dict, List + +import torch + +from ..sharding_strategy import OperationData, OperationDataType +from .node_handler import NodeHandler +from .registry import operator_registry +from .strategy import StrategyGenerator, TransposeGenerator + +__all__ = ['TransposeHandler'] + + +@operator_registry.register(torch.Tensor.transpose) +@operator_registry.register(torch.transpose) +class TransposeHandler(NodeHandler): + """ + A TransposeHandler which deals with the sharding strategies for torch.permute or torch.transpose. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(TransposeGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # check if the input operand is a parameter + if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + + input_data = self.node.args[0]._meta_data + physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data) + + transpose_dims = [] + # torch.transpose (input, dim0, dim1) + for arg in self.node.args: + if isinstance(arg, torch.fx.Node): + if isinstance(arg._meta_data, int): + transpose_dims.append(arg._meta_data) + else: + transpose_dims.append(arg) + + num_dims = self.node._meta_data.dim() + for i in range(2): + # recover negative value to positive + if transpose_dims[i] < 0: + transpose_dims[i] += num_dims + + physical_shape_operand = OperationData(name='transpose_dims', + type=OperationDataType.ARG, + data=list(transpose_dims)) + + output_data = self.node._meta_data + physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) + + mapping = { + "input": physical_input_operand, + "transpose_dims": physical_shape_operand, + "output": physical_output_operand + } + + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..0362de780d7af0fa9569a575a122f84ffb42b0db --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py @@ -0,0 +1,43 @@ +from typing import Dict, List + +import torch + +from ..sharding_strategy import OperationData, OperationDataType +from .node_handler import MetaInfoNodeHandler, NodeHandler +from .registry import operator_registry +from .strategy import StrategyGenerator, UnaryElementwiseGenerator + +__all__ = ['UnaryElementwiseHandler'] + + +@operator_registry.register(torch.Tensor.to) +@operator_registry.register(torch.Tensor.type) +@operator_registry.register(torch.abs) +@operator_registry.register(torch.nn.ReLU) +@operator_registry.register(torch.nn.Tanh) +@operator_registry.register(torch.tanh) +@operator_registry.register(torch.nn.modules.dropout.Dropout) +@operator_registry.register(torch.Tensor.contiguous) +@operator_registry.register(torch.nn.functional.dropout) +class UnaryElementwiseHandler(MetaInfoNodeHandler): + """ + A UnaryElementwiseHandler which deals with the sharding strategies for UnaryElementwise Op. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(UnaryElementwiseGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + physical_input_operand = OperationData(name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.node.args[0]._meta_data) + physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + + mapping = {"input": physical_input_operand, "output": physical_output} + + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/view_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/view_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..7dff89d1d7a39a6e4fe73514bfb16abe2e3e7bea --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/view_handler.py @@ -0,0 +1,52 @@ +from typing import Dict, List + +import torch + +from ..sharding_strategy import OperationData, OperationDataType +from .node_handler import NodeHandler +from .registry import operator_registry +from .strategy import StrategyGenerator, ViewGenerator + +__all__ = ['ViewHandler'] + + +@operator_registry.register(torch.Tensor.reshape) +@operator_registry.register(torch.reshape) +@operator_registry.register(torch.Tensor.view) +class ViewHandler(NodeHandler): + """ + A ViewHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(ViewGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + + # check if the input operand is a parameter + if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + + input_data = self.node.args[0]._meta_data + physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data) + + target_shape = self.node._meta_data.shape + physical_shape_operand = OperationData(name='tgt_shape', type=OperationDataType.ARG, data=target_shape) + + output_data = self.node._meta_data + physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) + + mapping = { + "input": physical_input_operand, + "tgt_shape": physical_shape_operand, + "output": physical_output_operand + } + + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..6de2aaafdd018f08195563ef882f07eb39d8d20a --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py @@ -0,0 +1,71 @@ +import copy +import operator +from typing import Dict, List + +import torch + +from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector +from ..utils import recover_sharding_spec_for_broadcast_shape +from .node_handler import NodeHandler +from .registry import operator_registry +from .strategy import StrategyGenerator, WhereGenerator + +__all__ = ['WhereHandler'] + + +@operator_registry.register(torch.where) +class WhereHandler(NodeHandler): + """ + A WhereHandler which deals with the sharding strategies for torch.where. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + logical_op_data_mapping, _ = self.get_operation_data_mapping() + generators = [] + generators.append(WhereGenerator(logical_op_data_mapping, self.device_mesh)) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + physical_condition_operand = OperationData(name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.node.args[0]._meta_data) + physical_x_operand = OperationData(name=str(self.node.args[1]), + type=OperationDataType.ARG, + data=self.node.args[1]._meta_data) + physical_y_operand = OperationData(name=str(self.node.args[2]), + type=OperationDataType.ARG, + data=self.node.args[2]._meta_data) + physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + physical_mapping = { + "condition": physical_condition_operand, + "x": physical_x_operand, + "y": physical_y_operand, + "output": physical_output + } + logical_shape_for_all = self.node._meta_data.shape + logical_mapping = {} + for key, physical_operand in physical_mapping.items(): + logical_mapping[key] = self.convert_physical_operand_to_logical_operand(physical_operand, + logical_shape_for_all) + + return logical_mapping, physical_mapping + + def convert_physical_operand_to_logical_operand(self, physical_operand, target_shape): + logical_operand = copy.deepcopy(physical_operand) + logical_operand.logical_shape = target_shape + return logical_operand + + def post_process(self, strategy: ShardingStrategy): + logical_op_data_mapping, physical_op_data_mapping = self.get_operation_data_mapping() + for key in logical_op_data_mapping.keys(): + logical_sharding_spec = strategy.sharding_specs[logical_op_data_mapping[key]] + logical_shape = logical_op_data_mapping[key].logical_shape + physical_shape = physical_op_data_mapping[key].logical_shape + physical_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape( + logical_sharding_spec, logical_shape, physical_shape) + strategy.sharding_specs.pop(logical_op_data_mapping[key]) + strategy.sharding_specs[physical_op_data_mapping[key]] = physical_sharding_spec + strategy.name = f"{strategy.sharding_specs[physical_op_data_mapping['output']].sharding_sequence} = {strategy.sharding_specs[physical_op_data_mapping['condition']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['x']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['y']].sharding_sequence}" + return strategy diff --git a/colossalai/auto_parallel/tensor_shard/options.py b/colossalai/auto_parallel/tensor_shard/options.py new file mode 100644 index 0000000000000000000000000000000000000000..f0ea502a6f0e2c4412ac333f9465aec6873e9791 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/options.py @@ -0,0 +1,49 @@ +from dataclasses import dataclass +from enum import Enum + +__all__ = ['SolverOptions', 'SolverPerference', 'DataloaderOption', 'ShardOption'] + + +class SolverPerference(Enum): + """ + This enum class is to define the solver preference. + """ + STANDARD = 0 + DP = 1 + TP = 2 + + +class ShardOption(Enum): + """ + This enum class is to define the shard level required in node strategies. + + Notes: + STANDARD: We do not add any extra shard requirements. + SHARD: We require the node to be shard using at least one device mesh axis. + SHARD_ONE_AXIS: We require the node to be shard using the last device mesh axis. + FULL_SHARD: We require the node to be shard using all device mesh axes. + TP_SHARD: We require the node to be shard using tensor parallel strategies on last device mesh axis. + TP_FULL_SHARD: We require the node to be shard using tensor parallel strategies on all device mesh axes. + """ + STANDARD = 0 + SHARD = 1 + SHARD_LAST_AXIS = 2 + FULL_SHARD = 3 + + +class DataloaderOption(Enum): + """ + This enum class is to define the dataloader option. + """ + REPLICATED = 0 + DISTRIBUTED = 1 + + +@dataclass +class SolverOptions: + """ + SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search. + """ + solver_perference: SolverPerference = SolverPerference.STANDARD + dataloader_option: DataloaderOption = DataloaderOption.REPLICATED + shard_option: ShardOption = ShardOption.STANDARD diff --git a/colossalai/auto_parallel/tensor_shard/sharding_strategy.py b/colossalai/auto_parallel/tensor_shard/sharding_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..6af92727243759bf2d0e0e1b8f472e1e59308ca3 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/sharding_strategy.py @@ -0,0 +1,277 @@ +from copy import deepcopy +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, List, Tuple, Union + +import torch +from torch.fx.node import Node + +from colossalai.tensor.comm_spec import CommSpec +from colossalai.tensor.sharding_spec import ShardingSpec + +from .constants import ( + BCAST_FUNC_OP, + ELEMENTWISE_FUNC_OP, + ELEMENTWISE_METHOD_OP, + ELEMENTWISE_MODULE_OP, + RESHAPE_FUNC_OP, + RESHAPE_METHOD_OP, +) + +__all__ = ['OperationDataType', 'OperationData', 'TrainCycleItem', 'MemoryCost', 'ShardingStrategy', 'StrategiesVector'] + + +class OperationDataType(Enum): + """ + An operation can come from the argument list of an operator or the parameter list of a module. + """ + INPUT = 0 + ARG = 1 + PARAM = 2 + BUFFER = 3 + OUTPUT = 4 + + +@dataclass +class OperationData: + """ + OperationData is the data related to an operator, the data can be the operand or the output. + + Args: + name (str): the name of the operation-related data + type (OperationDataType): the type of the operation data + data (Any): the value for this data, usually it is a meta tensor. + logical_shape (Tuple[int]): the logical shape of the data, it can be different from the its actual shape in memory. + """ + name: str + type: OperationDataType + data: Any + logical_shape: Tuple[int] = None + + def __post_init__(self): + # if no logical shape is specified, use the data shape as the logical shape + if self.logical_shape is None: + + def _infer_logical_shape(data: any): + """ + This function is used to infer the logical shape of the data. + """ + if isinstance(data, torch.Tensor): + return data.shape + elif isinstance(data, torch.Size): + return None + elif isinstance(data, (tuple, list)): + data_type = type(data) + return data_type([_infer_logical_shape(d) for d in data]) + else: + return None + + self.logical_shape = _infer_logical_shape(self.data) + + def __repr__(self) -> str: + return f'OperationData(name={self.name}, type={self.type})' + + def __eq__(self, other) -> bool: + return other.name == self.name + + def __hash__(self) -> int: + return hash(f'{self.name}') + + +@dataclass +class TrainCycleItem: + """ + TrainCycleItem is a dataclass to store the items which have different values for the forward and backward pass + in a training iteration. + + Args: + fwd (float): the item for the forward pass + bwd (float): the item for the backward pass + """ + fwd: Any + bwd: Any + total: Any + + +@dataclass +class MemoryCost: + """ + MemoryCost is a dataclass which stores the memory usage in the program. + + Args: + activation (int): the memory cost incurred by the activations in bytes. + parameter (int): the memory cost incurred by the module parameter in bytes. + temp (int): the memory cost incurred by the temporary tensors in bytes. + buffer (int): the memory cost incurred by the module buffer in bytes. + """ + activation: int = 0 + parameter: int = 0 + temp: int = 0 + buffer: int = 0 + + +class CommType(Enum): + """ + CommType describes the sequential order of a communication action and a computation action. + + Meaning: + BEFORE: the communication action happens just before the computation operation. + AFTER: the communication action happens after the computation operation. + HOOK: the communication action is used to do the grad all reduce. + IMPLICIT: the communication action happens during the kernel execution, such as SyncBatchNorm + """ + BEFORE = 0 + AFTER = 1 + HOOK = 2 + IMPLICIT = 3 + + +@dataclass +class CommAction: + """ + CommAction is used to record the communication action. + + Args: + comm_spec: express the communication pattern and the process groups to execute the communication action. + comm_type: describes the sequential order of a communication action and a computation action. + arg_index: record the location of tensor which join the communication, we cannot use name of node or op_data at runtime, + because the args of node may be changed by graph transform passes. + """ + comm_spec: CommSpec = None + comm_type: CommType = None + arg_index: int = -1 + key_for_kwarg: any = None + + +@dataclass +class ShardingStrategy: + """ + ShardingStrategy is a dataclass to store the meta information on tensor sharding for a node. + + Args: + name (str): express the sharding strategies in string, such as 'S0S1 = S0R x RS1'. + output_sharding_spec (ShardingSpec): ShardingSpec of the output node. + compute_cost (TrainCycleItem): Computation cost to complete this strategy. (default to None) + communication_cost (TrainCycleItem): Communication cost to complete this strategy. (default to None) + memory_cost (TrainCycleItem): Memory cost of the output node using this strategy. (default to None) + input_sharding_specs (List(ShardingSpec)): The ShardingSpecs of the input nodes. + """ + name: str + sharding_specs: Dict[OperationData, Union[ShardingSpec, Tuple[ShardingSpec]]] = None + compute_cost: TrainCycleItem = None + communication_cost: TrainCycleItem = None + memory_cost: TrainCycleItem = None + communication_actions: Dict[OperationData, CommAction] = None + resharding_costs: Dict[Node, List[TrainCycleItem]] = None + + @property + def input_sharding_specs(self) -> Dict[OperationData, ShardingSpec]: + specs = {} + specs.update(self._get_sharding_spec(OperationDataType.ARG)) + specs.update(self._get_sharding_spec(OperationDataType.PARAM)) + return specs + + @property + def argument_sharding_specs(self) -> Dict[OperationData, ShardingSpec]: + return self._get_sharding_spec(OperationDataType.ARG) + + @property + def param_sharding_specs(self) -> Dict[OperationData, ShardingSpec]: + return self._get_sharding_spec(OperationDataType.PARAM) + + @property + def output_sharding_specs(self) -> Dict[OperationData, ShardingSpec]: + return self._get_sharding_spec(OperationDataType.OUTPUT) + + def _get_sharding_spec(self, operation_data_type: OperationDataType): + specs = {k: v for k, v in self.sharding_specs.items() if k.type == operation_data_type} + return specs + + def get_op_data_by_name(self, name: str): + for op_data in self.sharding_specs.keys(): + if op_data.name == name: + return op_data + raise KeyError(f"Could not find the OperationData with name {name}") + + def get_sharding_spec_by_name(self, name: str): + for op_data, sharding_spec in self.sharding_specs.items(): + if op_data.name == name: + return sharding_spec + raise KeyError(f"Could not find the ShardingSpec for OperationData with name {name}") + + def clone(self): + + def _deepcopy_dict_vals(data: Dict): + return {k: deepcopy(v) for k, v in data.items()} + + sharding_specs = _deepcopy_dict_vals(self.sharding_specs) if self.sharding_specs is not None else None + # We need to deepcopy it when self.communication_actions is not None, instead of checking its __bool__ value. + # Consider the examples below: + # If self.communication_actions is an empty dictionary {}, then self.communication_actions is not None, but its __bool__ value is False. + # In this case, if we set None to the new object, program will crash when we try to access the communication_actions.items. + communication_actions = _deepcopy_dict_vals( + self.communication_actions) if self.communication_actions is not None else None + # same reason as communication_actions + resharding_costs = _deepcopy_dict_vals(self.resharding_costs) if self.resharding_costs is not None else None + compute_cost = deepcopy(self.compute_cost) + communication_cost = deepcopy(self.communication_cost) + memory_cost = deepcopy(self.memory_cost) + + return ShardingStrategy(name=self.name, + sharding_specs=sharding_specs, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + communication_actions=communication_actions, + resharding_costs=resharding_costs) + + +class StrategiesVector(list): + ''' + Each node in fx graph will have a corresponding StrategiesVector, to store all the possible + strategies of the node. + + Argument: + node (Node): node for which the list of sharding strategies are generated. + ''' + + def __init__(self, node: Node): + super().__init__() + self.node = node + # fetch its input and output nodes + # TODO: placeholder input nodes + self.predecessor_nodes = list(node._input_nodes.keys()) + self.successor_nodes = list(node.users.keys()) + + def check_merge(self): + merge_label = False + if self.node.op == 'call_module': + target = self.node.target + root_module = self.node.graph.owning_module + submod = root_module.get_submodule(target) + submod_type = type(submod) + # merge elementwise module node into source nodes + # we could merge element-wise op, because the output sharding spec is always same as the input sharding spec. + if submod_type in ELEMENTWISE_MODULE_OP: + merge_label = True + + if self.node.op == 'call_function': + # we could merge element-wise op, because the output sharding spec is always same as the input sharding spec. + if self.node.target in ELEMENTWISE_FUNC_OP: + merge_label = True + # we could merge bcast op if the rhs is a scalar, because it will fall back to the element-wise case. + # TODO: remove this after we support the fall back logic. + # if self.node.target in BCAST_FUNC_OP and len(self.predecessor_nodes) == 1: + # merge_label = True + # we could merge reshape op, because their computation costs are negligible. + if self.node.target in RESHAPE_FUNC_OP: + merge_label = True + + if self.node.op == 'call_method': + # we could merge reshape op, because their computation costs are negligible. + method = getattr(self.node.args[0]._meta_data.__class__, self.node.target) + if method in RESHAPE_METHOD_OP: + merge_label = True + if method in ELEMENTWISE_METHOD_OP: + merge_label = True + return merge_label diff --git a/colossalai/auto_parallel/tensor_shard/solver/__init__.py b/colossalai/auto_parallel/tensor_shard/solver/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f9e6bd9239214c4def03b1b419a2845581fa083a --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/solver/__init__.py @@ -0,0 +1,6 @@ +from .cost_graph import CostGraph +from .graph_analysis import GraphAnalyser +from .solver import Solver +from .strategies_constructor import StrategiesConstructor + +__all__ = ['GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph'] diff --git a/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py b/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..74290453ca0c2dd40008ec51584a134c8f278410 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py @@ -0,0 +1,205 @@ +import torch + +from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST + + +class CostGraph: + ''' + A graph data structure to simplify the edge cost graph. It has two main functions: + 1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in + CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list. + 2. To reduce the searching space, we merge computationally-trivial operators, such as + element-wise operators, transpose, and reduction, into their following nodes. The merging infomation will + be given by the StrategiesVector depending on the type of target node and following nodes. + + Argument: + leaf_strategies(List[StrategiesVector]): It stores StrategiesVector of every nodes on the graph. + simplify(bool, optional): The generated cost graph will be simplified if it is true. (default to True) + ''' + + def __init__(self, leaf_strategies, simplify=True, forward_only=False): + self.leaf_strategies = leaf_strategies + self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies] + # stores number of strategies in each node + self.node_lens = {strategies_vector.node: len(strategies_vector) for strategies_vector in self.leaf_strategies} + # extra_node_costs will store the extra costs introduced by merging nodes + self.extra_node_costs = {} + self.following_dict = {} + self.simplify = simplify + self.forward_only = forward_only + self._build_cost_graph() + + def _remove_invalid_node(self, node, attr_name): + remove_list = [] + target_node_list = getattr(node, attr_name, []) + for target_node in target_node_list: + if target_node not in self.nodes: + remove_list.append(target_node) + for element in remove_list: + target_node_list.remove(element) + + def _build_cost_graph(self): + ''' + This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be + set to node. + ''' + self.edge_costs = {} + if self.simplify: + self.merge_pair = [] + for strategies_vector in self.leaf_strategies: + # build edge_cost + dst_node = strategies_vector.node + for src_node in strategies_vector.predecessor_nodes: + if src_node not in self.nodes: + continue + node_pair = (src_node, dst_node) + edge_cost = {} + for i in range(len(strategies_vector)): + for j in range(len(src_node.strategies_vector)): + resharding_cost_item = strategies_vector[i].resharding_costs[src_node][j] + if self.forward_only: + edge_cost[(j, i)] = resharding_cost_item.fwd + else: + edge_cost[(j, i)] = resharding_cost_item.total + self.edge_costs[node_pair] = edge_cost + parent_nodes = [] + children_nodes = [] + + def _check_tensor_in_node(data): + """ + This method is used to check whether the data has a tensor inside or not. + """ + has_tensor_flag = False + if isinstance(data, torch.Tensor): + return True + elif isinstance(data, (tuple, list)): + for d in data: + has_tensor_flag = has_tensor_flag or _check_tensor_in_node(d) + return has_tensor_flag + + for node in strategies_vector.predecessor_nodes: + if _check_tensor_in_node(node._meta_data): + parent_nodes.append(node) + for node in strategies_vector.successor_nodes: + if _check_tensor_in_node(node._meta_data): + children_nodes.append(node) + + setattr(dst_node, 'parents', parent_nodes) + setattr(dst_node, 'children', children_nodes) + + if self.simplify and strategies_vector.check_merge(): + for followed_node in strategies_vector.predecessor_nodes: + # we only merge node pairs which src node has a tensor element inside. + # This is necessay because the node without a tensor element inside will not + # be assigned any strategy. + if _check_tensor_in_node(followed_node._meta_data): + self.merge_pair.append((followed_node, dst_node)) + + def get_edge_cost(self, src_node, dst_node): + return self.edge_costs[(src_node, dst_node)] + + def merge_node(self, src_node, dst_node): + ''' + To merge dst_node into src_node, we need to do it in following steps: + + 1. For each strategy in dst_node, we need to pick an appropriate strategy + of src_node to merge, it is important because the logical resharding costs + between the parents node of src_node and merged node depend on the src_node + strategies dispatching. For example, for the graph 0->1->2, after merging node 1 + into node 2, edge_costs[(node 0, node 2)][(0, 0)] = edge_costs[(node 0, node 1)][(0, x)] + x represents the picking strategy of node 1 merged into node 2 strategy 0. + + 2. We need to accumulate the extra costs introduced by merging nodes, the extra costs + contains two parts, one is resharding costs between src_node strategy and dst_node strategy, + another is the origin extra costs in src_node strategy. + + 3. Build connections between new node pairs, and remove the src_node after all consumer nodes + detached from it. + + Argument: + src_node(Node): The node will be merged into dst_node. + dst_node(Node): The node to integrate src_node. + ''' + # build merge_map + merge_map = {} + for src_index, _ in enumerate(src_node.strategies_vector): + min_cost = INFINITY_COST + lowest_cost_index = -1 + for dst_index, dst_strategy in enumerate(dst_node.strategies_vector): + resharding_cost_item = dst_strategy.resharding_costs[src_node][src_index] + if self.forward_only: + resharding_cost = resharding_cost_item.fwd + else: + resharding_cost = resharding_cost_item.total + if resharding_cost <= min_cost: + min_cost = resharding_cost + lowest_cost_index = dst_index + merge_map[src_index] = lowest_cost_index + + # extra_node_cost for src node + self.extra_node_costs[src_node] = [0.0] * self.node_lens[src_node] + for src_index, strategy in enumerate(src_node.strategies_vector): + target_strate_index = merge_map[src_index] + target_strategy = dst_node.strategies_vector[target_strate_index] + resharding_cost_item = target_strategy.resharding_costs[src_node][src_index] + if self.forward_only: + resharding_cost_to_add = resharding_cost_item.fwd + else: + resharding_cost_to_add = resharding_cost_item.total + self.extra_node_costs[src_node][src_index] += resharding_cost_to_add + if dst_node in self.extra_node_costs: + self.extra_node_costs[src_node][src_index] += self.extra_node_costs[dst_node][target_strate_index] + + # add new node pair to cost graph + for child_node in dst_node.children: + new_node_pair = (src_node, child_node) + old_node_pair = (dst_node, child_node) + if new_node_pair in self.edge_costs: + continue + edge_cost = {} + for i in range(self.node_lens[src_node]): + for j in range(self.node_lens[child_node]): + dst_strate_index = merge_map[i] + edge_cost[(i, j)] = self.edge_costs[old_node_pair][(dst_strate_index, j)] + if new_node_pair not in self.edge_costs: + self.edge_costs[new_node_pair] = edge_cost + else: + # we should accumulate the resharding costs if args of child node contain + # both src node and dst node. + for index_pair, resharding_cost in self.edge_costs[new_node_pair]: + self.edge_costs[new_node_pair][index_pair] += edge_cost[index_pair] + + # connect src node and children of dst node + dst_node.parents.remove(src_node) + src_node.children.remove(dst_node) + self.edge_costs.pop((src_node, dst_node)) + for child_node in dst_node.children: + if child_node not in src_node.children: + src_node.children.append(child_node) + if src_node not in child_node.parents: + child_node.parents.append(src_node) + # remove dst node from cost graph when dst node has no producer. + if len(dst_node.parents) == 0: + child_node.parents.remove(dst_node) + node_pair = (dst_node, child_node) + self.edge_costs.pop(node_pair) + if len(dst_node.parents) == 0: + self.following_dict[dst_node] = src_node + dst_node.children = [] + + def _reindexing_src(self, src): + if src not in self.following_dict: + return src + return self._reindexing_src(self.following_dict[src]) + + def simplify_graph(self): + if not self.simplify: + return + self.merge_pair.reverse() + for (src_node, dst_node) in self.merge_pair: + self.merge_node(src_node, dst_node) + self.merge_pair.reverse() + reindexing_following_dict = {} + for dst, src in self.following_dict.items(): + reindexing_following_dict[dst] = self._reindexing_src(src) + self.following_dict = reindexing_following_dict diff --git a/colossalai/auto_parallel/tensor_shard/solver/graph_analysis.py b/colossalai/auto_parallel/tensor_shard/solver/graph_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..be39a74cb23755f9ff2b83cf1123a7a7f9708ffa --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/solver/graph_analysis.py @@ -0,0 +1,164 @@ +from dataclasses import dataclass +from typing import List + +from torch.fx.graph import Graph +from torch.fx.graph_module import GraphModule +from torch.fx.node import Node + +from colossalai.fx.passes.utils import get_node_module + +__all__ = ['LiveVariable', 'LiveVariableVector', 'LiveStage', 'GraphAnalyser'] + + +@dataclass +class LiveVariable: + """ + LiveVariable is a data structure to store the meta information of a variable for liveness analysis. + """ + name: str + node: Node + is_inplace: bool + + +class LiveVariableVector(list): + """ + LiveVariableVector is a data structure to store the list of LiveVariable objects. + """ + + def exists(self, name) -> bool: + """ + Check if a variable has already existed in the current list by name. + """ + for var in self: + if name == var.name: + return True + return False + + def get(self, name) -> LiveVariable: + for var in self: + if name == var.name: + return var + raise KeyError(f"Variable {name} is not found") + + def copy(self) -> "LiveVariableVector": + """ + Create a copy of this vector + """ + vector = LiveVariableVector() + for var in self: + vector.append(var) + return vector + + +@dataclass +class LiveStage: + """ + LiveStage is a data structure to record the living variables at this current node. + """ + name: str + node: Node + all_live_vars: LiveVariableVector + unique_live_vars: LiveVariableVector + + +class GraphAnalyser: + + def __init__(self, gm: GraphModule): + self._gm = gm + self._graph = gm.graph + + @property + def gm(self) -> GraphModule: + """ + Return the GraphModule object associated with this analyser. + """ + return self._gm + + @property + def graph(self) -> Graph: + """ + Return the Graph object associated with this analyser. + """ + return self._graph + + def liveness_analysis(self) -> List[LiveStage]: + """ + Analyse the graph to obtain the variable liveness information. This function returns + an ordered dictionary where the key is the compute stage ID and the value is a LivenessStage object. + """ + compute_nodes = self.graph.nodes + liveness_list = [] + + # checked: record all variables created since the first stage + # all: record the live variables only exist until the current stage. + # this can be different from the `checked list`` as some varialbes may be destroyed prior to this stage. + # unique: record the unique live variables only exist until the current stage. + # this is different from `all list` as some variables are duplicated. + checked_variables = LiveVariableVector() + all_live_variables = LiveVariableVector() + unique_live_vars = LiveVariableVector() + + for idx, node in enumerate(compute_nodes): + ############################# + # find new living variables # + ############################# + # detect whether the current op is an in-place op + # if it is an in-place op, we would deem it as a duplciate var + is_inplace = False + if node.op == 'call_function': + # check if this is an inplace op such as torch.nn.functional.relu(x, inplace=True) + if node.kwargs.get('inplace', False): + is_inplace = True + elif node.op == 'call_module': + # to check if this is an inplace op such as torch.nn.Relu(inplace=True) + module = get_node_module(node) + if getattr(module, 'inplace', False): + is_inplace = True + + # add the output var + meta = getattr(node, '_meta_data', None) + live_var = LiveVariable(name=node.name, node=node, is_inplace=is_inplace) + if not is_inplace: + unique_live_vars.append(live_var) + checked_variables.append(live_var) + all_live_variables.append(live_var) + + # check if any input is not checked yet + for arg in node.args: + if not isinstance(arg, Node): + continue + arg_name = arg.name + if not checked_variables.exists(arg_name): + live_var_from_arg = LiveVariable(name=arg_name, node=node, is_inplace=False) + all_live_variables.append(live_var_from_arg) + checked_variables.append(live_var_from_arg) + unique_live_vars.append(live_var_from_arg) + + # TODO: add the logic to remove live variables + # this should be completed if we are able to trace the backward compute graph + + # add this stage to liveness dict + stage = LiveStage(name=node.name, + node=node, + all_live_vars=all_live_variables.copy(), + unique_live_vars=unique_live_vars.copy()) + # if a LiveStage is covered by another LiveStage, we just keep the larger one. + replace = False + for index, prev_stage in enumerate(liveness_list): + all_covered = True + for ele in prev_stage.unique_live_vars: + if ele not in stage.unique_live_vars: + all_covered = False + break + if all_covered: + replace = True + break + if replace: + liveness_list[index] = stage + else: + liveness_list.append(stage) + + return liveness_list + + def get_alias_set(self): + pass diff --git a/colossalai/auto_parallel/tensor_shard/solver/solver.py b/colossalai/auto_parallel/tensor_shard/solver/solver.py new file mode 100644 index 0000000000000000000000000000000000000000..f5c6663dce80671199dcaf235380178622813310 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/solver/solver.py @@ -0,0 +1,501 @@ +"""This code is adapted from Alpa + https://github.com/alpa-projects/alpa/ + with some changes. """ + +import multiprocessing +import time +import warnings +from typing import Dict + +import numpy as np +from torch.fx.graph import Graph +from torch.fx.node import Node + +from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST + +from .cost_graph import CostGraph +from .graph_analysis import GraphAnalyser +from .strategies_constructor import StrategiesConstructor + +try: + import pulp + from pulp import LpMinimize, LpProblem, LpStatus, LpVariable, lpDot, lpSum +except: + warnings.warn(f'please install the pulp') + +__all___ = ['Solver'] + + +class Solver: + + def __init__(self, + graph: Graph, + strategies_constructor: StrategiesConstructor, + cost_graph: CostGraph, + graph_analyser: GraphAnalyser = None, + memory_budget: float = -1.0, + solution_numbers: int = 1, + forward_only: bool = False, + memory_increasing_coefficient: float = 1.3, + verbose=False): + ''' + Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph. + Argument: + graph: The computing graph to be optimized. + strategies_constructor: It will provide all the possible strategies for each node in the computing graph. + cost_graph: A graph data structure to simplify the edge cost graph. + graph_analyser: graph_analyser will analyse the graph to obtain the variable liveness information, which will be used to generate memory constraints. + memory_budget: Memory constraint for the solution. + solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget. + memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget. + ''' + self.graph = graph + self.strategies_constructor = strategies_constructor + self.cost_graph = cost_graph + self.graph_analyser = graph_analyser + self.leaf_strategies = self.strategies_constructor.leaf_strategies + self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies] + self.strategy_map = self.strategies_constructor.strategy_map + self.memory_budget = memory_budget + self.solution_numbers = solution_numbers + self.forward_only = forward_only + if self.solution_numbers > 1: + self.memory_increasing_coefficient = memory_increasing_coefficient + else: + self.memory_increasing_coefficient = 1 + # temporarily we use all nodes as liveness list, we count the backward memory cost together with + # forward memory cost into the node memory cost, and no activation checkpoint is used in this phase. + # self.liveness_list = self.graph_analyser.liveness_analysis() + self.liveness_list = self.nodes + self.node_index_dict = self._generate_node_index_dict() + # The last solution vector of auto sharding. + self.last_s_val = None + # The last objective value of the best ILP solution. + self.last_objective = None + self.verbose = verbose + + def _recover_merged_node_strategy(self): + ''' + During cost graph constructing, some nodes, such as unary element-wise node or ReshapeOp, were merged into the previous node. + Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged + node. + ''' + for node_index, node in enumerate(self.nodes): + if node.strategies_vector.check_merge(): + # the merged node has only one input, and its strategies follow the input sharding strategy + input_strategies_vector = node.args[0].strategies_vector + input_best_strategy_index = self.last_s_val[node_index - 1] + input_sharding_spec = input_strategies_vector[input_best_strategy_index].output_sharding_spec + for strategy_index, strategy in enumerate(node.strategies_vector): + if strategy.input_shardings[0].sharding_sequence == input_sharding_spec.sharding_sequence: + self.last_s_val[node_index] = strategy_index + break + + def _generate_node_index_dict(self) -> Dict[Node, int]: + node_index_dict = {} + for index, strategies_vector in enumerate(self.leaf_strategies): + node_index_dict[strategies_vector.node] = index + return node_index_dict + + def _prepare_data_for_solver(self): + ''' + Extract information from components for solver. + ''' + node_nums = len(self.leaf_strategies) + memory_budget = self.memory_budget + + # prepare strategies_len + strategies_len = [] + for node in self.nodes: + strategies_len.append(self.cost_graph.node_lens[node]) + strategies_len = np.array(strategies_len) + + # prepare following_nodes + following_nodes = self.cost_graph.following_dict + index_following_nodes = {} + for src, target in following_nodes.items(): + src_index = self.node_index_dict[src] + target_index = self.node_index_dict[target] + index_following_nodes[src_index] = target_index + following_nodes = index_following_nodes + for index in range(node_nums): + if index not in following_nodes: + following_nodes[index] = -1 + + # prepare edge_pairs and resharding costs + edge_pairs = [] + resharding_costs = [] + for pairs, edge_cost in self.cost_graph.edge_costs.items(): + src_node = pairs[0] + dst_node = pairs[1] + src_node_index = self.node_index_dict[src_node] + dst_node_index = self.node_index_dict[dst_node] + edge_pairs.append(src_node_index) + edge_pairs.append(dst_node_index) + + for i in range(strategies_len[src_node_index]): + for j in range(strategies_len[dst_node_index]): + resharding_costs.append(edge_cost[(i, j)]) + edge_pairs = np.array(edge_pairs) + resharding_costs = np.array(resharding_costs) + + # prepare liveness_set + liveness_set = self.liveness_list + + # omit alias_set now + alias_set = self.strategies_constructor.alias_set + alias_convert_costs = None + + # prepare compute_costs, communication_costs and memory_costs + compute_costs = [] + communication_costs = [] + memory_costs = [] + extra_node_costs = self.cost_graph.extra_node_costs + for strategies_vector in self.leaf_strategies: + node = strategies_vector.node + for index, strategy in enumerate(strategies_vector): + compute_cost_item = strategy.compute_cost + communication_cost_item = strategy.communication_cost + memory_cost_item = strategy.memory_cost + + if self.forward_only: + origin_communication_cost = communication_cost_item.fwd + compute_cost = compute_cost_item.fwd + # extract MemoryCost item from the memory TrainCycleItem + memory_cost = memory_cost_item.fwd + else: + origin_communication_cost = communication_cost_item.total + compute_cost = compute_cost_item.total + # extract MemoryCost item from the memory TrainCycleItem + memory_cost = memory_cost_item.total + + # extract the memory cost in float from MemoryCost item and sum them up + memory_cost = memory_cost.parameter + memory_cost.activation + memory_cost.buffer + compute_costs.append(compute_cost) + # node in extra_node_costs means it has some extra communication + # cost from node merging, so we need to add those extra communication + # cost into + if node in extra_node_costs: + extra_node_cost = extra_node_costs[node][index] + communication_cost = origin_communication_cost + extra_node_cost + communication_costs.append(communication_cost) + else: + communication_costs.append(origin_communication_cost) + memory_costs.append(memory_cost) + + compute_costs = np.array(compute_costs) + communication_costs = np.array(communication_costs) + memory_costs = np.array(memory_costs) + + # omit initial value for nodes + s_init_np = None + + return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np, self.verbose + + def _call_solver_serialized_args(self, + node_nums, + memory_budget, + strategies_len, + following_nodes, + edge_pairs, + alias_set, + liveness_set, + compute_costs, + communication_costs, + memory_costs, + resharding_costs, + alias_convert_costs, + s_init_np=None, + verbose=True): + """ + Call the solver with serialized arguments. + """ + + tic = time.time() + + for x in [strategies_len, edge_pairs, compute_costs, communication_costs, memory_costs, resharding_costs]: + assert isinstance(x, np.ndarray) + assert len(strategies_len) == node_nums, "strategies_len" + + def get_non_zero_index(binary_vector): + """ + Get the index of non-zero item in a vector. + """ + ct = 0 + ret = None + for i, elem in enumerate(binary_vector): + if pulp.value(elem): + ret = i + ct += 1 + + assert ct == 1 + return ret + + # 0. Unpack flatten numpy arrays + s_follow = following_nodes + s_alias = alias_set + + E = edge_pairs.reshape((-1, 2)) # noqa + r = [] + pt = 0 + edge_set = set() + for (i, j) in E: + prod_length = strategies_len[i] * strategies_len[j] + + if (i, j) in edge_set: + raise ValueError(f"Duplicated edges: {(i, j)}") + + edge_set.add((i, j)) + r.append(resharding_costs[pt:pt + prod_length]) + pt += prod_length + assert pt == len(resharding_costs) + + ###################### + # omit alias set now # + ###################### + + # A = alias_set.reshape((-1, 2)) # noqa + # for (i, j) in A: + # prod_length = strategies_len[i] * strategies_len[j] + # v.append(alias_convert_costs[pt:pt + prod_length]) + # pt += prod_length + # assert pt == len(alias_convert_costs) + + # L = [] # noqa + # pt = node_nums + # for i in range(node_nums): + # length = liveness_set[i] + # L.append(liveness_set[pt:pt + length]) + # pt += length + # assert pt == len(liveness_set) + v = [] + pt = 0 + + c = [] + d = [] + m = [] + pt = 0 + for i in range(node_nums): + length = strategies_len[i] + c.append(compute_costs[pt:pt + length]) + d.append(communication_costs[pt:pt + length]) + m.append(memory_costs[pt:pt + length]) + pt += length + assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}" + assert pt == len(communication_costs), f"{pt} == {len(communication_costs)}" + assert pt == len(memory_costs), f"{pt} == {len(memory_costs)}" + + # 1. Create variables + + ############################# + # create variables for node # + ############################# + s = [] + num_nodes = 0 + reverse_follow_backpatch = [] + for i in range(node_nums): + if s_follow[i] < 0: + if strategies_len[i] == 1: + s.append([1]) + else: + if i not in s_alias: + num_nodes += 1 + s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary")) + else: + s.append(s[s_alias[i]]) + else: + if s_follow[i] < len(s): + s.append(s[s_follow[i]]) + else: + s.append(None) + reverse_follow_backpatch.append(i) + + for i in reverse_follow_backpatch: + s[i] = s[s_follow[i]] + + ############################# + # create variables for edge # + ############################# + e = [] + num_edges = 0 + map_edge_to_idx = {} + for (idx, (i, j)) in enumerate(E): + if len(s[i]) == 1: + e.append(s[j]) + elif len(s[j]) == 1: + e.append(s[i]) + else: + if i in s_alias and j in s_alias and (s_alias[i], s_alias[j]) in map_edge_to_idx: + e.append(e[map_edge_to_idx[(s_alias[i], s_alias[j])]]) + else: + num_edges += 1 + e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary")) + assert len(e[idx]) == len(r[idx]) + map_edge_to_idx[(i, j)] = idx + for element in s: + assert len(element) > 0 + # 2. Set initial value + ###################################### + # set a initial value for warm start # + ###################################### + if s_init_np is not None: + s_init = s_init_np.reshape((-1, 3)) + for (idx, value, fix) in s_init: + for i in range(len(s[idx])): + s[idx][i].setInitialValue(i == value) + if fix: + s[idx][i].fixValue() + + # 3. Objective + prob = LpProblem("myProblem", LpMinimize) + ################################################################### + # computing the node cost(computing cost and communication cost) # + ################################################################### + obj = 0 + for i in range(node_nums): + assert len(s[i]) == len(c[i]) + assert len(s[i]) == len(d[i]) + + obj += lpDot(s[i], c[i]) + lpDot(s[i], d[i]) + + ############################################# + # computing the edge cost(resharding cost) # + ############################################# + for i in range(len(E)): + assert len(e[i]) == len(r[i]) + obj += lpDot(e[i], r[i]) + + prob += obj + + # 4. Constraints + # (a). specified by `cat="Binary"` + + # (b) + ################################################# + # make sure each node only choose one strategy # + ################################################# + for i in range(node_nums): + if s_follow[i] < 0: + prob += lpSum(s[i]) == 1 + + # (c) + ################################################# + # compute memory consumption with liveness set # + ################################################# + if memory_budget > 0: + mem = 0 + for node in liveness_set: + if node not in self.node_index_dict: + continue + node_index = self.node_index_dict[node] + mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index]))) + prob += mem <= memory_budget + + # (d). specified by `cat="Binary"` + + for (idx, (i, j)) in enumerate(E): + if strategies_len[i] == 1 or strategies_len[j] == 1: + continue + + # (e) + prob += lpSum(e[idx]) == 1 + + # (f) + for row in range(len(s[i])): + C = len(s[j]) # noqa + prob += lpSum(e[idx][row * C + col] for col in range(0, C)) <= s[i][row] + + # (g) + for col in range(len(s[j])): + R = len(s[i]) # noqa + C = len(s[j]) # noqa + prob += lpSum(e[idx][row * C + col] for row in range(0, R)) <= s[j][col] + + # (h) + ###################### + # omit alias set now # + ###################### + + # alias_set = set() + # for (idx, (i, j)) in enumerate(A): + # R = len(s[i]) # noqa + # C = len(s[j]) # noqa + # if (i, j) in alias_set: + # raise ValueError(f"Duplicated edges: {(i, j)}") + + # alias_set.add((i, j)) + # alias_set.add((j, i)) + + # for row in range(len(s[i])): + # for col in range(len(s[j])): + # if v[idx][row * C + col] > 0.5: + # prob += s[i][row] + s[j][col] <= 1 + + msg = verbose + time_limit = 600 + assert "COIN_CMD" in pulp.listSolvers( + onlyAvailable=True), ("Please install ILP solvers by 'sudo apt install coinor-cbc'") + + solver = pulp.COIN_CMD(mip=True, msg=msg, timeLimit=time_limit, threads=multiprocessing.cpu_count()) + # solver = pulp.GLPK_CMD(mip=True, msg=msg, timeLimit=time_limit) + prob.solve(solver) + + status = prob.status + objective = pulp.value(prob.objective) + objective = float(objective) if objective is not None else -1.0 + if verbose: + print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t" + f"Time: {time.time() - tic}") + print(f"#nodes: {num_nodes}, #edges: {num_edges}") + + if prob.status in [pulp.LpStatusInfeasible]: + raise RuntimeError("Cannot run the function under the given memory budget. " + "Please increase the memory budget.") + + # Get and check results + s_val = np.full((node_nums,), -1, dtype=np.int32) + for i in range(node_nums): + s_val[i] = get_non_zero_index(s[i]) + + e_val = np.full((len(E),), -1, dtype=np.int32) + for (idx, (i, j)) in enumerate(E): + e_val[idx] = get_non_zero_index(e[idx]) + i_spec_index = e_val[idx] // len(s[j]) + j_spec_index = e_val[idx] % len(s[j]) + assert i_spec_index == s_val[i], f"e_val[{i}][{j}]" + assert j_spec_index == s_val[j], f"e_val[{i}][{j}]" + if verbose and r[idx][e_val[idx]] > 0: + print(f"Edge cost {(i, j)} : {r[idx][e_val[idx]]}") + + self.last_s_val = list(s_val) + # self._recover_merged_node_strategy() + self.last_objective = objective + + if objective > INFINITY_COST: + warnings.warn("Detect unexpected behaviors in the auto-sharding pass.") + + return self.last_s_val, e_val, self.last_objective, status + + def call_solver_serialized_args(self): + """ + Call the solver with serialized arguments and handle python errors. Additionally, + we could give a serious of solutions with different memory budget. + """ + if self.solution_numbers == 1: + args = self._prepare_data_for_solver() + ret = self._call_solver_serialized_args(*args) + + return ret + + origin_memory_budget = self.memory_budget + memory_budget_list = [ + origin_memory_budget * self.memory_increasing_coefficient**i for i in range(self.solution_numbers) + ] + ret_list = [] + for memory_budget in memory_budget_list: + self.memory_budget = memory_budget + args = self._prepare_data_for_solver() + ret = self._call_solver_serialized_args(*args) + ret_list.append(ret) + + return ret_list diff --git a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py new file mode 100644 index 0000000000000000000000000000000000000000..044a8ac847ead4b6b7d9f05c3d19a43a8fc2346c --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py @@ -0,0 +1,198 @@ +import builtins +import math +import operator +from copy import deepcopy +from typing import Dict, List + +import torch +from torch.fx import Graph, Node + +from colossalai.auto_parallel.tensor_shard.node_handler import ( + GetattrHandler, + OutputHandler, + PlaceholderHandler, + operator_registry, +) +from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector +from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec +from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks +from colossalai.device.device_mesh import DeviceMesh + +from ..options import DataloaderOption, SolverOptions + +__all__ = ['StrategiesConstructor'] + + +class StrategiesConstructor: + """ + StrategiesConstructor is used to construct the parallelization plan for the model execution. + + Args: + graph (Graph): a Graph object used for analysis and strategy generation. + device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster. + solver_options (SolverOptions): a SolverOptions object which specifies the preferences for plan searching. + """ + + def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: SolverOptions): + self.graph = graph + assert graph.owning_module is not None, 'The given graph is not associated with a owning_module' + self.root_module = self.graph.owning_module + self.nodes = list(graph.nodes) + self.device_mesh = device_mesh + self.leaf_strategies = [] + self.strategy_map = {} + self.solver_options = solver_options + self.no_strategy_nodes = [] + self.alias_set = None + + def remove_duplicated_strategy(self, strategies_vector): + ''' + In build_strategies_and_cost method, we may produce some duplicated strategies. + In this method, we will remove the duplicated strategies depending on the strategies name. + Note that this operation is in-place. + ''' + name_checklist = [] + remove_list = [] + for strategy in strategies_vector: + if strategy.name not in name_checklist: + name_checklist.append(strategy.name) + else: + remove_list.append(strategy) + for strategy in remove_list: + strategies_vector.remove(strategy) + + def generate_alias_set(self): + + node_list = [strategy_vector.node for strategy_vector in self.leaf_strategies] + common_blocks = find_repeat_blocks(node_list, self.root_module, common_length_threshold=10) + + repeat_block_nums = len(common_blocks) + alias_set = {} + + if repeat_block_nums == 0: + return alias_set + + for index, common_node in enumerate(common_blocks[0]): + for i in range(1, repeat_block_nums): + alias_set[node_list.index(common_blocks[i][index])] = node_list.index(common_node) + return alias_set + + def build_strategies_and_cost(self): + """ + This method is to build the strategy vector for each node in the computation graph. + """ + + def _check_no_strategy_for_node(node): + if node.op in ('placeholder', 'get_attr', 'output'): + return False + + def _check_no_strategy_for_data(data): + label = True + if isinstance(data, torch.Tensor): + return False + elif isinstance(data, (tuple, list)): + for d in data: + label = label and _check_no_strategy_for_data(d) + return label + + return _check_no_strategy_for_data(node._meta_data) + + for node in self.nodes: + strategies_vector = StrategiesVector(node) + + if _check_no_strategy_for_node(node): + self.no_strategy_nodes.append(node) + pass + + # placeholder node + elif node.op == 'placeholder': + if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED: + placeholder_option = 'distributed' + else: + assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported' + placeholder_option = 'replicated' + placeholder_handler = PlaceholderHandler(node, + self.device_mesh, + strategies_vector, + placeholder_option=placeholder_option) + placeholder_handler.register_strategy() + + # get_attr node + elif node.op == 'get_attr': + getattr_handler = GetattrHandler(node, + self.device_mesh, + strategies_vector, + shard_option=self.solver_options.shard_option, + solver_perference=self.solver_options.solver_perference) + getattr_handler.register_strategy() + + # call_module node + elif node.op == 'call_module': + target = node.target + submod = self.root_module.get_submodule(target) + submod_type = type(submod) + handler = operator_registry.get(submod_type)(node, + self.device_mesh, + strategies_vector, + shard_option=self.solver_options.shard_option, + solver_perference=self.solver_options.solver_perference) + handler.register_strategy() + # attach strategies_info to node + if hasattr(handler, 'strategies_info'): + setattr(node, 'strategies_info', handler.strategies_info) + + # call_function node + elif node.op == 'call_function': + target = node.target + handler = operator_registry.get(target)(node, + self.device_mesh, + strategies_vector, + shard_option=self.solver_options.shard_option, + solver_perference=self.solver_options.solver_perference) + handler.register_strategy() + # attach strategies_info to node + if hasattr(handler, 'strategies_info'): + setattr(node, 'strategies_info', handler.strategies_info) + + # call_method node + elif node.op == 'call_method': + method = getattr(node.args[0]._meta_data.__class__, node.target) + handler = operator_registry.get(method)(node, + self.device_mesh, + strategies_vector, + shard_option=self.solver_options.shard_option, + solver_perference=self.solver_options.solver_perference) + handler.register_strategy() + # attach strategies_info to node + if hasattr(handler, 'strategies_info'): + setattr(node, 'strategies_info', handler.strategies_info) + + # output node + elif node.op == 'output': + if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED: + output_option = 'distributed' + else: + assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported' + output_option = 'replicated' + output_handler = OutputHandler(node, self.device_mesh, strategies_vector, output_option=output_option) + output_handler.register_strategy() + + self.remove_duplicated_strategy(strategies_vector) + setattr(node, 'strategies_vector', strategies_vector) + self.leaf_strategies.append(strategies_vector) + self.strategy_map[node] = strategies_vector + + # remove no strategy nodes + remove_list = [] + for strategies_vector in self.leaf_strategies: + if len(strategies_vector) == 0: + remove_list.append(strategies_vector.node) + + for node in remove_list: + if node.strategies_vector in self.leaf_strategies: + self.leaf_strategies.remove(node.strategies_vector) + if node in self.strategy_map: + self.strategy_map.pop(node) + + alias_set = self.generate_alias_set() + self.alias_set = alias_set diff --git a/colossalai/auto_parallel/tensor_shard/utils/__init__.py b/colossalai/auto_parallel/tensor_shard/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7fe5430bf136b08d93706b33cf4dbf82e342013 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/utils/__init__.py @@ -0,0 +1,25 @@ +from .broadcast import ( + BroadcastType, + comm_actions_for_oprands, + get_broadcast_shape, + is_broadcastable, + recover_sharding_spec_for_broadcast_shape, +) +from .factory import generate_resharding_costs, generate_sharding_spec +from .misc import check_sharding_spec_validity, ignore_sharding_exception, pytree_map +from .reshape import check_keep_sharding_status, detect_reshape_mapping, infer_output_dim_partition_dict +from .sharding import ( + enumerate_all_possible_1d_sharding, + enumerate_all_possible_2d_sharding, + generate_sharding_size, + transpose_partition_dim, + update_partition_dim, +) + +__all__ = [ + 'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape', + 'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity' + 'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding', + 'enumerate_all_possible_2d_sharding', 'generate_sharding_size', 'comm_actions_for_oprands', 'pytree_map', + 'detect_reshape_mapping', 'check_keep_sharding_status', 'infer_output_dim_partition_dict' +] diff --git a/colossalai/auto_parallel/tensor_shard/utils/broadcast.py b/colossalai/auto_parallel/tensor_shard/utils/broadcast.py new file mode 100644 index 0000000000000000000000000000000000000000..28aa551328d7a6d5f283338fe55c90eb102d253c --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/utils/broadcast.py @@ -0,0 +1,160 @@ +from enum import Enum, auto +from typing import List + +import torch +from torch.fx.node import Node + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommAction, + CommType, + OperationData, + OperationDataType, +) +from colossalai.tensor.comm_spec import CollectiveCommPattern, CommSpec +from colossalai.tensor.sharding_spec import ShardingSpec + +__all__ = [ + 'BroadcastType', 'is_broadcastable', 'get_broadcast_shape', 'recover_sharding_spec_for_broadcast_shape', + 'comm_actions_for_oprands' +] + + +class BroadcastType(Enum): + EQUAL = auto() + PADDDING = auto() + MULTIPLE = auto() + + +def is_broadcastable(shape1: torch.Size, shape2: torch.Size) -> bool: + """ + Check if two shapes are broadcastable to each other. + """ + for s1, s2 in zip(shape1[::-1], shape2[::-1]): + if s1 == 1 or s2 == 1 or s1 == s2: + pass + else: + return False + return True + + +def get_broadcast_shape(shape1: torch.Size, shape2: torch.Size) -> List[int]: + """ + Compute the broadcast shape given two shapes. + """ + assert is_broadcastable(shape1, shape2), f'{shape1} and {shape2} are not broadcastable' + shape1_reverse = shape1[::-1] + shape2_reverse = shape2[::-1] + min_common_dim = min(len(shape1), len(shape2)) + dims = [] + for s1, s2 in zip(shape1_reverse, shape2_reverse): + dims.append(max(s1, s2)) + + # append the remaining dims + dims.extend(shape1_reverse[min_common_dim:]) + dims.extend(shape2_reverse[min_common_dim:]) + return dims[::-1] + + +def get_broadcast_dim_info(logical_shape, physical_shape): + # get the number of dimensions + logical_num_dims = len(logical_shape) + physical_num_dims = len(physical_shape) + + assert logical_num_dims >= physical_num_dims, \ + 'The number of dimensions in the logical shape is smaller than that of the physical shape, this tensor is not broadcast!' + + # track the dim and its broadcasting type + logical_dim_broadcast_info = {} + + for i in range(logical_num_dims): + # get the trailing dim size + logical_dim_idx = logical_num_dims - i - 1 + phyiscal_dim_idx = physical_num_dims - i - 1 + logical_dim_size = logical_shape[logical_dim_idx] + + if phyiscal_dim_idx >= 0: + physical_dim_size = physical_shape[phyiscal_dim_idx] + + if physical_dim_size == logical_dim_size: + logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.EQUAL + elif physical_dim_size == 1 and physical_dim_size != logical_dim_size: + logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.MULTIPLE + else: + logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.PADDDING + + return logical_dim_broadcast_info + + +def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpec, logical_shape: torch.Size, + physical_shape: torch.Size) -> ShardingSpec: + """ + This function computes the sharding spec for the physical shape of a broadcast tensor. + + Args: + logical_sharding_spec (ShardingSpec): the sharding spec for the broadcast tensor + logical_shape (torch.Size): logical shape is the broadcast shape of a tensor + physical_shape (torch.Size): the shape of the tensor before broadcasting + """ + # if the two shapes are the same, no broadcast occurs + # we directly return the current sharding spec + + # recording the sharding dimensions removed during logical shape converting to physical one + removed_dims = [] + if list(logical_shape) == list(physical_shape): + return logical_sharding_spec, removed_dims + + # get the number of dimensions + logical_num_dims = len(logical_shape) + physical_num_dims = len(physical_shape) + + # get the broadcast info + logical_dim_broadcast_info = get_broadcast_dim_info(logical_shape, physical_shape) + + # generate the sharding spec for the physical shape + physical_dim_partition = {} + logical_dim_partition = logical_sharding_spec.dim_partition_dict + + for shape_dim, mesh_dim in logical_dim_partition.items(): + logical_broadcast_type = logical_dim_broadcast_info[shape_dim] + + if logical_broadcast_type == BroadcastType.PADDDING or logical_broadcast_type == BroadcastType.MULTIPLE: + removed_dims.extend(mesh_dim) + else: + # get the corresponding physical dim + physical_dim = physical_num_dims - (logical_num_dims - shape_dim) + physical_dim_partition[physical_dim] = mesh_dim + + physical_sharding_spec = ShardingSpec(device_mesh=logical_sharding_spec.device_mesh, + entire_shape=physical_shape, + dim_partition_dict=physical_dim_partition) + + return physical_sharding_spec, removed_dims + + +def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: OperationData, + sharding_spec: ShardingSpec) -> CommAction: + """ + This method is used to generate communication actions for oprands which lose information + during convert logical shape to physical shape. + """ + if len(removed_dims) == 1: + # if list length is 1, extract element from list to avoid using flatten device mesh + removed_dims = removed_dims[0] + comm_spec = CommSpec(comm_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + sharding_spec=sharding_spec, + logical_process_axis=removed_dims) + if op_data.type == OperationDataType.PARAM: + comm_type = CommType.HOOK + else: + comm_type = CommType.BEFORE + arg_index = -1 + for index, arg in enumerate(node.args): + if op_data.name == str(arg): + arg_index = index + assert arg_index >= 0, f'op_data should be an argument of node.' + comm_action = CommAction( + comm_spec=comm_spec, + comm_type=comm_type, + arg_index=arg_index, + ) + return comm_action diff --git a/colossalai/auto_parallel/tensor_shard/utils/factory.py b/colossalai/auto_parallel/tensor_shard/utils/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..05331e56000110a982cc776a24eb81d45fceb825 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/utils/factory.py @@ -0,0 +1,206 @@ +import copy +import operator +import warnings +from functools import reduce +from typing import Dict, List, Optional, Union + +import torch +from torch.fx.node import Node +from torch.utils._pytree import tree_map + +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec + +from ..constants import INFINITY_COST + +__all__ = ['generate_sharding_spec', 'generate_resharding_costs'] + + +def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh, + dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec: + """ + Generate the sharding spec of the tensor based on the given dim_partition_dict. + + + Args: + input_ (Union[Node, torch.Tensor]): the input can be a Node object or a PyTorch tensor. If a node is used, it will look for its meta data associated with this node. + device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster. + dim_partition_dict (Dict[int, List[int]]): a dictionary to specify the sharding specs, the key is the tensor dimension and the value is the mesh dimension for sharding. + """ + + if isinstance(input_, Node): + assert hasattr(input_, '_meta_data'), f'The given node has no attribte _meta_data' + meta_tensor = input_._meta_data + assert meta_tensor is not None, "The given node's _meta_data attribute is None" + shape = meta_tensor.shape + elif isinstance(input_, torch.Tensor): + shape = input_.shape + else: + raise TypeError( + f'We cannot generate sharding spec for {type(input_)} type, only torch.fx.Node or torch.Tensor is expected.' + ) + for dim_index, sharding_index_list in dim_partition_dict.items(): + sharding_list = [device_mesh.mesh_shape[sharding_index] for sharding_index in sharding_index_list] + sharding_size = reduce(operator.mul, sharding_list, 1) + assert shape[ + dim_index] % sharding_size == 0, f'we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions.' + + sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=shape, dim_partition_dict=dim_partition_dict) + return sharding_spec + + +def generate_resharding_costs(nodes: List[Node], + sharding_specs: List[ShardingSpec], + count_backward: Optional[bool] = True, + dtype: Optional[torch.dtype] = None, + index=None): + ''' + Compute the resharding costs with this specific strategy. + + Argument: + nodes (List[Node]): a list of nodes + sharding_spec_for_input(ShardingSpec): a list of ShardingSpec for the nodes. + count_backward (Optional[bool]): whether to include the cost of resharding in the backward pass, default is True. False can be used for inference. + dtype (Optional[torch.dtype]): the data type for cost calculation, default is None. + ''' + # The resharding_cost of weight is counted due to sharing weight cases. + resharding_costs = {} + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + + # shape consistency manager is a singleton class + shape_consistency_manager = ShapeConsistencyManager() + + for input_node, input_spec in zip(nodes, sharding_specs): + resharding_costs[input_node] = [] + for strategy in input_node.strategies_vector: + input_sharding_spec = strategy.output_sharding_spec + if not isinstance(input_sharding_spec, ShardingSpec): + assert isinstance(input_sharding_spec, list), 'only ShardingSpec or List[ShardingSpec] is expected.' + input_sharding_spec = input_sharding_spec[index] + assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.' + try: + # compute the resharding cost + _, _, total_resharding_cost = shape_consistency_manager.shape_consistency( + input_sharding_spec, input_spec) + + # we need multiply the size of elem dtype to get correct communication cost + resharding_cost = total_resharding_cost["total"] * size_per_elem_bytes + except AssertionError as e: + warnings.warn(f'{e}') + resharding_cost = INFINITY_COST + resharding_costs[input_node].append(resharding_cost) + return resharding_costs + + +def find_repeat_blocks(node_list: List[torch.fx.Node], root_module, common_length_threshold: int = 20): + ''' + Find the largest repeat blocks in the graph, whose length is larger than the threshold. + + Args: + gm (GraphModule): the graph module to be analyzed. + common_length_threshold (int): the threshold of the repeat block length. + ''' + + # graph = gm.graph + + def _process_args(args): + new_args = [] + for arg in args: + if hasattr(arg, '_meta_data'): + meta_data = arg._meta_data + else: + meta_data = arg + + def _process_arg(data): + if isinstance(data, torch.Tensor): + data = data.size() + elif isinstance(data, slice): + data = (data.start, data.step, data.stop) + return data + + new_meta_data = tree_map(_process_arg, meta_data) + new_args.append(new_meta_data) + + return new_args + + def _all_equal(check_list, check_fn): + base_value = check_list[-1] + for e in check_list: + if not check_fn(e, base_value): + return False + return True + + def _check_node_list_equal(l1, l2): + if len(l1) != len(l2): + return False + for node1, node2 in zip(l1, l2): + if hash(node1.hash_key) != hash(node2.hash_key): + return False + return True + + def _check_node_equal(node1, node2): + if hash(node1.hash_key) == hash(node2.hash_key): + return True + return False + + for index, node in enumerate(node_list): + if node.op == 'call_module': + target = node.target + submod = root_module.get_submodule(target) + submod_type = type(submod) + target = submod_type + else: + target = node.target + + new_args = _process_args(node.args) + + if node.op != 'get_attr': + hash_key = (node.op, target, *new_args) + else: + hash_key = (node.op,) + + setattr(node, 'hash_key', hash_key) + + hash_value_to_node_dict = {} + + for index, node in enumerate(node_list): + hash_value = hash(node.hash_key) + if hash_value not in hash_value_to_node_dict: + hash_value_to_node_dict[hash_value] = [] + hash_value_to_node_dict[hash_value].append(index) + + # node_list = list(graph.nodes) + + node_list_start = 0 + max_common_length = common_length_threshold + common_blocks_index = [] + for index, node in enumerate(node_list): + # the comparison will be triggered if a common node appears + if len(hash_value_to_node_dict[hash(node.hash_key)]) >= 2: + start_index_list = hash_value_to_node_dict[hash(node.hash_key)] + check_block_list = [node_list[start:start + max_common_length] for start in start_index_list] + + common_label = True + if not _all_equal(check_block_list, _check_node_list_equal): + common_label = False + + if common_label: + common_blocks_index = copy.deepcopy(start_index_list) + max_step = len(node_list) - common_blocks_index[-1] - max_common_length - 1 + + for i in range(max_step): + # add assertion to avoid out of index + next_node_list = [node_list[index + max_common_length + i] for index in start_index_list] + if not _all_equal(next_node_list, _check_node_equal): + max_step = i + break + max_common_length += max_step + node_list_start += max_common_length + + # recover common subgraph from the index + common_blocks = [] + for start in common_blocks_index: + common_blocks.append(node_list[start:start + max_common_length]) + + return common_blocks diff --git a/colossalai/auto_parallel/tensor_shard/utils/misc.py b/colossalai/auto_parallel/tensor_shard/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..9e402dab757820c5d76ee6d1166de473c040784b --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/utils/misc.py @@ -0,0 +1,97 @@ +import functools +from typing import Any, Callable, Dict, List, Tuple, Type, Union + +import torch + +from colossalai.logging import get_dist_logger +from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException + +__all__ = ['ignore_sharding_exception', 'pytree_map'] + + +def ignore_sharding_exception(func): + """ + A function wrapper to handle the ShardingSpecException in the function. + If ShardingSpecException occurs, this function will return None. + + Usage: + # mute the assertion error in the function + @ignore_sharding_exception + def do_something(): + ... + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + try: + logger = get_dist_logger() + rst = func(*args, **kwargs) + return rst + except ShardingSpecException as e: + logger.debug(e) + return None + + return wrapper + + +def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tensor): + """ + This function checks whether the ShardingSpec is valid for the physical tensor. + This check includes 3 items: + 1. the sharding spec covers all dimensions of the physical tensor + 2. the sharding spec for each dimension is divisible by the number of devices. + 3. the sharding spec's entire shape must match the tensor shape + # + """ + # make sure all dims are covered in sharding spec + sharding_len = len(sharding_spec.sharding_sequence) + tensor_num_dim = tensor.dim() + num_devices_in_col = sharding_spec.device_mesh.mesh_shape[0] + num_devices_in_row = sharding_spec.device_mesh.mesh_shape[1] + assert sharding_len == tensor_num_dim, \ + f'The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape}).' + + # make sure the sharding is valid for each dim + for i in range(tensor_num_dim): + dim_size = tensor.shape[i] + dim_spec = sharding_spec.sharding_sequence[i] + + if str(dim_spec).startswith('S'): + devices_str = str(dim_spec).lstrip('S') + num_devices = 1 + + if '0' in devices_str: + num_devices *= num_devices_in_col + if '1' in devices_str: + num_devices *= num_devices_in_row + + assert dim_size >= num_devices and dim_size % num_devices == 0, \ + f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.' + + # make sure the entire shape matches the physical tensor shape + assert sharding_spec.entire_shape == tensor.shape, \ + f'The entire_shape of the sharding spec {sharding_spec.entire_shape} does not match the tensor shape {tensor.shape}' + + +def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any: + """process object recursively, like pytree + + Args: + obj (:class:`Any`): object to process + fn (:class:`Callable`): a function to process subobject in obj + process_types (:class: `type | tuple[type]`): types to determine the type to process + map_all (:class: `bool`): if map_all is True, then any type of element will use fn + + Returns: + :class:`Any`: returns have the same structure of `obj` and type in process_types after map of `fn` + """ + if isinstance(obj, dict): + return {k: pytree_map(obj[k], fn, process_types, map_all) for k in obj} + elif isinstance(obj, tuple): + return tuple(pytree_map(o, fn, process_types, map_all) for o in obj) + elif isinstance(obj, list): + return list(pytree_map(o, fn, process_types, map_all) for o in obj) + elif isinstance(obj, process_types): + return fn(obj) + else: + return fn(obj) if map_all else obj diff --git a/colossalai/auto_parallel/tensor_shard/utils/reshape.py b/colossalai/auto_parallel/tensor_shard/utils/reshape.py new file mode 100644 index 0000000000000000000000000000000000000000..a32a14bf7d577713ae2cb986ffbb42d87b0cabc1 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/utils/reshape.py @@ -0,0 +1,192 @@ +from enum import Enum +from typing import Dict, List, Tuple + +import torch + + +class PreviousStatus(Enum): + """ + This class shows the status of previous comparision. + """ + RESET = 0 + # ORIGIN means the dimension size of original tensor is larger in the previous comparision. + ORIGIN = 1 + # TGT means the dimension size of target tensor is larger in the previous comparision. + TGT = 2 + + +def detect_reshape_mapping(origin_shape: torch.Size, tgt_shape: torch.Size) -> Dict[Tuple[int], Tuple[int]]: + """ + This method is used to detect the reshape mapping between original tensor and target tensor. + + Returns: + reshape_mapping_dict: The dictionary shows how a tuple of origin dims(keys) mapping to the related + target dims(values) during reshaping operation. + Examples: + import torch + origin_shape = torch.Size([4, 4, 4]) + tgt_shape = torch.Size([2, 8, 2, 2]) + reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape) + print(reshape_mapping_dict) + Output: + {(2,): (3, 2), (1, 0): (1,), (0,): (0, 1)} + """ + + # reverse the shape object + origin_shape = list(origin_shape) + tgt_shape = list(tgt_shape) + origin_shape.reverse() + tgt_shape.reverse() + + # initialize arguments + reshape_mapping_dict = {} + origin_len = len(origin_shape) + tgt_len = len(tgt_shape) + origin_index = 0 + tgt_index = 0 + original_dimension_size = origin_shape[origin_index] + tgt_dimension_size = tgt_shape[tgt_index] + tgt_dims = [tgt_len - tgt_index - 1] + origin_dims = [origin_len - origin_index - 1] + previous_label = PreviousStatus.RESET + + while origin_index != len(origin_shape) or tgt_index != len(tgt_shape): + if original_dimension_size == tgt_dimension_size: + reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims) + # if the origin_dims has no element, it means the original tensor has been fully matched. + # Therefore, we do not have to increase the origin_index for that case. + if len(origin_dims) > 0: + origin_index += 1 + # if the tgt_dims has no element, it means the original tensor has been fully matched. + # Therefore, we do not have to increase the tgt_index for that case. + if len(tgt_dims) > 0: + tgt_index += 1 + # the last step of loop should always end with condition + # so we need to manually skip the preparation for next step + # in the last step. + if origin_index == len(origin_shape) and tgt_index == len(tgt_shape): + continue + + # If origin_index equals to origin_len, we just need to set the original_dimension_size + # to 1 to match the remaining '1's in the target tensor shape. + if origin_index == len(origin_shape): + original_dimension_size = 1 + origin_dims = [] + else: + original_dimension_size = origin_shape[origin_index] + origin_dims = [origin_len - origin_index - 1] + + # If tgt_index equals to tgt_len, we just need to set the tgt_dimension_size + # to 1 to match the remaining '1's in the original tensor shape. + if tgt_index == len(tgt_shape): + tgt_dimension_size = 1 + tgt_dims = [] + else: + tgt_dimension_size = tgt_shape[tgt_index] + tgt_dims = [tgt_len - tgt_index - 1] + + previous_label = PreviousStatus.RESET + + elif original_dimension_size > tgt_dimension_size: + tgt_index += 1 + + if previous_label == PreviousStatus.TGT: + # if the target dimension size is larger in the previous comparision, which means + # the origin dimension size has already accumulated larger than target dimension size, so + # we need to offload the origin dims and tgt dims into the reshape_mapping_dict. + reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims) + original_dimension_size = original_dimension_size // tgt_dimension_size + origin_dims = [origin_len - origin_index - 1] + tgt_dimension_size = tgt_shape[tgt_index] + tgt_dims = [tgt_len - tgt_index - 1, tgt_len - tgt_index] + # reset the previous_label after offloading the origin dims and tgt dims + previous_label = PreviousStatus.RESET + else: + # accumulate the tgt_dimension_size until tgt_dimension_size larger than original_dimension_size + tgt_dimension_size *= tgt_shape[tgt_index] + tgt_dims.append(tgt_len - tgt_index - 1) + previous_label = PreviousStatus.ORIGIN + + else: + origin_index += 1 + + if previous_label == PreviousStatus.ORIGIN: + # if the origin element is larger in the previous comparision, which means + # the target element has already accumulated larger than origin element, so + # we need to offload the origin dims and tgt dims into the reshape_mapping_dict. + reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims) + tgt_dimension_size = tgt_dimension_size // original_dimension_size + tgt_dims = [tgt_len - tgt_index - 1] + original_dimension_size = origin_shape[origin_index] + origin_dims = [origin_len - origin_index - 1, origin_len - origin_index] + # reset the previous_label after offloading the origin dims and tgt dims + previous_label = PreviousStatus.RESET + else: + # accumulate the original_dimension_size until original_dimension_size larger than tgt_dimension_size + original_dimension_size *= origin_shape[origin_index] + origin_dims.append(origin_len - origin_index - 1) + previous_label = PreviousStatus.TGT + + return reshape_mapping_dict + + +def check_keep_sharding_status(input_dim_partition_dict: Dict[int, List[int]], + reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]) -> bool: + """ + This method is used to check whether the reshape operation could implement without converting + the input to fully replicated status. + + Rule: + For a sharded dimension of input tensor, if it is not the minimum element of the input tuple, + the function will return false. + To illustrate this issue, there are two cases to analyse: + 1. no sharded dims in the input tuple: we could do the reshape operation safely just as the normal + operation without distributed tensor. + 2. sharded dims in the input tuple: the sharded dim must be the minimum element, then during shape + consistency process, torch.cat will be implemented on the sharded dim, and everything after the sharded + dim get recovered. + + Examples: + # the second dimension of the input has been sharded. + input_dim_partition_dict = {1: [1]} + origin_shape = torch.Size([8, 4, 2]) + tgt_shape = torch.Size([2, 4, 8]) + reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape) + # {(2, 1): (2,), (0,): (1, 0)} + # the sharded dim of input is 1, which is the minimum element of the tuple (2, 1), + # so we do not have to convert the input to fully replicated status. + print(check_keep_sharding_status(input_dim_partition_dict, reshape_mapping_dict)) + + Output: + True + """ + sharded_dims = list(input_dim_partition_dict.keys()) + for input_dims in reshape_mapping_dict.keys(): + # if input_dims has no element, we could just skip this iteration. + if len(input_dims) == 0: + continue + min_element = min(input_dims) + for dim in input_dims: + if dim in sharded_dims and dim is not min_element: + return False + return True + + +def infer_output_dim_partition_dict(input_dim_partition_dict: Dict[int, List[int]], + reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]) -> Dict[Tuple[int], Tuple[int]]: + """ + This method is used to infer the output dim partition dict for a reshape operation, + given the input dim partition dict and reshape mapping dict. + """ + assert check_keep_sharding_status(input_dim_partition_dict, reshape_mapping_dict), \ + 'we only infer output dim partition dict for the reshape operation could keep sharding spec.' + sharded_dims = list(input_dim_partition_dict.keys()) + output_dim_partition_dict = {} + for input_dims, output_dims in reshape_mapping_dict.items(): + for dim in input_dims: + if dim in sharded_dims: + output_dim_partition_dict[min(output_dims)] = input_dim_partition_dict[dim] + # we could break because input dims cannot contain two sharded dims, otherwise + # the keep sharding status check will fail. + break + return output_dim_partition_dict diff --git a/colossalai/auto_parallel/tensor_shard/utils/sharding.py b/colossalai/auto_parallel/tensor_shard/utils/sharding.py new file mode 100644 index 0000000000000000000000000000000000000000..e2ce59e0b5772679be11e960322e3110c500d6aa --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/utils/sharding.py @@ -0,0 +1,120 @@ +import operator +from copy import deepcopy +from functools import reduce +from typing import Dict + +import torch + +from colossalai.tensor.sharding_spec import ShardingSpec + +__all__ = [ + 'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding', + 'enumerate_all_possible_2d_sharding', 'generate_sharding_size' +] + + +def transpose_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -> ShardingSpec: + """ + Switch the sharding mesh dimensions for two tensor dimensions. This operation is in-place. + + Args: + sharding_spec (ShardingSpec): the sharding spec for which partition dim are switched + dim1 (int): the tensor dimension to switch + dim2 (int): the tensor dimension to switch + """ + assert len(sharding_spec.entire_shape) >= 2, \ + 'The entire_shape of the sharding spec must have at least 2 dimensions' + dim_partition_dict = sharding_spec.dim_partition_dict + + # transpose the dim partition + dim1_partition = dim_partition_dict.pop(dim1, None) + dim2_partition = dim_partition_dict.pop(dim2, None) + + if dim1_partition: + dim_partition_dict[dim2] = dim1_partition + if dim2_partition: + dim_partition_dict[dim1] = dim2_partition + + # get the transposed shape + new_shape = list(sharding_spec.entire_shape[:]) + new_shape[dim2], new_shape[dim1] = new_shape[dim1], new_shape[dim2] + new_shape = torch.Size(new_shape) + + # re-init the sharding spec + sharding_spec.__init__(sharding_spec.device_mesh, new_shape, dim_partition_dict) + return sharding_spec + + +def update_partition_dim(sharding_spec: ShardingSpec, + dim_mapping: Dict[int, int], + physical_shape: torch.Size, + inplace: bool = False): + """ + This method is used to update the partition dim dict from the logical one to the physical one. + + Args: + sharding_spec (ShardingSpec): the sharding spec for which partition dims are updated + dim_mapping (Dict[int, int]): the mapping from the logical tensor dimension to the physical tensor dimension + physical_shape (torch.Size): the physical shape for the tensor + """ + + if inplace: + current_sharding_spec = sharding_spec + else: + current_sharding_spec = deepcopy(sharding_spec) + + old_dim_partition_dict = current_sharding_spec.dim_partition_dict + new_dim_partition_dict = {} + + # assign new dim + for old_dim, new_dim in dim_mapping.items(): + mesh_dims = old_dim_partition_dict.pop(old_dim) + new_dim_partition_dict[new_dim] = mesh_dims + + for tensor_dim, mesh_dims in old_dim_partition_dict.items(): + if tensor_dim in new_dim_partition_dict: + raise KeyError(f"There are duplicated entries for the tensor sharding dimension {tensor_dim}") + else: + new_dim_partition_dict[tensor_dim] = mesh_dims + + # update sharding spec + current_sharding_spec.__init__(device_mesh=sharding_spec.device_mesh, + entire_shape=physical_shape, + dim_partition_dict=new_dim_partition_dict) + return current_sharding_spec + + +def enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size): + dim_partition_list = [] + # enumerate all the 2D sharding cases + for i in range(dim_size): + for j in range(i + 1, dim_size): + dim_partition_dict_0 = {i: [mesh_dim_0], j: [mesh_dim_1]} + dim_partition_dict_1 = {i: [mesh_dim_1], j: [mesh_dim_0]} + dim_partition_list.append(dim_partition_dict_0) + dim_partition_list.append(dim_partition_dict_1) + for i in range(dim_size): + dim_partition_dict_flatten = {i: [mesh_dim_0, mesh_dim_1]} + dim_partition_list.append(dim_partition_dict_flatten) + + return dim_partition_list + + +def enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size): + dim_partition_list = [] + # enumerate all the 1D sharding cases + for i in range(dim_size): + dim_partition_dict_0 = {i: [mesh_dim_0]} + dim_partition_list.append(dim_partition_dict_0) + + return dim_partition_list + + +def generate_sharding_size(dim_partition_dict, device_mesh): + total_sharding_size = 1 + for mesh_dim_list in dim_partition_dict.values(): + mesh_dim_sharding_size = [device_mesh.shape[mesh_dim] for mesh_dim in mesh_dim_list] + sharding_size = reduce(operator.mul, mesh_dim_sharding_size) + total_sharding_size *= sharding_size + + return total_sharding_size diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py new file mode 100644 index 0000000000000000000000000000000000000000..d0a467254d7279c37031a755fa62a53fc4e0d9b9 --- /dev/null +++ b/colossalai/autochunk/autochunk_codegen.py @@ -0,0 +1,561 @@ +from typing import Any, Callable, Dict, Iterable, List, Tuple + +import torch + +import colossalai +from colossalai.fx._compatibility import is_compatible_with_meta +from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE + +AUTOCHUNK_AVAILABLE = CODEGEN_AVAILABLE and is_compatible_with_meta() + +if AUTOCHUNK_AVAILABLE: + from torch.fx.graph import CodeGen, PythonCode, _custom_builtins, _CustomBuiltin, _format_target, _is_from_torch, _Namespace, _origin_type_map, inplace_methods, magic_methods + +from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg + +from .search_chunk import SearchChunk +from .utils import delete_free_var_from_last_use, get_logger, get_node_name, get_node_shape + + +def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) -> str: + """ + Generate chunk slice string, eg. [:, :, chunk_idx_name:chunk_idx_name + chunk_size, :] + + Args: + chunk_dim (int) + chunk_indice_name (str): chunk indice name + shape (List): node shape + + Returns: + new_shape (str): return slice + """ + new_shape = "[" + for idx, _ in enumerate(shape): + if idx == chunk_dim: + new_shape += "%s:%s + chunk_size" % (chunk_indice_name, chunk_indice_name) + else: + new_shape += ":" + new_shape += ", " + new_shape = new_shape[:-2] + "]" + return new_shape + + +def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_ouput_dim: int, chunk_size=2) -> str: + """ + Generate chunk loop start + + eg. chunk_result = torch.empty([100, 100], dtype=input_node.dtype, device=input_node.device) + chunk_size = 32 + for chunk_idx in range(0, 100, 32): + ...... + + Args: + chunk_input (List[Node]): chunk input node + chunk_output (Node): chunk output node + chunk_ouput_dim (int): chunk output node chunk dim + chunk_size (int): chunk size. Defaults to 2. + + Returns: + context (str): generated str + """ + input_node = chunk_input[0] + + context = "" + for i in range(len(chunk_output)): + shape_str = str(list(get_node_shape(chunk_output[i]))) + if get_node_name(chunk_output[i]) in ["split", "unbind"]: + tensor_str = "torch.empty(%s, dtype=%s.dtype, device=%s.device), " % (shape_str, input_node.name, + input_node.name) + tensor_str = tensor_str * len(chunk_output[i].meta['tensor_meta']) + tensor_str = "[" + tensor_str[:-2] + "]" + context += "%s = %s; " % (chunk_output[i].name, tensor_str) + else: + context += "%s = torch.empty(%s, dtype=%s.dtype, device=%s.device); " % (chunk_output[i].name, shape_str, + input_node.name, input_node.name) + + out_shape = get_node_shape(chunk_output[0]) + chunk_shape = out_shape[chunk_ouput_dim[0]] + context += "chunk_size = %d\nfor chunk_idx in range(0, %d, chunk_size):\n" % (chunk_size, chunk_shape) + return context + + +def _gen_loop_end(chunk_inputs: List[Node], chunk_non_compute_inputs: List[Node], node_list: List[Node], + chunk_outputs_idx: int, chunk_outputs_non_tensor: List[Node], search_chunk: SearchChunk) -> str: + """ + Generate chunk loop end + + eg. chunk_result[chunk_idx:chunk_idx + chunk_size] = output_node + output_node = chunk_result; xx = None; xx = None + + Args: + chunk_inputs (List[Node]): chunk input node + chunk_non_compute_inputs (List[Node]): input node without chunk + chunk_outputs (Node): chunk output node + chunk_outputs_dim (int): chunk output node chunk dim + node_list (List) + + Returns: + context (str): generated str + """ + context = "chunk_size = None" + # determine if its the last use for chunk input + for chunk_input in chunk_inputs + chunk_non_compute_inputs: + if all([search_chunk.node_mgr.find_node_idx(user) <= chunk_outputs_idx for user in chunk_input.users.keys()]): + context += "; %s = None" % chunk_input.name + for chunk_output_non_tensor, chunk_output_non_tensor_val in chunk_outputs_non_tensor.items(): + context += "; %s = %s" % (chunk_output_non_tensor.name, chunk_output_non_tensor_val) + context += "\n" + return context + + +def _replace_name(context: str, name_from: str, name_to: str) -> str: + """ + replace node name + """ + patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ","), (" ", ")"), (" ", ""), ("", " ")] + for p in patterns: + source = p[0] + name_from + p[1] + target = p[0] + name_to + p[1] + if source in context: + context = context.replace(source, target) + break + return context + + +def _replace_reshape_size(context: str, node_name: str, reshape_size_dict: Dict) -> str: + """ + replace reshape size, some may have changed due to chunk + """ + if node_name not in reshape_size_dict: + return context + context = context.replace(reshape_size_dict[node_name][0], reshape_size_dict[node_name][1]) + return context + + +def _replace_new_tensor_like_shape( + search_chunk: SearchChunk, + chunk_infos: List[Dict], + region_idx: int, + node_idx: int, + node: Node, + body: List[str], +) -> List[str]: + """ + add chunk slice for new tensor op such as ones like + """ + if get_node_name(node) in ["ones_like", "zeros_like", "empty_like"]: + meta_node = search_chunk.node_mgr.get_node_by_idx(node_idx) + chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"] + if get_node_shape(meta_node)[chunk_dim] != 1: + source_node = meta_node.args[0].args[0] + if (source_node not in chunk_infos[region_idx]["node_chunk_dim"] + or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"] is None): + chunk_slice = _gen_chunk_slice_dim(chunk_dim, "chunk_idx", get_node_shape(node)) + body[-1] = _replace_name(body[-1], node.args[0].name, node.args[0].name + chunk_slice) + return body + + +def _replace_new_tensor_shape( + search_chunk: SearchChunk, + chunk_infos: List[Dict], + region_idx: int, + node_idx: int, + node: Node, + body: List[str], +) -> List[str]: + """ + add chunk slice for new tensor op such as ones + """ + if get_node_name(node) in ["ones", "zeros", "empty"]: + meta_node = search_chunk.node_mgr.get_node_by_idx(node_idx) + chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"] + if chunk_dim is None: + return + if get_node_shape(meta_node)[chunk_dim] == 1: + return + origin_shape = str(node.args) + new_shape = list(node.args) + new_shape[chunk_dim] = "min(chunk_size, %d - chunk_idx)" % get_node_shape(meta_node)[chunk_dim] + new_shape = str(new_shape) + new_shape = new_shape.replace("'", "") + body[-1] = _replace_name(body[-1], origin_shape[1:-1], new_shape[1:-1]) + return body + + +def _add_node_slice( + chunk_nodes: List[Node], + region_idx: int, + chunk_nodes_dim: Dict, + node_idx: int, + body: List[str], + node: Node, +) -> List[str]: + """ + add chunk slice for input nodes + """ + for chunk_node_idx, chunk_node in enumerate(chunk_nodes[region_idx]): + # inputs node + if isinstance(chunk_nodes_dim[region_idx][chunk_node_idx], dict): + for idx, dim in chunk_nodes_dim[region_idx][chunk_node_idx].items(): + if idx == node_idx: + chunk_slice = _gen_chunk_slice_dim(dim[0], "chunk_idx", get_node_shape(chunk_node)) + body[-1] = _replace_name(body[-1], chunk_node.name, chunk_node.name + chunk_slice) + # outputs node + else: + if chunk_node.name == node.name or (chunk_node.name in [i.name for i in node.all_input_nodes]): + chunk_slice = _gen_chunk_slice_dim(chunk_nodes_dim[region_idx][chunk_node_idx], "chunk_idx", + get_node_shape(chunk_node)) + if get_node_name(chunk_node) in ["split", "unbind"]: + split_chunk_slice = "" + for i in range(len(chunk_node.meta['tensor_meta'])): + split_chunk_slice += "%s[%d]%s, " % (chunk_node.name, i, chunk_slice) + split_chunk_slice = split_chunk_slice[:-2] + body[-1] = _replace_name(body[-1], chunk_node.name, split_chunk_slice) + else: + body[-1] = _replace_name(body[-1], chunk_node.name, chunk_node.name + chunk_slice) + return body + + +def emit_code_with_chunk(body: List[str], + nodes: Iterable[Node], + emit_node_func: Callable, + delete_unused_value_func: Callable, + search_chunk: SearchChunk, + chunk_infos: List, + eval_mem: bool = False): + """ + Emit code with chunk according to chunk_infos. + + It will generate a for loop in chunk regions, and + replace inputs and outputs of regions with chunked variables. + + Args: + body: forward code + nodes: graph.nodes + emit_node_func: function to emit node + delete_unused_value_func: function to remove the unused value + search_chunk: the class to search all chunks + chunk_infos: store all information about all chunks. + """ + node_list = list(nodes) + + # chunk region + chunk_starts = [i["region"][0] for i in chunk_infos] + chunk_ends = [i["region"][1] for i in chunk_infos] + + # chunk inputs + chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk + chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] # input without chunk + chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim + chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [j.name for i in chunk_inputs_non_chunk for j in i] + + # chunk outputs + chunk_outputs = [i["outputs"] for i in chunk_infos] + chunk_outputs_non_tensor = [i["outputs_non_tensor"] for i in chunk_infos] + chunk_outputs_dim = [i["outputs_dim"] for i in chunk_infos] + + node_list = search_chunk.reorder_graph.reorder_node_list(node_list) + node_idx = 0 + region_idx = 0 + within_chunk_region = False + + if eval_mem: + body.append("init_memory = torch.cuda.memory_allocated() / 1024**2\n") + + while node_idx < len(node_list): + node = node_list[node_idx] + + # if is chunk start, generate for loop start + if node_idx in chunk_starts: + within_chunk_region = True + region_idx = chunk_starts.index(node_idx) + body.append( + _gen_loop_start( + chunk_inputs[region_idx], + chunk_outputs[region_idx], + chunk_outputs_dim[region_idx], + chunk_infos[region_idx]["chunk_size"], + )) + + if within_chunk_region: + emit_node_func(node, body) + # replace input var with chunk var + body = _add_node_slice(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body, node) + # replace output var with chunk var + body = _add_node_slice(chunk_outputs, region_idx, chunk_outputs_dim, node_idx, body, node) + # new tensor like + body = _replace_new_tensor_like_shape(search_chunk, chunk_infos, region_idx, node_idx, node, body) + # new tensor + body = _replace_new_tensor_shape(search_chunk, chunk_infos, region_idx, node_idx, node, body) + # reassign reshape size + body[-1] = _replace_reshape_size(body[-1], node.name, chunk_infos[region_idx]["reshape_size"]) + body[-1] = " " + body[-1] + delete_unused_value_func(node, body, chunk_inputs_names) + if eval_mem: + body.append( + " if chunk_idx == 0:\n print('%s', torch.cuda.max_memory_allocated() / 1024**2 - init_memory); torch.cuda.reset_peak_memory_stats()\n" + % (node.name)) + else: + emit_node_func(node, body) + if node_idx not in chunk_inputs: + delete_unused_value_func(node, body, chunk_inputs_names) + if eval_mem: + body.append( + "print('%s', torch.cuda.max_memory_allocated() / 1024**2 - init_memory); torch.cuda.reset_peak_memory_stats()\n" + % (node.name)) + + # generate chunk region end + if node_idx in chunk_ends: + body.append( + _gen_loop_end(chunk_inputs[region_idx], chunk_inputs_non_chunk[region_idx], node_list, + chunk_ends[region_idx], chunk_outputs_non_tensor[region_idx], search_chunk)) + within_chunk_region = False + + node_idx += 1 + + +if AUTOCHUNK_AVAILABLE: + + class AutoChunkCodeGen(CodeGen): + + def __init__(self, + meta_graph, + max_memory: int = None, + print_mem: bool = False, + print_progress: bool = False, + eval_mem: bool = False) -> None: + super().__init__() + self.eval_mem = eval_mem + # find the chunk regions + self.search_chunk = SearchChunk(meta_graph, max_memory, print_mem, print_progress) + self.chunk_infos = self.search_chunk.search_region() + if print_progress: + get_logger().info("AutoChunk start codegen") + + def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode: + free_vars: List[str] = [] + body: List[str] = [] + globals_: Dict[str, Any] = {} + wrapped_fns: Dict[str, None] = {} + + # Wrap string in list to pass by reference + maybe_return_annotation: List[str] = [""] + + def add_global(name_hint: str, obj: Any): + """Add an obj to be tracked as a global. + + We call this for names that reference objects external to the + Graph, like functions or types. + + Returns: the global name that should be used to reference 'obj' in generated source. + """ + if (_is_from_torch(obj) and obj != torch.device): # to support registering torch.device + # HACK: workaround for how torch custom ops are registered. We + # can't import them like normal modules so they must retain their + # fully qualified name. + return _get_qualified_name(obj) + + # normalize the name hint to get a proper identifier + global_name = namespace.create_name(name_hint, obj) + + if global_name in globals_: + assert globals_[global_name] is obj + return global_name + globals_[global_name] = obj + return global_name + + # set _custom_builtins here so that we needn't import colossalai in forward + _custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai) + + # Pre-fill the globals table with registered builtins. + for name, (_, obj) in _custom_builtins.items(): + add_global(name, obj) + + def type_repr(o: Any): + if o == (): + # Empty tuple is used for empty tuple type annotation Tuple[()] + return "()" + + typename = _type_repr(o) + + if hasattr(o, "__origin__"): + # This is a generic type, e.g. typing.List[torch.Tensor] + origin_type = _origin_type_map.get(o.__origin__, o.__origin__) + origin_typename = add_global(_type_repr(origin_type), origin_type) + + if hasattr(o, "__args__"): + # Assign global names for each of the inner type variables. + args = [type_repr(arg) for arg in o.__args__] + + if len(args) == 0: + # Bare type, such as `typing.Tuple` with no subscript + # This code-path used in Python < 3.9 + return origin_typename + + return f'{origin_typename}[{",".join(args)}]' + else: + # Bare type, such as `typing.Tuple` with no subscript + # This code-path used in Python 3.9+ + return origin_typename + + # Common case: this is a regular module name like 'foo.bar.baz' + return add_global(typename, o) + + def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str: + + def _get_repr(arg): + # Handle NamedTuples (if it has `_fields`) via add_global. + if isinstance(arg, tuple) and hasattr(arg, "_fields"): + qualified_name = _get_qualified_name(type(arg)) + global_name = add_global(qualified_name, type(arg)) + return f"{global_name}{repr(tuple(arg))}" + return repr(arg) + + args_s = ", ".join(_get_repr(a) for a in args) + kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items()) + if args_s and kwargs_s: + return f"{args_s}, {kwargs_s}" + return args_s or kwargs_s + + # Run through reverse nodes and record the first instance of a use + # of a given node. This represents the *last* use of the node in the + # execution order of the program, which we will use to free unused + # values + node_to_last_use: Dict[Node, Node] = {} + user_to_last_uses: Dict[Node, List[Node]] = {} + + def register_last_uses(n: Node, user: Node): + if n not in node_to_last_use: + node_to_last_use[n] = user + user_to_last_uses.setdefault(user, []).append(n) + + for node in reversed(nodes): + map_arg(node.args, lambda n: register_last_uses(n, node)) + map_arg(node.kwargs, lambda n: register_last_uses(n, node)) + + delete_free_var_from_last_use(user_to_last_uses) + + # NOTE: we add a variable to distinguish body and ckpt_func + def delete_unused_values(user: Node, body, to_keep=[]): + """ + Delete values after their last use. This ensures that values that are + not used in the remainder of the code are freed and the memory usage + of the code is optimal. + """ + if user.op == "placeholder": + return + if user.op == "output": + body.append("\n") + return + nodes_to_delete = user_to_last_uses.get(user, []) + nodes_to_delete = [i for i in nodes_to_delete if i.name not in to_keep] + if len(nodes_to_delete): + to_delete_str = " = ".join([repr(n) for n in nodes_to_delete] + ["None"]) + body.append(f"; {to_delete_str}\n") + else: + body.append("\n") + + # NOTE: we add a variable to distinguish body and ckpt_func + def emit_node(node: Node, body): + maybe_type_annotation = ("" if node.type is None else f" : {type_repr(node.type)}") + if node.op == "placeholder": + assert isinstance(node.target, str) + maybe_default_arg = ("" if not node.args else f" = {repr(node.args[0])}") + free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}") + raw_name = node.target.replace("*", "") + if raw_name != repr(node): + body.append(f"{repr(node)} = {raw_name}\n") + return + elif node.op == "call_method": + assert isinstance(node.target, str) + body.append( + f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}" + f"({_format_args(node.args[1:], node.kwargs)})") + return + elif node.op == "call_function": + assert callable(node.target) + # pretty print operators + if (node.target.__module__ == "_operator" and node.target.__name__ in magic_methods): + assert isinstance(node.args, tuple) + body.append(f"{repr(node)}{maybe_type_annotation} = " + f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}") + return + + # pretty print inplace operators; required for jit.script to work properly + # not currently supported in normal FX graphs, but generated by torchdynamo + if (node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods): + body.append(f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; " + f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}") + return + + qualified_name = _get_qualified_name(node.target) + global_name = add_global(qualified_name, node.target) + # special case for getattr: node.args could be 2-argument or 3-argument + # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value + if (global_name == "getattr" and isinstance(node.args, tuple) and isinstance(node.args[1], str) + and node.args[1].isidentifier() and len(node.args) == 2): + body.append( + f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}") + return + body.append( + f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})") + if node.meta.get("is_wrapped", False): + wrapped_fns.setdefault(global_name) + return + elif node.op == "call_module": + assert isinstance(node.target, str) + body.append(f"{repr(node)}{maybe_type_annotation} = " + f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})") + return + elif node.op == "get_attr": + assert isinstance(node.target, str) + body.append(f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}") + return + elif node.op == "output": + if node.type is not None: + maybe_return_annotation[0] = f" -> {type_repr(node.type)}" + body.append(self.generate_output(node.args[0])) + return + raise NotImplementedError(f"node: {node.op} {node.target}") + + # Modified for activation checkpointing + ckpt_func = [] + + # if any node has a list of labels for activation_checkpoint, we + # will use nested type of activation checkpoint codegen + emit_code_with_chunk(body, nodes, emit_node, delete_unused_values, self.search_chunk, self.chunk_infos, + self.eval_mem) + + if len(body) == 0: + # If the Graph has no non-placeholder nodes, no lines for the body + # have been emitted. To continue to have valid Python code, emit a + # single pass statement + body.append("pass\n") + + if len(wrapped_fns) > 0: + wrap_name = add_global("wrap", torch.fx.wrap) + wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns]) + else: + wrap_stmts = "" + + if self._body_transformer: + body = self._body_transformer(body) + + for name, value in self.additional_globals(): + add_global(name, value) + + # as we need colossalai.utils.checkpoint, we need to import colossalai + # in forward function + prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) + prologue = "".join(ckpt_func) + prologue + prologue = prologue + + code = "".join(body) + code = "\n".join(" " + line for line in code.split("\n")) + fn_code = f""" +{wrap_stmts} + +{prologue} +{code}""" + # print(fn_code) + return PythonCode(fn_code, globals_) diff --git a/colossalai/autochunk/estimate_memory.py b/colossalai/autochunk/estimate_memory.py new file mode 100644 index 0000000000000000000000000000000000000000..77bc2ef17bc3bc5faca5903ddd4dfc5a7653275a --- /dev/null +++ b/colossalai/autochunk/estimate_memory.py @@ -0,0 +1,240 @@ +import copy +from typing import Any, Callable, Dict, Iterable, List, Tuple + +import torch +from torch.fx.node import Node + +from colossalai.fx.profiler import activation_size, parameter_size + +from .utils import NodeMgr, get_node_shape, is_non_memory_node + + +class EstimateMemory(object): + """ + Estimate memory with chunk + """ + + def __init__(self) -> None: + pass + + def _get_node_size(self, x: Node) -> float: + """ + return node size in MB + """ + x = x.meta["tensor_meta"] + if not hasattr(x, "numel"): + out = sum([i.numel * torch.tensor([], dtype=i.dtype).element_size() for i in x]) + else: + out = x.numel * torch.tensor([], dtype=x.dtype).element_size() + out = float(out) / 1024**2 + return out + + def _add_active_node(self, n: Node, active_nodes: Dict, chunk_ratio: float) -> None: + """ + add an active node and its shape to active node dict + """ + if get_node_shape(n) is None: + return + if n.op == "placeholder": + return + if n not in active_nodes: + node_size = self._get_node_size(n) * chunk_ratio + active_nodes[n] = node_size + + def _build_delete_node_dict(self, node_mgr: NodeMgr) -> Dict: + """ + build delete node dict, means node should be deleted at what time + """ + delete_node_dict = {} + for idx, node in enumerate(node_mgr.get_node_list()): + # skip non shape node + if get_node_shape(node) is None: + continue + # dont remove free nodes + elif node.op == "placeholder": + delete_node_dict[node] = len(node_mgr.get_node_list()) + # node no user + elif len(node.users) == 0: + delete_node_dict[node] = idx + # log max use + else: + node_user_idx = [node_mgr.find_node_idx(i) for i in node.users.keys()] + delete_node_dict[node] = max(node_user_idx) + return delete_node_dict + + def _remove_deactive_node(self, + user_idx: int, + user: Node, + active_nodes: List, + delete_node_dict: List, + kept_nodes: List = None) -> None: + """ + remove deactivate nodes from active nodes + """ + if kept_nodes is None: + kept_nodes = [] + if user.op in ("output",): + return + + for node in list(active_nodes.keys()): + # dont delete kept nodes + if node in kept_nodes: + continue + # should be deleted + if delete_node_dict[node] <= user_idx: + active_nodes.pop(node) + + def _get_tmp_memory(self, node, not_contiguous_list, delete=False): + mem = 0 + not_contiguous_ops = ["permute"] + + if node.op == "call_function" and any(n in node.name for n in ["matmul", "reshape"]): + for n in node.args: + if n in not_contiguous_list: + # matmul won't change origin tensor, but create a tmp copy + mem += self._get_node_size(n) + elif node.op == "call_module": + for n in node.args: + if n in not_contiguous_list: + # module will just make origin tensor to contiguous + if delete: + not_contiguous_list.remove(n) + elif node.op == "call_method" and any(i in node.name for i in not_contiguous_ops): + if node not in not_contiguous_list: + not_contiguous_list.append(node) + return mem + + def _get_chunk_ratio(self, node, chunk_node_dim, chunk_size): + if node not in chunk_node_dim: + return 1.0 + node_shape = get_node_shape(node) + chunk_dim = chunk_node_dim[node]["chunk_dim"] + if chunk_dim is None: + return 1.0 + else: + return chunk_size / float(node_shape[chunk_dim]) + + def _print_compute_op_mem_log(self, log, nodes, title=None): + if title: + print(title) + for idx, (l, n) in enumerate(zip(log, nodes)): + if n.op in ["placeholder", "get_attr", "output"]: + continue + if any(i in n.name for i in ["getitem", "getattr"]): + continue + print("%s:%.2f \t" % (n.name, l), end="") + if (idx + 1) % 3 == 0: + print("") + print("\n") + + def _add_active_nodes_from_list(self, active_nodes: List, nodes: List) -> List: + """ + add active nodes from nodes + """ + for n in nodes: + self._add_active_node(n, active_nodes, 1) + + def _get_memory_from_active_nodes(self, active_nodes: Dict) -> float: + """ + sum all memory of active nodes + """ + out = [i for i in active_nodes.values()] + out = sum(out) + return out + + def estimate_chunk_inference_mem(self, node_list: List, chunk_infos: Dict = None, print_mem: bool = False): + """ + Estimate inference memory with chunk + + Args: + node_list (List): _description_ + chunk_infos (Dict): Chunk information. Defaults to None. + print_mem (bool): Wether to print peak memory of every node. Defaults to False. + + Returns: + act_memory_peak_log (List): peak memory of every node + act_memory_after_node_log (List): memory after executing every node + active_node_list_log (List): active nodes of every node. active nodes refer to + nodes generated but not deleted. + """ + act_memory = 0.0 + act_memory_peak_log = [] + act_memory_after_node_log = [] + active_nodes = {} + active_nodes_log = [] + not_contiguous_list = [] + node_mgr = NodeMgr(node_list) + delete_node_dict = self._build_delete_node_dict(node_mgr) + + use_chunk = True if chunk_infos is not None else False + chunk_within = False + chunk_region_idx = None + chunk_ratio = 1 # use it to estimate chunk mem + chunk_inputs_all = [] + + if use_chunk: + chunk_regions = [i["region"] for i in chunk_infos] + chunk_starts = [i[0] for i in chunk_regions] + chunk_ends = [i[1] for i in chunk_regions] + chunk_inputs = [i["inputs"] for i in chunk_infos] + chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] + chunk_inputs_all = [j for i in chunk_inputs for j in i] + [j for i in chunk_inputs_non_chunk for j in i] + chunk_outputs = [i["outputs"] for i in chunk_infos] + chunk_node_dim = [i["node_chunk_dim"] for i in chunk_infos] + chunk_sizes = [i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos] + + for idx, node in enumerate(node_mgr.get_node_list()): + + # if node in chunk start nodes, change chunk ratio and add chunk_tensor + if use_chunk and idx in chunk_starts: + chunk_within = True + chunk_region_idx = chunk_starts.index(idx) + self._add_active_nodes_from_list(active_nodes, chunk_outputs[chunk_region_idx]) + + # determine chunk ratio for current node + if chunk_within: + chunk_ratio = self._get_chunk_ratio(node, chunk_node_dim[chunk_region_idx], + chunk_sizes[chunk_region_idx]) + + # add current node as active node + self._add_active_node(node, active_nodes, chunk_ratio) + act_memory = self._get_memory_from_active_nodes(active_nodes) + + # if node is placeholder, just add the size of the node + if node.op == "placeholder": + act_memory_peak_log.append(act_memory) + # skip output + elif node.op == "output": + continue + # no change for non compute node + elif is_non_memory_node(node): + act_memory_peak_log.append(act_memory) + # node is a compute op, calculate tmp + else: + # forward memory + # TODO: contiguous_memory still not accurate for matmul, view, reshape and transpose + tmp_memory = self._get_tmp_memory(node, not_contiguous_list, delete=True) * chunk_ratio + # record max act memory + act_memory_peak_log.append(act_memory + tmp_memory) + + # remove_deactive_node + self._remove_deactive_node(idx, node, active_nodes, delete_node_dict, kept_nodes=chunk_inputs_all) + + # if node in chunk end nodes, restore chunk settings + if use_chunk and idx in chunk_ends: + self._remove_deactive_node(idx, node, active_nodes, delete_node_dict) # dont provide kept nodes now + chunk_within = False + chunk_ratio = 1 + chunk_region_idx = None + + act_memory = self._get_memory_from_active_nodes(active_nodes) + act_memory_after_node_log.append(act_memory) + active_nodes_log.append(active_nodes.copy()) + + if print_mem: + print("with chunk" if use_chunk else "without chunk") + self._print_compute_op_mem_log(act_memory_peak_log, node_mgr.get_node_list(), "peak") + + # param_memory = parameter_size(gm) + # all_memory = act_memory + param_memory + return act_memory_peak_log, act_memory_after_node_log, active_nodes_log diff --git a/colossalai/autochunk/reorder_graph.py b/colossalai/autochunk/reorder_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..3b00d47fb9555f0187d8bd290ecb74f78fd94f50 --- /dev/null +++ b/colossalai/autochunk/reorder_graph.py @@ -0,0 +1,111 @@ +from .trace_indice import TraceIndice +from .utils import NodeMgr + + +class ReorderGraph(object): + """ + Reorder node list and indice trace list + """ + + def __init__(self, trace_indice: TraceIndice, node_mgr: NodeMgr) -> None: + self.trace_indice = trace_indice + self.node_mgr = node_mgr + self.all_reorder_map = {i: i for i in range(len(self.node_mgr.get_node_list()))} + + def _get_reorder_map(self, chunk_info): + reorder_map = {i: i for i in range(len(self.node_mgr.get_node_list()))} + + chunk_region_start = chunk_info["region"][0] + chunk_region_end = chunk_info["region"][1] + chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"] + chunk_prepose_nodes_idx = [self.node_mgr.find_node_idx(i) for i in chunk_prepose_nodes] + # put prepose nodes ahead + for idx, n in enumerate(chunk_prepose_nodes): + n_idx = chunk_prepose_nodes_idx[idx] + reorder_map[n_idx] = chunk_region_start + idx + # put other nodes after prepose nodes + for n in self.node_mgr.get_node_slice_by_idx(chunk_region_start, chunk_region_end + 1): + if n in chunk_prepose_nodes: + continue + n_idx = self.node_mgr.find_node_idx(n) + pos = sum([n_idx < i for i in chunk_prepose_nodes_idx]) + reorder_map[n_idx] = n_idx + pos + + return reorder_map + + def _reorder_chunk_info(self, chunk_info, reorder_map): + # update chunk info + chunk_info["region"] = ( + chunk_info["region"][0] + len(chunk_info["args"]["prepose_nodes"]), + chunk_info["region"][1], + ) + new_inputs_dim = [] + for _, input_dim in enumerate(chunk_info["inputs_dim"]): + new_input_dim = {} + for k, v in input_dim.items(): + new_input_dim[reorder_map[k]] = v + new_inputs_dim.append(new_input_dim) + chunk_info["inputs_dim"] = new_inputs_dim + return chunk_info + + def _update_all_reorder_map(self, reorder_map): + for origin_idx, map_idx in self.all_reorder_map.items(): + self.all_reorder_map[origin_idx] = reorder_map[map_idx] + + def _reorder_self_node_list(self, reorder_map): + new_node_list = [None for _ in range(len(self.node_mgr.get_node_list()))] + for old_idx, new_idx in reorder_map.items(): + new_node_list[new_idx] = self.node_mgr.get_node_by_idx(old_idx) + self.node_mgr.update_node_list(new_node_list) + + def _reorder_idx_trace(self, reorder_map): + # reorder list + new_idx_trace_list = [None for _ in range(len(self.trace_indice.indice_trace_list))] + for old_idx, new_idx in reorder_map.items(): + new_idx_trace_list[new_idx] = self.trace_indice.indice_trace_list[old_idx] + self.trace_indice.indice_trace_list = new_idx_trace_list + # update compute + for idx_trace in self.trace_indice.indice_trace_list: + compute = idx_trace["compute"] + for dim_compute in compute: + for idx, i in enumerate(dim_compute): + dim_compute[idx] = reorder_map[i] + # update source + for idx_trace in self.trace_indice.indice_trace_list: + source = idx_trace["source"] + for dim_idx, dim_source in enumerate(source): + new_dim_source = {} + for k, v in dim_source.items(): + new_dim_source[reorder_map[k]] = v + source[dim_idx] = new_dim_source + + def reorder_all(self, chunk_info): + if chunk_info is None: + return chunk_info + if len(chunk_info["args"]["prepose_nodes"]) == 0: + return chunk_info + reorder_map = self._get_reorder_map(chunk_info) + self._update_all_reorder_map(reorder_map) + self._reorder_idx_trace(reorder_map) + self._reorder_self_node_list(reorder_map) + chunk_info = self._reorder_chunk_info(chunk_info, reorder_map) + return chunk_info + + def reorder_node_list(self, node_list): + new_node_list = [None for _ in range(len(node_list))] + for old_idx, new_idx in self.all_reorder_map.items(): + new_node_list[new_idx] = node_list[old_idx] + return new_node_list + + def tmp_reorder(self, node_list, chunk_info): + if len(chunk_info["args"]["prepose_nodes"]) == 0: + return node_list, chunk_info + reorder_map = self._get_reorder_map(chunk_info) + + # new tmp node list + new_node_list = [None for _ in range(len(node_list))] + for old_idx, new_idx in reorder_map.items(): + new_node_list[new_idx] = node_list[old_idx] + + chunk_info = self._reorder_chunk_info(chunk_info, reorder_map) + return new_node_list, chunk_info diff --git a/colossalai/autochunk/search_chunk.py b/colossalai/autochunk/search_chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..59645c80e8089d63a53abd8a27c7784f2f90cd8d --- /dev/null +++ b/colossalai/autochunk/search_chunk.py @@ -0,0 +1,293 @@ +import copy +from typing import Dict, List, Tuple + +from torch.fx.node import Node + +from .estimate_memory import EstimateMemory +from .reorder_graph import ReorderGraph +from .select_chunk import SelectChunk +from .trace_flow import TraceFlow +from .trace_indice import TraceIndice +from .utils import NodeMgr, get_logger, get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder + + +class SearchChunk(object): + """ + This is the core class for AutoChunk. + + It defines the framework of the strategy of AutoChunk. + Chunks will be selected one by one until search stops. + + The chunk search is as follows: + 1. find the peak memory node + 2. find the max chunk region according to the peak memory node + 3. find all possible chunk regions in the max chunk region + 4. find the best chunk region for current status + 5. goto 1 + + Attributes: + gm: graph model + print_mem (bool): print estimated memory + trace_index: trace the flow of every dim of every node to find all free dims + trace_flow: determine the region chunk strategy + reorder_graph: reorder nodes to improve chunk efficiency + estimate_memory: estimate memory with chunk + select_chunk: select the best chunk region + + Args: + gm: graph model + max_memory (int): max memory in MB + print_mem (bool): print estimated memory + """ + + def __init__(self, gm, max_memory=None, print_mem=False, print_progress=False) -> None: + self.print_mem = print_mem + self.max_memory = max_memory + self.print_progress = print_progress + self.node_mgr = NodeMgr(list(gm.graph.nodes)) + self.trace_indice = TraceIndice(self.node_mgr) + self.estimate_memory = EstimateMemory() + self._init_trace() + self.trace_flow = TraceFlow(self.trace_indice, self.node_mgr) + self.reorder_graph = ReorderGraph(self.trace_indice, self.node_mgr) + self.select_chunk = SelectChunk( + self.trace_indice, + self.estimate_memory, + self.reorder_graph, + self.node_mgr, + max_memory=max_memory, + ) + + def _init_trace(self) -> None: + """ + find the max trace range for every node + reduce the computation complexity of trace_indice + """ + # find all max ranges + active_nodes = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list())[2] + # set trace range and do the trace + if self.print_progress: + get_logger().info("AutoChunk start tracing indice") + self.trace_indice.set_active_nodes(active_nodes) + self.trace_indice.trace_indice() + + def _find_peak_region(self, mem_peak: List) -> int: + """ + find peak node, along with its neighbor nodes exceeds max mem + """ + max_value = max(mem_peak) + max_idx = mem_peak.index(max_value) + peak_region = [max_idx, max_idx] + if self.max_memory is None: + return peak_region + + # to left + count = 0 + for i in range(max_idx - 1, -1, -1): + if mem_peak[i] > self.max_memory: + peak_region[0] = i + else: + count += 1 + if count >= 3: + break + # to right + count = 0 + for i in range(max_idx + 1, len(mem_peak) - 1): + if mem_peak[i] > self.max_memory: + peak_region[1] = i + count = 0 + else: + count += 1 + if count >= 3: + break + + return peak_region + + def _search_max_chunk_region(self, active_node: List, peak_region: int, chunk_regions: List = None) -> Tuple: + """ + Search max chunk region according to peak memory node + + Chunk region starts extending from the peak node, stops where free var num is min + + Args: + active_node (List): active node status for every node + peak_node_idx (int): peak memory node idx + chunk_regions (List): chunk region infos + + Returns: + chunk_region_start (int) + chunk_region_end (int) + """ + # check if peak node already in chunk info + if chunk_regions is not None: + for i in chunk_regions: + if i["region"][0] < peak_region[0] <= i["region"][1] or \ + i["region"][0] < peak_region[1] <= i["region"][1]: + return None + + active_node_num = [len(i) for i in active_node] + window_size = 100 + # search min for start + min_num = 1e4 + for i in range(peak_region[0], max(peak_region[0] - window_size, -1), -1): + if active_node_num[i] < min_num: + min_num = active_node_num[i] + chunk_region_start = i + # search min for end + min_num = 1e4 + for i in range(peak_region[1], min(peak_region[1] + window_size, len(active_node_num))): + if active_node_num[i] < min_num: + min_num = active_node_num[i] + chunk_region_end = i + + # avoid chunk regions overlap + if chunk_regions is not None: + for i in chunk_regions: + region = i["region"] + if chunk_region_start >= region[0] and chunk_region_end <= region[1]: + return None + elif (region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]): + chunk_region_start = region[1] + 1 + elif (region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]): + chunk_region_end = region[0] - 1 + return chunk_region_start, chunk_region_end + + def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List: + """ + Find chunk info for a region. + + We are given the region start and region end, and need to find out all chunk info for it. + We first loop every dim of start node and end node, to see if we can find dim pair, + which is linked in a flow and not computed. + If found, we then search flow in the whole region to find out all chunk infos. + + Args: + input_trace (List): node's input trace in region + output_trace (List): node's output trace in region + start_idx (int): region start node index + end_idx (int): region end node index + + Returns: + chunk_infos: possible regions found + """ + start_traces = input_trace[start_idx] + if len(start_traces) > 1: # TODO need to be removed + return [] + end_trace = output_trace[end_idx] + end_node = self.node_mgr.get_node_by_idx(end_idx) + + chunk_infos = [] + for end_dim, _ in enumerate(end_trace["indice"]): + for start_node, start_trace in start_traces.items(): + for start_dim, _ in enumerate(start_trace["indice"]): + if not self.trace_flow.check_region_start_end(start_node, start_dim, start_idx, end_node, end_dim, + end_idx): + continue + # flow search + chunk_info = self.trace_flow.flow_search(start_idx, start_dim, end_idx, end_dim) + if chunk_info is None: + continue + chunk_infos.append(chunk_info) + return chunk_infos + + def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_region: Node) -> List: + """ + Search every possible region within the max chunk region. + + Args: + max_chunk_region (Tuple) + peak_node (Node): peak memory node + + Returns: + possible_chunk_region (List) + """ + possible_chunk_region = [] + output_trace = copy.deepcopy(self.trace_indice.indice_trace_list) + input_trace = [] # trace of a node's input nodes + for _, n in enumerate(self.node_mgr.get_node_list()): + cur_trace = {} + for arg in n.args: + if type(arg) == type(n) and not is_non_compute_node_except_placeholder(arg): + cur_trace[arg] = self.trace_indice._find_trace_from_node(arg) + input_trace.append(cur_trace) + + for start_idx in range(max_chunk_region[0], peak_region[0] + 1): + for end_idx in range(peak_region[1], max_chunk_region[1] + 1): + # skip non compute nodes + if is_non_compute_node(self.node_mgr.get_node_by_idx(start_idx)) or is_non_compute_node( + self.node_mgr.get_node_by_idx(end_idx)): + continue + # select free dim + chunk_info = self._find_chunk_info(input_trace, output_trace, start_idx, end_idx) + if len(chunk_info) > 0: + possible_chunk_region.extend(chunk_info) + return possible_chunk_region + + def _step_search( + self, + mem_peak: List[float], + active_node: List[List[Node]], + chunk_infos: List[Dict], + ) -> Dict: + """ + Find one chunk region + + The chunk search is as follows: + 1. find the peak memory node + 2. find the max chunk region according to the peak memory node + 3. find all possible chunk regions in the max chunk region + 4. find the best chunk region for current status + + Args: + mem_peak (List): peak memory for every node + active_node (List[List[Node]]): active node for every node + chunk_infos (List[Dict]): all chunk info + + Returns: + best_chunk_region (Dict) + """ + peak_region = self._find_peak_region(mem_peak) + max_chunk_region = self._search_max_chunk_region(active_node, peak_region, chunk_infos) + if max_chunk_region == None: + return None + possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_region) + best_chunk_region = self.select_chunk._select_best_chunk_region(possible_chunk_regions, chunk_infos, mem_peak) + best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region) + return best_chunk_region + + def search_region(self) -> Dict: + """ + Search all chunk regions: + 1. Estimate current memory + 2. Find best chunk for current memory + 3. goto 1 + + Returns: + chunk_infos (Dict) + """ + if self.print_progress: + get_logger().info("AutoChunk start searching chunk regions") + + chunk_infos = [] + init_mem_peak, _, active_node = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list()) + mem_peak = init_mem_peak + + while True: + chunk_info = self._step_search(mem_peak, active_node, chunk_infos) + if chunk_info is None: + break + chunk_infos.append(chunk_info) + + mem_peak, _, active_node = self.estimate_memory.estimate_chunk_inference_mem( + self.node_mgr.get_node_list(), chunk_infos) + + if self.print_progress: + get_logger().info("AutoChunk find chunk region %d = (%d, %d)" % + (len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1])) + + if self.print_mem: + self.print_mem = False + self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list(), + chunk_infos, + print_mem=True) + return chunk_infos diff --git a/colossalai/autochunk/select_chunk.py b/colossalai/autochunk/select_chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..94a29bfd56911eb9df749f284ae64fbd1b5d7a18 --- /dev/null +++ b/colossalai/autochunk/select_chunk.py @@ -0,0 +1,181 @@ +from .estimate_memory import EstimateMemory +from .reorder_graph import ReorderGraph +from .trace_indice import TraceIndice +from .utils import NodeMgr, is_non_compute_node + + +class SelectChunk(object): + + def __init__( + self, + trace_indice: TraceIndice, + estimate_memory: EstimateMemory, + reorder_graph: ReorderGraph, + node_mgr: NodeMgr, + max_memory=None, + ): + self.trace_indice = trace_indice + self.estimate_memory = estimate_memory + self.reorder_graph = reorder_graph + self.node_mgr = node_mgr + if max_memory is not None: + self.stratge = "fit_memory" + self.max_memory = max_memory # MB + else: + self.stratge = "min_memory" + + def _select_best_chunk_region(self, possible_chunk_regions, chunk_infos, mem_peak): + if self.stratge == "min_memory": + best_region = self._select_min_memory_chunk_region(possible_chunk_regions, chunk_infos) + elif self.stratge == "fit_memory": + best_region = self._select_fit_memory_chunk_region(possible_chunk_regions, chunk_infos, mem_peak) + else: + raise RuntimeError() + return best_region + + def _select_fit_memory_chunk_region(self, possible_chunk_regions, chunk_infos, mem_peak): + # stop chunk if max memory satisfy memory limit + if max(mem_peak) < self.max_memory: + return None + + # remove illegal regions + illegal_regions = [] + for i in possible_chunk_regions: + if not self._is_legal_region(i, chunk_infos): + illegal_regions.append(i) + for i in illegal_regions: + if i in possible_chunk_regions: + possible_chunk_regions.remove(i) + + if len(possible_chunk_regions) == 0: + return None + + # get mem for chunk region + regions_dict = [] + for region in possible_chunk_regions: + cur_region = region.copy() + cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region) + cur_chunk_infos = chunk_infos + [cur_region] + cur_mem = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0] + cur_chunk_region_peak = cur_mem[cur_region["region"][0]:cur_region["region"][1] + 1] + cur_chunk_region_max_peak = max(cur_chunk_region_peak) + if cur_chunk_region_max_peak < self.max_memory: + regions_dict.append({ + "chunk_info": region, + "chunk_max_mem": cur_chunk_region_max_peak, + "chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]), + "reorder_chunk_info": cur_region, + "reorder_node_list": cur_node_list, + }) + # no region found + if len(regions_dict) == 0: + raise RuntimeError("Search failed. Try a larger memory threshold.") + + # select the min chunk len + chunk_len = [i["chunk_len"] for i in regions_dict] + best_region_idx = chunk_len.index(min(chunk_len)) + best_region = regions_dict[best_region_idx] + + # get max chunk size + best_region = self._get_fit_chunk_size(best_region, chunk_infos) + return best_region + + def _get_fit_chunk_size(self, chunk_region_dict, chunk_infos): + chunk_size = 1 + reorder_chunk_info = chunk_region_dict["reorder_chunk_info"] + reorder_chunk_info["chunk_size"] = chunk_size + cur_chunk_max_mem = 0 + # search a region + while cur_chunk_max_mem < self.max_memory: + chunk_size *= 2 + reorder_chunk_info["chunk_size"] = chunk_size + cur_chunk_infos = chunk_infos + [reorder_chunk_info] + cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(chunk_region_dict["reorder_node_list"], + cur_chunk_infos)[0] + cur_chunk_max_mem = max(cur_mem_peak[reorder_chunk_info["region"][0]:reorder_chunk_info["region"][1] + 1]) + # search exact size + chunk_info = chunk_region_dict["chunk_info"] + chunk_info["chunk_size"] = self._chunk_size_binary_search(chunk_size // 2, chunk_size, chunk_region_dict, + chunk_infos) + return chunk_info + + def _chunk_size_binary_search(self, left, right, chunk_region_dict, chunk_infos): + if left >= 16: + gap = 4 + else: + gap = 1 + chunk_info = chunk_region_dict["reorder_chunk_info"] + while right >= left + gap: + mid = int((left + right) / 2 + 0.5) + chunk_info["chunk_size"] = mid + cur_chunk_infos = chunk_infos + [chunk_info] + cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(chunk_region_dict["reorder_node_list"], + cur_chunk_infos)[0] + cur_chunk_max_mem = max(cur_mem_peak[chunk_info["region"][0]:chunk_info["region"][1] + 1]) + if cur_chunk_max_mem >= self.max_memory: + right = mid - gap + else: + left = mid + gap + return left + + def _get_compute_node_num(self, start, end): + count = 0 + for i in self.node_mgr.get_node_slice_by_idx(start, end + 1): + if not is_non_compute_node(i): + count += 1 + return count + + def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos): + # remove illegal regions + illegal_regions = [] + for i in possible_chunk_regions: + if not self._is_legal_region(i, chunk_infos): + illegal_regions.append(i) + for i in illegal_regions: + if i in possible_chunk_regions: + possible_chunk_regions.remove(i) + + if len(possible_chunk_regions) == 0: + return None + + # get max possible chunk region + max_possible_chunk_region = (min([i["region"][0] for i in possible_chunk_regions]), + max([i["region"][1] for i in possible_chunk_regions])) + + # get mem for chunk region + regions_dict_list = [] + for region in possible_chunk_regions: + cur_region = region.copy() + cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region) + cur_chunk_infos = chunk_infos + [cur_region] + cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0] + cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1] + cur_chunk_region_max_peak = max(cur_chunk_region_peak) + regions_dict_list.append({ + "chunk_info": region, + "chunk_max_mem": cur_chunk_region_max_peak, + "chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]), + "reorder_chunk_info": cur_region, + "reorder_node_list": cur_node_list, + }) + + # select the min mem + chunk_max_mem = [i["chunk_max_mem"] for i in regions_dict_list] + best_region_idx = chunk_max_mem.index(min(chunk_max_mem)) + best_region = regions_dict_list[best_region_idx]["chunk_info"] + if best_region is not None: + best_region["chunk_size"] = 1 + return best_region + + def _is_legal_region(self, cur_chunk_info, chunk_infos): + (chunk_region_start, chunk_region_end) = cur_chunk_info["region"] + if cur_chunk_info in chunk_infos: + return False + if chunk_region_end < chunk_region_start: + return False + for i in chunk_infos: + region = i["region"] + if not ((chunk_region_start > region[1] and chunk_region_end > region[1]) or + (chunk_region_start < region[0] and chunk_region_end < region[0])): + return False + return True diff --git a/colossalai/autochunk/trace_flow.py b/colossalai/autochunk/trace_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..db25267e9b4200414c3cb89afaacd529a1534e06 --- /dev/null +++ b/colossalai/autochunk/trace_flow.py @@ -0,0 +1,485 @@ +from typing import Dict, List, Tuple + +from torch.fx.node import Node + +from .trace_indice import TraceIndice +from .utils import ( + NodeMgr, + find_chunk_all_input_nodes, + find_chunk_compute_input_and_output_nodes, + find_tensor_shape_node, + flat_list, + get_node_name, + get_node_shape, + is_non_compute_node, +) + + +class TraceFlow(object): + + def __init__(self, trace_indice: TraceIndice, node_mgr: NodeMgr) -> None: + self.trace_indice = trace_indice + self.node_mgr = node_mgr + + def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node): + """ + Check 2 given index: one index should be source of the other + Args: + start_idx(int): start node chunk dim + start_node(node): start node + end_idx(int): end node chunk dim + end_node(node): end node + + Returns: + bool: True if check pass + """ + # we use start_node_idx instead of real chunk index + start_node_idx = self.node_mgr.find_node_idx(start_node) + end_node_trace = self.trace_indice._find_trace_from_node(end_node) + end_node_trace_source = end_node_trace["source"][end_dim] + sorted_source = sorted(end_node_trace_source.items(), key=lambda d: d[0], reverse=True) + for node_idx, node_dim in sorted_source: + if node_idx == start_node_idx and start_dim in node_dim: + return True + # it means we meet a node outside the loop, and the node is not input node + if node_idx < start_node_idx: + return False + return False + + def check_index_compute(self, start_idx, end_dim, end_node, end_idx): + """ + Check 2 given index: check they haven't been computed in the source trace. + Args: + start_idx(int): start node chunk dim + start_node(node): start node + end_idx(int): end node chunk dim + end_node(node): end node + + Returns: + bool: True if check pass + """ + end_node_trace = self.trace_indice._find_trace_from_node(end_node) + end_node_compute = end_node_trace["compute"][end_dim] + if any(start_idx <= i <= end_idx for i in end_node_compute): + return False + return True + + def _assgin_single_node_flow( + self, + arg_node: Node, + start_idx: int, + end_idx: int, + cur_node: Node, + cur_node_dim: int, + cur_node_compute: Dict, + cur_node_source: Dict, + cur_node_fix_dim: List, + all_node_info: Dict, + next_node_list: List, + ) -> bool: + """ + Given the current node and one of its arg node, + this function finds out arg node's chunk dim and fix dim + + Args: + arg_node (Node): input node + start_idx (int): chunk region start + end_idx (int): chunk region end + cur_node_dim (int): current node chunk dim + cur_node_compute (Dict): current node compute dict + cur_node_source (Dict): current node source dict + cur_node_fix_dim (List): current node fix dim + all_node_info (Dict): all node chunk info in the chunk region + next_node_list (List) + + Returns: + bool: True if this node can be added to the flow, vice versa. + """ + arg_idx = self.node_mgr.find_node_idx(arg_node) + # arg in chunk range or be inputs + if not (start_idx <= arg_idx < end_idx): + return True + + # get fix dim + arg_fix_dim = [] + if cur_node_dim is not None: + for i in cur_node_fix_dim: + fix_dim_source = cur_node_source[i] + if arg_idx in fix_dim_source: + arg_fix_dim.append(fix_dim_source[arg_idx][0]) + if arg_node in all_node_info: + arg_fix_dim = list(set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim)) + + # find arg dim + if cur_node_dim is not None: + # dim is computed + if arg_idx in cur_node_compute[cur_node_dim]: + return False + if arg_idx not in cur_node_source[cur_node_dim]: + arg_dim = None + else: + arg_dim = cur_node_source[cur_node_dim][arg_idx][0] + # chunk dim cannot be in fix dims + if arg_dim in arg_fix_dim: + return False + # chunk dim should be None if shape size is 1 + if get_node_shape(arg_node)[arg_dim] == 1: + arg_dim = None + # chunk shape should equal cur node + elif get_node_shape(arg_node)[arg_dim] != 1: + if cur_node_dim is not None and get_node_shape(cur_node)[cur_node_dim] != 1: + if get_node_shape(arg_node)[arg_dim] != get_node_shape(cur_node)[cur_node_dim]: + return False + else: + arg_dim = None + + # add arg rest dim as fix dim + arg_fix_dim = list(range(len(get_node_shape(arg_node)))) + if arg_dim is not None: + arg_fix_dim.remove(arg_dim) + + # if already in node_info, arg dim must be same + if arg_node in all_node_info: + if all_node_info[arg_node]["chunk_dim"] != arg_dim: + return False + all_node_info[arg_node]["fix_dim"] = arg_fix_dim + # else add it to list + else: + all_node_info[arg_node] = {"chunk_dim": arg_dim, "fix_dim": arg_fix_dim} + + next_node_list.append(arg_node) + return True + + def _get_all_node_info(self, end_dim, start_idx, end_idx): + cur_node_list = [self.node_mgr.get_node_by_idx(end_idx)] # start from the last node + all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}} + + while len(cur_node_list) > 0: + next_node_list = [] + + for cur_node in cur_node_list: + # get cur node info + cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"] + cur_node_fix_dim = all_node_info[cur_node]["fix_dim"] + if cur_node_chunk_dim is not None: + cur_node_compute = self.trace_indice._find_compute_trace_from_node(cur_node) + cur_node_source = self.trace_indice._find_source_trace_from_node(cur_node) + else: + cur_node_compute = cur_node_source = None + + # get all valid args + arg_list = [] + for arg in cur_node.all_input_nodes: + if type(arg) != type(cur_node): + continue + if is_non_compute_node(arg): + continue + if get_node_shape(arg) is None: + continue + arg_list.append(arg) + flow_flag = self._assgin_single_node_flow( + arg, + start_idx, + end_idx, + cur_node, + cur_node_chunk_dim, + cur_node_compute, + cur_node_source, + cur_node_fix_dim, + all_node_info, + next_node_list, + ) + if flow_flag == False: + return None + + cur_node_list = next_node_list + return all_node_info + + def _get_input_nodes_dim(self, inputs: List[Node], start_idx: int, end_idx: int, all_node_info: Dict) -> Tuple: + """ + Get chunk dim for every input node for their every entry, remove unchunked nodes + + Args: + inputs (List[Node]): input nodes + all_node_info (Dict): describe all node's chunk dim and fix dim + start_idx (int): chunk start idx + end_idx (int): chunk end idx + + Returns: + inputs (List(Node)): new inputs + inputs_dim (List): chunk dim for inputs + """ + inputs_dim = [] + remove_inputs = [] + for input_node in inputs: + input_dict = {} + input_node_idx = self.node_mgr.find_node_idx(input_node) + for user in input_node.users.keys(): + # skip non compute + if is_non_compute_node(user): + continue + # untraced node, mostly non compute + if user not in all_node_info: + continue + user_idx = self.node_mgr.find_node_idx(user) + if start_idx <= user_idx <= end_idx: + chunk_dim = all_node_info[user]["chunk_dim"] + if chunk_dim is not None: + user_source = self.trace_indice._find_source_trace_from_node(user)[chunk_dim] + if input_node_idx in user_source: + if get_node_shape(input_node)[user_source[input_node_idx][0]] == 1: + input_dict[user_idx] = [None] + else: + input_dict[user_idx] = user_source[input_node_idx] + else: + return None, None + if len(input_dict) == 0: + remove_inputs.append(input_node) + else: + inputs_dim.append(input_dict) + # remove unchunked inputs + for i in remove_inputs: + if i in inputs: + inputs.remove(i) + return inputs, inputs_dim + + def _get_prepose_nodes(self, all_node_info: Dict, start_idx: int, end_idx: int, chunk_info) -> List[Node]: + """ + get all useless nodes in chunk region and prepose them + + Args: + all_node_info (Dict): describe all node's chunk dim and fix dim + start_idx (int): chunk start idx + end_idx (int): chunk end idx + + Returns: + List[Node]: all nodes to be preposed + """ + # get all possible prepose nodes + maybe_prepose_nodes = [] + for node, node_info in all_node_info.items(): + if node_info["chunk_dim"] is None: + maybe_prepose_nodes.append(node) + for node in self.node_mgr.get_node_slice_by_idx(start_idx, end_idx): + if node not in all_node_info and node not in chunk_info["outputs"]: + maybe_prepose_nodes.append(node) + maybe_prepose_nodes.sort( + key=lambda x: self.node_mgr.find_node_idx(x), + reverse=True, + ) # from last node to first node + prepose_nodes = [] + # set every node as root, search its args, if all legal, turn root and args as prepose nodes + while len(maybe_prepose_nodes) > 0: + tmp_cur_prepose_nodes = [maybe_prepose_nodes[0]] + tmp_cur_related_prepose_nodes = [] + prepose_flag = True + + # loop cur node's all arg until out of chunk + while len(tmp_cur_prepose_nodes) > 0: + if prepose_flag == False: + break + tmp_next_prepose_nodes = [] + tmp_cur_related_prepose_nodes.extend(tmp_cur_prepose_nodes) + for cur_prepose_node in tmp_cur_prepose_nodes: + if prepose_flag == False: + break + for cur_prepose_node_arg in cur_prepose_node.all_input_nodes: + if type(cur_prepose_node_arg) != type(cur_prepose_node): + continue + # out of loop + if not (start_idx <= self.node_mgr.find_node_idx(cur_prepose_node_arg) < end_idx): + continue + # compute op in loop + elif cur_prepose_node_arg in all_node_info: + if all_node_info[cur_prepose_node_arg]["chunk_dim"] is None: + tmp_next_prepose_nodes.append(cur_prepose_node_arg) + else: + prepose_flag = False + break + # non compute op + else: + tmp_next_prepose_nodes.append(cur_prepose_node_arg) + tmp_cur_prepose_nodes = tmp_next_prepose_nodes + + if prepose_flag == False: + maybe_prepose_nodes.remove(maybe_prepose_nodes[0]) + continue + else: + for n in tmp_cur_related_prepose_nodes: + if n not in prepose_nodes: + prepose_nodes.append(n) + if n in maybe_prepose_nodes: + maybe_prepose_nodes.remove(n) + # sort by index + prepose_nodes.sort(key=lambda x: self.node_mgr.find_node_idx(x)) + chunk_info["args"]["prepose_nodes"] = prepose_nodes + + def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx): + # we need to log input nodes to avoid deleteing them in the loop + chunk_node_list = self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1) + # also need to get some prepose node's arg out of non_chunk_inputs + for n in chunk_info["args"]["prepose_nodes"]: + chunk_node_list.remove(n) + non_chunk_inputs = find_chunk_all_input_nodes(chunk_node_list) + for i in non_chunk_inputs: + if i not in chunk_info["inputs"]: + chunk_info["inputs_non_chunk"].append(i) + return chunk_info + + def flow_search(self, start_idx, start_dim, end_idx, end_dim): + inputs, outputs = find_chunk_compute_input_and_output_nodes( + self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1)) + + # get every node's chunk dim and fix dim + all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx) + if all_node_info is None: + return None + + chunk_info = { + "region": (start_idx, end_idx), + "inputs": [], + "inputs_non_chunk": [], + "inputs_dim": [], + "outputs": [self.node_mgr.get_node_by_idx(end_idx)], + "outputs_non_tensor": {}, + "outputs_dim": [end_dim], + "node_chunk_dim": all_node_info, + "args": {}, + } + + # find chunk info for other outputs + if len(find_tensor_shape_node(outputs)) > 1: + chunk_info = self._get_other_output_info(outputs, start_idx, start_dim, end_idx, end_dim, chunk_info) + if chunk_info is None: + return None + + # get input nodes' chunk dim + inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info) + if inputs is None: + return None + chunk_info["inputs"] = inputs + chunk_info["inputs_dim"] = inputs_dim + + # move useless nodes ahead of loop + self._get_prepose_nodes(all_node_info, start_idx, end_idx, chunk_info) + + # find non chunk inputs + chunk_info = self._get_non_chunk_inputs(chunk_info, start_idx, end_idx) + + # reassgin reshape size, some size may have changed due to chunk + chunk_info = self._reassgin_reshape_size(chunk_info) + + return chunk_info + + def _get_other_output_info(self, outputs: List[Node], start_idx: int, start_dim: int, end_idx: int, end_dim: int, + chunk_info: Dict): + start_node = self.node_mgr.get_node_by_idx(start_idx) + # loop all outputs + for output in outputs: + output_legal = False + output_idx = self.node_mgr.find_node_idx(output) + # skip the origin output + if output_idx == end_idx: + continue + # skip non tensor + if get_node_shape(output) is None: + # log shape tensor + if len(output.meta['fwd_out']) > 0 and isinstance(output.meta['fwd_out'][0], int): + chunk_info["outputs_non_tensor"][output] = str(output.meta['fwd_out']) + continue + # loop every dim of outputs, try to find a legal one + for output_dim in range(len(get_node_shape(output))): + if not self.check_region_start_end(start_node, start_dim, start_idx, output, output_dim, output_idx): + continue + new_all_node_info = self._get_all_node_info(output_dim, start_idx, output_idx) + if new_all_node_info is None: + continue + # check node info legal + if self._update_chunk_info(chunk_info, new_all_node_info, output, output_dim) == True: + output_legal = True + break + # not legal + if output_legal == False: + return None + return chunk_info + + def _update_chunk_info(self, chunk_info: Dict, new_all_node_info: Dict, output: Node, output_dim: int) -> bool: + """ + check if there is conflict between new node info and old chunk info. If not, update old chunk info + """ + # check if conflict + overlap_flag = False + for k, v in new_all_node_info.items(): + if k in chunk_info["node_chunk_dim"]: + overlap_flag = True + if chunk_info["node_chunk_dim"][k]["chunk_dim"] != v["chunk_dim"]: + return False + # if no overlap, we just consider them as prepose nodes, instead of new output + if overlap_flag == False: + return True + # update chunk info + for k, v in new_all_node_info.items(): + if k in chunk_info["node_chunk_dim"]: + chunk_info["node_chunk_dim"][k]["fix_dim"] = list( + set(chunk_info["node_chunk_dim"][k]["fix_dim"] + v["fix_dim"])) + else: + chunk_info["node_chunk_dim"][k] = v + chunk_info["outputs"].append(output) + chunk_info["outputs_dim"].append(output_dim) + return True + + def _reassgin_reshape_size(self, chunk_info): + """ + Some shape args in reshape may have changed due to chunk + reassgin those changed shape + """ + chunk_region = chunk_info["region"] + reshape_size = {} + chunk_shape = get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"][0]] + for node in self.node_mgr.get_node_slice_by_idx(chunk_region[0], chunk_region[1] + 1): + if any(i == get_node_name(node) for i in ["reshape", "view"]): + if node in chunk_info["args"]["prepose_nodes"]: + continue + if node.args[0] in chunk_info["inputs_non_chunk"]: + continue + reshape_args = flat_list(node.args[1:]) + if len(reshape_args) == 1 and get_node_shape(reshape_args[0]) is None and len( + reshape_args[0].meta['fwd_out']) > 1: + continue + chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"] + new_shape = "" + for reshape_arg_dim, reshape_arg in enumerate(reshape_args): + if reshape_arg_dim == chunk_dim: + new_shape += "min(chunk_size, %d - chunk_idx), " % chunk_shape + else: + if isinstance(reshape_arg, int): + new_shape += "%s, " % str(reshape_arg) + else: + new_shape += "%s, " % reshape_arg.name + new_shape = new_shape[:-2] + origin_shape = str(reshape_args)[1:-1] + reshape_size[node.name] = [origin_shape, new_shape] + chunk_info["reshape_size"] = reshape_size + return chunk_info + + def check_region_start_end(self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int, + end_idx: int) -> bool: + """ + check if region start and end is legal + """ + # dim cannot be None + if (get_node_shape(end_node) is None or get_node_shape(start_node) is None): + return False + # dim size cannot be 1 + if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1): + return False + # must have users + if len(end_node.users) == 0: + return False + # check index source align + if not self.check_index_source(start_dim, start_node, start_idx, end_dim, end_node): + return False + # check index compute + if not self.check_index_compute(start_idx, end_dim, end_node, end_idx): + return False + return True diff --git a/colossalai/autochunk/trace_indice.py b/colossalai/autochunk/trace_indice.py new file mode 100644 index 0000000000000000000000000000000000000000..c7fce4c8bee1f2886d63c9ca6bb35c332cb293d8 --- /dev/null +++ b/colossalai/autochunk/trace_indice.py @@ -0,0 +1,930 @@ +import copy +from typing import Dict, List, Tuple + +from torch.fx.node import Node + +from .utils import NodeMgr, find_first_tensor_arg, flat_list, get_module_node_name, get_node_name, get_node_shape + + +class TraceIndice(object): + """ + Trace all indice information for every node. + + Indice is a logical concept. Equal dims can been treated as one indice. + eg. dim(x1) = [a, b, c] + dim(x2) = [d, e, f] + and we have x3 = x1 * x2. + then a=d, b=e, c=f, due to the broadcast property, + dim(x1)=dim(x2)=dim(x3)=[a, b, c] + This class will record every node's dims' indice, compute and source. + + Attibutes: + node_list (List) + indice_trace_list (List): [{"indice": [...], "compute": [...], "source": [...]}, {...}] + indice_view_list (Dict): not used for now + indice_count (int): record indice number + + Args: + node_list (List) + """ + + def __init__(self, node_mgr: NodeMgr) -> None: + self.node_mgr = node_mgr + self.indice_trace_list = self._init_indice_trace_list() + self.indice_view_list = {} + self.indice_count = -1 + self.active_node_list = [] + + def _init_indice_trace_list(self) -> List: + indice_trace_list = [] + for n in self.node_mgr.get_node_list(): + if get_node_shape(n) != None: + cur_trace = { + "indice": [None for _ in range(len(get_node_shape(n)))], + "compute": [[] for _ in range(len(get_node_shape(n)))], + "source": [{} for _ in range(len(get_node_shape(n)))], + } + else: + cur_trace = {"indice": [], "compute": [], "source": []} + indice_trace_list.append(cur_trace) + return indice_trace_list + + def set_active_nodes(self, active_node_list: List) -> None: + self.active_node_list = active_node_list + + def _add_indice(self) -> int: + """ + Update the count and return it. To record the idx number. + + Returns: + indice_count: int + """ + self.indice_count += 1 + return self.indice_count + + def _del_dim(self, idx: int, dim_idx: int) -> None: + """ + delete a dim for indice, compute and source + """ + self.indice_trace_list[idx]["indice"].pop(dim_idx) + self.indice_trace_list[idx]["compute"].pop(dim_idx) + self.indice_trace_list[idx]["source"].pop(dim_idx) + + def _add_dim(self, node_idx: int, dim_idx: int) -> None: + """ + add a dim for indice, compute and source + """ + # need to remap if dim_idx < 0, e.g. -1 + if dim_idx < 0: + dim_idx = list(range(len(self.indice_trace_list[node_idx]["indice"]) + 1))[dim_idx] + self.indice_trace_list[node_idx]["indice"].insert(dim_idx, self._add_indice()) + self.indice_trace_list[node_idx]["compute"].insert(dim_idx, []) + self.indice_trace_list[node_idx]["source"].insert(dim_idx, {}) + + def _add_source( + self, + node_from: Node, + node_from_dim: int, + node_to: Node, + node_to_dim: int, + init=False, + ) -> None: + node_from_dim = self._transform_indice(node_from, node_from_dim) + node_from_trace_source = self._find_source_trace_from_node(node_from) + node_to_dim = self._transform_indice(node_to, node_to_dim) + node_to_trace_source = self._find_source_trace_from_node(node_to) + node_from_idx = self.node_mgr.find_node_idx(node_from) + if init: + node_to_trace_source[node_to_dim] = {} + # add dim to cur new source + if node_from_idx not in node_to_trace_source[node_to_dim]: + node_to_trace_source[node_to_dim][node_from_idx] = [node_from_dim] + else: + if node_from_dim not in node_to_trace_source[node_to_dim][node_from_idx]: + node_to_trace_source[node_to_dim][node_from_idx].append(node_from_dim) + # update inputs source + for node_idx, node_dim in node_from_trace_source[node_from_dim].items(): + if node_idx not in node_to_trace_source[node_to_dim]: + node_to_trace_source[node_to_dim][node_idx] = copy.deepcopy(node_dim) + else: + for d in node_dim: + if d not in node_to_trace_source[node_to_dim][node_idx]: + node_to_trace_source[node_to_dim][node_idx].append(d) + + def _transform_indice(self, node: Node, node_dim: int) -> int: + node_idx = self._find_indice_trace_from_node(node) + dims = list(range(len(node_idx))) + return dims[node_dim] + + def _inherit_indice( + self, + node_from: Node, + node_from_dim: int, + node_to: Node, + node_to_dim: int, + init: bool = True, + ) -> None: + """ + node_to's node_to_dim inherit node_from's node_from_dim by indice, compute and source + """ + node_from_dim = self._transform_indice(node_from, node_from_dim) + node_to_dim = self._transform_indice(node_to, node_to_dim) + node_from_trace = self._find_trace_from_node(node_from) + node_to_trace = self._find_trace_from_node(node_to) + if init: + node_to_trace["indice"][node_to_dim] = node_from_trace["indice"][node_from_dim] + node_to_trace["compute"][node_to_dim] = copy.deepcopy(node_from_trace["compute"][node_from_dim]) + else: + for j in node_from_trace["compute"][node_from_dim]: + if j not in node_to_trace["compute"][node_to_dim]: + node_to_trace["compute"][node_to_dim].append(j) + self._add_source(node_from, node_from_dim, node_to, node_to_dim, init) + + def _inherit_all_indice(self, node_from: Node, node_to: Node) -> None: + """ + inherit all dims with init + """ + # find indice just for assert length + node_from_indice = self._find_indice_trace_from_node(node_from) + node_to_indice = self._find_indice_trace_from_node(node_to) + assert len(node_from_indice) == len(node_to_indice) + for i in range(len(node_from_indice)): + self._inherit_indice(node_from, i, node_to, i, init=True) + + def _inherit_more_indice_from_node_with_exclude(self, node_from: Node, node_to: Node, exclude: List = None) -> None: + """ + inherit indice from node without init + """ + if exclude == None: + exclude = [] + else: + exclude = [self._transform_indice(node_to, i) for i in exclude] + node_from_compute = self._find_compute_trace_from_node(node_from) + node_to_compute = self._find_compute_trace_from_node(node_to) + # assert len(node_from_compute) == len(node_to_compute) + for i in range(-1, -min(len(node_from_compute), len(node_to_compute)) - 1, -1): + if self._transform_indice(node_to, i) in exclude: + continue + self._inherit_indice(node_from, i, node_to, i, init=False) + + def _mark_computation(self, node: Node, idx: int, dim: int) -> None: + """ + Mark some dims of node as computed. + + Args: + node (node) + idx (int): node index + dim (list or int): dims to be marked as computed + """ + if isinstance(dim, int): + dim = [dim] + dims = list(range(len(get_node_shape(node)))) + for d in dim: + cur_dim = dims[d] + if idx not in self.indice_trace_list[idx]["compute"][cur_dim]: + self.indice_trace_list[idx]["compute"][cur_dim].append(idx) + + def _find_trace_from_node(self, node: Node) -> Dict: + """ + Find node idx and compute trace by the node. + + Args: + node (node) + Returns: + idx (list): idx of the node + compute (list): computed idx of the node. + """ + node_idx = self.node_mgr.find_node_idx(node) + node_dict = self.indice_trace_list[node_idx] + return node_dict + + def _find_source_trace_from_node(self, node: Node) -> List: + """ + Find node source trace by the node. + + Args: + node (node) + Returns: + idx (list): idx of the node + compute (list): computed idx of the node. + """ + node_idx = self.node_mgr.find_node_idx(node) + node_dict = self.indice_trace_list[node_idx] + return node_dict["source"] + + def _find_indice_trace_from_node(self, node) -> List: + """ + Find node idx trace by the node. + + Args: + node (node) + Returns: + idx (list): idx of the node + """ + node_idx = self.node_mgr.find_node_idx(node) + return self.indice_trace_list[node_idx]["indice"] + + def _find_compute_trace_from_node(self, node: Node) -> List: + """ + Find node compute trace by the node. + + Args: + node (node) + Returns: + compute (list): computed idx of the node. + """ + node_idx = self.node_mgr.find_node_idx(node) + return self.indice_trace_list[node_idx]["compute"] + + def _assign_indice_as_input(self, node: Node, node_idx: int, input_node=None) -> None: + """ + Assign node's trace as its input node. + + Args: + node (node) + node_idx (int) + """ + if input_node == None: + input_node = find_first_tensor_arg(node) + self._inherit_all_indice(input_node, node) + + def _assign_all_indice(self, node: Node, node_idx: int) -> None: + """ + Add new indice for all node's dims. + + Args: + node (node) + node_idx (int) + """ + shape = node.meta["tensor_meta"].shape + if shape is None: + return + new_trace = [] + for _ in shape: + new_trace.append(self._add_indice()) + self.indice_trace_list[node_idx]["indice"] = new_trace + + def _assign_transpose_indice(self, node: Node, node_idx: int) -> None: + """ + Assign indice for transpose op. + 1. swap input's dim according to transpose args + 2. inherit input's computation + + Args: + node (node) + node_idx (int) + """ + input_node = node.args[0] + tranpose_dim = node.args[1:] + + self._assign_indice_as_input(node, node_idx, input_node) + self._inherit_indice(input_node, tranpose_dim[1], node, tranpose_dim[0]) + self._inherit_indice(input_node, tranpose_dim[0], node, tranpose_dim[1]) + + def _assign_permute_indice(self, node: Node, node_idx: int) -> None: + """ + Assign indice for permute op. + 1. swap input's dim according to permute args + 2. inherit input's computation + + Args: + node (node) + node_idx (int) + """ + permute_dim = flat_list(node.args[1:]) + input_node = node.args[0] + + self._assign_indice_as_input(node, node_idx, input_node) + for idx, d in enumerate(permute_dim): + self._inherit_indice(input_node, d, node, idx) + + def _assign_linear_indice(self, node: Node, node_idx: int) -> None: + """ + Assign indice for linear op. + 1. copy trace from input node and change last indice according to weight + 2. mark equal for input node last indice, weight first dim and bias dim. + 3. inherit input's computation, mark computation for last dim. + + Args: + node (node) + node_idx (int) + """ + self._assign_indice_as_input(node, node_idx) + + if len(node.args) >= 2: + weight = node.args[1] + self._inherit_indice(weight, 1, node, -1) + else: + self._del_dim(node_idx, -1) + self._add_dim(node_idx, -1) + self._mark_computation(node, node_idx, [-1]) + + def _assign_addmm_indice(self, node: Node, node_idx: int) -> None: + """ + Assign indice for addmm op. + + Args: + node (node) + node_idx (int) + """ + bias, input_node, weight = node.args + assert len(get_node_shape(bias)) == 1 and len(get_node_shape(weight)) == 2 + self._assign_indice_as_input(node, node_idx, input_node) + self._inherit_indice(weight, 1, node, -1) + self._inherit_more_indice_from_node_with_exclude(bias, node) + + self._mark_computation(node, node_idx, [-1]) + + def _assign_baddbmm_indice(self, node: Node, node_idx: int) -> None: + """ + Assign indice for baddbmm(batch add and batch matmul) op. + add, matmul_left, matmul_right = args + out = add + (matmul_left x matmul_right) + + Args: + node (node) + node_idx (int) + """ + add, matmul_left, matmul_right = node.args + + assert get_node_shape(add) == get_node_shape(node) + assert len(get_node_shape(matmul_left)) == len(get_node_shape(matmul_right)) + self._assign_indice_as_input(node, node_idx, matmul_left) + # matmul + self._inherit_indice(matmul_right, -1, node, -1) + self._inherit_more_indice_from_node_with_exclude(matmul_right, node, [-2, -1]) + self._mark_computation(node, node_idx, [-1]) + # add + self._inherit_more_indice_from_node_with_exclude(add, node) + + def _assign_matmul_indice(self, node: Node, node_idx: int) -> None: + """ + Assign indice for matmul op. + 1. copy trace from matmul_left and change last indice according to matmul_right. (assert they have same length) + 2. mark equal for input matmul_left -1 indice and matmul_right -2 dim. + 3. inherit matmul_left and matmul_right computation, mark computation for last dim. + + Args: + node (node) + node_idx (int) + """ + matmul_left, matmul_right = node.args + + assert len(get_node_shape(matmul_left)) == len(get_node_shape(matmul_right)) + self._assign_indice_as_input(node, node_idx, matmul_left) + + self._inherit_indice(matmul_right, -1, node, -1) + self._inherit_more_indice_from_node_with_exclude(matmul_right, node, [-1, -2]) + self._mark_computation(node, node_idx, [-1]) + + def _assign_conv2d_indice(self, node: Node, node_idx: int) -> None: + """ + Assign indice for conv2d op. + + Args: + node (node) + node_idx (int) + """ + # get conv module + node_targets = node.target.split(".") + conv_module = node.graph.owning_module + for i in node_targets: + conv_module = getattr(conv_module, i) + assert conv_module.dilation == (1, 1), "dilation for conv2d not implemented" + + # get conv input + assert len(node.args) == 1 + input_node = node.args[0] + assert len(get_node_shape(input_node)) == 4 + + # assgin index + self._assign_indice_as_input(node, node_idx, input_node) + self._del_dim(node_idx, 1) + self._add_dim(node_idx, 1) + self._mark_computation(node, node_idx, [1, 2, 3]) + + def _assign_interpolate_indice(self, node: Node, node_idx: int) -> None: + """ + Assign indice for interpolate op. + + Args: + node (node) + node_idx (int) + """ + # get conv input + assert node.kwargs['size'] is None + assert len(get_node_shape(node)) == 4 + + # assgin index + self._assign_indice_as_input(node, node_idx) + self._mark_computation(node, node_idx, [-1, -2]) + + def _assign_layernorm_indice(self, node, idx): + """ + Assign indice for layernorm op. + 1. assign indice as input node + 2. inherit computation and mark last 2 dims as computed. + + Args: + node (node) + node_idx (int) + """ + self._assign_indice_as_input(node, idx) + self._mark_computation(node, idx, [-1]) + + def _assign_groupnorm_indice(self, node, idx): + """ + Assign indice for groupnorm op. + + Args: + node (node) + node_idx (int) + """ + assert len(get_node_shape(node)) == 4 + self._assign_indice_as_input(node, idx) + self._mark_computation(node, idx, [-1, -2, -3]) + + def _assign_elementwise_indice(self, node, idx): + """ + Assign indice for element-wise op (eg. relu sigmoid add mul). + 1. assign indice as input node + 2. inherit computation from all input nodes. + + Args: + node (node) + node_idx (int) + """ + self._assign_indice_as_input(node, idx) + nodes_in = [] + for node_in in node.args: + if type(node_in) == type(node): + nodes_in.append(node_in) + self._inherit_more_indice_from_node_with_exclude(node_in, node) + + def _assgin_no_change_indice(self, node, idx): + self._assign_indice_as_input(node, idx) + for node_in in node.args: + if type(node_in) == type(node): + self._inherit_more_indice_from_node_with_exclude(node_in, node) + + def _assign_einsum_indice(self, node, idx): + """ + Assign indice for einsum op. + + Args: + node (node) + node_idx (int) + """ + patterns = node.args[0] + input_nodes = node.args[1:] + + patterns = patterns.replace(" ", "") + left, right = patterns.split("->") + left = left.split(",") + + if "..." in right: + replace_list = "!@#$%^&*" + target_len = len(get_node_shape(node)) + add_len = target_len - len(right) + 3 + replace_str = replace_list[:add_len] + right = right.replace("...", replace_str) + for ll in range(len(left)): + left[ll] = left[ll].replace("...", replace_str) + + all_index = [] + for i in left: + for c in i: + all_index.append(c) + all_index = set(all_index) + + for right_idx, right_indice in enumerate(right): + for left_idx, left_str in enumerate(left): + if right_indice in left_str: + source_idx = left_str.index(right_indice) + self._inherit_indice(input_nodes[left_idx], source_idx, node, right_idx) + + def _assign_softmax_indice(self, node, idx): + """ + Assign indice for softmax op. + 1. assign indice as input node + 2. inherit computation and mark softmax dim as computed. + + Args: + node (node) + node_idx (int) + """ + self._assign_indice_as_input(node, idx) + self._mark_computation(node, idx, [node.kwargs["dim"]]) + + def _assign_split_indice(self, node: Node, node_idx: int) -> None: + """ + Assign indice for split op. + + Args: + node (node) + node_idx (int) + """ + self._assign_indice_as_input(node, node_idx) + dim_idx = node.kwargs["dim"] + self._del_dim(node_idx, dim_idx) + self._add_dim(node_idx, dim_idx) + + def _assign_unsqueeze_indice(self, node: Node, node_idx: int) -> None: + """ + Assign indice for unsqueeze op. + 1. assign new indice for unsqueeze dim + + Args: + node (node) + node_idx (int) + """ + self._del_dim(node_idx, -1) + self._assign_indice_as_input(node, node_idx) + dim_idx = node.args[1] + # unsqueeze(-1) = unsqueeze(shape_num + 1) + if dim_idx < 0: + dim_idx = list(range(len(get_node_shape(node))))[dim_idx] + self._add_dim(node_idx, dim_idx) + + def _assign_cat_indice(self, node: Node, node_idx: int) -> None: + """ + Assign indice for cat op. + + Args: + node (node) + node_idx (int) + """ + nodes_in = flat_list(node.args[0]) + self._assign_indice_as_input(node, node_idx, input_node=nodes_in[0]) + for n in nodes_in[1:]: + self._inherit_more_indice_from_node_with_exclude(n, node) + cat_dim = node.kwargs["dim"] + self._del_dim(node_idx, cat_dim) + self._add_dim(node_idx, cat_dim) + + def _assign_sum_indice(self, node: Node, node_idx: int) -> None: + """ + Assign indice for sum op. + + Args: + node (node) + node_idx (int) + """ + nodes_in = flat_list(node.args[0]) + self._add_dim(node_idx, 0) + self._assign_indice_as_input(node, node_idx, input_node=nodes_in[0]) + for n in nodes_in[1:]: + self._inherit_more_indice_from_node_with_exclude(n, node) + cat_dim = node.kwargs["dim"] + self._del_dim(node_idx, cat_dim) + + def _assign_flatten_indice(self, node: Node, node_idx: int) -> None: + """ + Assign indice for flatten op. + + Args: + node (node) + node_idx (int) + """ + nodes_in = node.args[0] + nodes_in_shape = get_node_shape(nodes_in) + flatten_start_dim = node.args[1] + flatten_dim_num = len(nodes_in_shape) - flatten_start_dim - 1 + assert flatten_dim_num > 0 + for _ in range(flatten_dim_num): + self._add_dim(node_idx, 0) + self._assign_indice_as_input(node, node_idx, nodes_in) + for _ in range(flatten_dim_num + 1): + self._del_dim(node_idx, -1) + self._add_dim(node_idx, -1) + + def _assign_expand_indice(self, node: Node, node_idx: int) -> None: + """ + Assign indice for expand op. + + Args: + node (node) + node_idx (int) + """ + expand_shape = node.args[1:] + node_in_shape = get_node_shape(node.args[0]) + assert len(expand_shape) == len(node_in_shape) + self._assign_indice_as_input(node, node_idx) + for i in range(len(node_in_shape)): + if expand_shape[i] == node_in_shape[i] or expand_shape[i] == -1: + continue + elif expand_shape[i] > node_in_shape[i]: + self._del_dim(node_idx, i) + self._add_dim(node_idx, i) + else: + raise RuntimeError() + + def _assign_unbind_indice(self, node: Node, node_idx: int) -> None: + """ + Assign indice for unbind op. + + Args: + node (node) + node_idx (int) + """ + unbind_dim = node.args[1] + self._add_dim(node_idx, unbind_dim) + self._assign_indice_as_input(node, node_idx) + self._del_dim(node_idx, unbind_dim) + + def _assign_embedding_indice(self, node: Node, node_idx: int) -> None: + """ + Assign indice for embedding op. + + Args: + node (node) + node_idx (int) + """ + self._del_dim(node_idx, -1) + self._assign_indice_as_input(node, node_idx) + self._add_dim(node_idx, -1) + + def _assign_getitem_indice(self, node: Node, node_idx: int) -> None: + """ + Assign indice for getitem. + getitem can act like slice sometimes + + Args: + node (node) + node_idx (int) + """ + node_args = flat_list(node.args[1:]) + + # deal with split + if get_node_name(node.args[0]) == "split": + self._assign_indice_as_input(node, node_idx) + self._del_dim(node_idx, node.args[0].kwargs["dim"]) + self._add_dim(node_idx, node.args[0].kwargs["dim"]) + return + + # skip non tensor + if get_node_shape(node) is None: + return + + # find if slice + flag = False + for node_arg in node_args: + node_arg_str = str(node_arg) + if any(i == node_arg_str for i in ["None", "Ellipsis"]): + flag = True + break + if "slice" in node_arg_str: + flag = True + break + if flag == False: + return + + # node args should be like [Ellipsis, slice(start, step, end), None] + node_shape = get_node_shape(node) + origin_idx_count = 0 + new_idx_count = 0 + new_dim_num = sum([1 if str(i) == "None" else 0 for i in node_args]) + for _ in range(new_dim_num): + self._del_dim(node_idx, 0) + delete_dim_num = sum([1 if str(i) == "0" else 0 for i in node_args]) + for _ in range(delete_dim_num): + self._add_dim(node_idx, 0) + self._assign_indice_as_input(node, node_idx) + + for _, node_arg in enumerate(node_args): + node_arg_str = str(node_arg) + # Ellipsis means [..., ] + if "Ellipsis" == node_arg_str: + shape_gap = len(node_shape) - len(node_args) + 1 + origin_idx_count += shape_gap + new_idx_count += shape_gap + # slice(None, None, None) means all indexes + elif "slice" in node_arg_str: + if "slice(None, None, None)" != node_arg_str: + self._del_dim(node_idx, new_idx_count) + self._add_dim(node_idx, new_idx_count) + origin_idx_count += 1 + new_idx_count += 1 + # None means a new dim + elif "None" == node_arg_str: + self._add_dim(node_idx, new_idx_count) + new_idx_count += 1 + elif "0" == node_arg_str: + self._del_dim(node_idx, new_idx_count) + origin_idx_count += 1 + else: + raise NotImplementedError() + + def _assign_view_reshape_indice(self, node: Node, node_idx: int) -> None: + """ + Assign indice for view and reshape op. + 1. get origin shape and target shape by meta info. + 2. compute the real value of -1 in target shape. + 3. determine changed dim, and assign indice for generated dim. + 4. log changed dim and generated dim for restore + 5. inherit computation. + 6. look into view list to see whether the view is associated with other, + if so assign equal dim according to previous view. + + Args: + node (node) + node_idx (int) + """ + # get data, turn into number + origin_node = node.args[0] + origin_shape = origin_node.meta["tensor_meta"].shape + target_shape = [] + unflated_args = flat_list(node.args) + for i in range(1, len(unflated_args)): + if isinstance(unflated_args[i], int): + target_shape.append(unflated_args[i]) + else: + target_shape.extend(unflated_args[i].meta["fwd_out"]) + + # compute the value of -1 + if -1 in target_shape: + origin_product = 1 + for i in origin_shape: + origin_product *= i + target_product = -1 + for i in target_shape: + target_product *= i + shape_idx = target_shape.index(-1) + target_shape[shape_idx] = origin_product // target_product + + # find same dim + dim_to_same_dim = [] + dim_from_same_dim = [] + for i in range(len(origin_shape)): + if origin_shape[i] == target_shape[i]: + dim_to_same_dim.append(i) + dim_from_same_dim.append(i) + else: + break + for i in range(-1, -len(origin_shape), -1): + if origin_shape[i] == target_shape[i]: + dim_to_same_dim.append(len(target_shape) + i) + dim_from_same_dim.append(len(origin_shape) + i) + else: + break + + dim_from = list(set(range(len(origin_shape))) - set(dim_from_same_dim)) + dim_to = list(set(range(len(target_shape))) - set(dim_to_same_dim)) + assert len(dim_from) == 1 or len(dim_to) == 1 or len(dim_from) == len(dim_to) + + dim_diff = len(dim_from) - len(dim_to) + if dim_diff > 0: + # dim merge + for i in range(dim_diff): + self._add_dim(node_idx, -1) + elif dim_diff < 0: + # dim expand + for i in range(-dim_diff): + self._del_dim(node_idx, -1) + + # get new indice + origin_trace = self._find_indice_trace_from_node(origin_node) + self._assign_indice_as_input(node, node_idx, origin_node) + dim_from.reverse() + for i in dim_from: + self._del_dim(node_idx, i) + for i in dim_to: + self._add_dim(node_idx, i) + dim_from.reverse() + + # inheirt indice from current node + if len(dim_from) != 0 and len(dim_to) != 0: + if dim_diff == 1: + if origin_shape[dim_from[0]] == 1: + self._inherit_indice(origin_node, dim_from[1], node, dim_to[0], init=False) + elif origin_shape[dim_from[1]] == 1: + self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False) + elif dim_diff == -1: + if target_shape[dim_to[0]] == 1: + self._inherit_indice(origin_node, dim_from[0], node, dim_to[1], init=False) + elif target_shape[dim_to[1]] == 1: + self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False) + + # log view, not used now + view_dict = { + "idx_from": [origin_trace[i] for i in dim_from], + "dim_from": dim_from, + "idx_to": [self.indice_trace_list[node_idx]["indice"][i] for i in dim_to], + "dim_to": dim_to, + } + self.indice_view_list[node] = view_dict + + def _clear_trace(self, node_idx: int) -> None: + """ + clear too far trace to speed up computation + """ + trace_barrier = max(node_idx - 100, 0) + active_nodes = self.active_node_list[trace_barrier] + active_nodes = [self.node_mgr.find_node_idx(i) for i in active_nodes.keys()] + + trace = self.indice_trace_list[node_idx] + # clear compute + for dim_compute in trace["compute"]: + for i in range(len(dim_compute) - 1, -1, -1): + if (dim_compute[i] < trace_barrier and dim_compute[i] not in active_nodes): + dim_compute.pop(i) + continue + # clear source + for dim_source in trace["source"]: + for k in list(dim_source.keys()): + if k < trace_barrier and k not in active_nodes: + dim_source.pop(k) + + def trace_indice(self) -> None: + for idx, node in enumerate(self.node_mgr.get_node_list()): + node_name = get_node_name(node) + if node.op == "placeholder": + self._assign_all_indice(node, idx) + elif node.op == "call_method": + if "transpose" == node_name: + self._assign_transpose_indice(node, idx) + elif "permute" == node_name: + self._assign_permute_indice(node, idx) + elif "view" == node_name or "reshape" == node_name: + self._assign_view_reshape_indice(node, idx) + elif "unsqueeze" == node_name: + self._assign_unsqueeze_indice(node, idx) + elif "split" == node_name: + self._assign_split_indice(node, idx) + elif any(i == node_name for i in ["to", "contiguous", "clone", "type", "float"]): + self._assgin_no_change_indice(node, idx) + elif "new_ones" == node_name: + self._assign_all_indice(node, idx) + elif "flatten" == node_name: + self._assign_flatten_indice(node, idx) + elif "expand" == node_name: + self._assign_expand_indice(node, idx) + elif "unbind" == node_name: + self._assign_unbind_indice(node, idx) + elif "softmax" == node_name: + self._assign_softmax_indice(node, idx) + elif any(i == node_name for i in ["size"]): + continue + else: + raise NotImplementedError(node_name, "method not implemented yet!") + elif node.op == "call_function": + if "linear" == node_name: + self._assign_linear_indice(node, idx) + elif "cat" == node_name: + self._assign_cat_indice(node, idx) + elif any(n == node_name for n in ["matmul", "bmm"]): + self._assign_matmul_indice(node, idx) + elif "softmax" == node_name: + self._assign_softmax_indice(node, idx) + elif any(n == node_name for n in [ + "mul", "add", "sigmoid", "relu", "sub", "truediv", "pow", "dropout", "where", "tanh", "exp", + "sin", "cos" + ]): + self._assign_elementwise_indice(node, idx) + elif "einsum" == node_name: + self._assign_einsum_indice(node, idx) + elif "sum" == node_name: + self._assign_sum_indice(node, idx) + elif "layer_norm" == node_name: + self._assign_layernorm_indice(node, idx) + elif "getitem" == node_name: + self._assign_getitem_indice(node, idx) + elif "addmm" == node_name: + self._assign_addmm_indice(node, idx) + elif "baddbmm" == node_name: + self._assign_baddbmm_indice(node, idx) + elif "interpolate" == node_name: + self._assign_interpolate_indice(node, idx) + elif any(i == node_name for i in ["arange", "ones", "ones_like", "tensor", "empty"]): + self._assign_all_indice(node, idx) + elif any(i == node_name for i in ["getattr", "eq", "_assert_is_none", "_assert", "finfo"]): + continue + else: + raise NotImplementedError(node_name, "function not implemented yet!") + elif node.op == "call_module": + node_name = get_module_node_name(node) + if "layernorm" == node_name: + self._assign_layernorm_indice(node, idx) + elif "groupnorm" == node_name: + self._assign_groupnorm_indice(node, idx) + elif "embedding" == node_name: + self._assign_embedding_indice(node, idx) + elif "linear" == node_name: + self._assign_linear_indice(node, idx) + elif "conv2d" == node_name: + self._assign_conv2d_indice(node, idx) + elif "identity" == node_name: + self._assgin_no_change_indice(node, idx) + elif any(n == node_name for n in ["sigmoid", "dropout", "relu", "silu", "gelu"]): + self._assign_elementwise_indice(node, idx) + else: + raise NotImplementedError(node_name, "module not implemented yet!") + elif node.op == "get_attr": + self._assign_all_indice(node, idx) # get param + elif node.op == "output": + continue + else: + raise NotImplementedError(node.op, "op not implemented yet!") + + # limit trace range + self._clear_trace(idx) diff --git a/colossalai/autochunk/utils.py b/colossalai/autochunk/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..064baa047155ac399b14ae0fcdb044db1125d70b --- /dev/null +++ b/colossalai/autochunk/utils.py @@ -0,0 +1,244 @@ +from typing import Any, Callable, Dict, Iterable, List, Tuple, Union + +from torch.fx.node import Node + +from colossalai.logging import get_dist_logger + +NON_COMPUTE_OP = ["placeholder", "get_attr", "output"] +NON_COMPUTE_NAME = ["getattr", "eq", "_assert_is_none", "_assert", "finfo", "size"] +logger = get_dist_logger() + + +class NodeMgr(object): + + def __init__(self, nodes_list: List[Node]) -> None: + self._node_list = nodes_list + self._node_dict = {} + self._set_node_dict() + + def _set_node_dict(self) -> None: + """ + create a dict {node_name: node_idx} + """ + self._node_dict.clear() + for idx, node in enumerate(self._node_list): + self._node_dict[node.name] = idx + + def find_node_idx(self, node: Node) -> int: + """ + find node's index + """ + return self._node_dict[node.name] + + def find_node_idx_by_name(self, node_name: str) -> int: + """ + find node's index + """ + return self._node_dict[node_name] + + def get_node_by_idx(self, idx: int) -> Node: + """ + get a node by index + """ + return self._node_list[idx] + + def get_node_slice_by_idx(self, start: int, end: int) -> List[Node]: + """ + get a slice of node by index + """ + return self._node_list[start:end] + + def get_node_list(self) -> List: + """ + get full node list + """ + return self._node_list + + def update_node_list(self, node_list: List) -> None: + """ + update node list, reset node dict + """ + self._node_list = node_list + self._set_node_dict() + + +def get_logger() -> Any: + return logger + + +def flat_list(inputs: Any) -> List: + """ + flat a list by recursion + """ + if not (isinstance(inputs, list) or isinstance(inputs, set) or isinstance(inputs, tuple)): + return [inputs] + res = [] + for i in inputs: + if isinstance(i, list) or isinstance(i, set) or isinstance(i, tuple): + res.extend(flat_list(i)) + elif isinstance(i, dict): + res.extend(flat_list(list(i.keys()))) + else: + res.append(i) + return res + + +def find_first_tensor_arg(node: Node) -> Node: + """ + Find the first input tensor arg for a node + """ + for arg in node.args: + if type(arg) == type(node): + return arg + raise RuntimeError() + + +def is_non_compute_node(node: Node) -> bool: + if any(i == node.op for i in NON_COMPUTE_OP) or any(i == get_node_name(node) for i in NON_COMPUTE_NAME): + return True + if "getitem" in node.name: + if get_node_shape(node) is not None: + return False + node_args = flat_list(node.args[1:]) + for node_arg in node_args: + if any(i == str(node_arg) for i in ["None", "Ellipsis"]): + return False + if "slice" in str(node_arg): + return False + return True + return False + + +def get_node_shape(node: Node) -> Any: + """ + return node data shape + """ + if get_node_name(node) in ["split", "unbind"]: + return node.meta["tensor_meta"][0].shape + if hasattr(node.meta["tensor_meta"], "shape"): + return node.meta["tensor_meta"].shape + return None + + +def is_non_memory_node(node: Node) -> bool: + if "getitem" in node.name: + return True + if "output" in node.op: + return True + return is_non_compute_node(node) + + +def is_non_compute_node_except_placeholder(node: Node) -> bool: + if "placeholder" in node.op: + return False + return is_non_compute_node(node) + + +def is_non_compute_node_except_placeholder_output(node: Node) -> bool: + if "output" in node.op: + return False + return is_non_compute_node_except_placeholder(node) + + +def delete_free_var_from_last_use(user_to_last_uses: Dict) -> None: + for key, value in user_to_last_uses.items(): + for n in value: + if n.op == "placeholder": + user_to_last_uses[key].remove(n) + + +def find_chunk_all_input_nodes(nodes: List[Node]) -> List: + """ + Find non-compute input and output node names. + input nodes are nodes used in the list + output nodes are nodes will use nodes in the list + """ + input_nodes = [] + for node in nodes: + for input_node in node._input_nodes.keys(): + if input_node not in nodes and input_node not in input_nodes: + input_nodes.append(input_node) + return input_nodes + + +def find_chunk_compute_input_and_output_nodes(nodes: List[Node]) -> Union[List, List]: + """ + Find non-compute input and output node names. + input nodes are nodes used in the list + output nodes are nodes will use nodes in the list + """ + input_nodes = [] + output_nodes = [] + + # if a node has an input node which is not in the node list + # we treat that input node as the input of the checkpoint function + for node in nodes: + for input_node in node._input_nodes.keys(): + if (input_node not in nodes and input_node not in input_nodes + and not is_non_compute_node_except_placeholder(input_node)): + input_nodes.append(input_node) + + # if a node has a user node which is not in the node list + # we treat that user node as the node receiving the current node output + for node in nodes: + for output_node in node.users.keys(): + if (output_node not in nodes and node not in output_nodes + and not is_non_compute_node_except_placeholder_output(output_node)): + output_nodes.append(node) + + return input_nodes, output_nodes + + +def get_module_node_name(node: Node) -> str: + """ + get module class name + """ + node_targets = node.target.split(".") + module = node.graph.owning_module + for i in node_targets: + module = getattr(module, i) + module_name = str(module.__class__).split(".")[-1][:-2] + module_name = module_name.lower() + return module_name + + +def get_node_name(node: Node) -> str: + """ + get node name + """ + node_name = node.name + if "_" in node_name: + for i in range(len(node_name) - 1, -1, -1): + if node_name[i] == "_": + node_name = node_name[:i] + break + elif node_name[i] in ["1", "2", "3", "4", "5", "6", "7", "8", "9", "0"]: + continue + else: + break + return node_name + + +def find_tensor_node(node_list: List[Node]) -> List[Node]: + """ + find tensor nodes from a node list + """ + out = [] + for node in node_list: + if get_node_shape(node) is not None: + out.append(node) + return out + + +def find_tensor_shape_node(node_list: List[Node]) -> List[Node]: + """ + find tensor and shape nodes from a node list + """ + out = [] + for node in node_list: + if get_node_shape(node) is not None: + out.append(node) + elif len(node.meta['fwd_out']) > 0 and isinstance(node.meta['fwd_out'], list) and isinstance( + node.meta['fwd_out'][0], int): + out.append(node) + return out diff --git a/colossalai/booster/__init__.py b/colossalai/booster/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..841054a9c67273fb9cb155070b81edf51f8316d2 --- /dev/null +++ b/colossalai/booster/__init__.py @@ -0,0 +1,3 @@ +from .accelerator import Accelerator +from .booster import Booster +from .plugin import Plugin diff --git a/colossalai/booster/accelerator.py b/colossalai/booster/accelerator.py new file mode 100644 index 0000000000000000000000000000000000000000..fc2c4a40068b50cb49db8a3f9c33f22b8f307966 --- /dev/null +++ b/colossalai/booster/accelerator.py @@ -0,0 +1,54 @@ +import torch +import torch.nn as nn + +__all__ = ['Accelerator'] + +_supported_devices = [ + 'cpu', + 'cuda', + + # To be supported + # 'xpu', + # 'npu', + # 'tpu', +] + + +class Accelerator: + """ + Accelerator is an abstraction for the hardware device that is used to run the model. + + Args: + device (str): The device to be used. Currently only support 'cpu' and 'gpu'. + """ + + def __init__(self, device: str): + self.device = device + + assert self.device in _supported_devices, f"Device {self.device} is not supported yet, supported devices include {_supported_devices}" + + def bind(self): + """ + Set the default device for the current process. + """ + if self.device == 'cpu': + pass + elif self.device == 'cuda': + # TODO(FrankLeeeee): use global environment to check if it is a dist job + # if is_distributed: + # local_rank = EnvTable().get_local_rank() + # torch.cuda.set_device(torch.device(f'cuda:{local_rank}')) + torch.cuda.set_device(torch.device('cuda')) + pass + else: + raise ValueError(f"Device {self.device} is not supported yet") + + def configure_model(self, model: nn.Module) -> nn.Module: + """ + Move the model to the device. + + Args: + model (nn.Module): The model to be moved. + """ + model = model.to(torch.device(self.device)) + return model diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py new file mode 100644 index 0000000000000000000000000000000000000000..c14e602deaf5ce60808a87bbbd238c9419b9d502 --- /dev/null +++ b/colossalai/booster/booster.py @@ -0,0 +1,174 @@ +import warnings +from contextlib import contextmanager +from typing import Callable, Iterator, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils.data import DataLoader + +from colossalai.checkpoint_io import GeneralCheckpointIO + +from .accelerator import Accelerator +from .mixed_precision import MixedPrecision, mixed_precision_factory +from .plugin import Plugin + +__all__ = ['Booster'] + + +class Booster: + """ + Booster is a high-level API for training neural networks. It provides a unified interface for + training with different precision, accelerator, and plugin. + + Examples: + >>> colossalai.launch(...) + >>> plugin = GeminiPlugin(stage=3, ...) + >>> booster = Booster(precision='fp16', plugin=plugin) + >>> + >>> model = GPT2() + >>> optimizer = Adam(model.parameters()) + >>> dataloader = Dataloader(Dataset) + >>> lr_scheduler = LinearWarmupScheduler() + >>> criterion = GPTLMLoss() + >>> + >>> model, optimizer, lr_scheduler, dataloader = booster.boost(model, optimizer, lr_scheduler, dataloader) + >>> + >>> for epoch in range(max_epochs): + >>> for input_ids, attention_mask in dataloader: + >>> outputs = model(input_ids, attention_mask) + >>> loss = criterion(outputs.logits, input_ids) + >>> booster.backward(loss, optimizer) + >>> optimizer.step() + >>> lr_scheduler.step() + >>> optimizer.zero_grad() + + + Args: + device (str or torch.device): The device to run the training. Default: 'cuda'. + mixed_precision (str or MixedPrecision): The mixed precision to run the training. Default: None. + If the argument is a string, it can be 'fp16', 'fp16_apex', 'bf16', or 'fp8'. + 'fp16' would use PyTorch AMP while `fp16_apex` would use Nvidia Apex. + plugin (Plugin): The plugin to run the training. Default: None. + """ + + def __init__(self, + device: str = 'cuda', + mixed_precision: Union[MixedPrecision, str] = None, + plugin: Optional[Plugin] = None) -> None: + if plugin is not None: + assert isinstance( + plugin, Plugin), f'Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}.' + self.plugin = plugin + + # set accelerator + if self.plugin and self.plugin.control_device(): + self.accelerator = None + warnings.warn('The plugin will control the accelerator, so the device argument will be ignored.') + else: + self.accelerator = Accelerator(device) + + # set precision + if self.plugin and self.plugin.control_precision(): + warnings.warn('The plugin will control the precision, so the mixed_precision argument will be ignored.') + self.mixed_precision = None + elif mixed_precision is None: + self.mixed_precision = None + else: + # validate and set precision + if isinstance(mixed_precision, str): + # the user will take the default arguments for amp training + self.mixed_precision = mixed_precision_factory(mixed_precision) + elif isinstance(mixed_precision, MixedPrecision): + # the user can customize the arguments by passing the precision object + self.mixed_precision = mixed_precision + else: + raise ValueError( + f'Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}.' + ) + + if self.plugin is not None and self.plugin.control_checkpoint_io(): + self.checkpoint_io = self.plugin.get_checkpoint_io() + else: + self.checkpoint_io = GeneralCheckpointIO() + + def boost( + self, + model: nn.Module, + optimizer: Optimizer, + criterion: Callable = None, + dataloader: DataLoader = None, + lr_scheduler: LRScheduler = None, + ) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]: + """ + Boost the model, optimizer, criterion, lr_scheduler, and dataloader. + + Args: + model (nn.Module): The model to be boosted. + optimizer (Optimizer): The optimizer to be boosted. + criterion (Callable): The criterion to be boosted. + dataloader (DataLoader): The dataloader to be boosted. + lr_scheduler (LRScheduler): The lr_scheduler to be boosted. + """ + # TODO(FrankLeeeee): consider multi-model and multi-optimizer case + # TODO(FrankLeeeee): consider multi-dataloader case + # transform model for mixed precision + if self.plugin: + model, optimizer, criterion, dataloader, lr_scheduler = self.plugin.configure( + model, optimizer, criterion, dataloader, lr_scheduler) + + if self.plugin and not self.plugin.control_device(): + # transform model for accelerator + model = self.accelerator.configure(model) + + if self.mixed_precision and (self.plugin is None or self.plugin and not self.plugin.control_precision()): + # transform model for mixed precision + # when mixed_precision is specified and the plugin is not given or does not control the precision + model, optimizer, criterion = self.mixed_precision.configure(model, optimizer, criterion) + + return model, optimizer, criterion, dataloader, lr_scheduler + + def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None: + # TODO: implement this method with plugin + optimizer.backward(loss) + + def execute_pipeline(self, + data_iter: Iterator, + model: nn.Module, + criterion: Callable[[torch.Tensor], torch.Tensor], + optimizer: Optimizer, + return_loss: bool = True, + return_outputs: bool = False) -> Tuple[Optional[torch.Tensor], ...]: + # TODO: implement this method + # run pipeline forward backward pass + # return loss or outputs if needed + pass + + def no_sync(self, model: nn.Module) -> contextmanager: + assert self.plugin is not None, f'no_sync is only enabled when a plugin is provided and the plugin supports no_sync.' + assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.' + return self.plugin.no_sync(model) + + def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True): + self.checkpoint_io.load_model(model, checkpoint, strict) + + def save_model(self, + model: nn.Module, + checkpoint: str, + prefix: str = None, + shard: bool = False, + size_per_shard: int = 1024): + self.checkpoint_io.save_model(model, checkpoint, prefix, shard, size_per_shard) + + def load_optimizer(self, optimizer: Optimizer, checkpoint: str): + self.checkpoint_io.load_optimizer(optimizer, checkpoint) + + def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024): + self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, size_per_shard) + + def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): + self.checkpoint_io.save_lr_scheduler(lr_scheduler, checkpoint) + + def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): + self.checkpoint_io.load_lr_scheduler(lr_scheduler, checkpoint) diff --git a/colossalai/booster/mixed_precision/__init__.py b/colossalai/booster/mixed_precision/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3cf0ad28cdbe78b24e1c8fd18fd6bd473e095bd6 --- /dev/null +++ b/colossalai/booster/mixed_precision/__init__.py @@ -0,0 +1,33 @@ +from .bf16 import BF16MixedPrecision +from .fp8 import FP8MixedPrecision +from .fp16_apex import FP16ApexMixedPrecision +from .fp16_torch import FP16TorchMixedPrecision +from .mixed_precision_base import MixedPrecision + +__all__ = [ + 'MixedPrecision', 'mixed_precision_factory', 'FP16_Apex_MixedPrecision', 'FP16_Torch_MixedPrecision', + 'FP32_MixedPrecision', 'BF16_MixedPrecision', 'FP8_MixedPrecision' +] + +_mixed_precision_mapping = { + 'fp16': FP16TorchMixedPrecision, + 'fp16_apex': FP16ApexMixedPrecision, + 'bf16': BF16MixedPrecision, + 'fp8': FP8MixedPrecision +} + + +def mixed_precision_factory(mixed_precision_type: str) -> MixedPrecision: + """ + Factory method to create mixed precision object + + Args: + mixed_precision_type (str): mixed precision type, including None, 'fp16', 'fp16_apex', 'bf16', and 'fp8'. + """ + + if mixed_precision_type in _mixed_precision_mapping: + return _mixed_precision_mapping[mixed_precision_type]() + else: + raise ValueError( + f'Mixed precision type {mixed_precision_type} is not supported, support types include {list(_mixed_precision_mapping.keys())}' + ) diff --git a/colossalai/booster/mixed_precision/bf16.py b/colossalai/booster/mixed_precision/bf16.py new file mode 100644 index 0000000000000000000000000000000000000000..4a840fea69ea280b43c451e033b131993fffd857 --- /dev/null +++ b/colossalai/booster/mixed_precision/bf16.py @@ -0,0 +1,5 @@ +from .mixed_precision_base import MixedPrecision + + +class BF16MixedPrecision(MixedPrecision): + pass diff --git a/colossalai/booster/mixed_precision/fp16_apex.py b/colossalai/booster/mixed_precision/fp16_apex.py new file mode 100644 index 0000000000000000000000000000000000000000..266a750734b14ade3c5f53fc71ef276d09e1ec83 --- /dev/null +++ b/colossalai/booster/mixed_precision/fp16_apex.py @@ -0,0 +1,5 @@ +from .mixed_precision_base import MixedPrecision + + +class FP16ApexMixedPrecision(MixedPrecision): + pass diff --git a/colossalai/booster/mixed_precision/fp16_torch.py b/colossalai/booster/mixed_precision/fp16_torch.py new file mode 100644 index 0000000000000000000000000000000000000000..9999aa5e0eb475b8303b76382d9287b3ac876696 --- /dev/null +++ b/colossalai/booster/mixed_precision/fp16_torch.py @@ -0,0 +1,124 @@ +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch import Tensor +from torch.optim import Optimizer + +from colossalai.interface import ModelWrapper, OptimizerWrapper + +from .mixed_precision_base import MixedPrecision + +__all__ = ['FP16_Torch_MixedPrecision', 'TorchAMPOptimizer', 'TorchAMPModule'] + + +class TorchAMPOptimizer(OptimizerWrapper): + """ + Optimizer wrapper for mixed precision training in FP16 using PyTorch AMP. + + Args: + optim (Optimizer): Optimizer to wrap. + init_scale (float): Initial scale factor. Default: 2**16. + growth_factor (float): Factor by which the scale is multiplied during + :meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be finite + this iteration. Default: 2.0. + backoff_factor (float): Factor by which the scale is multiplied during + :meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be infinite + this iteration. Default: 0.5. + growth_interval (int): Number of iterations between :meth:`torch.cuda.amp.GradScaler.step` + calls that may cause the scale to increase. Default: 2000. + """ + + def __init__(self, + optim: Optimizer, + init_scale: float = 2.**16, + growth_factor: float = 2.0, + backoff_factor: float = 0.5, + growth_interval: int = 2000) -> None: + super().__init__(optim) + self.scaler = torch.cuda.amp.GradScaler(init_scale=init_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval) + + def backward(self, loss: Tensor, *args, **kwargs) -> None: + scaled_loss = self.scale_loss(loss) + scaled_loss.backward(*args, **kwargs) + + def step(self, *args, **kwargs) -> Optional[float]: + out = self.scaler.step(self.optim, *args, **kwargs) + self.scaler.update() + return out + + def scale_loss(self, loss: Tensor) -> Tensor: + return self.scaler.scale(loss) + + def unscale_grad(self) -> None: + self.scaler.unscale_(self.optim) + + def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None: + self.unscale_grad() + super().clip_grad_by_value(clip_value, *args, **kwargs) + + def clip_grad_by_norm(self, + max_norm: Union[float, int], + norm_type: Union[float, int] = 2.0, + error_if_nonfinite: bool = False, + *args, + **kwargs) -> None: + self.unscale_grad() + super().clip_grad_by_norm(max_norm, norm_type, error_if_nonfinite, *args, **kwargs) + + +class TorchAMPModule(ModelWrapper): + """ + Module wrapper for mixed precision training in FP16 using PyTorch AMP. + + Args: + module (nn.Module): Module to wrap. + """ + + def __init__(self, module: nn.Module): + super().__init__(module) + + def forward(self, *args, **kwargs): + with torch.cuda.amp.autocast(): + return self.module(*args, **kwargs) + + +class FP16TorchMixedPrecision(MixedPrecision): + """ + Precision for mixed precision training in FP16 using PyTorch AMP. + + Args: + init_scale (float): Initial scale factor. Default: 2**16. + growth_factor (float): Factor by which the scale is multiplied during + :meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be finite + this iteration. Default: 2.0. + backoff_factor (float): Factor by which the scale is multiplied during + :meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be infinite + this iteration. Default: 0.5. + growth_interval (int): Number of iterations between :meth:`torch.cuda.amp.GradScaler.step` + calls that may cause the scale to increase. Default: 2000. + """ + + def __init__(self, + init_scale: float = 2.**16, + growth_factor: float = 2.0, + backoff_factor: float = 0.5, + growth_interval: int = 2000) -> None: + super().__init__() + self.torch_amp_kwargs = dict(init_scale=init_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval) + + def configure(self, + model: nn.Module, + optimizer: Optimizer, + criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]: + model = TorchAMPModule(model) + optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs) + if criterion is not None: + criterion = TorchAMPModule(criterion) + return model, optimizer, criterion diff --git a/colossalai/booster/mixed_precision/fp8.py b/colossalai/booster/mixed_precision/fp8.py new file mode 100644 index 0000000000000000000000000000000000000000..28847345d91d1a6ab176b6d4675f2728c8e67641 --- /dev/null +++ b/colossalai/booster/mixed_precision/fp8.py @@ -0,0 +1,5 @@ +from .mixed_precision_base import MixedPrecision + + +class FP8MixedPrecision(MixedPrecision): + pass diff --git a/colossalai/booster/mixed_precision/mixed_precision_base.py b/colossalai/booster/mixed_precision/mixed_precision_base.py new file mode 100644 index 0000000000000000000000000000000000000000..2490e9811ccf3ef71a1dcb90d30ddb794fc82d04 --- /dev/null +++ b/colossalai/booster/mixed_precision/mixed_precision_base.py @@ -0,0 +1,21 @@ +from abc import ABC, abstractmethod +from typing import Callable, Tuple + +import torch.nn as nn +from torch.optim import Optimizer + +from colossalai.interface import OptimizerWrapper + + +class MixedPrecision(ABC): + """ + An abstract class for mixed precision training. + """ + + @abstractmethod + def configure(self, + model: nn.Module, + optimizer: Optimizer, + criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]: + # TODO: implement this method + pass diff --git a/colossalai/booster/plugin/__init__.py b/colossalai/booster/plugin/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aa45bcb59ad7b9b25a06bea5fa1c0f776e6870ae --- /dev/null +++ b/colossalai/booster/plugin/__init__.py @@ -0,0 +1,6 @@ +from .gemini_plugin import GeminiPlugin +from .low_level_zero_plugin import LowLevelZeroPlugin +from .plugin_base import Plugin +from .torch_ddp_plugin import TorchDDPPlugin + +__all__ = ['Plugin', 'TorchDDPPlugin', 'GeminiPlugin', 'LowLevelZeroPlugin'] diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py new file mode 100644 index 0000000000000000000000000000000000000000..deda00d8a7b3b6849a1cf5b1db0b796a1a1c7d89 --- /dev/null +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -0,0 +1,302 @@ +import random +import warnings +from typing import Callable, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +from torch import Tensor +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO +from colossalai.checkpoint_io.utils import save_state_dict +from colossalai.cluster import DistCoordinator +from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.utils import get_current_device +from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper +from colossalai.zero.gemini.memory_tracer import MemStats + +from .plugin_base import Plugin + +__all__ = ['GeminiPlugin'] + + +class GeminiCheckpointIO(GeneralCheckpointIO): + + def __init__(self) -> None: + super().__init__() + self.coordinator = DistCoordinator() + + def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True): + """ + Load model from checkpoint with automatic unwrapping. + """ + # the model should be unwrapped in self.load_model via ModelWrapper.unwrap + return super().load_unsharded_model(model, checkpoint, strict=strict) + + def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + """ + Save model to checkpoint but only on master process. + """ + # the model should be unwrapped in self.load_model via ModelWrapper.unwrap + # as there is communication when get state dict, this must be called on all processes + state_dict = model.state_dict(only_rank_0=True) + if self.coordinator.is_master(): + save_state_dict(state_dict, checkpoint, use_safetensors) + + def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): + """ + Save optimizer to checkpoint but only on master process. + """ + # TODO(ver217): optimizer state dict is sharded + super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor) + + def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): + """ + Save model to checkpoint but only on master process. + """ + if self.coordinator.is_master(): + super().save_lr_scheduler(lr_scheduler, checkpoint) + + +class GeminiModel(ModelWrapper): + + def __init__(self, module: nn.Module, gemini_config: dict, verbose: bool = False) -> None: + super().__init__(module) + self.module = zero_model_wrapper(module, zero_stage=3, gemini_config=gemini_config, verbose=verbose) + + def unwrap(self): + # as save/load state dict is coupled with the GeminiDDP, we only return GeminiDDP model + return self.module + + +class GeminiOptimizer(OptimizerWrapper): + + def __init__(self, + module: GeminiDDP, + optimizer: Optimizer, + zero_optim_config: dict, + optim_kwargs: dict, + verbose: bool = False) -> None: + optimizer = zero_optim_wrapper(module, + optimizer, + optim_config=zero_optim_config, + **optim_kwargs, + verbose=verbose) + super().__init__(optimizer) + + def backward(self, loss: Tensor, *args, **kwargs): + self.optim.backward(loss) + + def clip_grad_by_norm(self, + max_norm: Union[float, int], + norm_type: Union[float, int] = 2, + error_if_nonfinite: bool = False, + *args, + **kwargs) -> Tensor: + warnings.warn(f'Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm') + + def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None: + raise NotImplementedError('Gemini does not support clip_grad_by_value') + + +class GeminiPlugin(Plugin): + """ + Plugin for Gemini. + + Example: + >>> from colossalai.booster import Booster + >>> from colossalai.booster.plugin import GeminiPlugin + >>> + >>> model, train_dataset, optimizer, criterion = ... + >>> plugin = GeminiPlugin() + + >>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8) + >>> booster = Booster(plugin=plugin) + >>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) + + Args: + device (torch.device): device to place the model. + placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu". + pin_memory (bool, optional): use pin memory on CPU. Defaults to False. + force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False. + strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False. + search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32. + hidden_dim (int, optional): the hidden dimension of DNN. + Users can provide this argument to speed up searching. + If users do not know this argument before training, it is ok. We will use a default value 1024. + min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte. + If the aggregate size of parameters is still samller than the minimum chunk size, + all parameters will be compacted into one small chunk. + memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer. + gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward) + which will be used when using hybrid CPU optimizer. + This argument is meaningless when `placement_policy` of `GeminiManager` is not "auto". + Defaults to 0.0. + initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32. + min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1. + growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2. + backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5. + growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000. + hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2. + max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32. + max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do + clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm. + norm_type (float, optional): norm_type used for `clip_grad_norm`. + verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False. + """ + + def __init__( + self, + device: Optional[torch.device] = None, + placement_policy: str = "cpu", + pin_memory: bool = False, + force_outputs_fp32: bool = False, + strict_ddp_mode: bool = False, + search_range_mb: int = 32, + hidden_dim: Optional[int] = None, + min_chunk_size_mb: float = 32, + memstats: Optional[MemStats] = None, + gpu_margin_mem_ratio: float = 0.0, + initial_scale: float = 2**32, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0.0, + norm_type: float = 2.0, + verbose: bool = False, + ) -> None: + + assert dist.is_initialized( + ), 'torch.distributed is not initialized, please use colossalai.launch to create the distributed environment' + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + self.gemini_config = dict( + device=(device or get_current_device()), + placement_policy=placement_policy, + pin_memory=pin_memory, + force_outputs_fp32=force_outputs_fp32, + strict_ddp_mode=strict_ddp_mode, + search_range_mb=search_range_mb, + hidden_dim=hidden_dim, + min_chunk_size_mb=min_chunk_size_mb, + memstats=memstats, + ) + self.zero_optim_config = dict(gpu_margin_mem_ratio=gpu_margin_mem_ratio,) + self.optim_kwargs = dict(initial_scale=initial_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + min_scale=min_scale, + max_scale=max_scale, + max_norm=max_norm, + norm_type=norm_type) + self.verbose = verbose + + def support_no_sync(self) -> bool: + return False + + def control_precision(self) -> bool: + return True + + def supported_precisions(self) -> List[str]: + return ['fp16'] + + def control_device(self) -> bool: + return True + + def supported_devices(self) -> List[str]: + return ['cuda'] + + def prepare_train_dataloader(self, + dataset, + batch_size, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + **kwargs): + r""" + Prepare a dataloader for distributed training. The dataloader will be wrapped by + `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. + + Note: + 1. Evaluation datasets should not be passed to this function. + + Args: + dataset (`torch.utils.data.Dataset`): The dataset to be loaded. + shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. + seed (int, optional): Random worker seed for sampling, defaults to 1024. + add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. + drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size + is not divisible by the batch size. If False and the size of dataset is not divisible by + the batch size, then the last batch will be smaller, defaults to False. + pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. + num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. + kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in + `DataLoader `_. + + Returns: + :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. + """ + _kwargs = kwargs.copy() + sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle) + + # Deterministic dataloader + def seed_worker(worker_id): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return DataLoader(dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs) + + def configure( + self, + model: nn.Module, + optimizer: Optimizer, + criterion: Callable = None, + dataloader: DataLoader = None, + lr_scheduler: LRScheduler = None, + ) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: + + if not isinstance(model, ModelWrapper): + # convert model to sync bn + # FIXME(ver217): gemini does not support sync bn + # In torch/nn/modules/_functions.py, line 22, ``mean, invstd = torch.batch_norm_stats(input, eps)`` will get fp32 mean and invstd even though the input is fp16. + # This inconsistency of dtype will cause the error. + # We have two possible solutions: + # 1. keep batch norm always in fp32. This is hard for gemini, as it use chunks. + # 2. patch sync bn or write a new on. This is relatively easy, but we need to test it. + # model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None) + + # wrap the model with Gemini + model = GeminiModel(model, self.gemini_config, self.verbose) + + if not isinstance(optimizer, OptimizerWrapper): + optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs, + self.verbose) + + return model, optimizer, criterion, dataloader, lr_scheduler + + def control_checkpoint_io(self) -> bool: + return True + + def get_checkpoint_io(self) -> CheckpointIO: + return GeminiCheckpointIO() diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py new file mode 100644 index 0000000000000000000000000000000000000000..969c430bd317600091e69a68276a5f14006d5420 --- /dev/null +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -0,0 +1,259 @@ +import random +import warnings +from typing import Callable, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +from torch import Tensor +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils._pytree import tree_map +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from colossalai.checkpoint_io import CheckpointIO +from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.utils import get_current_device +from colossalai.zero import zero_model_wrapper, zero_optim_wrapper + +from .plugin_base import Plugin +from .torch_ddp_plugin import TorchDDPCheckpointIO + +__all__ = ['LowLevelZeroPlugin'] + + +def _convert_to_fp16(x): + if isinstance(x, torch.Tensor) and torch.is_floating_point(x): + return x.half() + return x + + +class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): + + def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): + """ + Save optimizer to checkpoint but only on master process. + """ + # TODO(ver217): optimizer state dict is sharded + super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor) + + +class LowLevelZeroModel(ModelWrapper): + + def __init__(self, module: nn.Module, stage: int, precision: str) -> None: + super().__init__(module) + self.convert_inputs = (precision == 'fp16') + module = zero_model_wrapper(module, zero_stage=stage) + if precision == 'fp16': + module = module.half() + module = module.to(get_current_device()) + self.module = module + + def forward(self, *args, **kwargs): + if self.convert_inputs: + args = tree_map(_convert_to_fp16, args) + kwargs = tree_map(_convert_to_fp16, kwargs) + return super().forward(*args, **kwargs) + + +class LowLevelZeroOptimizer(OptimizerWrapper): + + def __init__(self, + module: nn.Module, + optimizer: Optimizer, + zero_optim_config: dict, + optim_kwargs: dict, + verbose: bool = False) -> None: + optimizer = zero_optim_wrapper(module, + optimizer, + optim_config=zero_optim_config, + **optim_kwargs, + verbose=verbose) + super().__init__(optimizer) + + def backward(self, loss: Tensor, *args, **kwargs): + self.optim.backward(loss) + + def clip_grad_by_norm(self, + max_norm: Union[float, int], + norm_type: Union[float, int] = 2, + error_if_nonfinite: bool = False, + *args, + **kwargs) -> Tensor: + warnings.warn(f'LowLevelZero controls grad clipping by itself, so you should not use clip_grad_by_norm') + + def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None: + raise NotImplementedError('LowLevelZero does not support clip_grad_by_value') + + +class LowLevelZeroPlugin(Plugin): + """ + Plugin for low level zero. + + Example: + >>> from colossalai.booster import Booster + >>> from colossalai.booster.plugin import LowLevelZeroPlugin + >>> + >>> model, train_dataset, optimizer, criterion = ... + >>> plugin = LowLevelZeroPlugin() + + >>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8) + >>> booster = Booster(plugin=plugin) + >>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) + + Args: + strage (int, optional): ZeRO stage. Defaults to 1. + precision (str, optional): precision. Support 'fp16' and 'fp32'. Defaults to 'fp16'. + initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32. + min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1. + growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2. + backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5. + growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000. + hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2. + max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32. + max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do + clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm. + norm_type (float, optional): norm_type used for `clip_grad_norm`. + reduce_bucket_size_in_m (int, optional): grad reduce bucket size in M. Defaults to 12. + communication_dtype (torch.dtype, optional): communication dtype. If not specified, the dtype of param will be used. Defaults to None. + overlap_communication (bool, optional): whether to overlap communication and computation. Defaults to True. + cpu_offload (bool, optional): whether to offload grad, master weight and optimizer state to cpu. Defaults to False. + verbose (bool, optional): verbose mode. Debug info including grad overflow will be printed. Defaults to False. + """ + + def __init__( + self, + stage: int = 1, + precision: str = 'fp16', + initial_scale: float = 2**32, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0.0, + norm_type: float = 2.0, + reduce_bucket_size_in_m: int = 12, + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = True, + cpu_offload: bool = False, + verbose: bool = False, + ) -> None: + + assert dist.is_initialized( + ), 'torch.distributed is not initialized, please use colossalai.launch to create the distributed environment' + assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training' + assert precision in ('fp16', 'fp32'), f'LowLevelZeroPlugin only supports fp16/fp32 training' + + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + + self.stage = stage + self.precision = precision + self.zero_optim_config = dict(reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024, + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + cpu_offload=cpu_offload) + self.optim_kwargs = dict(initial_scale=initial_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + min_scale=min_scale, + max_scale=max_scale, + max_norm=max_norm, + norm_type=norm_type) + self.verbose = verbose + + def support_no_sync(self) -> bool: + return False + + def control_precision(self) -> bool: + return True + + def supported_precisions(self) -> List[str]: + return ['fp16', 'fp32'] + + def control_device(self) -> bool: + return True + + def supported_devices(self) -> List[str]: + return ['cuda'] + + def prepare_train_dataloader(self, + dataset, + batch_size, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + **kwargs): + r""" + Prepare a dataloader for distributed training. The dataloader will be wrapped by + `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. + + Note: + 1. Evaluation datasets should not be passed to this function. + + Args: + dataset (`torch.utils.data.Dataset`): The dataset to be loaded. + shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. + seed (int, optional): Random worker seed for sampling, defaults to 1024. + add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. + drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size + is not divisible by the batch size. If False and the size of dataset is not divisible by + the batch size, then the last batch will be smaller, defaults to False. + pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. + num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. + kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in + `DataLoader `_. + + Returns: + :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. + """ + _kwargs = kwargs.copy() + sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle) + + # Deterministic dataloader + def seed_worker(worker_id): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return DataLoader(dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs) + + def configure( + self, + model: nn.Module, + optimizer: Optimizer, + criterion: Callable = None, + dataloader: DataLoader = None, + lr_scheduler: LRScheduler = None, + ) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: + + if not isinstance(model, ModelWrapper): + model = LowLevelZeroModel(model, self.stage, self.precision) + + if not isinstance(optimizer, OptimizerWrapper): + optimizer = LowLevelZeroOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs, + self.verbose) + + return model, optimizer, criterion, dataloader, lr_scheduler + + def control_checkpoint_io(self) -> bool: + return True + + def get_checkpoint_io(self) -> CheckpointIO: + return LowLevelZeroCheckpointIO() diff --git a/colossalai/booster/plugin/plugin_base.py b/colossalai/booster/plugin/plugin_base.py new file mode 100644 index 0000000000000000000000000000000000000000..7a222022c1b264585b03ea4c402e580a909c4af7 --- /dev/null +++ b/colossalai/booster/plugin/plugin_base.py @@ -0,0 +1,61 @@ +from abc import ABC, abstractmethod +from typing import Callable, List, Tuple, Union + +import torch.nn as nn +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils.data import DataLoader + +from colossalai.checkpoint_io import CheckpointIO +from colossalai.interface import OptimizerWrapper + +__all__ = ['Plugin'] + + +class Plugin(ABC): + + @abstractmethod + def supported_devices(self) -> List[str]: + pass + + @abstractmethod + def supported_precisions(self) -> List[str]: + pass + + @abstractmethod + def control_precision(self) -> bool: + pass + + @abstractmethod + def control_device(self) -> bool: + pass + + @abstractmethod + def support_no_sync(self) -> bool: + pass + + @abstractmethod + def configure( + self, + model: nn.Module, + optimizer: Optimizer, + criterion: Callable = None, + dataloader: DataLoader = None, + lr_scheduler: LRScheduler = None, + ) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: + # implement this method + pass + + @abstractmethod + def control_checkpoint_io(self) -> bool: + """ + Whether the plugin controls the checkpoint io + """ + pass + + @abstractmethod + def get_checkpoint_io(self) -> CheckpointIO: + """ + Get checkpoint io object for this plugin, only invoked when control_checkpoint_io is True. + """ + pass diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py new file mode 100644 index 0000000000000000000000000000000000000000..c5e310c7e7695893266784e9b90018b0d373f6f2 --- /dev/null +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -0,0 +1,204 @@ +import random +from typing import Callable, List, Tuple, Union + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO +from colossalai.cluster import DistCoordinator +from colossalai.interface import ModelWrapper, OptimizerWrapper + +from .plugin_base import Plugin + +__all__ = ['TorchDDPPlugin'] + + +class TorchDDPCheckpointIO(GeneralCheckpointIO): + + def __init__(self) -> None: + super().__init__() + self.coordinator = DistCoordinator() + + def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True): + """ + Load model from checkpoint with automatic unwrapping. + """ + # the model should be unwrapped in self.load_model via ModelWrapper.unwrap + return super().load_unsharded_model(model, checkpoint, strict=strict) + + def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + """ + Save model to checkpoint but only on master process. + """ + # the model should be unwrapped in self.load_model via ModelWrapper.unwrap + if self.coordinator.is_master(): + super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors) + + def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): + """ + Save optimizer to checkpoint but only on master process. + """ + if self.coordinator.is_master(): + super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor) + + def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): + """ + Save model to checkpoint but only on master process. + """ + if self.coordinator.is_master(): + super().save_lr_scheduler(lr_scheduler, checkpoint) + + +class TorchDDPModel(ModelWrapper): + + def __init__(self, module: nn.Module, *args, **kwargs) -> None: + super().__init__(module) + self.module = DDP(module, *args, **kwargs) + + def unwrap(self): + return self.module.module + + +class TorchDDPPlugin(Plugin): + """ + Plugin for PyTorch DDP. + + Example: + >>> from colossalai.booster import Booster + >>> from colossalai.booster.plugin import TorchDDPPlugin + >>> + >>> model, train_dataset, optimizer, criterion = ... + >>> plugin = TorchDDPPlugin() + + >>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8) + >>> booster = Booster(plugin=plugin) + >>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) + + Args: + broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training. Defaults to True. + bucket_cap_mb (int, optional): The bucket size in MB. Defaults to 25. + find_unused_parameters (bool, optional): Whether to find unused parameters. Defaults to False. + check_reduction (bool, optional): Whether to check reduction. Defaults to False. + gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view. Defaults to False. + static_graph (bool, optional): Whether to use static graph. Defaults to False. + """ + + def __init__(self, + broadcast_buffers: bool = True, + bucket_cap_mb: int = 25, + find_unused_parameters: bool = False, + check_reduction: bool = False, + gradient_as_bucket_view: bool = False, + static_graph: bool = False) -> None: + + assert dist.is_initialized( + ), 'torch.distributed is not initialized, please use colossalai.launch to create the distributed environment' + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + self.ddp_kwargs = dict(broadcast_buffers=broadcast_buffers, + bucket_cap_mb=bucket_cap_mb, + find_unused_parameters=find_unused_parameters, + check_reduction=check_reduction, + gradient_as_bucket_view=gradient_as_bucket_view, + static_graph=static_graph) + + def support_no_sync(self) -> bool: + return True + + def control_precision(self) -> bool: + return False + + def supported_precisions(self) -> List[str]: + return ['fp16', 'fp16_apex', 'bf16', 'fp8'] + + def control_device(self) -> bool: + return True + + def supported_devices(self) -> List[str]: + return ['cuda'] + + def prepare_train_dataloader(self, + dataset, + batch_size, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + **kwargs): + r""" + Prepare a dataloader for distributed training. The dataloader will be wrapped by + `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. + + Note: + 1. Evaluation datasets should not be passed to this function. + + Args: + dataset (`torch.utils.data.Dataset`): The dataset to be loaded. + shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. + seed (int, optional): Random worker seed for sampling, defaults to 1024. + add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. + drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size + is not divisible by the batch size. If False and the size of dataset is not divisible by + the batch size, then the last batch will be smaller, defaults to False. + pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. + num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. + kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in + `DataLoader `_. + + Returns: + :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. + """ + _kwargs = kwargs.copy() + sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle) + + # Deterministic dataloader + def seed_worker(worker_id): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return DataLoader(dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs) + + def configure( + self, + model: nn.Module, + optimizer: Optimizer, + criterion: Callable = None, + dataloader: DataLoader = None, + lr_scheduler: LRScheduler = None, + ) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: + # cast model to cuda + model = model.cuda() + + # convert model to sync bn + model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None) + + # wrap the model with PyTorch DDP + model = TorchDDPModel(model, **self.ddp_kwargs) + + if not isinstance(optimizer, OptimizerWrapper): + optimizer = OptimizerWrapper(optimizer) + + return model, optimizer, criterion, dataloader, lr_scheduler + + def control_checkpoint_io(self) -> bool: + return True + + def get_checkpoint_io(self) -> CheckpointIO: + return TorchDDPCheckpointIO() diff --git a/colossalai/builder/__init__.py b/colossalai/builder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cf09e1e7a31a15e979c5358c5da27683a0ccb2f9 --- /dev/null +++ b/colossalai/builder/__init__.py @@ -0,0 +1,3 @@ +from .builder import build_from_config, build_from_registry, build_gradient_handler + +__all__ = ['build_gradient_handler', 'build_from_config', 'build_from_registry'] diff --git a/colossalai/builder/builder.py b/colossalai/builder/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..4a907601327c9c938243bfee121165937c02537c --- /dev/null +++ b/colossalai/builder/builder.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import inspect + +from colossalai.registry import * + + +def build_from_config(module, config: dict): + """Returns an object of :class:`module` constructed from `config`. + + Args: + module: A python or user-defined class + config: A python dict containing information used in the construction of the return object + + Returns: An ``object`` of interest + + Raises: + AssertionError: Raises an AssertionError if `module` is not a class + + """ + assert inspect.isclass(module), 'module must be a class' + return module(**config) + + +def build_from_registry(config, registry: Registry): + r"""Returns an object constructed from `config`, the type of the object + is specified by `registry`. + + Note: + the `config` is used to construct the return object such as `LAYERS`, `OPTIMIZERS` + and other support types in `registry`. The `config` should contain + all required parameters of corresponding object. The details of support + types in `registry` and the `mod_type` in `config` could be found in + `registry `_. + + Args: + config (dict or :class:`colossalai.context.colossalai.context.Config`): information + used in the construction of the return object. + registry (:class:`Registry`): A registry specifying the type of the return object + + Returns: + A Python object specified by `registry`. + + Raises: + Exception: Raises an Exception if an error occurred when building from registry. + """ + config_ = config.copy() # keep the original config untouched + assert isinstance(registry, Registry), f'Expected type Registry but got {type(registry)}' + + mod_type = config_.pop('type') + assert registry.has(mod_type), f'{mod_type} is not found in registry {registry.name}' + try: + obj = registry.get_module(mod_type)(**config_) + except Exception as e: + print(f'An error occurred when building {mod_type} from registry {registry.name}', flush=True) + raise e + + return obj + + +def build_gradient_handler(config, model, optimizer): + """Returns a gradient handler object of :class:`BaseGradientHandler` constructed from `config`, + `model` and `optimizer`. + + Args: + config (dict or :class:`colossalai.context.Config`): A python dict or + a :class:`colossalai.context.Config` object containing information + used in the construction of the ``GRADIENT_HANDLER``. + model (:class:`nn.Module`): A model containing parameters for the gradient handler + optimizer (:class:`torch.optim.Optimizer`): An optimizer object containing parameters for the gradient handler + + Returns: + An object of :class:`colossalai.engine.BaseGradientHandler` + """ + config_ = config.copy() + config_['model'] = model + config_['optimizer'] = optimizer + return build_from_registry(config_, GRADIENT_HANDLER) diff --git a/colossalai/checkpoint_io/__init__.py b/colossalai/checkpoint_io/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c25048e25754eb6fc55db1b1de4ac7b21d05bda3 --- /dev/null +++ b/colossalai/checkpoint_io/__init__.py @@ -0,0 +1,5 @@ +from .checkpoint_io_base import CheckpointIO +from .general_checkpoint_io import GeneralCheckpointIO +from .index_file import CheckpointIndexFile + +__all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO'] diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py new file mode 100644 index 0000000000000000000000000000000000000000..cb853559c48c1a6ceb41231e312a2480fd68bd97 --- /dev/null +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -0,0 +1,330 @@ +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Union +from typing import Optional + +import torch +import torch.nn as nn +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler + +from colossalai.interface import ModelWrapper + +from .utils import has_index_file + +__all__ = ['CheckpointIO'] + + +class CheckpointIO(ABC): + """ + CheckpointIO is the base class for all checkpoint IO classes. It defines the interface for checkpoint IO. + + + Examples: + >>> from colossalai.checkpoint_io import GeneralCheckpointIO + >>> checkpoint_io = CheckpointIO() + >>> + >>> # load model from checkpoint + >>> model = checkpoint_io.load_model(model, 'model.pt') + >>> + >>> # save model to checkpoint, any distributed tensor is gathered by default + >>> checkpoint_io.save_model(model, 'model.pt') + >>> + >>> # if the model contains distributed tensor, and you don't want to gather it + >>> # each rank will save its own shard of the distributed tensor + >>> checkpoint_io.save_model(model, 'model.pt', gather_dtensor=False) + >>> + >>> # save model to sharded checkpoints + >>> checkpoint_io.save_model(model, './checkpoints/', shard=True) + >>> + >>> # save model to sharded and assume we don't want to gather distributed tensors + >>> checkpoint_io.save_model(model, './checkpoints/', shard=True, gather_dtensor=False) + >>> + >>> # Note: + >>> # 1. we don't support loading from distributed tensors, conversion from distributed tensors + >>> # checkpoints to full tensor checkpoint should be done offline via our CLI + >>> # 2. you don't have to specify whether the model is sharded or not when loading the model + >>> # as it will be automatically detected + >>> + >>> # load model from sharded checkpoints + >>> model = checkpoint_io.load_model(model, './checkpoints/') + >>> + >>> # load model from unsharded checkpoints + >>> model = checkpoint_io.load_model(model, './checkpoints/') + >>> + >>> # load optimizer from checkpoint + >>> optimizer = checkpoint_io.load_optimizer(optimizer, 'optimizer.pt') + >>> + >>> # save optimizer to checkpoint + >>> checkpoint_io.save_optimizer(optimizer, 'optimizer.pt') + """ + + # ====================================== + # Public methods + # ====================================== + def load_model(self, + model: Union[nn.Module, ModelWrapper], + checkpoint: str, + strict: bool = True) -> Union[nn.Module, ModelWrapper]: + """ + Load model from checkpoint. + + Args: + model (nn.Module): model to be loaded. + checkpoint (str): checkpoint path. This value is made compatibility with the model checkpoints in the + mainstream model zoos such as Hugging Face and TIMM. The checkpoint path can be: + 1. a file path, e.g. 'model.pt' + 2. a path to a json file which defines the index to the sharded checkpoint + 3. a path to a folder containing a unique .index.json file for sharded checkpoint + Distributed tensors cannot be loaded directly unless gathered offline via our CLI. + strict (bool): whether to strictly enforce that the param name in + the checkpoint match the keys returned by this module's. + """ + # since we only support loaded sharded and unsharded weight format + # containing no distributed tensors, dtensor -> full tensor conversion + # should be done offline via our CLI + # the existence of index file means it is a sharded checkpoint + ckpt_path = Path(checkpoint) + index_file_exists, index_file_path = has_index_file(checkpoint) + + # return the origin model instead of the unwrapped model + origin_model = model + + if isinstance(model, ModelWrapper): + model = model.unwrap() + + if index_file_exists: + self.load_sharded_model(model, index_file_path, strict) + else: + self.load_unsharded_model(model, checkpoint, strict) + + return origin_model + + def save_model(self, + model: Union[nn.Module, ModelWrapper], + checkpoint: str, + shard: bool = False, + gather_dtensor: bool = True, + variant: str = None, + size_per_shard: int = 1024, + use_safetensors: bool = False): + """ + Save model to checkpoint. + + Examples: + >>> from colossalai.checkpoint_io import GeneralCheckpointIO + >>> checkpoint_io = CheckpointIO() + >>> + >>> # save model to a single file + >>> save_model(model, 'model.pt') + >>> + >>> # save model to a sharded checkpoint + >>> save_model(model, './checkpoints/', shard=True) + + Args: + model (nn.Module): model to be saved. + checkpoint (str): checkpoint path. The checkpoint path can be : + 1. a file path, e.g. 'model.pt' + 2. a directory path to save the sharded checkpoint, e.g. './checkpoints/' when shard = True. + shard (bool): whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into + multiple files. The model shards will be specified by a `model.index.json` file. When shard = True, please ensure + that the checkpoint path is a directory path instead of a file path. + gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True. + variant (str): If specified, weights are saved in the format pytorch_model..bin. Default: None. + size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True. + use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved + """ + + if isinstance(model, ModelWrapper): + model = model.unwrap() + + if shard: + self.save_sharded_model(model, checkpoint, gather_dtensor, variant, size_per_shard, use_safetensors) + else: + self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors) + + def load_optimizer(self, optimizer: Optimizer, checkpoint: str): + """ + Load optimizer from checkpoint. + + Args: + optimizer (Optimizer): optimizer to be loaded. + checkpoint (str): checkpoint path. This value is made compatibility with the model checkpoints in the + """ + index_file_exists, index_file_path = has_index_file(checkpoint) + + if Path(checkpoint).is_dir() and not index_file_exists: + # if the checkpoint is a directory and there is no index file, raise error + raise ValueError(f'Cannot find index file in {checkpoint}') + + if index_file_exists: + # the existence of index file means it is a sharded checkpoint + self.load_sharded_optimizer(optimizer, index_file_path) + else: + self.load_unsharded_optimizer(optimizer, checkpoint) + + def save_optimizer(self, + optimizer: Optimizer, + checkpoint: str, + shard: bool = False, + gather_dtensor=True, + prefix: str = None, + size_per_shard: int = 1024): + """ + Save optimizer to checkpoint. Optimizer states saving is not compatible with safetensors. + + Args: + optimizer (Optimizer): optimizer to be saved. + checkpoint (str): checkpoint path. The checkpoint path can be : + 1. a file path, e.g. 'model.pt' + 2. a path to a json file which defines the index to the sharded checkpoint for the optimizer + 3. a path to a folder containing a unique .index.json file for sharded checkpoint + shard (bool): whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into + multiple files. The optimizer shards will be specified by a `optimizer.index.json` file. + gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True. + prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None. + size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True. + """ + if shard: + self.save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard) + else: + self.save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor) + + # ======================================================== + # Abstract methods for model loading/saving implementation + # ======================================================== + @abstractmethod + def load_sharded_model(self, model: nn.Module, index_file_path: str, strict: bool): + """ + Load model from sharded checkpoint. + + Args: + model (nn.Module): model to be loaded. + index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file. + strict (bool): whether to strictly enforce that the param name in + the checkpoint match the keys returned by this module's. + """ + pass + + @abstractmethod + def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool): + """ + Load model from unsharded checkpoint. + + Args: + model (nn.Module): model to be loaded. + checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary. + strict (bool): whether to strictly enforce that the param name in + the checkpoint match the keys returned by this module's. + """ + pass + + @abstractmethod + def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, variant: Optional[str], + size_per_shard: int, use_safetensors: bool): + """ + Save model to sharded checkpoint. + + Args: + model (nn.Module): model to be saved. + checkpoint (str): checkpoint path. It should be a directory path. + gather_dtensor (bool): whether to gather the distributed tensor to the first device. + prefix (str): prefix for the model checkpoint. + size_per_shard (int): size per shard in MB. + use_safetensors (bool): whether to use safe tensors. + """ + pass + + @abstractmethod + def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + """ + Save model to unsharded checkpoint. + + Args: + model (nn.Module): model to be saved. + checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary. + gather_dtensor (bool): whether to gather the distributed tensor to the first device. + use_safetensors (bool): whether to use safe tensors. + """ + pass + + # ======================================================== + # Abstract methods for optimizer loading/saving implementation + # ======================================================== + + @abstractmethod + def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int): + """ + Load optimizer from sharded checkpoint. + + Args: + optimizer (Optimizer): optimizer to be loaded. + index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file. + prefix (str): prefix for the optimizer checkpoint. + size_per_shard (int): size per shard in MB. + """ + pass + + @abstractmethod + def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): + """ + Load optimizer from unsharded checkpoint. + + Args: + optimizer (Optimizer): optimizer to be loaded. + checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary. + """ + pass + + @abstractmethod + def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, + size_per_shard: int): + """ + Save optimizer to sharded checkpoint. + + Args: + optimizer (Optimizer): optimizer to be saved. + checkpoint (Path): checkpoint path. It should be a directory path. + gather_dtensor (bool): whether to gather the distributed tensor to the first device. + prefix (str): prefix for the optimizer checkpoint. + size_per_shard (int): size per shard in MB. + """ + pass + + @abstractmethod + def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool): + """ + Save optimizer to unsharded checkpoint. + + Args: + optimizer (Optimizer): optimizer to be saved. + checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary. + gather_dtensor (bool): whether to gather the distributed tensor to the first device. + """ + pass + + # ============================================ + # methods for loading and saving lr scheduler + # as this is quite standard, there is no need + # to make them abstract + # ============================================ + def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): + """ + Save lr scheduler to checkpoint. + + Args: + lr_scheduler (LRScheduler): lr scheduler to be saved. + checkpoint: checkpoint path. The checkpoint path can only be a file path. + """ + torch.save(lr_scheduler.state_dict(), checkpoint) + + def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): + """ + Load lr scheduler from checkpoint. + + Args: + lr_scheduler (LRScheduler): lr scheduler to be loaded. + checkpoint (str): the path for a single checkpoint file. + """ + state_dict = torch.load(checkpoint) + lr_scheduler.load_state_dict(state_dict) diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py new file mode 100644 index 0000000000000000000000000000000000000000..bf584f45d045228c5d6d1e02b470c7696f5194db --- /dev/null +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -0,0 +1,138 @@ +from pathlib import Path + +import torch.nn as nn +from torch.optim import Optimizer +import logging +import os +import json +import gc +from typing import Optional + +from .checkpoint_io_base import CheckpointIO +from .index_file import CheckpointIndexFile +from .utils import ( + has_index_file, + load_state_dict, + save_state_dict, + is_safetensors_available, + shard_checkpoint, + load_shard_state_dict, + load_state_dict_into_model, + add_variant + ) +from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME + + +__all__ = ['GeneralCheckpointIO'] + + +class GeneralCheckpointIO(CheckpointIO): + """ + Checkpoint IO + """ + def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool): + checkpoint = load_state_dict(checkpoint) + model.load_state_dict(checkpoint, strict=strict) + + def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + state_dict = model.state_dict() + + # TODO(FrankLeeeee): add support for gather_dtensor + if gather_dtensor: + pass + + # save the checkpoint + save_state_dict(state_dict, checkpoint, use_safetensors) + + def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int): + raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.") + + def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): + checkpoint = load_state_dict(checkpoint) + optimizer.load_state_dict(checkpoint) + + def save_sharded_optimizer( + self, + optimizer: Optimizer, + checkpoint: Path, + gather_dtensor: bool, + prefix: str, + size_per_shard: int, + ): + raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.") + + def save_unsharded_optimizer( + self, + optimizer: Optimizer, + checkpoint: Path, + gather_dtensor: bool, + ): + # TODO(FrankLeeeee): handle distributed tensors + save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False) + + + def save_sharded_model(self, model: nn.Module, checkpoint_path: str, gather_dtensor:bool = False, + variant: Optional[str] = None, max_shard_size: int = 1024, use_safetensors: bool = False): + """ + implement this method as it can be supported by Huggingface model, + save shard model, save model to multiple files + """ + if os.path.isfile(checkpoint_path): + logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file") + return + + Path(checkpoint_path).mkdir(parents=True, exist_ok=True) + + # shard checkpoint + state_dict = model.state_dict() + weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME + weights_name = add_variant(weights_name, variant) + shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name) + + # Save the model + for shard_file, shard in shards.items(): + checkpoint_file_path = os.path.join(checkpoint_path, shard_file) + save_state_dict(shard, checkpoint_file_path, use_safetensors) + + # save index file + save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME + + save_index_file = os.path.join(checkpoint_path, add_variant(save_index_file, variant)) + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + logging.info( + f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " + f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + + def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False): + """ + load shard model, load model from multiple files + """ + use_safetensors = False + if "safetensors" in checkpoint_index_file.name: + use_safetensors = True + + if use_safetensors and not is_safetensors_available(): + raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.") + + # read checkpoint index file + ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) + checkpoint_files, _ = ckpt_index_file.get_checkpoint_fileanames() + missing_keys = ckpt_index_file.get_all_param_names() + + for shard_file in checkpoint_files: + state_dict = load_shard_state_dict(Path(shard_file), use_safetensors) + load_state_dict_into_model(model, state_dict, missing_keys, strict) + del state_dict + gc.collect() + + if strict and len(missing_keys) > 0: + error_msgs = 'Missing key(s) in state_dict: {}. '.format( + ', '.join('"{}"'.format(k) for k in missing_keys)) + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( + self.__class__.__name__, "\n\t".join(error_msgs))) + diff --git a/colossalai/checkpoint_io/index_file.py b/colossalai/checkpoint_io/index_file.py new file mode 100644 index 0000000000000000000000000000000000000000..89224787a91b9824f2261aece909f6a1cb094a17 --- /dev/null +++ b/colossalai/checkpoint_io/index_file.py @@ -0,0 +1,156 @@ +import json +from pathlib import Path +from typing import Any, List, Union + +from .utils import is_dtensor_checkpoint + +__all__ = ['CheckpointIndexFile'] + + +class CheckpointIndexFile: + """ + This class is a data structure to keep the content in the index.json file for sharded checkpoint. + + Example: + >>> index = CheckpointIndexFile.from_file('model.index.json') + >>> index.append_metadata('model_type', 'bert') + >>> index.append_weight_map('bert.embeddings.word_embeddings.weight', 'model_0001-of-0002.bin') + >>> index.export('new_index.json') + """ + + def __init__(self) -> None: + self.root_path = None + self.metadata: dict = dict() + self.weight_map: dict = dict() + + @staticmethod + def from_file(index_path: Union[str, Path]): + """ + Create a CheckpointIndexFile object from a json file. + + Args: + index_path (str): path to the json file. + + Returns: + CheckpointIndexFile: CheckpointIndexFile object. + """ + index = CheckpointIndexFile() + index.load(index_path) + return index + + def load(self, json_path: str): + """ + Load the index file from a json file. + + Args: + json_path (str): path to the json file. + """ + # load the json file + with open(json_path, 'r') as f: + index = json.load(f) + + # assign attributes if exists + if "metadata" in index: + self.metadata = index["metadata"] + if "weight_map" in index: + self.weight_map = index["weight_map"] + + # assign the root directory for the index file + self.root_path = Path(json_path).absolute().parent + + def export(self, json_path: str): + """ + Export the index file to a json file. + + Args: + json_path (str): path to the json file. + """ + # create the index file + index = dict() + index["metadata"] = self.metadata + index["weight_map"] = self.weight_map + + # export the index file + with open(json_path, 'w') as f: + json.dump(index, f, indent=4) + + def append_weight_map(self, param_name: str, shard_file: str): + """ + Append a weight map entry to the index file. + + Args: + param_name (str): name of the parameter. + shard_file (str): name of the shard file. + """ + self.weight_map[param_name] = shard_file + + def append_meta_data(self, name: str, val: Any): + """ + Append a metadata entry to the index file. + + Args: + name (str): name of the metadata. + val (Any): value of the metadata. + """ + self.metadata[name] = val + + def contains_dtensor(self): + """ + Check if the index file contains any distributed tensor. The distributed tensors will be stored in + `dtensor/module.linear.weight.*.bin` or `dtensor/module.linear.weight.*.safetensors` in the weight map. + + Returns: + bool: True if the index file contains any distributed tensor, False otherwise. + """ + for value in self.weight_map.values(): + if value.endswith(".*.bin") or value.endswith(".*.safetensors"): + return True + return False + + def get_checkpoint_fileanames(self) -> List[str]: + """ + Get the set of checkpoint filenames in the weight map. + + Returns: + list: checkpoint shard filenames. + """ + # read the checkpoint file list from the json file and get a list of unique file names + checkpoint_files = sorted(list(set(self.weight_map.values()))) + + # get the absolute paths for all checkpoint files + checkpoint_files = [str(self.root_path.joinpath(f)) for f in checkpoint_files] + + dtensor_list = [] + checkpoint_list = [] + + for ckpt_file in checkpoint_files: + if is_dtensor_checkpoint(ckpt_file): + dtensor_list.append(ckpt_file) + else: + checkpoint_list.append(ckpt_file) + + return checkpoint_list, dtensor_list + + def assert_no_dtensor_checkpoint(self): + for val in self.weight_map.values(): + if is_dtensor_checkpoint(val): + raise ValueError(f"Checkpoint file {val} contains distributed tensor") + + def get_checkpoint_file(self, param_name: str) -> str: + """ + Get the checkpoint file name for a parameter. + + Args: + param_name (str): name of the parameter. + + Returns: + str: checkpoint file name. + """ + ckpt_path = self.weight_map[param_name] + return ckpt_path + + def get_all_param_names(self): + """ + Get all the weight keys. + """ + return list(self.weight_map.keys()) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..37d22d08df40eaaa209ad32b097ce735a14378dc --- /dev/null +++ b/colossalai/checkpoint_io/utils.py @@ -0,0 +1,419 @@ +# coding=utf-8 +from pathlib import Path +import torch +import torch.nn as nn +from typing import List, Dict, Mapping, OrderedDict, Optional, Tuple +from colossalai.tensor.d_tensor.d_tensor import DTensor +import re + +SAFE_WEIGHTS_NAME = "model.safetensors" +WEIGHTS_NAME = "pytorch_model.bin" +SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json" +WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json" + +# ====================================== +# General helper functions +# ====================================== + +def calculate_tensor_size(tensor: torch.Tensor) -> float: + """ + Calculate the size of a parameter in MB. Used to compute whether a group of params exceed the shard size. + If so, a new shard should be created. + + Args: + tenosr (torch.Tensor): the tensor to calculate size for. + + Returns: + float: size of the tensor in MB. + """ + return tensor.numel() * tensor.element_size() / 1024 / 1024 + +def is_safetensors_available() -> bool: + """ + Check whether safetensors is available. + + Returns: + bool: whether safetensors is available. + """ + try: + import safetensors + return True + except ImportError: + return False + + +def is_dtensor_checkpoint(checkpoint_file_path: str) -> bool: + """ + Check whether the checkpoint file is a dtensor checkpoint. + + Args: + checkpoint_file_path (str): path to the checkpoint file. + + Returns: + bool: whether the checkpoint file is a dtensor checkpoint. + """ + if checkpoint_file_path.endswith('.*.safetensors') or checkpoint_file_path.endswith('.*.bin'): + return True + else: + return False + + +def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool: + """ + Check whether the checkpoint file is a safetensor checkpoint. + + Args: + checkpoint_file_path (str): path to the checkpoint file. + + Returns: + bool: whether the checkpoint file is a safetensor checkpoint. + """ + if checkpoint_file_path.endswith('.safetensors'): + return True + else: + return False + + +# ====================================== +# Helper functions for saving shard file +# ====================================== +def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024, weights_name: str = WEIGHTS_NAME): + + """ + Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a + given size. + """ + sharded_state_dicts = [] + current_block = {} + current_block_size = 0 + total_size = 0 + + for key, weight in state_dict.items(): + if type(weight) != DTensor: + weight_size = calculate_tensor_size(weight) + + # If this weight is going to tip up over the maximal size, we split. + if current_block_size + weight_size > max_shard_size: + sharded_state_dicts.append(current_block) + current_block = {} + current_block_size = 0 + + current_block[key] = weight + current_block_size += weight_size + total_size += weight_size + + # Add the last block + sharded_state_dicts.append(current_block) + + # If we only have one shard, we return it + if len(sharded_state_dicts) == 1: + return {weights_name: sharded_state_dicts[0]}, None + + # Otherwise, let's build the index + weight_map = {} + shards = {} + + for idx, shard in enumerate(sharded_state_dicts): + shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin") + shard_file = shard_file.replace( + ".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors" + ) + shards[shard_file] = shard + for key in shard.keys(): + weight_map[key] = shard_file + + # Add the metadata + metadata = {"total_size": total_size} + index = {"metadata": metadata, "weight_map": weight_map} + return shards, index + +def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False): + """ + load shard state dict into model + """ + if use_safetensors and not checkpoint_file.suffix == ".safetensors": + raise Exception("load the model using `safetensors`, but no file endwith .safetensors") + if use_safetensors: + from safetensors.torch import safe_open + from safetensors.torch import load_file as safe_load_file + with safe_open(checkpoint_file, framework="pt") as f: + metadata = f.metadata() + if metadata["format"] != "pt": + raise NotImplementedError( + f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet." + ) + return safe_load_file(checkpoint_file) + else: + return torch.load(checkpoint_file) + +def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False): + r"""Copies parameters and buffers from :attr:`state_dict` into + this module and its descendants. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + """ + if not isinstance(state_dict, Mapping): + raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict))) + + unexpected_keys: List[str] = [] + sub_missing_keys: List[str] = [] + error_msgs: List[str] = [] + + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, '_metadata', None) + state_dict = OrderedDict(state_dict) + if metadata is not None: + state_dict._metadata = metadata + + def load(module: nn.Module, state_dict, prefix=""): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) + # Parameters of module and children will start with prefix. We can exit early if there are none in this + # state_dict + if len([key for key in state_dict if key.startswith(prefix)]) > 0: + module._load_from_state_dict(*args) + + for name, child in module._modules.items(): + if child is not None: + load(child, state_dict, prefix + name + ".") + + load(model, state_dict, "") + del load + + # deal with missing key + if len(missing_keys) > 0: + deleted_keys = [] + for key in missing_keys: + if key not in sub_missing_keys: + deleted_keys.append(key) + for key in deleted_keys: + missing_keys.remove(key) + + if strict: + if len(unexpected_keys) > 0: + error_msgs = 'Unexpected key(s) in state_dict: {}. '.format( + ', '.join('"{}"'.format(k) for k in unexpected_keys)) + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( + model.__class__.__name__, "\n\t".join(error_msgs))) + +# ====================================== +# Helper functions for saving state dict +# ====================================== + + +def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None: + """ + Save state dict to checkpoint. + + Args: + state_dict (dict): state dict. + checkpoint_file_path (str): path to the checkpoint file. + use_safetensors (bool): whether to use safetensors to save the checkpoint. + """ + if use_safetensors: + assert is_safetensors_available(), "safetensors is not available." + assert checkpoint_file_path.endswith('.safetensors'), \ + "safetensors only supports .safetensors suffix for checkpoint file." + from safetensors.torch import save_file as safe_save_file + safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"}) + else: + torch.save(state_dict, checkpoint_file_path) + + +def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFile", use_safetensors: bool) -> None: + """ + Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains + only one tensor. + + Args: + tensor (Tensor): tensor to be saved. + index_file (CheckpointIndexFile): path to the checkpoint file. + size_per_shard (int): size per shard in MB. + """ + root_path = index_file.root_path + output_root_path = root_path.joinpath('dtensor') + + # create directory + output_root_path.mkdir(exist_ok=True) + + # save tensor to this directory + # TODO(YuliangLiu): get index of the tensor shard + # e.g. index = + index = 0 + + # save tensor to file + ckpt_file_name = generate_dtensor_file_name(name, index, use_safetensors) + ckpt_file_path = output_root_path.joinpath(ckpt_file_name) + + # dtensor ckpt file always contains only one tensor + state_dict = {name: tensor} + save_state_dict(state_dict, str(ckpt_file_path), use_safetensors) + + # update the weight map + # * means all shards + ckpt_file_name_in_weight_map = 'dtensor/' + generate_dtensor_file_name(name, '*', use_safetensors) + index_file.append_weight_map(name, ckpt_file_name_in_weight_map) + + +def get_checkpoint_file_suffix(use_safetensors: bool) -> str: + """ + Get checkpoint file suffix. + + Args: + use_safetensors (bool): whether to use safetensors to save the checkpoint. + + Returns: + str: checkpoint file suffix. + """ + if use_safetensors: + return '.safetensors' + else: + return '.bin' + + +def generate_checkpoint_shard_file_name(index: int, + total_number: int, + use_safetensors: bool, + prefix: str = None) -> str: + """ + Generate checkpoint shard file name. + + Args: + index (int): index of the shard. + total_number (int): total number of shards. + use_safetensors (bool): whether to use safetensors to save the checkpoint. + prefix (str): prefix of the shard file name. Default: None. + + Returns: + str: checkpoint shard file name. + """ + suffix = get_checkpoint_file_suffix(use_safetensors) + + if prefix is None: + return f"{index:05d}-of-{total_number:05d}.{suffix}" + else: + return f"{prefix}-{index:05d}-of-{total_number:05d}.{suffix}" + + +def generate_dtensor_file_name(param_name: str, index: int, use_safetensors: bool) -> str: + """ + Generate dtensor file name. + + Args: + param_name (str): name of the distributed parameter. + index (int): index of the shard. + use_safetensors (bool): whether to use safetensors to save the checkpoint. + + Returns: + str: dtensor file name. + """ + suffix = get_checkpoint_file_suffix(use_safetensors) + return f'{param_name}.{index}.{suffix}' + + +def save_state_dict_as_shard( + state_dict: dict, + checkpoint_path: str, + index: int, + total_number: int, + use_safetensors: bool, + prefix: str = None, +) -> None: + """ + Save state dict as shard. + + Args: + state_dict (dict): state dict. + checkpoint_path (str): path to the checkpoint file. + index (int): index of the shard. + total_number (int): total number of shards. + prefix (str): prefix of the shard file name. + use_safetensors (bool): whether to use safetensors to save the checkpoint. + """ + # generate the shard name + shard_file_name = generate_checkpoint_shard_file_name(index, total_number, use_safetensors, prefix) + shard_file_path = Path(checkpoint_path).joinpath(shard_file_name).absolute() + + # save the shard + save_state_dict(state_dict, str(shard_file_path), use_safetensors) + + +# ======================================== +# Helper functions for loading state dict +# ======================================== + + +def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]: + """ + Check whether the checkpoint has an index file. + + Args: + checkpoint_path (str): path to the checkpoint. + + Returns: + Tuple[bool, Optional[Path]]: a tuple of (has_index_file, index_file_path) + """ + checkpoint_path = Path(checkpoint_path) + if checkpoint_path.is_file(): + # check if it is .index.json + reg = re.compile("(.*?).index((\..*)?).json") + if reg.fullmatch(checkpoint_path.name) is not None: + return True, checkpoint_path + else: + return False, None + elif checkpoint_path.is_dir(): + # check if there is only one a file ending with .index.json in this directory + index_files = list(checkpoint_path.glob('*.index.*json')) + + # if we found a .index.json file, make sure there is only one + if len(index_files) > 0: + assert len( + index_files + ) == 1, f'Expected to find one .index.json file in {checkpoint_path}, but found {len(index_files)}' + + if len(index_files) == 1: + return True, index_files[0] + else: + return False, None + + +def load_state_dict(checkpoint_file_path: Path): + """ + Load state dict from checkpoint. + + Args: + checkpoint_file_path (Path): path to the checkpoint file. + + Returns: + dict: state dict. + """ + + assert not is_dtensor_checkpoint(checkpoint_file_path), \ + f'Cannot load state dict from dtensor checkpoint {checkpoint_file_path}, you should convert the distributed tensors to gathered tensors with our CLI offline.' + + if is_safetensor_checkpoint(checkpoint_file_path): + assert is_safetensors_available(), \ + f'Cannot load state dict from safetensor checkpoint {checkpoint_file_path}, because safetensors is not available. Please install safetensors first with pip install safetensors.' + # load with safetensors + from safetensors import safe_open + state_dict = {} + with safe_open(checkpoint_file_path, framework="pt", device="cpu") as f: + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + return state_dict + + else: + # load with torch + return torch.load(checkpoint_file_path) + + + +def add_variant(weights_name: str, variant: Optional[str] = None) -> str: + if variant is not None and len(variant) > 0: + splits = weights_name.split(".") + splits = splits[:-1] + [variant] + splits[-1:] + weights_name = ".".join(splits) + + return weights_name diff --git a/colossalai/cli/__init__.py b/colossalai/cli/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..658e35e4c72e77f7b3161bf86b0c9600a80562e5 --- /dev/null +++ b/colossalai/cli/__init__.py @@ -0,0 +1,3 @@ +from .cli import cli + +__all__ = ['cli'] diff --git a/colossalai/cli/benchmark/__init__.py b/colossalai/cli/benchmark/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..618ff8c61dd41ec83d567e7cd9103f2aa9921846 --- /dev/null +++ b/colossalai/cli/benchmark/__init__.py @@ -0,0 +1,28 @@ +import click + +from colossalai.context import Config + +from .benchmark import run_benchmark +from .utils import * + +__all__ = ['benchmark'] + + +@click.command() +@click.option("-g", "--gpus", type=int, default=None, help="Total number of devices to use.") +@click.option("-b", "--batch_size", type=int, default=8, help="Batch size of the input tensor.") +@click.option("-s", "--seq_len", type=int, default=512, help="Sequence length of the input tensor.") +@click.option("-d", "--dimension", type=int, default=1024, help="Hidden dimension of the input tensor.") +@click.option("-w", "--warmup_steps", type=int, default=10, help="The number of warmup steps.") +@click.option("-p", "--profile_steps", type=int, default=50, help="The number of profiling steps.") +@click.option("-l", "--layers", type=int, default=2) +@click.option("-m", + "--model", + type=click.Choice(['mlp'], case_sensitive=False), + default='mlp', + help="Select the model to benchmark, currently only supports MLP") +def benchmark(gpus: int, batch_size: int, seq_len: int, dimension: int, warmup_steps: int, profile_steps: int, + layers: int, model: str): + args_dict = locals() + args = Config(args_dict) + run_benchmark(args) diff --git a/colossalai/cli/benchmark/benchmark.py b/colossalai/cli/benchmark/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..97a9f45722dd6a4c1e316d1e91f27439797ae17a --- /dev/null +++ b/colossalai/cli/benchmark/benchmark.py @@ -0,0 +1,105 @@ +from functools import partial +from typing import Dict, List + +import click +import torch.multiprocessing as mp + +import colossalai +from colossalai.cli.benchmark.utils import find_all_configs, get_batch_data, profile_model +from colossalai.context import Config +from colossalai.context.random import reset_seeds +from colossalai.core import global_context as gpc +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.testing import free_port +from colossalai.utils import MultiTimer + +from .models import MLP + + +def run_benchmark(args: Config) -> None: + """ + Run benchmarking with torch.multiprocessing. + """ + + # sanity checks + if args.gpus is None: + click.echo("Error: --num_gpus is not given") + exit() + if args.gpus <= 1: + click.echo("Warning: tensor parallel will be activated with at least 2 devices.") + + click.echo("=== Benchmarking Parameters ===") + for k, v in args.items(): + click.echo(f'{k}: {v}') + click.echo('') + + config_list = find_all_configs(args.gpus) + + avail_ports = [free_port() for _ in range(len(config_list))] + run_func = partial(run_dist_profiling, + world_size=args.gpus, + port_list=avail_ports, + config_list=config_list, + hyperparams=args) + mp.spawn(run_func, nprocs=args.gpus) + + +def run_dist_profiling(rank: int, world_size: int, port_list: List[int], config_list: List[Dict], + hyperparams: Config) -> None: + """ + A function executed for profiling, this function should be spawn by torch.multiprocessing. + + Args: + rank (int): rank of the process + world_size (int): the number of processes + port_list (List[int]): a list of free ports for initializing distributed networks + config_list (List[Dict]): a list of configuration + hyperparams (Config): the hyperparameters given by the user + + """ + + # disable logging for clean output + disable_existing_loggers() + logger = get_dist_logger() + logger.set_level('WARNING') + + for config, port in zip(config_list, port_list): + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + timer = MultiTimer() + + # 1D parallel should be skipped if in_features or out_features is not able to be divided exactly by 1D parallel size. + if config.parallel.tensor.mode == '1d' and hyperparams.dimension % config.parallel.tensor.size != 0: + click.echo( + "1D parallel will be skipped because in_features or out_features is not able to be divided exactly by 1D parallel size." + ) + continue + + if hyperparams.model == 'mlp': + model = MLP(dim=hyperparams.dimension, layers=hyperparams.layers) + else: + if gpc.get_global_rank() == 0: + click.echo("Error: Invalid argument for --model") + exit() + + data_func = partial(get_batch_data, + dim=hyperparams.dimension, + batch_size=hyperparams.batch_size, + seq_length=hyperparams.seq_len, + mode=config.parallel.tensor.mode) + + fwd_time, bwd_time, max_allocated, max_cached = profile_model(model=model, + warmup_steps=hyperparams.warmup_steps, + profile_steps=hyperparams.profile_steps, + data_func=data_func, + timer=timer) + + gpc.destroy() + reset_seeds() + + if gpc.get_global_rank() == 0: + config_str = ', '.join([f'{k}: {v}' for k, v in config.parallel.tensor.items()]) + click.echo(f"=== {config_str} ===") + click.echo(f"Average forward time: {fwd_time}") + click.echo(f"Average backward time: {bwd_time}") + click.echo(f"Max allocated GPU memory: {max_allocated}") + click.echo(f"Max cached GPU memory: {max_cached}\n") diff --git a/colossalai/cli/benchmark/models.py b/colossalai/cli/benchmark/models.py new file mode 100644 index 0000000000000000000000000000000000000000..f8fd1c41a059806891713340f2ea4931ec9726f2 --- /dev/null +++ b/colossalai/cli/benchmark/models.py @@ -0,0 +1,18 @@ +import torch + +import colossalai.nn as col_nn + + +class MLP(torch.nn.Module): + + def __init__(self, dim: int, layers: int): + super().__init__() + self.layers = torch.nn.ModuleList() + + for _ in range(layers): + self.layers.append(col_nn.Linear(dim, dim)) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x diff --git a/colossalai/cli/benchmark/utils.py b/colossalai/cli/benchmark/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..825b795f21f680bcb2e2ea5eee4b328c2e1777db --- /dev/null +++ b/colossalai/cli/benchmark/utils.py @@ -0,0 +1,158 @@ +import math +import time +import torch + +from colossalai.utils import MultiTimer +from colossalai.context import ParallelMode, Config +from typing import List, Dict, Tuple, Callable + + +def get_time_stamp() -> int: + """ + Return the time stamp for profiling. + + Returns: + time_stamp (int): the time given by time.time() + """ + + torch.cuda.synchronize() + time_stamp = time.time() + return time_stamp + + +def get_memory_states() -> Tuple[float]: + """ + Return the memory statistics. + + Returns: + max_allocated (float): the allocated CUDA memory + max_cached (float): the cached CUDA memory + """ + + max_allocated = torch.cuda.max_memory_allocated() / (1024**3) + max_cached = torch.cuda.max_memory_reserved() / (1024**3) + torch.cuda.reset_peak_memory_stats() + torch.cuda.empty_cache() + return max_allocated, max_cached + + +def find_all_configs(device_cnt: int) -> List[Dict]: + """ + Find all possible configurations for tensor parallelism + + Args: + device_cnt (int): the number of devices + + Returns: + config_list (List[Dict]): a list of configurations + """ + + def _is_square(num): + # 2D parallel should be implemented with at least 2 devices. + if num <= 1: + return False + return math.floor(math.sqrt(num))**2 == num + + def _is_cube(num): + # 3D parallel should be implemented with at least 2 devices. + if num <= 1: + return False + return math.floor(num**(1. / 3.))**3 == num + + config_list = [] + + # add non-parallel config + config = dict(parallel=dict(tensor=dict(size=device_cnt, mode=None))) + config_list.append(config) + + # add 1D config + config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='1d'))) + config_list.append(config) + + # add 2D config only if device_cnt is a square + if _is_square(device_cnt): + config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='2d'))) + config_list.append(config) + + # check for 2.5D + # iterate over depth + for depth in range(1, device_cnt): + if device_cnt % depth == 0 and _is_square(device_cnt // depth): + config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='2.5d', depth=depth))) + config_list.append(config) + + # check for 3D if device_cnt is a cube + if _is_cube(device_cnt): + config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='3d'))) + config_list.append(config) + + config_list = [Config(cfg) for cfg in config_list] + return config_list + + +def profile_model(model: torch.nn.Module, warmup_steps: int, profile_steps: int, data_func: Callable, + timer: MultiTimer) -> Tuple[float]: + """ + Profile the forward and backward of a model + + Args: + model (torch.nn.Module): a PyTorch model + warmup_steps (int): the number of steps for warmup + profile_steps (int): the number of steps for profiling + data_func (Callable): a function to generate random data + timer (colossalai.utils.Multitimer): a timer instance for time recording + + Returns: + fwd_time (float): the average forward time taken by forward pass in second + bwd_time (float): the average backward time taken by forward pass in second + max_allocated (float): the maximum GPU memory allocated in GB + max_cached (float): the maximum GPU memory cached in GB + """ + + def _run_step(data): + timer.start('forward') + out = model(data) + timer.stop('forward', keep_in_history=True) + timer.start('backward') + out.mean().backward() + timer.stop('backward', keep_in_history=True) + + data_list = [data_func() for _ in range(warmup_steps)] + for data in data_list: + _run_step(data) + timer.reset('forward') + timer.reset('backward') + + for _ in range(profile_steps): + data = data_func() + _run_step(data) + + max_allocated, max_cached = get_memory_states() + fwd_time = timer.get_timer('forward').get_history_mean() + bwd_time = timer.get_timer('backward').get_history_mean() + return fwd_time, bwd_time, max_allocated, max_cached + + +def get_batch_data(dim: int, batch_size: int, seq_length: int, mode: ParallelMode) -> torch.Tensor: + """ + Return a random data of shape (batch_size, seq_length, dim) for profiling. + + Args: + dim (int): hidden size + batch_size (int): the number of data samples + seq_length (int): the number of tokens + mode (ParallelMode): Colossal-AI ParallelMode enum + + Returns: + data (torch.Tensor): random data + """ + + if mode in ['2d', '2.5d']: + batch_size = batch_size // 2 + dim = dim // 2 + elif mode == '3d': + batch_size = batch_size // 4 + dim = dim // 2 + + data = torch.rand(batch_size, seq_length, dim).cuda() + return data diff --git a/colossalai/cli/check/__init__.py b/colossalai/cli/check/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a86b32bb6a181a7c75bfa6682c2f7514169d9a7f --- /dev/null +++ b/colossalai/cli/check/__init__.py @@ -0,0 +1,13 @@ +import click +from .check_installation import check_installation + +__all__ = ['check'] + + +@click.command(help="Check if Colossal-AI is correct based on the given option") +@click.option('-i', '--installation', is_flag=True, help="Check if Colossal-AI is built correctly") +def check(installation): + if installation: + check_installation() + return + click.echo("No option is given") diff --git a/colossalai/cli/check/check_installation.py b/colossalai/cli/check/check_installation.py new file mode 100644 index 0000000000000000000000000000000000000000..cb3dbbc09301012aa662264f241e1fce89470d39 --- /dev/null +++ b/colossalai/cli/check/check_installation.py @@ -0,0 +1,214 @@ +import subprocess + +import click +import torch +from torch.utils.cpp_extension import CUDA_HOME + +import colossalai + + +def to_click_output(val): + # installation check output to understandable symbols for readability + VAL_TO_SYMBOL = {True: u'\u2713', False: 'x', None: 'N/A'} + + if val in VAL_TO_SYMBOL: + return VAL_TO_SYMBOL[val] + else: + return val + + +def check_installation(): + """ + This function will check the installation of colossalai, specifically, the version compatibility of + colossalai, pytorch and cuda. + + Example: + ```text + ``` + + Returns: A table of installation information. + """ + found_aot_cuda_ext = _check_aot_built_cuda_extension_installed() + cuda_version = _check_cuda_version() + torch_version, torch_cuda_version = _check_torch_version() + colossalai_verison, prebuilt_torch_version_required, prebuilt_cuda_version_required = _parse_colossalai_version() + + # if cuda_version is None, that means either + # CUDA_HOME is not found, thus cannot compare the version compatibility + if not cuda_version: + sys_torch_cuda_compatibility = None + else: + sys_torch_cuda_compatibility = _is_compatible([cuda_version, torch_cuda_version]) + + # if cuda_version or cuda_version_required is None, that means either + # CUDA_HOME is not found or AOT compilation is not enabled + # thus, there is no need to compare the version compatibility at all + if not cuda_version or not prebuilt_cuda_version_required: + sys_colossalai_cuda_compatibility = None + else: + sys_colossalai_cuda_compatibility = _is_compatible([cuda_version, prebuilt_cuda_version_required]) + + # if torch_version_required is None, that means AOT compilation is not enabled + # thus there is no need to compare the versions + if prebuilt_torch_version_required is None: + torch_compatibility = None + else: + torch_compatibility = _is_compatible([torch_version, prebuilt_torch_version_required]) + + click.echo(f'#### Installation Report ####') + click.echo(f'\n------------ Environment ------------') + click.echo(f"Colossal-AI version: {to_click_output(colossalai_verison)}") + click.echo(f"PyTorch version: {to_click_output(torch_version)}") + click.echo(f"System CUDA version: {to_click_output(cuda_version)}") + click.echo(f"CUDA version required by PyTorch: {to_click_output(torch_cuda_version)}") + click.echo("") + click.echo(f"Note:") + click.echo(f"1. The table above checks the versions of the libraries/tools in the current environment") + click.echo(f"2. If the System CUDA version is N/A, you can set the CUDA_HOME environment variable to locate it") + click.echo( + f"3. If the CUDA version required by PyTorch is N/A, you probably did not install a CUDA-compatible PyTorch. This value is give by torch.version.cuda and you can go to https://pytorch.org/get-started/locally/ to download the correct version." + ) + + click.echo(f'\n------------ CUDA Extensions AOT Compilation ------------') + click.echo(f"Found AOT CUDA Extension: {to_click_output(found_aot_cuda_ext)}") + click.echo(f"PyTorch version used for AOT compilation: {to_click_output(prebuilt_torch_version_required)}") + click.echo(f"CUDA version used for AOT compilation: {to_click_output(prebuilt_cuda_version_required)}") + click.echo("") + click.echo(f"Note:") + click.echo( + f"1. AOT (ahead-of-time) compilation of the CUDA kernels occurs during installation when the environment variable CUDA_EXT=1 is set" + ) + click.echo(f"2. If AOT compilation is not enabled, stay calm as the CUDA kernels can still be built during runtime") + + click.echo(f"\n------------ Compatibility ------------") + click.echo(f'PyTorch version match: {to_click_output(torch_compatibility)}') + click.echo(f"System and PyTorch CUDA version match: {to_click_output(sys_torch_cuda_compatibility)}") + click.echo(f"System and Colossal-AI CUDA version match: {to_click_output(sys_colossalai_cuda_compatibility)}") + click.echo(f"") + click.echo(f"Note:") + click.echo(f"1. The table above checks the version compatibility of the libraries/tools in the current environment") + click.echo( + f" - PyTorch version mismatch: whether the PyTorch version in the current environment is compatible with the PyTorch version used for AOT compilation" + ) + click.echo( + f" - System and PyTorch CUDA version match: whether the CUDA version in the current environment is compatible with the CUDA version required by PyTorch" + ) + click.echo( + f" - System and Colossal-AI CUDA version match: whether the CUDA version in the current environment is compatible with the CUDA version used for AOT compilation" + ) + + +def _is_compatible(versions): + """ + Compare the list of versions and return whether they are compatible. + """ + if None in versions: + return False + + # split version into [major, minor, patch] + versions = [version.split('.') for version in versions] + + for version in versions: + if len(version) == 2: + # x means unknown + version.append('x') + + for idx, version_values in enumerate(zip(*versions)): + equal = len(set(version_values)) == 1 + + if idx in [0, 1] and not equal: + return False + elif idx == 1: + return True + else: + continue + + +def _parse_colossalai_version(): + """ + Get the Colossal-AI version information. + + Returns: + colossalai_version: Colossal-AI version. + torch_version_for_aot_build: PyTorch version used for AOT compilation of CUDA kernels. + cuda_version_for_aot_build: CUDA version used for AOT compilation of CUDA kernels. + """ + # colossalai version can be in two formats + # 1. X.X.X+torchX.XXcuXX.X (when colossalai is installed with CUDA extensions) + # 2. X.X.X (when colossalai is not installed with CUDA extensions) + # where X represents an integer. + colossalai_verison = colossalai.__version__.split('+')[0] + + try: + torch_version_for_aot_build = colossalai.__version__.split('torch')[1].split('cu')[0] + cuda_version_for_aot_build = colossalai.__version__.split('cu')[1] + except: + torch_version_for_aot_build = None + cuda_version_for_aot_build = None + return colossalai_verison, torch_version_for_aot_build, cuda_version_for_aot_build + + +def _check_aot_built_cuda_extension_installed(): + """ + According to `op_builder/README.md`, the CUDA extension can be built with either + AOT (ahead-of-time) or JIT (just-in-time) compilation. + AOT compilation will build CUDA extensions to `colossalai._C` during installation. + JIT (just-in-time) compilation will build CUDA extensions to `~/.cache/colossalai/torch_extensions` during runtime. + """ + try: + import colossalai._C.fused_optim + found_aot_cuda_ext = True + except ImportError: + found_aot_cuda_ext = False + return found_aot_cuda_ext + + +def _check_torch_version(): + """ + Get the PyTorch version information. + + Returns: + torch_version: PyTorch version. + torch_cuda_version: CUDA version required by PyTorch. + """ + # get torch version + # torch version can be of two formats + # - 1.13.1+cu113 + # - 1.13.1.devxxx + torch_version = torch.__version__.split('+')[0] + torch_version = '.'.join(torch_version.split('.')[:3]) + + # get cuda version in pytorch build + try: + torch_cuda_major = torch.version.cuda.split(".")[0] + torch_cuda_minor = torch.version.cuda.split(".")[1] + torch_cuda_version = f'{torch_cuda_major}.{torch_cuda_minor}' + except: + torch_cuda_version = None + + return torch_version, torch_cuda_version + + +def _check_cuda_version(): + """ + Get the CUDA version information. + + Returns: + cuda_version: CUDA version found on the system. + """ + + # get cuda version + if CUDA_HOME is None: + cuda_version = CUDA_HOME + else: + try: + raw_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"], universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + release = output[release_idx].split(".") + bare_metal_major = release[0] + bare_metal_minor = release[1][0] + cuda_version = f'{bare_metal_major}.{bare_metal_minor}' + except: + cuda_version = None + return cuda_version diff --git a/colossalai/cli/cli.py b/colossalai/cli/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..a94e1150e49fc00a210c20b32a9bfc85eda66aa6 --- /dev/null +++ b/colossalai/cli/cli.py @@ -0,0 +1,25 @@ +import click + +from .benchmark import benchmark +from .check import check +from .launcher import run + + +class Arguments(): + + def __init__(self, arg_dict): + for k, v in arg_dict.items(): + self.__dict__[k] = v + + +@click.group() +def cli(): + pass + + +cli.add_command(run) +cli.add_command(check) +cli.add_command(benchmark) + +if __name__ == '__main__': + cli() diff --git a/colossalai/cli/launcher/__init__.py b/colossalai/cli/launcher/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8d9ec147d401a2e5d055852e661e189985b6db6e --- /dev/null +++ b/colossalai/cli/launcher/__init__.py @@ -0,0 +1,87 @@ +import click + +from colossalai.context import Config + +from .run import launch_multi_processes + + +@click.command(help="Launch distributed training on a single node or multiple nodes", + context_settings=dict(ignore_unknown_options=True)) +@click.option("-H", + "-host", + "--host", + type=str, + default=None, + help="the list of hostnames to launch in the format ,") +@click.option( + "--hostfile", + type=str, + default=None, + help="Hostfile path that defines the device pool available to the job, each line in the file is a hostname") +@click.option("--include", + type=str, + default=None, + help="Specify computing devices to use during execution. String format is ,," + " only effective when used with --hostfile.") +@click.option( + "--exclude", + type=str, + default=None, + help= + "Specify computing devices to NOT use during execution. Mutually exclusive with --include. Formatting is the same as --includ," + " only effective when used with --hostfile.") +@click.option("--num_nodes", + type=int, + default=-1, + help="Total number of worker nodes to use, only effective when used with --hostfile.") +@click.option("--nproc_per_node", type=int, default=None, help="Number of GPUs to use on each node.") +@click.option("--master_port", + type=int, + default=29500, + help="(optional) Port used by PyTorch distributed for communication during distributed training.") +@click.option("--master_addr", + type=str, + default="127.0.0.1", + help="(optional) IP address of node 0, will be inferred via 'hostname -I' if not specified.") +@click.option( + "--extra_launch_args", + type=str, + default=None, + help= + "Set additional torch distributed launcher arguments such as --standalone. The format is --extra_launch_args arg1=1,arg2=2. " + "This will be converted to --arg1=1 --arg2=2 during execution") +@click.option("--ssh-port", type=int, default=None, help="(optional) the port used for ssh connection") +@click.argument("user_script", type=str) +@click.argument('user_args', nargs=-1) +def run(host: str, hostfile: str, num_nodes: int, nproc_per_node: int, include: str, exclude: str, master_addr: str, + master_port: int, extra_launch_args: str, ssh_port: int, user_script: str, user_args: str) -> None: + """ + To launch multiple processes on a single node or multiple nodes via command line. + + Usage:: + # run with 4 GPUs on the current node use default port 29500 + colossalai run --nprocs_per_node 4 train.py + + # run with 2 GPUs on the current node at port 29550 + colossalai run --nprocs_per_node 4 --master_port 29550 train.py + + # run on two nodes + colossalai run --host , --master_addr host1 --nprocs_per_node 4 train.py + + # run with hostfile + colossalai run --hostfile --master_addr --nprocs_per_node 4 train.py + + # run with hostfile with only included hosts + colossalai run --hostfile --master_addr host1 --include host1,host2 --nprocs_per_node 4 train.py + + # run with hostfile excluding the hosts selected + colossalai run --hostfile --master_addr host1 --exclude host2 --nprocs_per_node 4 train.py + """ + if not user_script.endswith('.py'): + click.echo(f'Error: invalid Python file {user_script}. Did you use a wrong option? Try colossalai run --help') + exit() + + args_dict = locals() + args = Config(args_dict) + args.user_args = list(args.user_args) + launch_multi_processes(args) diff --git a/colossalai/cli/launcher/hostinfo.py b/colossalai/cli/launcher/hostinfo.py new file mode 100644 index 0000000000000000000000000000000000000000..065cbc37101f9705319d96120779edcbbbf6dde9 --- /dev/null +++ b/colossalai/cli/launcher/hostinfo.py @@ -0,0 +1,127 @@ +import socket +from typing import List + + +class HostInfo: + """ + A data class to store host connection-related data. + + Args: + hostname (str): name or IP address of the host + port (str): the port for ssh connection + """ + + def __init__( + self, + hostname: str, + port: str = None, + ): + self.hostname = hostname + self.port = port + self.is_local_host = HostInfo.is_host_localhost(hostname, port) + + @staticmethod + def is_host_localhost(hostname: str, port: str = None) -> None: + """ + Check if the host refers to the local machine. + + Args: + hostname (str): name or IP address of the host + port (str): the port for ssh connection + + Returns: + bool: True if it is local, False otherwise + """ + + if port is None: + port = 22 # no port specified, lets just use the ssh port + + # socket.getfqdn("127.0.0.1") does not return localhost + # on some users' machines + # thus, we directly return True if hostname is locahost, 127.0.0.1 or 0.0.0.0 + if hostname in ("localhost", "127.0.0.1", "0.0.0.0"): + return True + + hostname = socket.getfqdn(hostname) + localhost = socket.gethostname() + localaddrs = socket.getaddrinfo(localhost, port) + targetaddrs = socket.getaddrinfo(hostname, port) + for (family, socktype, proto, canonname, sockaddr) in localaddrs: + for (rfamily, rsocktype, rproto, rcanonname, rsockaddr) in targetaddrs: + if rsockaddr[0] == sockaddr[0]: + return True + return False + + def __str__(self): + return f'hostname: {self.hostname}, port: {self.port}' + + def __repr__(self): + return self.__str__() + + +class HostInfoList: + """ + A data class to store a list of HostInfo objects. + """ + + def __init__(self): + self.hostinfo_list = [] + + def append(self, hostinfo: HostInfo) -> None: + """ + Add an HostInfo object to the list. + + Args: + hostinfo (HostInfo): host information + """ + + self.hostinfo_list.append(hostinfo) + + def remove(self, hostname: str) -> None: + """ + Add an HostInfo object to the list. + + Args: + hostname (str): the name of the host + """ + + hostinfo = self.get_hostinfo(hostname) + self.hostinfo_list.remove(hostinfo) + + def get_hostinfo(self, hostname: str) -> HostInfo: + """ + Return the HostInfo object which matches with the hostname. + + Args: + hostname (str): the name of the host + + Returns: + hostinfo (HostInfo): the HostInfo object which matches with the hostname + """ + + for hostinfo in self.hostinfo_list: + if hostinfo.hostname == hostname: + return hostinfo + + raise Exception(f"Hostname {hostname} is not found") + + def has(self, hostname: str) -> bool: + """ + Check if the hostname has been added. + + Args: + hostname (str): the name of the host + + Returns: + bool: True if added, False otherwise + """ + for hostinfo in self.hostinfo_list: + if hostinfo.hostname == hostname: + return True + return False + + def __iter__(self): + return iter(self.hostinfo_list) + + def __len__(self): + return len(self.hostinfo_list) diff --git a/colossalai/cli/launcher/multinode_runner.py b/colossalai/cli/launcher/multinode_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..a51e1e371f13df492c16ac5b554d02ba6491d65b --- /dev/null +++ b/colossalai/cli/launcher/multinode_runner.py @@ -0,0 +1,123 @@ +from multiprocessing import Pipe, Process +from multiprocessing import connection as mp_connection + +import click +import fabric + +from .hostinfo import HostInfo, HostInfoList + + +def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Connection, + send_conn: mp_connection.Connection, env: dict) -> None: + """ + Use fabric connection to execute command on local or remote hosts. + + Args: + hostinfo (HostInfo): host information + workdir (str): the directory to execute the command + recv_conn (multiprocessing.connection.Connection): receive messages from the master sender + send_conn (multiprocessing.connection.Connection): send messages to the master receiver + env (dict): a dictionary for environment variables + """ + + fab_conn = fabric.Connection(hostinfo.hostname, port=hostinfo.port) + finish = False + env_msg = ' '.join([f'{k}=\"{v}\"' for k, v in env.items()]) + + # keep listening until exit + while not finish: + # receive cmd + cmds = recv_conn.recv() + + if cmds == 'exit': + # exit from the loop + finish = True + break + else: + # execute the commands + try: + # cd to execute directory + with fab_conn.cd(workdir): + # propagate the runtime environment + with fab_conn.prefix(f"export {env_msg}"): + if hostinfo.is_local_host: + # execute on the local machine + fab_conn.local(cmds, hide=False) + else: + # execute on the remote machine + fab_conn.run(cmds, hide=False) + send_conn.send('success') + except Exception as e: + click.echo( + f"Error: failed to run {cmds} on {hostinfo.hostname}, is localhost: {hostinfo.is_local_host}, exception: {e}" + ) + send_conn.send('failure') + + # shutdown + send_conn.send("finish") + fab_conn.close() + + +class MultiNodeRunner: + """ + A runner to execute commands on an array of machines. This runner + is inspired by Nezha (https://github.com/zhuzilin/NeZha). + """ + + def __init__(self): + self.processes = {} + self.master_send_conns = {} + self.master_recv_conns = {} + + def connect(self, host_info_list: HostInfoList, workdir: str, env: dict) -> None: + """ + Establish connections to a list of hosts + + Args: + host_info_list (HostInfoList): a list of HostInfo objects + workdir (str): the directory where command is executed + env (dict): environment variables to propagate to hosts + """ + for hostinfo in host_info_list: + master_send_conn, worker_recv_conn = Pipe() + master_recv_conn, worker_send_conn = Pipe() + p = Process(target=run_on_host, args=(hostinfo, workdir, worker_recv_conn, worker_send_conn, env)) + p.start() + self.processes[hostinfo.hostname] = p + self.master_recv_conns[hostinfo.hostname] = master_recv_conn + self.master_send_conns[hostinfo.hostname] = master_send_conn + + def send(self, hostinfo: HostInfo, cmd: str) -> None: + """ + Send a command to a local/remote host. + + Args: + hostinfo (HostInfo): host information + cmd (str): the command to execute + """ + + assert hostinfo.hostname in self.master_send_conns, \ + f'{hostinfo} is not found in the current connections' + conn = self.master_send_conns[hostinfo.hostname] + conn.send(cmd) + + def stop_all(self) -> None: + """ + Stop connections to all hosts. + """ + + for hostname, conn in self.master_send_conns.items(): + conn.send('exit') + + def recv_from_all(self) -> dict: + """ + Receive messages from all hosts + + Returns: + msg_from_node (dict): a dictionry which contains messages from each node + """ + + msg_from_node = dict() + for hostname, conn in self.master_recv_conns.items(): + msg_from_node[hostname] = conn.recv() + return msg_from_node diff --git a/colossalai/cli/launcher/run.py b/colossalai/cli/launcher/run.py new file mode 100644 index 0000000000000000000000000000000000000000..6411b4302e95d25efc5a55da1b523b76ee6ee1e3 --- /dev/null +++ b/colossalai/cli/launcher/run.py @@ -0,0 +1,311 @@ +import os +import sys +from typing import List + +import click +import torch +from packaging import version + +from colossalai.context import Config + +from .hostinfo import HostInfo, HostInfoList +from .multinode_runner import MultiNodeRunner + +# Constants that define our syntax +NODE_SEP = ',' + + +def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList: + """ + Parse the hostfile to obtain a list of hosts. + + A hostfile should look like: + worker-0 + worker-1 + worker-2 + ... + + Args: + hostfile_path (str): the path to the hostfile + ssh_port (int): the port to connect to the host + """ + + if not os.path.isfile(hostfile_path): + click.echo(f"Error: Unable to find the hostfile, no such file: {hostfile_path}") + exit() + + with open(hostfile_path, 'r') as fd: + device_pool = HostInfoList() + + for line in fd.readlines(): + line = line.strip() + if line == '': + # skip empty lines + continue + + # build the HostInfo object + hostname = line.strip() + hostinfo = HostInfo(hostname=hostname, port=ssh_port) + + if device_pool.has(hostname): + click.echo(f"Error: found duplicate host {hostname} in the hostfile") + exit() + + device_pool.append(hostinfo) + return device_pool + + +def parse_device_filter(device_pool: HostInfoList, include_str=None, exclude_str=None) -> HostInfoList: + '''Parse an inclusion or exclusion string and filter a hostfile dictionary. + + Examples: + include_str="worker-0,worker-1" will execute jobs only on worker-0 and worker-1. + exclude_str="worker-1" will use all available devices except worker-1. + + Args: + device_pool (HostInfoList): a list of HostInfo objects + include_str (str): --include option passed by user, default None + exclude_str (str): --exclude option passed by user, default None + + Returns: + filtered_hosts (HostInfoList): filtered hosts after inclusion/exclusion + ''' + + # Ensure include/exclude are mutually exclusive + if include_str and exclude_str: + click.echo("--include and --exclude are mutually exclusive, only one can be used") + exit() + + # no-op + if include_str is None and exclude_str is None: + return device_pool + + # Either build from scratch or remove items + if include_str: + parse_str = include_str + filtered_hosts = HostInfoList() + elif exclude_str: + parse_str = exclude_str + filtered_hosts = device_pool + + # foreach node in the list + for node_config in parse_str.split(NODE_SEP): + hostname = node_config + hostinfo = device_pool.get_hostinfo(hostname) + # sanity check hostname + if not device_pool.has(hostname): + click.echo(f"Error: Hostname '{hostname}' not found in hostfile") + exit() + + if include_str: + filtered_hosts.append(hostinfo) + elif exclude_str: + filtered_hosts.remove(hostname) + + return filtered_hosts + + +def get_launch_command( + master_addr: str, + master_port: int, + nproc_per_node: int, + user_script: str, + user_args: List[str], + node_rank: int, + num_nodes: int, + extra_launch_args: str = None, +) -> str: + """ + Generate a command for distributed training. + + Args: + master_addr (str): the host of the master node + master_port (str): the port of the master node + nproc_per_node (str): the number of processes to launch on each node + user_script (str): the user Python file + user_args (str): the arguments for the user script + node_rank (int): the unique ID for the node + num_nodes (int): the number of nodes to execute jobs + + Returns: + cmd (str): the command the start distributed training + """ + + def _arg_dict_to_list(arg_dict): + ret = [] + + for k, v in arg_dict.items(): + if v: + ret.append(f'--{k}={v}') + else: + ret.append(f'--{k}') + return ret + + if extra_launch_args: + extra_launch_args_dict = dict() + for arg in extra_launch_args.split(','): + if '=' in arg: + k, v = arg.split('=') + extra_launch_args_dict[k] = v + else: + extra_launch_args_dict[arg] = None + extra_launch_args = extra_launch_args_dict + else: + extra_launch_args = dict() + + torch_version = version.parse(torch.__version__) + assert torch_version.major == 1 + + if torch_version.minor < 9: + cmd = [ + sys.executable, "-m", "torch.distributed.launch", f"--nproc_per_node={nproc_per_node}", + f"--master_addr={master_addr}", f"--master_port={master_port}", f"--nnodes={num_nodes}", + f"--node_rank={node_rank}" + ] + else: + # extra launch args for torch distributed launcher with torch >= 1.9 + default_torchrun_rdzv_args = dict(rdzv_backend="c10d", + rdzv_endpoint=f"{master_addr}:{master_port}", + rdzv_id="colossalai-default-job") + + # update rdzv arguments + for key in default_torchrun_rdzv_args.keys(): + if key in extra_launch_args: + value = extra_launch_args.pop(key) + default_torchrun_rdzv_args[key] = value + + if torch_version.minor < 10: + cmd = [ + sys.executable, "-m", "torch.distributed.run", f"--nproc_per_node={nproc_per_node}", + f"--nnodes={num_nodes}", f"--node_rank={node_rank}" + ] + else: + cmd = [ + "torchrun", f"--nproc_per_node={nproc_per_node}", f"--nnodes={num_nodes}", f"--node_rank={node_rank}" + ] + cmd += _arg_dict_to_list(default_torchrun_rdzv_args) + + cmd += _arg_dict_to_list(extra_launch_args) + [user_script] + user_args + cmd = ' '.join(cmd) + return cmd + + +def launch_multi_processes(args: Config) -> None: + """ + Launch multiple processes on a single node or multiple nodes. + + The overall logic can be summarized as the pseudo code below: + + if hostfile given: + hostinfo = parse_hostfile(hostfile) + hostinfo = include_or_exclude_hosts(hostinfo) + launch_on_multi_nodes(hostinfo) + elif hosts given: + hostinfo = parse_hosts(hosts) + launch_on_multi_nodes(hostinfo) + else: + launch_on_current_node() + + Args: + args (Config): the arguments taken from command line + + """ + assert isinstance(args, Config) + + if args.nproc_per_node is None: + click.echo("--nproc_per_node did not receive any value") + exit() + + # cannot accept hosts and hostfile at the same time + if args.host and args.hostfile: + click.echo("Error: hostfile and hosts are mutually exclusive, only one is required") + + # check if hostfile is given + if args.hostfile: + device_pool = fetch_hostfile(args.hostfile, ssh_port=args.ssh_port) + active_device_pool = parse_device_filter(device_pool, args.include, args.exclude) + + if args.num_nodes > 0: + # only keep the first num_nodes to execute jobs + updated_active_device_pool = HostInfoList() + for count, hostinfo in enumerate(active_device_pool): + if args.num_nodes == count: + break + updated_active_device_pool.append(hostinfo) + active_device_pool = updated_active_device_pool + else: + active_device_pool = None + + env = os.environ.copy() + + # use hosts if hostfile is not given + if args.host and active_device_pool is None: + active_device_pool = HostInfoList() + host_list = args.host.strip().split(NODE_SEP) + for hostname in host_list: + hostinfo = HostInfo(hostname=hostname, port=args.ssh_port) + active_device_pool.append(hostinfo) + + if not active_device_pool: + # run on local node if not hosts or hostfile is given + # add local node to host info list + active_device_pool = HostInfoList() + localhost_info = HostInfo(hostname='127.0.0.1', port=args.ssh_port) + active_device_pool.append(localhost_info) + + # launch distributed processes + runner = MultiNodeRunner() + curr_path = os.path.abspath('.') + + # collect current path env + env = dict() + for k, v in os.environ.items(): + # do not support multi-line env var + if v and '\n' not in v: + env[k] = v + + # establish remote connection + runner.connect(host_info_list=active_device_pool, workdir=curr_path, env=env) + + # execute distributed launching command + for node_id, hostinfo in enumerate(active_device_pool): + cmd = get_launch_command(master_addr=args.master_addr, + master_port=args.master_port, + nproc_per_node=args.nproc_per_node, + user_script=args.user_script, + user_args=args.user_args, + node_rank=node_id, + num_nodes=len(active_device_pool), + extra_launch_args=args.extra_launch_args) + runner.send(hostinfo=hostinfo, cmd=cmd) + + # start training + msg_from_node = runner.recv_from_all() + has_error = False + + # print node status + click.echo("\n====== Training on All Nodes =====") + for hostname, msg in msg_from_node.items(): + click.echo(f"{hostname}: {msg}") + + # check if a process failed + if msg == "failure": + has_error = True + + # stop all nodes + runner.stop_all() + + # receive the stop status + msg_from_node = runner.recv_from_all() + + # printe node status + click.echo("\n====== Stopping All Nodes =====") + for hostname, msg in msg_from_node.items(): + click.echo(f"{hostname}: {msg}") + + # give the process an exit code + # so that it behaves like a normal process + if has_error: + sys.exit(1) + else: + sys.exit(0) diff --git a/colossalai/cluster/__init__.py b/colossalai/cluster/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2fbdfd3cc9996b1044720ef0c1669f5f67fbe8b3 --- /dev/null +++ b/colossalai/cluster/__init__.py @@ -0,0 +1,5 @@ +from .device_mesh_manager import DeviceMeshManager +from .dist_coordinator import DistCoordinator +from .process_group_manager import ProcessGroupManager + +__all__ = ['DistCoordinator', 'ProcessGroupManager', 'DeviceMeshManager'] diff --git a/colossalai/cluster/device_mesh_manager.py b/colossalai/cluster/device_mesh_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..8754baa19792adf6c5ec79c115d36fac9a3f3c5d --- /dev/null +++ b/colossalai/cluster/device_mesh_manager.py @@ -0,0 +1,117 @@ +from dataclasses import dataclass +from typing import Dict, List, Tuple, Union + +import torch +import torch.distributed as dist + +from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler +from colossalai.device.device_mesh import DeviceMesh + + +@dataclass +class DeviceMeshInfo: + ''' + This class is used to store the information used to initialize the device mesh. + + Args: + physical_ids (List[int]): The physical ids of the current booster. For example, if we have the last 4 GPUs on a 8-devices cluster, then the physical ids should be [4, 5, 6, 7]. + mesh_shapes (List[Union[torch.Size, List[int], Tuple[int]]]): The shape of the mesh. For example, if we have 4 GPUs and we want to use 2D mesh with mesh shape [2, 2], then the mesh shape should be [2, 2]. + ''' + physical_ids: List[int] + mesh_shape: Union[torch.Size, List[int], Tuple[int]] = None + + def __post_init__(self): + if self.mesh_shape is not None: + world_size = len(self.physical_ids) + mesh_shape_numel = torch.Size(self.mesh_shape).numel() + assert world_size == mesh_shape_numel, f'the numel of mesh_shape should be equal to world size, but got {world_size} != {mesh_shape_numel}' + + +def initialize_device_mesh(device_mesh_info: DeviceMeshInfo): + ''' + This method is used to initialize the device mesh. + + Args: + device_mesh_info (DeviceMeshInfo): The information used to initialize device mesh. + ''' + # parse the device mesh info + physical_devices = device_mesh_info.physical_ids + physical_mesh = torch.tensor(physical_devices) + logical_mesh_shape = device_mesh_info.mesh_shape + + if logical_mesh_shape is None: + ab_profiler = AlphaBetaProfiler(physical_devices) + # search for the best logical mesh shape + logical_mesh_id = ab_profiler.search_best_logical_mesh() + logical_mesh_id = torch.Tensor(logical_mesh_id).to(torch.int) + + else: + logical_mesh_id = physical_mesh.reshape(logical_mesh_shape) + + device_mesh = DeviceMesh(physical_mesh_id=physical_mesh, logical_mesh_id=logical_mesh_id, init_process_group=True) + return device_mesh + + +class DeviceMeshManager: + """ + Device mesh manager is responsible for creating and managing device meshes. + """ + + def __init__(self): + self.device_mesh_store: Dict[str, DeviceMesh] = dict() + + def create_device_mesh(self, name, device_mesh_info: DeviceMeshInfo) -> DeviceMesh: + """ + Create a device mesh and store it in the manager. + + Args: + name (str): name of the device mesh + device_mesh_info (DeviceMeshInfo): the information used to initialize the device mesh + """ + if name not in self.device_mesh_store: + device_mesh = initialize_device_mesh(device_mesh_info) + self.device_mesh_store[name] = device_mesh + return device_mesh + else: + raise ValueError(f'Device mesh {name} already exists.') + + def get(self, name: str) -> DeviceMesh: + """ + Get a device mesh by name. + + Args: + name (str): name of the device mesh + + Returns: + DeviceMesh: the device mesh + """ + if name in self.device_mesh_store: + return self.device_mesh_store[name] + else: + raise ValueError(f'Device mesh {name} does not exist.') + + def destroy(self, name: str) -> None: + """ + Destroy a device mesh by name. + + Args: + name (str): name of the device mesh + """ + if name in self.device_mesh_store: + for pgs in self.device_mesh_store[name].process_groups_dict.values(): + for pg in pgs: + dist.destroy_process_group(pg) + del self.device_mesh_store[name] + else: + raise ValueError(f'Device mesh {name} does not exist.') + + def destroy_all(self): + """ + Destroy all device meshes. + """ + for name in self.device_mesh_store: + for pgs in self.device_mesh_store[name].process_groups_dict.values(): + for pg in pgs: + dist.destroy_process_group(pg) + + self.device_mesh_store.clear() diff --git a/colossalai/cluster/dist_coordinator.py b/colossalai/cluster/dist_coordinator.py new file mode 100644 index 0000000000000000000000000000000000000000..99dde810e11251e16235573f6dd88e68b74a64b3 --- /dev/null +++ b/colossalai/cluster/dist_coordinator.py @@ -0,0 +1,194 @@ +import functools +import os +from contextlib import contextmanager + +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from colossalai.context.singleton_meta import SingletonMeta + + +class DistCoordinator(metaclass=SingletonMeta): + """ + This class is used to coordinate distributed training. It is a singleton class, which means that there is only one instance of this + class in the whole program. + + There are some terms that are used in this class: + - rank: the rank of the current process + - world size: the total number of processes + - local rank: the rank of the current process on the current node + - master: the process with rank 0 + - node master: the process with local rank 0 on the current node + + Example: + >>> from colossalai.cluster.dist_coordinator import DistCoordinator + >>> coordinator = DistCoordinator() + >>> + >>> if coordinator.is_master(): + >>> do_something() + >>> + >>> coordinator.print_on_master('hello world') + + Attributes: + rank (int): the rank of the current process + world_size (int): the total number of processes + local_rank (int): the rank of the current process on the current node + """ + + def __init__(self): + assert dist.is_initialized( + ), 'Distributed is not initialized. Please call `torch.distributed.init_process_group` or `colossalai.launch` first.' + self._rank = dist.get_rank() + self._world_size = dist.get_world_size() + # this is often passed by launchers such as torchrun + self._local_rank = os.environ.get('LOCAL_RANK', -1) + + @property + def rank(self) -> int: + return self._rank + + @property + def world_size(self) -> int: + return self._world_size + + @property + def local_rank(self) -> int: + return self._local_rank + + def _assert_local_rank_set(self): + """ + Assert that the local rank is set. This is often passed by launchers such as torchrun. + """ + assert self.local_rank >= 0, 'The environment variable LOCAL_RANK is not set, thus the coordinator is not aware of the local rank of the current process.' + + def is_master(self, process_group: ProcessGroup = None) -> bool: + """ + Check if the current process is the master process (rank is 0). It can accept a sub process group to check the rank 0 with respect to the process. + + Args: + process_group (ProcessGroup, optional): process group to use for the rank 0 check. Defaults to None, which refers to the default process group. + + Returns: + bool: True if the current process is the master process, False otherwise + """ + rank = dist.get_rank(group=process_group) + return rank == 0 + + def is_node_master(self) -> bool: + """ + Check if the current process is the master process on the current node (local rank is 0). + + Returns: + bool: True if the current process is the master process on the current node, False otherwise + """ + self._assert_local_rank_set() + return self.local_rank == 0 + + def is_last_process(self, process_group: ProcessGroup = None) -> bool: + """ + Check if the current process is the last process (rank is world size - 1). It can accept a sub process group to check the last rank with respect to the process. + + Args: + process_group (ProcessGroup, optional): process group to use for the last rank check. Defaults to None, which refers to the default process group. + + Returns: + bool: True if the current process is the last process, False otherwise + """ + rank = dist.get_rank(group=process_group) + world_size = dist.get_world_size(group=process_group) + return rank == world_size - 1 + + def print_on_master(self, msg: str, process_group: ProcessGroup = None): + """ + Print message only from rank 0. + + Args: + msg (str): message to print + process_group (ProcessGroup, optional): process group to use for the rank 0 check. Defaults to None, which refers to the default process group. + """ + rank = dist.get_rank(group=process_group) + if rank == 0: + print(msg) + + def print_on_node_master(self, msg: str): + """ + Print message only from local rank 0. Local rank 0 refers to the 0th process running the current node. + + Args: + msg (str): message to print + """ + self._assert_local_rank_set() + if self.local_rank == 0: + print(msg) + + @contextmanager + def priority_execution(self, executor_rank: int = 0, process_group: ProcessGroup = None): + """ + This context manager is used to allow one process to execute while blocking all + other processes in the same process group. This is often useful when downloading is required + as we only want to download in one process to prevent file corruption. + + Example: + >>> from colossalai.cluster import DistCoordinator + >>> dist_coordinator = DistCoordinator() + >>> with dist_coordinator.priority_execution(): + >>> dataset = CIFAR10(root='./data', download=True) + + Args: + executor_rank (int): the process rank to execute without blocking, all other processes will be blocked + process_group (ProcessGroup, optional): process group to use for the executor rank check. Defaults to None, which refers to the default process group. + """ + rank = dist.get_rank(group=process_group) + should_block = rank != executor_rank + + if should_block: + self.block_all(process_group) + + yield + + if not should_block: + self.block_all(process_group) + + def destroy(self, process_group: ProcessGroup = None): + """ + Destroy the distributed process group. + + Args: + process_group (ProcessGroup, optional): process group to destroy. Defaults to None, which refers to the default process group. + """ + dist.destroy_process_group(process_group) + + def block_all(self, process_group: ProcessGroup = None): + """ + Block all processes in the process group. + + Args: + process_group (ProcessGroup, optional): process group to block. Defaults to None, which refers to the default process group. + """ + dist.barrier(group=process_group) + + def on_master_only(self, process_group: ProcessGroup = None): + """ + A function wrapper that only executes the wrapped function on the master process (rank 0). + + Example: + >>> from colossalai.cluster import DistCoordinator + >>> dist_coordinator = DistCoordinator() + >>> + >>> @dist_coordinator.on_master_only() + >>> def print_on_master(msg): + >>> print(msg) + """ + is_master = self.is_master(process_group) + + # define an inner functiuon + def decorator(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + if is_master: + return func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/colossalai/cluster/process_group_manager.py b/colossalai/cluster/process_group_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..e52661846f3ed6d25602252886401837613a75e3 --- /dev/null +++ b/colossalai/cluster/process_group_manager.py @@ -0,0 +1,75 @@ +from typing import List + +import torch.distributed as dist +from torch.distributed import ProcessGroup + + +class ProcessGroupManager: + """ + ProcessGroupManager is used to manage the process groups in the cluster. + + There are some terms used in this class: + - pg: the short name for process group + - pg_name: the name of the process group + - pg_size: the world size of the process group + - rank: the rank of the current process in the process group + - world_size: the total number of processes in the process group + """ + + def __init__(self): + self.pg_store = dict() + + def create_process_group(self, name: str, ranks: List[int], backend: str = 'nccl') -> ProcessGroup: + """ + Get a process group by name. If the process group does not exist, it will be created. + + Args: + name (str): name of the process group + ranks (List[int]): ranks of the process group + backend (str, optional): backend of the process group. Defaults to 'nccl'. + + Returns: + ProcessGroup: the process group + """ + if name not in self.pg_store: + pg = dist.new_group(ranks=ranks, backend=backend) + self.pg_store[name] = pg + return pg + else: + raise ValueError(f'Process group {name} already exists.') + + def get(self, name: str) -> ProcessGroup: + """ + Get a process group by name. + + Args: + name (str): name of the process group + + Returns: + ProcessGroup: the process group + """ + if name in self.pg_store: + return self.pg_store[name] + else: + raise ValueError(f'Process group {name} does not exist.') + + def destroy(self, name: str) -> None: + """ + Destroy a process group by name. + + Args: + name (str): name of the process group + """ + if name in self.pg_store: + dist.destroy_process_group(self.pg_store[name]) + del self.pg_store[name] + else: + raise ValueError(f'Process group {name} does not exist.') + + def destroy_all(self) -> None: + """ + Destroy all process groups. + """ + for name in self.pg_store: + dist.destroy_process_group(self.pg_store[name]) + self.pg_store.clear() diff --git a/colossalai/communication/__init__.py b/colossalai/communication/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..220481b7af15bcada443ed7c9f8c91350a5f76b1 --- /dev/null +++ b/colossalai/communication/__init__.py @@ -0,0 +1,26 @@ +from .collective import all_gather, reduce_scatter, all_reduce, broadcast, reduce +from .p2p import (send_forward, send_forward_recv_forward, send_backward_recv_forward, send_backward, + send_backward_recv_backward, send_forward_recv_backward, send_forward_backward_recv_forward_backward, + recv_forward, recv_backward) +from .ring import ring_forward +from .utils import send_obj_meta, recv_obj_meta + +__all__ = [ + 'all_gather', + 'reduce_scatter', + 'all_reduce', + 'broadcast', + 'reduce', + 'send_forward', + 'send_forward_recv_forward', + 'send_forward_backward_recv_forward_backward', + 'send_backward', + 'send_backward_recv_backward', + 'send_backward_recv_forward', + 'send_forward_recv_backward', + 'recv_backward', + 'recv_forward', + 'ring_forward', + 'send_obj_meta', + 'recv_obj_meta', +] diff --git a/colossalai/communication/collective.py b/colossalai/communication/collective.py new file mode 100644 index 0000000000000000000000000000000000000000..64fb5b8b5296fa8afe7b20c9c96609f7b999e8c0 --- /dev/null +++ b/colossalai/communication/collective.py @@ -0,0 +1,248 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch +import torch.distributed as dist +from torch import Tensor +from torch.distributed import ReduceOp + +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc + +_all_gather_func = dist._all_gather_base \ + if "all_gather_into_tensor" not in dir(dist) else dist.all_gather_into_tensor +_reduce_scatter_func = dist._reduce_scatter_base \ + if "reduce_scatter_tensor" not in dir(dist) else dist.reduce_scatter_tensor + + +def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: bool = False) -> Tensor: + r"""Gathers all tensors from the parallel group and concatenates them in a + specific dimension. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + + Args: + tensor (:class:`torch.Tensor`): Tensor to be gathered. + dim (int): The dimension concatenating in. + parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication. + async_op (bool, optional): Whether operations are asynchronous. + + Returns: + Union[tuple(:class:`torch.Tensor`, work handle), :class:`torch.Tensor`]: The result of all-together only, + if async_op is set to False. A tuple of output of all-gather and Async work handle, if async_op is set to True. + """ + depth = gpc.get_world_size(parallel_mode) + if depth == 1: + out = tensor + work = None + else: + tensor_in = tensor.contiguous() if dim == 0 else tensor.transpose(0, dim).contiguous() + out_shape = (tensor_in.shape[0] * depth,) + tensor_in.shape[1:] + tensor_out = torch.empty(out_shape, dtype=tensor.dtype, device=tensor.device) + group = gpc.get_cpu_group(parallel_mode) if tensor.device.type == "cpu" else gpc.get_group(parallel_mode) + work = _all_gather_func(tensor_out, tensor_in, group=group, async_op=async_op) + out = tensor_out if dim == 0 else tensor_out.transpose(0, dim) + if async_op: + return out, work + else: + return out + + +def reduce_scatter(tensor: Tensor, + dim: int, + parallel_mode: ParallelMode, + op: ReduceOp = ReduceOp.SUM, + async_op: bool = False) -> Tensor: + r"""Reduces all tensors then scatters it in a specific dimension to all + members in the parallel group. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + + Args: + tensor (:class:`torch.Tensor`): Tensor to be reduce_scattered. + dim (int): The dimension concatenating in. + parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication. + op (torch.distributed.ReduceOp, optional): The type of reduce operation, + should be included in [SUM, AVG, PRODUCT, MIN, MAX, BAND, BOR, BXOR]. + More details about ReduceOp please refer to + `ReduceOp `_. + async_op (bool, optional): Whether operations are asynchronous. + + Returns: + Union[tuple(:class:`torch.Tensor`, work handle), :class:`torch.Tensor`]: The result of reduce_scatter only, + if async_op is set to False. A tuple of output of all-gather and Async work handle, if async_op is set to True. + """ + depth = gpc.get_world_size(parallel_mode) + if depth == 1: + out = tensor + work = None + else: + tensor_in = tensor.contiguous() if dim == 0 else tensor.transpose(0, dim).contiguous() + out_shape = (tensor_in.shape[0] // depth,) + tensor_in.shape[1:] + tensor_out = torch.empty(out_shape, dtype=tensor.dtype, device=tensor.device) + group = gpc.get_cpu_group(parallel_mode) if tensor.device.type == "cpu" else gpc.get_group(parallel_mode) + work = _reduce_scatter_func(tensor_out, tensor_in, op=op, group=group, async_op=async_op) + out = tensor_out if dim == 0 else tensor_out.transpose(0, dim) + if async_op: + return out, work + else: + return out + + +def all_reduce(tensor: Tensor, + parallel_mode: ParallelMode, + op: ReduceOp = ReduceOp.SUM, + async_op: bool = False) -> Tensor: + r"""Reduces the tensor data across whole parallel group in such a way that all get the final result. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + + Args: + tensor (:class:`torch.Tensor`): Tensor to be all-reduced. + parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication. + op (torch.distributed.ReduceOp, optional): The type of reduce operation, + should be included in [SUM, AVG, PRODUCT, MIN, MAX, BAND, BOR, BXOR]. + More details about ReduceOp please refer to + `ReduceOp `_. + async_op (bool, optional): Whether operations are asynchronous. + + Returns: + Union[tuple(:class:`torch.Tensor`, work handle), :class:`torch.Tensor`]: The result of all-gather only, + if async_op is set to False. A tuple of output of all-gather and Async work handle, if async_op is set to True. + """ + depth = gpc.get_world_size(parallel_mode) + if depth == 1: + out = tensor + work = None + else: + out = tensor.contiguous() + group = gpc.get_cpu_group(parallel_mode) if tensor.device.type == "cpu" else gpc.get_group(parallel_mode) + work = dist.all_reduce(out, op=op, group=group, async_op=async_op) + if async_op: + return out, work + else: + return out + + +def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, async_op: bool = False): + r"""Broadcast tensors to whole parallel group. Tensor must have the same + number of elements in all processes participating in the collective. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + + Args: + tensor (:class:`torch.Tensor`): Tensor to be broadcast. + src (int): Source rank. + parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication. + async_op (bool, optional): Whether operations are asynchronous. + + Returns: + Union[tuple(:class:`torch.Tensor`, work handle), :class:`torch.Tensor`]: The tensor need to be broadcast only, + if async_op is set to False. A tuple of output of all-gather and Async work handle, if async_op is set to True. + """ + depth = gpc.get_world_size(parallel_mode) + if depth == 1: + out = tensor + work = None + else: + out = tensor.contiguous() + group = gpc.get_cpu_group(parallel_mode) if tensor.device.type == "cpu" else gpc.get_group(parallel_mode) + work = dist.broadcast(out, src=src, group=group, async_op=async_op) + if async_op: + return out, work + else: + return out + + +def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp = ReduceOp.SUM, async_op: bool = False): + r"""Reduce tensors across whole parallel group. Only the process with + rank ``dst`` is going to receive the final result. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + + Args: + tensor (:class:`torch.Tensor`): Tensor to be reduced. + dst (int): Destination rank. + parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication. + async_op (bool, optional): Whether operations are asynchronous. + + Returns: + Union[tuple(:class:`torch.Tensor`, work handle), :class:`torch.Tensor`]: The result of reduce only, + if async_op is set to False. A tuple of output of all-gather and Async work handle, if async_op is set to True. + """ + depth = gpc.get_world_size(parallel_mode) + if depth == 1: + out = tensor + work = None + else: + out = tensor.contiguous() + group = gpc.get_cpu_group(parallel_mode) if tensor.device.type == "cpu" else gpc.get_group(parallel_mode) + work = dist.reduce(out, dst=dst, op=op, group=group, async_op=async_op) + if async_op: + return out, work + else: + return out + + +def scatter_object_list(scatter_object_output_list, scatter_object_input_list, src=0, group=None) -> None: + r"""Modified from `torch.distributed.scatter_object_list + ` to fix issues + """ + if dist.distributed_c10d._rank_not_in_group(group): + return + + if (not isinstance(scatter_object_output_list, list) or len(scatter_object_output_list) < 1): + raise RuntimeError("Expected argument scatter_object_output_list to be a list of size at least 1.") + + # set tensor device to cuda if backend is nccl + device = torch.cuda.current_device() if dist.get_backend(group) == 'nccl' else torch.device("cpu") + + my_rank = dist.get_rank() # use global rank + if my_rank == src: + tensor_list, tensor_sizes = zip( + *[dist.distributed_c10d._object_to_tensor(obj) for obj in scatter_object_input_list]) + tensor_list = list(map(lambda x: x.to(device), tensor_list)) + tensor_sizes = list(map(lambda x: x.to(device), tensor_sizes)) + + # Src rank broadcasts the maximum tensor size. This is because all ranks are + # expected to call into scatter() with equal-sized tensors. + if my_rank == src: + max_tensor_size = max(tensor_sizes) + for tensor in tensor_list: + tensor.resize_(max_tensor_size) + else: + max_tensor_size = torch.tensor([0], dtype=torch.long).to(device) + + dist.broadcast(max_tensor_size, src=src, group=group) + + # Scatter actual serialized objects + output_tensor = torch.empty(max_tensor_size.item(), dtype=torch.uint8).to(device) + dist.scatter( + output_tensor, + scatter_list=None if my_rank != src else tensor_list, + src=src, + group=group, + ) + + # Scatter per-object sizes to trim tensors when deserializing back to object + obj_tensor_size = torch.tensor([0], dtype=torch.long).to(device) + dist.scatter( + obj_tensor_size, + scatter_list=None if my_rank != src else tensor_sizes, + src=src, + group=group, + ) + + output_tensor, obj_tensor_size = output_tensor.cpu(), obj_tensor_size.cpu() + # Deserialize back to object + scatter_object_output_list[0] = dist.distributed_c10d._tensor_to_object(output_tensor, obj_tensor_size) diff --git a/colossalai/communication/p2p.py b/colossalai/communication/p2p.py new file mode 100644 index 0000000000000000000000000000000000000000..0200cd3c6553dc8e2b3bbaa60ffb1d416c699370 --- /dev/null +++ b/colossalai/communication/p2p.py @@ -0,0 +1,405 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from typing import List, Tuple, Union +import torch +import torch.distributed as dist + +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.utils import get_current_device +from functools import reduce +import operator +from .utils import split_tensor_into_1d_equal_chunks, gather_split_1d_tensor + +TensorShape = Union[torch.Size, List[int], Tuple[int]] + + +def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) -> Tuple[TensorShape, bool]: + """get the exact tensor shape when communicating and return whether the tensor is a chunk + + Args: + tensor_shape (:class:`torch.Size`): shape of tensor + chunk_tensor (bool, optional): whether to chunk tensor, defaults to False + + Returns: + Tuple[Union[:class:`torch.Size`, List[int], Tuple[int]], bool]: exact tensor shape, whether to chunk tensor + """ + if chunk_tensor: + tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) + tensor_parallel_world_size = gpc.get_world_size(ParallelMode.TENSOR) + if tensor_chunk_shape % tensor_parallel_world_size == 0: + tensor_chunk_shape = tensor_chunk_shape // tensor_parallel_world_size + else: + tensor_chunk_shape = tensor_shape + chunk_tensor = False + else: + tensor_chunk_shape = tensor_shape + return tensor_chunk_shape, chunk_tensor + + +def create_recv_buffer_with_shapes(recv_shapes, dtype, scatter_gather_tensors): + if isinstance(recv_shapes, torch.Size): + recv_chunk_shape, recv_split = _get_tensor_shape(recv_shapes, scatter_gather_tensors) + buffer_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype) + return buffer_recv, recv_split + buffer_recv = [] + for recv_shape in recv_shapes: + recv_chunk_shape, recv_split = _get_tensor_shape(recv_shape, scatter_gather_tensors) + tensor_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype) + buffer_recv.append(tensor_recv) + return buffer_recv, recv_split + + +def process_object_to_send(object_send, scatter_gather_tensors): + if isinstance(object_send, torch.Tensor): + send_split = _get_tensor_shape(object_send.shape, scatter_gather_tensors)[1] + if send_split: + object_send = split_tensor_into_1d_equal_chunks(object_send) + return object_send + + object_send_list = [] + for tensor_send in object_send: + send_split = _get_tensor_shape(tensor_send.shape, scatter_gather_tensors)[1] + if send_split: + object_send_list.append(split_tensor_into_1d_equal_chunks(tensor_send)) + else: + object_send_list.append(tensor_send) + object_send = tuple(object_send_list) + + return object_send + + +def filling_ops_queue(obj, comm_op, comm_rank, ops_queue): + if isinstance(obj, torch.Tensor): + op_to_add = dist.P2POp(comm_op, obj, comm_rank) + ops_queue.append(op_to_add) + else: + for tensor_to_comm in obj: + op_to_add = dist.P2POp(comm_op, tensor_to_comm, comm_rank) + ops_queue.append(op_to_add) + + +def _communicate(object_send_next: Union[torch.Tensor, List[torch.Tensor]] = None, + object_send_prev: Union[torch.Tensor, List[torch.Tensor]] = None, + recv_prev: bool = False, + recv_next: bool = False, + recv_prev_shape: Union[torch.Size, List[torch.Size]] = None, + recv_next_shape: Union[torch.Size, List[torch.Size]] = None, + prev_rank: int = None, + next_rank: int = None, + dtype: torch.dtype = None, + scatter_gather_tensors: bool = False) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]: + """ + Adapted from megatron.p2p_communication. + Communicate tensors between stages. Used as helper method in other + communication methods that are used in pipeline schedule. + Takes the following arguments: + object_send_next (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): tensor to send to next rank (no tensor sent if + set to None). + object_send_prev (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): tensor to send to prev rank (no tensor sent if + set to None). + recv_prev (bool): boolean for whether tensor should be received from + previous rank. + recv_next (bool): boolean for whether tensor should be received from + next rank. + recv_prev_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): shape of the tensor to be received from the previous stage, defaults to None. + recv_next_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): shape of the tensor to be received from the next stage, defaults to None. + prev_rank (int): the rank of the previous pipeline stage, defaults to None, + next_rank (int): the rank of the next pipeline stage, defaults to None, + dtype (torch.dtype): data type of intermediate buffers, defaults to None + scatter_gather_tensors (bool): whether to scatter and gather tensor between pipeline stages, defaults to False + + Returns: + Tuple[Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]]: returns tensor_recv_prev, tensor_recv_next + """ + + # Create placeholder tensors for receive in forward and backward directions + # if needed. + tensor_recv_prev = None + tensor_recv_next = None + + if recv_prev: + assert recv_prev_shape is not None + tensor_recv_prev, recv_prev_split = create_recv_buffer_with_shapes(recv_prev_shape, dtype, + scatter_gather_tensors) + + if recv_next: + assert recv_next_shape is not None + tensor_recv_next, recv_next_split = create_recv_buffer_with_shapes(recv_next_shape, dtype, + scatter_gather_tensors) + + if object_send_prev is not None or recv_prev: + if prev_rank is None: + prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) + + if object_send_next is not None or recv_next: + if next_rank is None: + next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) + + if object_send_prev is not None: + object_send_prev = process_object_to_send(object_send_prev, scatter_gather_tensors) + + if object_send_next is not None: + object_send_next = process_object_to_send(object_send_next, scatter_gather_tensors) + + ops = [] + if object_send_prev is not None: + filling_ops_queue(object_send_prev, dist.isend, prev_rank, ops) + + if tensor_recv_prev is not None: + filling_ops_queue(tensor_recv_prev, dist.irecv, prev_rank, ops) + + if tensor_recv_next is not None: + filling_ops_queue(tensor_recv_next, dist.irecv, next_rank, ops) + + if object_send_next is not None: + filling_ops_queue(object_send_next, dist.isend, next_rank, ops) + + if len(ops) > 0: + reqs = dist.batch_isend_irecv(ops) + for req in reqs: + req.wait() + # To protect against race condition when using batch_isend_irecv(). + torch.cuda.synchronize() + + if recv_prev and recv_prev_split: + if isinstance(tensor_recv_prev, torch.Tensor): + tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_() + else: + for index in range(len(tensor_recv_prev)): + tensor_recv_prev[index] = gather_split_1d_tensor(tensor_recv_prev[index]).view( + recv_prev_shape[index]).requires_grad_() + + if recv_next and recv_next_split: + if isinstance(tensor_recv_next, torch.Tensor): + tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_() + else: + for index in range(len(tensor_recv_next)): + tensor_recv_next[index] = gather_split_1d_tensor(tensor_recv_next[index]).view( + recv_next_shape[index]).requires_grad_() + + return tensor_recv_prev, tensor_recv_next + + +def recv_forward(input_tensor_shape, + prev_rank=None, + dtype=torch.float, + scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]: + """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. + + Args: + input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received. + prev_rank (int, optional): The rank of the source of the tensor. + + Returns: + Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor or input tensor list. + """ + if gpc.is_pipeline_first_stage(): + input_tensor = None + else: + input_tensor, _ = _communicate(recv_prev=True, + recv_prev_shape=input_tensor_shape, + prev_rank=prev_rank, + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors) + return input_tensor + + +def recv_backward(output_grad_shape, + next_rank=None, + dtype=torch.float, + scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]: + """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. + + Args: + output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received. + next_rank (int, optional): The rank of the source of the tensor. + + Returns: + Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor or gradident tensor list. + """ + if gpc.is_pipeline_last_stage(): + output_tensor_grad = None + else: + _, output_tensor_grad = _communicate(recv_next=True, + recv_next_shape=output_grad_shape, + next_rank=next_rank, + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors) + return output_tensor_grad + + +def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False) -> None: + """Sends the input tensor to the next stage in pipeline. + + Args: + output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent. + next_rank (int, optional): The rank of the recipient of the tensor. + """ + if not gpc.is_pipeline_last_stage(): + _communicate(object_send_next=output_tensor, next_rank=next_rank, scatter_gather_tensors=scatter_gather_tensors) + + +def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=False) -> None: + """Sends the gradient tensor to the previous stage in pipeline. + + Args: + input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent + prev_rank (int, optional): The rank of the recipient of the tensor + """ + if not gpc.is_pipeline_first_stage(): + _communicate(object_send_prev=input_tensor_grad, + prev_rank=prev_rank, + scatter_gather_tensors=scatter_gather_tensors) + + +def send_forward_recv_backward(output_tensor, + output_grad_shape, + recv_next=True, + next_rank=None, + dtype=torch.float, + scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]: + """Batched communication operation. Sends the input tensor to the + next stage in pipeline, while receives the gradient tensor from the + next stage in pipeline as the input gradient tensor of this stage. + + Args: + output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent. + output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received. + + Returns: + Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor. + """ + if gpc.is_pipeline_last_stage(): + output_tensor_grad = None + else: + _, output_tensor_grad = _communicate(object_send_next=output_tensor, + recv_next=recv_next, + recv_next_shape=output_grad_shape, + next_rank=next_rank, + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors) + return output_tensor_grad + + +def send_backward_recv_forward(input_tensor_grad, + input_tensor_shape, + recv_prev=True, + prev_rank=None, + dtype=torch.float, + scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]: + """Batched communication operation. Sends the gradient tensor to the + previous stage in pipeline, while receives the output tensor from the + previous stage in pipeline as the input of this stage. + + Args: + input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent. + input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received. + + Returns: + Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor. + """ + if gpc.is_pipeline_first_stage(): + input_tensor = None + else: + input_tensor, _ = _communicate(object_send_prev=input_tensor_grad, + recv_prev=recv_prev, + recv_prev_shape=input_tensor_shape, + prev_rank=prev_rank, + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors) + return input_tensor + + +def send_forward_recv_forward(output_tensor, + input_tensor_shape, + recv_prev=True, + prev_rank=None, + next_rank=None, + dtype=torch.float, + scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]: + """Batched communication operation. Sends the input tensor to the + next stage in pipeline, while receives the output tensor from the + previous stage in pipeline as the input of this stage. + + Args: + output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent. + input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received. + + Returns: + Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor. + """ + input_tensor, _ = _communicate(object_send_next=output_tensor, + recv_prev=recv_prev, + recv_prev_shape=input_tensor_shape, + prev_rank=prev_rank, + next_rank=next_rank, + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors) + return input_tensor + + +def send_backward_recv_backward(input_tensor_grad, + output_grad_shape, + recv_next=True, + prev_rank=None, + next_rank=None, + dtype=torch.float, + scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]: + """Batched communication operation. Sends the gradient tensor to the + previous stage in pipeline, while receives the gradient tensor from the + next member in pipeline as the input of this stage. + + Args: + input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent. + output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received. + + Returns: + Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor. + """ + _, output_tensor_grad = _communicate(object_send_prev=input_tensor_grad, + recv_next=recv_next, + recv_next_shape=output_grad_shape, + prev_rank=prev_rank, + next_rank=next_rank, + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors) + return output_tensor_grad + + +def send_forward_backward_recv_forward_backward( + output_tensor, + input_tensor_grad, + input_tensor_shape, + output_grad_shape, + recv_prev=True, + recv_next=True, + prev_rank=None, + next_rank=None, + dtype=torch.float, + scatter_gather_tensors=False) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]: + """Batched communication operation. Sends the input tensor to the next stage in pipeline and + the gradient tensor to the previous stage, while receives the input gradient tensor from the + next stage and the input tensor from the previous stage. + + Args: + output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor sent to the next. + input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor sent to the previous. + input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor received from the previous. + output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor received from the next. + + Returns: + Tuple(Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]], Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): (the input tensor, the input gradient tensor) + """ + input_tensor, output_tensor_grad = _communicate(object_send_next=output_tensor, + object_send_prev=input_tensor_grad, + recv_prev=recv_prev, + recv_next=recv_next, + recv_prev_shape=input_tensor_shape, + recv_next_shape=output_grad_shape, + prev_rank=prev_rank, + next_rank=next_rank, + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors) + return input_tensor, output_tensor_grad diff --git a/colossalai/communication/p2p_v2.py b/colossalai/communication/p2p_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..0dacd8c3c9b5bacf65aee3168bcbd170c5a5b6dc --- /dev/null +++ b/colossalai/communication/p2p_v2.py @@ -0,0 +1,268 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import io +import pickle +from typing import Any, List, Tuple, Union + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroupNCCL +from torch.distributed import distributed_c10d as c10d + +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc + +TensorShape = Union[torch.Size, List[int], Tuple[int]] +_pg_manager = {} +_unpickler = pickle.Unpickler + + +def init_process_group(): + """intialise process group by dist.new_group in the adjacent stages + + Args: + None + + Returns: + None + """ + world_size = gpc.get_world_size(ParallelMode.PIPELINE) + for i in range(world_size - 1): + _pg_manager[(i, i + 1)] = dist.new_group([i, i + 1]) + + +def _acquire_pair_group_handle(first_rank: int, second_rank: int) -> ProcessGroupNCCL: + """get the group handle of two given ranks + + Args: + first_rank (int): first rank in the pair + second_rank (int): second rank in the pair + + Returns: + :class:`ProcessGroupNCCL`: the handle of the group consisting of the given two ranks + """ + if len(_pg_manager) == 0: + init_process_group() + if first_rank > second_rank: + first_rank, second_rank = second_rank, first_rank + pair_key = (first_rank, second_rank) + return _pg_manager[pair_key] + + +def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -> object: + """transform tensor to object with unpickle. + Info of the device in bytes stream will be modified into current device before unpickling + + Args: + tensor (:class:`torch.tensor`): tensor to be unpickled + tensor_size (:class:`torch.Size`): Size of the real info in bytes + + Returns: + Any: object after unpickled + """ + buf = tensor.numpy().tobytes()[:tensor_size] + if b'cuda' in buf: + buf_array = bytearray(buf) + device_index = torch.cuda.current_device() + buf_array[buf_array.find(b'cuda') + 5] = 48 + device_index + buf = bytes(buf_array) + + io_bytes = io.BytesIO(buf) + byte_pickler = _unpickler(io_bytes) + unpickle = byte_pickler.load() + + return unpickle + + +def _broadcast_object_list(object_list: List[Any], src: int, dst: int, device=None): + """This is a modified version of the broadcast_object_list in torch.distribution + The only difference is that object will be move to correct device after unpickled. + If local_rank = src, then object list will be sent to rank src. Otherwise, object list will + be updated with data sent from rank src. + + Args: + object_list (List[Any]): list of object to broadcast + src (int): source rank to broadcast + dst (int): dst rank to broadcast + device (:class:`torch.device`): device to do broadcast. current device in default + + """ + group = _acquire_pair_group_handle(src, dst) + + if c10d._rank_not_in_group(group): + c10d._warn_not_in_group("broadcast_object_list") + return + + local_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + # Serialize object_list elements to tensors on src rank. + if local_rank == src: + tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list]) + object_sizes_tensor = torch.cat(size_list) + else: + object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long) + + is_nccl_backend = c10d._check_for_nccl_backend(group) + current_device = None + + if device is not None: + if is_nccl_backend and device.type != "cuda": + raise ValueError("device type must be cuda for nccl backend") + current_device = device + else: + current_device = torch.device("cpu") + if is_nccl_backend: + current_device = torch.device("cuda", torch.cuda.current_device()) + if is_nccl_backend: + object_sizes_tensor = object_sizes_tensor.to(current_device) + + # Broadcast object sizes + c10d.broadcast(object_sizes_tensor, src=src, group=group, async_op=False) + + # Concatenate and broadcast serialized object tensors + if local_rank == src: + object_tensor = torch.cat(tensor_list) + else: + object_tensor = torch.empty( # type: ignore[call-overload] + torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type] + dtype=torch.uint8, + ) + + if is_nccl_backend: + object_tensor = object_tensor.to(current_device) + + c10d.broadcast(object_tensor, src=src, group=group, async_op=False) + + # Deserialize objects using their stored sizes. + offset = 0 + + if local_rank != src: + for i, obj_size in enumerate(object_sizes_tensor): + obj_view = object_tensor[offset:offset + obj_size] + obj_view = obj_view.type(torch.uint8) + if obj_view.device != torch.device("cpu"): + obj_view = obj_view.cpu() + offset += obj_size + # unpickle + unpickle_object = _cuda_safe_tensor_to_object(obj_view, obj_size) + + # unconsistence in device + if isinstance(unpickle_object, + torch.Tensor) and unpickle_object.device.index != torch.cuda.current_device(): + unpickle_object = unpickle_object.cuda() + + object_list[i] = unpickle_object + + +def _send_object(object: Any, dst: int) -> None: + """send anything to dst rank + Args: + object (Any): object needed to be sent + dst (int): rank of the destination + + Returns: + None + """ + local_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + # handler = _acquire_pair_group_handle(local_rank, dst) + + # transform to list if not + if isinstance(object, torch.Tensor): + object = [object] + + # broadcast length first + # TODO : more elegant ? P.S. reduce a _broadcast_object_list + _broadcast_object_list([len(object)], local_rank, dst) + # then broadcast safely + _broadcast_object_list(object, local_rank, dst) + + +def _recv_object(src: int) -> Any: + """recv anything from src + + Args: + src (int): source rank of data. local rank will receive data from src rank. + + Returns: + Any: Object received from src. + """ + local_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + # handler = _acquire_pair_group_handle(local_rank, src) + # recv length first + length = [0] + _broadcast_object_list(length, src, local_rank) + + # then create recv buff from length[0] and broadcast + object = [None] * length[0] + _broadcast_object_list(object, src, local_rank) + + if length[0] == 1: + object = object[0] + + return object + + +def recv_forward(prev_rank: int = None) -> Any: + """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. + + Args: + input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received. + prev_rank (int, optional): The rank of the source of the tensor. + + Returns: + Any: The input tensor or input tensor list. + """ + if gpc.is_pipeline_first_stage(): + input_tensor = None + else: + if prev_rank is None: + prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) + input_tensor = _recv_object(prev_rank) + + return input_tensor + + +def recv_backward(next_rank: int = None) -> Any: + """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. + + Args: + output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received. + next_rank (int, optional): The rank of the source of the tensor. + + Returns: + Any: The input gradient tensor or gradient tensor list. + """ + if gpc.is_pipeline_last_stage(): + output_tensor_grad = None + else: + if next_rank is None: + next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) + output_tensor_grad = _recv_object(next_rank) + + return output_tensor_grad + + +def send_forward(output_object: Any, next_rank: int = None) -> None: + """Sends the input tensor to the next stage in pipeline. + + Args: + output_object Any: Object to be sent. + next_rank (int, optional): The rank of the recipient of the tensor. + """ + if not gpc.is_pipeline_last_stage(): + if next_rank is None: + next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) + _send_object(output_object, next_rank) + + +def send_backward(input_object: Any, prev_rank: int = None) -> None: + """Sends the gradient tensor to the previous stage in pipeline. + + Args: + input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent + prev_rank (int, optional): The rank of the recipient of the tensor + """ + if not gpc.is_pipeline_first_stage(): + if prev_rank is None: + prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) + _send_object(input_object, prev_rank) diff --git a/colossalai/communication/ring.py b/colossalai/communication/ring.py new file mode 100644 index 0000000000000000000000000000000000000000..aece7574b7c41cac3b16cd5891b1e26d0ede9c36 --- /dev/null +++ b/colossalai/communication/ring.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch + +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.utils import get_current_device, synchronize + + +def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) -> torch.Tensor: + """Sends a tensor to the next member and receives a tensor from the previous member. + This function returns the received tensor from the previous member. + + Args: + tensor_send_next (:class:`torch.Tensor`): Tensor sent to next member + parallel_mode (ParallelMode): Parallel group mode used in this communication + + Returns: + :class:`torch.Tensor`: The tensor received from the previous. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + buffer_shape = tensor_send_next.size() + + ops = [] + current_rank = gpc.get_global_rank() + + tensor_recv_prev = torch.empty(buffer_shape, + requires_grad=True, + device=get_current_device(), + dtype=tensor_send_next.dtype) + + # send to next rank + send_next_op = torch.distributed.P2POp(torch.distributed.isend, tensor_send_next, + gpc.get_next_global_rank(parallel_mode)) + ops.append(send_next_op) + + # receive from prev rank + recv_prev_op = torch.distributed.P2POp(torch.distributed.irecv, tensor_recv_prev, + gpc.get_prev_global_rank(parallel_mode)) + ops.append(recv_prev_op) + + if current_rank % 2 == 0: + ops = ops[::-1] + + reqs = torch.distributed.batch_isend_irecv(ops) + for req in reqs: + req.wait() + + # To protect against race condition when using batch_isend_irecv(). + synchronize() + + return tensor_recv_prev diff --git a/colossalai/communication/utils.py b/colossalai/communication/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ef9eceea847dd3d6cb036e87e369529dcbe0db41 --- /dev/null +++ b/colossalai/communication/utils.py @@ -0,0 +1,126 @@ +import torch +import torch.distributed as dist + +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.utils import get_current_device +from typing import Union, List, Tuple + +TensorShape = Union[torch.Size, List[int], Tuple[int]] + + +def send_meta_helper(obj, next_rank, tensor_kwargs): + send_shape = torch.tensor(obj.size(), **tensor_kwargs) + send_ndims = torch.tensor(len(obj.size()), **tensor_kwargs) + dist.send(send_ndims, next_rank) + dist.send(send_shape, next_rank) + + +def send_obj_meta(obj, need_meta=True, next_rank=None) -> bool: + """Sends obj meta information before sending a specific obj. + Since the recipient must know the shape of the obj in p2p communications, + meta information of the obj should be sent before communications. This function + synchronizes with :func:`recv_obj_meta`. + + Args: + obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): obj to be sent. + need_meta (bool, optional): If False, meta information won't be sent. + next_rank (int): The rank of the next member in pipeline parallel group. + + Returns: + bool: False + """ + if need_meta: + if next_rank is None: + next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) + + tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()} + if isinstance(obj, torch.Tensor): + send_obj_nums = torch.tensor(1, **tensor_kwargs) + dist.send(send_obj_nums, next_rank) + send_meta_helper(obj, next_rank, tensor_kwargs) + else: + send_obj_nums = torch.tensor(len(obj), **tensor_kwargs) + dist.send(send_obj_nums, next_rank) + for tensor_to_send in obj: + send_meta_helper(tensor_to_send, next_rank, tensor_kwargs) + + return False + + +def recv_meta_helper(prev_rank, tensor_kwargs): + recv_ndims = torch.empty((), **tensor_kwargs) + dist.recv(recv_ndims, prev_rank) + recv_shape = torch.empty(recv_ndims, **tensor_kwargs) + dist.recv(recv_shape, prev_rank) + return recv_shape + + +def recv_obj_meta(obj_shape, prev_rank=None) -> torch.Size: + """Receives obj meta information before receiving a specific obj. + Since the recipient must know the shape of the obj in p2p communications, + meta information of the obj should be received before communications. This function + synchronizes with :func:`send_obj_meta`. + + Args: + obj_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the obj to be received. + prev_rank (int): The rank of the source of the obj. + + Returns: + Union[:class:`torch.Size`, List[:class:`torch.Size`]]: The shape of the obj to be received. + """ + if obj_shape is None: + if prev_rank is None: + prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) + + tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()} + recv_obj_nums = torch.empty((), **tensor_kwargs) + dist.recv(recv_obj_nums, prev_rank) + if recv_obj_nums.item() == 1: + recv_shape = recv_meta_helper(prev_rank, tensor_kwargs) + obj_shape = torch.Size(recv_shape) + else: + obj_shape = [] + for i in range(recv_obj_nums.item()): + recv_shape = recv_meta_helper(prev_rank, tensor_kwargs) + obj_shape.append(torch.Size(recv_shape)) + + return obj_shape + + +def split_tensor_into_1d_equal_chunks(tensor: torch.Tensor, new_buffer=False) -> torch.Tensor: + """Break a tensor into equal 1D chunks. + + Args: + tensor (:class:`torch.Tensor`): Tensor to be split before communication. + new_buffer (bool, optional): Whether to use a new buffer to store sliced tensor. + + Returns: + :class:`torch.Tensor`: The split tensor + """ + partition_size = torch.numel(tensor) // gpc.get_world_size(ParallelMode.PARALLEL_1D) + start_index = partition_size * gpc.get_local_rank(ParallelMode.PARALLEL_1D) + end_index = start_index + partition_size + if new_buffer: + data = torch.empty(partition_size, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False) + data.copy_(tensor.view(-1)[start_index:end_index]) + else: + data = tensor.view(-1)[start_index:end_index] + return data + + +def gather_split_1d_tensor(tensor: torch.Tensor) -> torch.Tensor: + """Opposite of above function, gather values from model parallel ranks. + + Args: + tensor (:class:`torch.Tensor`): Tensor to be gathered after communication. + Returns: + :class:`torch.Tensor`: The gathered tensor. + """ + world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) + numel = torch.numel(tensor) + numel_gathered = world_size * numel + gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False) + chunks = [gathered[i * numel:(i + 1) * numel] for i in range(world_size)] + dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.PARALLEL_1D)) + return gathered diff --git a/colossalai/constants.py b/colossalai/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..6cf9085f9fbb63ea18d2712f99c08f24b539245d --- /dev/null +++ b/colossalai/constants.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +ALLOWED_MODES = [None, '1d', '2d', '2.5d', '3d', 'sequence'] +TENSOR_PARALLEL_MODE = 'tensor_parallel_mode' + +# initializer +INITIALIZER_MAPPING = { + 'data': 'Initializer_Data', + 'tensor': 'Initializer_Tensor', + 'pipeline': 'Initializer_Pipeline', + 'embedding': 'Initializer_Embedding', + '1d': 'Initializer_1D', + '2d': 'Initializer_2D', + '2.5d': 'Initializer_2p5D', + '3d': 'Initializer_3D', + 'sequence': 'Initializer_Sequence', + 'model': 'Initializer_Model', + 'moe': 'Initializer_Moe' +} + +# 3D parallelism groups +INPUT_GROUP_3D = 'input_group_3d' +WEIGHT_GROUP_3D = 'weight_group_3d' +OUTPUT_GROUP_3D = 'output_group_3d' +INPUT_X_WEIGHT_3D = 'input_x_weight_group_3d' +OUTPUT_X_WEIGHT_3D = 'output_x_weight_group_3d' + +# Attributes of tensor parallel parameters +IS_TENSOR_PARALLEL = 'is_tensor_parallel' +NUM_PARTITIONS = 'num_partitions' +TENSOR_PARALLEL_ATTRIBUTES = [IS_TENSOR_PARALLEL, NUM_PARTITIONS] diff --git a/colossalai/context/__init__.py b/colossalai/context/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..50178b5fa850777f8455798cc6ab9d7254c5a9fe --- /dev/null +++ b/colossalai/context/__init__.py @@ -0,0 +1,6 @@ +from .config import Config, ConfigException +from .parallel_context import ParallelContext +from .parallel_mode import ParallelMode +from .moe_context import MOE_CONTEXT +from .process_group_initializer import * +from .random import * diff --git a/colossalai/context/config.py b/colossalai/context/config.py new file mode 100644 index 0000000000000000000000000000000000000000..8903707708df96eac7a0a70343e37e984e6fabed --- /dev/null +++ b/colossalai/context/config.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import inspect +import sys +from importlib.machinery import SourceFileLoader +from pathlib import Path +from colossalai.logging import get_dist_logger + + +class Config(dict): + """This is a wrapper class for dict objects so that values of which can be + accessed as attributes. + + Args: + config (dict): The dict object to be wrapped. + """ + + def __init__(self, config: dict = None): + if config is not None: + for k, v in config.items(): + self._add_item(k, v) + + def __missing__(self, key): + raise KeyError(key) + + def __getattr__(self, key): + try: + value = super(Config, self).__getitem__(key) + return value + except KeyError: + raise AttributeError(key) + + def __setattr__(self, key, value): + super(Config, self).__setitem__(key, value) + + def _add_item(self, key, value): + if isinstance(value, dict): + self.__setattr__(key, Config(value)) + else: + self.__setattr__(key, value) + + def update(self, config): + assert isinstance(config, (Config, dict)), 'can only update dictionary or Config objects.' + for k, v in config.items(): + self._add_item(k, v) + return self + + @staticmethod + def from_file(filename: str): + """Reads a python file and constructs a corresponding :class:`Config` object. + + Args: + filename (str): Name of the file to construct the return object. + + Returns: + :class:`Config`: A :class:`Config` object constructed with information in the file. + + Raises: + AssertionError: Raises an AssertionError if the file does not exist, or the file is not .py file + """ + + # check config path + if isinstance(filename, str): + filepath = Path(filename).absolute() + elif isinstance(filename, Path): + filepath = filename.absolute() + + assert filepath.exists(), f'{filename} is not found, please check your configuration path' + + # check extension + extension = filepath.suffix + assert extension == '.py', 'only .py files are supported' + + # import the config as module + remove_path = False + if filepath.parent not in sys.path: + sys.path.insert(0, (filepath)) + remove_path = True + + module_name = filepath.stem + source_file = SourceFileLoader(fullname=str(module_name), path=str(filepath)) + module = source_file.load_module() + + # load into config + config = Config() + + for k, v in module.__dict__.items(): + if k.startswith('__') or inspect.ismodule(v) or inspect.isclass(v): + continue + else: + config._add_item(k, v) + + logger = get_dist_logger() + logger.debug('variables which starts with __, is a module or class declaration are omitted in config file') + + # remove module + del sys.modules[module_name] + if remove_path: + sys.path.pop(0) + + return config + + +class ConfigException(Exception): + pass diff --git a/colossalai/context/moe_context.py b/colossalai/context/moe_context.py new file mode 100644 index 0000000000000000000000000000000000000000..b41f4072a4052113b3e3a79a20c0278b9fed8295 --- /dev/null +++ b/colossalai/context/moe_context.py @@ -0,0 +1,129 @@ +from typing import Tuple + +import torch +import torch.distributed as dist + +from colossalai.context.parallel_mode import ParallelMode +from colossalai.context.singleton_meta import SingletonMeta +from colossalai.tensor import ProcessGroup + + +def _check_sanity(): + from colossalai.core import global_context as gpc + if gpc.tensor_parallel_size > 1 or gpc.pipeline_parallel_size > 1: + raise NotImplementedError("Moe is not compatible with tensor or " + "pipeline parallel at present.") + + +class MoeParallelInfo: + """Moe parallelism information, storing parallel sizes and groups. + """ + + def __init__(self, ep_size: int, dp_size: int): + _check_sanity() + self.ep_size = ep_size + self.dp_size = dp_size + self.pg = ProcessGroup(tp_degree=ep_size, dp_degree=dp_size) + self.ep_group = self.pg.tp_process_group() + self.dp_group = self.pg.dp_process_group() + + +class MoeContext(metaclass=SingletonMeta): + """MoE parallel context manager. This class manages different + parallel groups in MoE context and MoE loss in training. + """ + + def __init__(self): + self.world_size = 1 + # Users may want to set maximum expert parallel size smaller than the world size + # since very low bandwidth across nodes may constrain the performance of MoE + # When we have a maximum expert parallel size, we have a minimum data parallel size naturally + self.max_ep_size = 1 + self.min_dp_size = 1 + self.aux_loss = None + self.use_kernel_optim = True + + self.has_setup = False + self._parallel_info_dict = dict() + + @property + def parallel_info_dict(self): + return self._parallel_info_dict + + @property + def is_initialized(self): + return self.has_setup + + def setup(self, seed: int, use_kernel_optim: bool = True): + assert not self.is_initialized, "MoE distributed context shouldn't be set up again" + _check_sanity() + assert torch.cuda.is_available(), "MoE requires to enable CUDA first" + + self.world_size = dist.get_world_size() + + from colossalai.core import global_context as gpc + self.max_ep_size = gpc.config.get('max_ep_size', self.world_size) + assert self.world_size % self.max_ep_size == 0, \ + "Maximum expert parallel size must be a factor of the number of GPUs" + self.min_dp_size = self.world_size // self.max_ep_size + + # Enabling kernel optimization may raise error in some cases + # Users can close kernel optimization manually + self.use_kernel_optim = use_kernel_optim + + from .random import moe_set_seed + moe_set_seed(seed) + self.has_setup = True + + def get_info(self, num_experts: int) -> Tuple[int, MoeParallelInfo]: + """Calculate the Data Parallel Group and Expert Parallel Group. + + Parameters + ---------- + num_experts : int + The number experts + + Returns + ------- + int, MoeParallelInfo + number of local experts, the MoeParallelInfo of the current ep_size + """ + + gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater + lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less + + assert gt_flag or lt_flag, "Automatic experts placement dose not not support expert number" \ + " is not a multiple of ep size or vice versa." + + # If the number of experts is greater than maximum expert parallel size. a.k.a ep_size, + # there are multiple experts in each GPU and each GPU has different experts + # So it's data parallel size is 1 + # Otherwise, there is only one expert in each GPU + # The data parallel size should be calculated + dp_size = 1 if gt_flag else self.max_ep_size // num_experts + ep_size = self.max_ep_size // dp_size + + # Calculate the number of experts for each GPU + num_local_experts = 1 if lt_flag else num_experts // self.max_ep_size + + # Don't forget to multiply minimum data parallel size + dp_size *= self.min_dp_size + if not (ep_size in self.parallel_info_dict): + self.parallel_info_dict[ep_size] = MoeParallelInfo(ep_size, dp_size) + + return num_local_experts, self.parallel_info_dict[ep_size] + + def set_kernel_not_use(self): + self.use_kernel_optim = False + + def reset_loss(self): + self.aux_loss = 0 + + def add_loss(self, loss): + self.aux_loss += loss + + def get_loss(self): + return self.aux_loss + + +MOE_CONTEXT = MoeContext() diff --git a/colossalai/context/parallel_context.py b/colossalai/context/parallel_context.py new file mode 100644 index 0000000000000000000000000000000000000000..003f0cdd91b6630fe1a88271eed3afdd4021c3b8 --- /dev/null +++ b/colossalai/context/parallel_context.py @@ -0,0 +1,578 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import random +import socket +from collections import Counter +from threading import local +from typing import Union + +import numpy as np +import torch +import torch.distributed as dist + +from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING +from colossalai.context.config import Config +from colossalai.context.singleton_meta import SingletonMeta +from colossalai.global_variables import tensor_parallel_env as env +from colossalai.logging import get_dist_logger +from colossalai.registry import DIST_GROUP_INITIALIZER + +from .parallel_mode import ParallelMode +from .random import add_seed, get_seeds, set_mode + + +class ParallelContext(metaclass=SingletonMeta): + """This class provides interface functions for users to get the parallel context, + such as the global rank, the local rank, the world size, etc. of each device. + + Note: + The parallel_mode used in this class should be concluded in ``ParallelMode``. + More details about ``ParallelMode`` could be found in + `parallel_mode `_. + """ + + def __init__(self): + # distributed settings + self._global_ranks = dict() + self._local_ranks = dict() + self._world_sizes = dict() + self._groups = dict() + self._cpu_groups = dict() + self._ranks_in_group = dict() + + # load config from file + self._config = None + + # default 3D parallel args, will be overwritten during process group initialization + self.world_size = 1 + self.data_parallel_size = 1 + self.pipeline_parallel_size = 1 + self.tensor_parallel_size = 1 + self.num_processes_on_current_node = -1 + self.virtual_pipeline_parallel_size = None + self.virtual_pipeline_parallel_rank = None + + # logging + self._verbose = False + self._logger = get_dist_logger() + + @property + def config(self): + return self._config + + @property + def verbose(self): + return self._verbose + + @verbose.setter + def verbose(self, verbose_: bool): + self._verbose = verbose_ + + def load_config(self, config: Union[dict, str]): + """Loads the configuration from either a dict or a file. + + Args: + config (dict or str): Either a dict containing the configuration information or the filename + of a file containing the configuration information. + + Raises: + TypeError: Raises a TypeError if `config` is neither a dict nor a str. + """ + if isinstance(config, str): + self._config = Config.from_file(config) + elif isinstance(config, dict): + self._config = Config(config) + else: + raise TypeError("Invalid type for config, only dictionary or string is supported") + + def detect_num_processes_on_current_node(self): + hostname = socket.gethostname() + hostname_list = [None for _ in range(self.get_world_size(ParallelMode.GLOBAL))] + dist.all_gather_object(hostname_list, hostname, group=self.get_group(ParallelMode.GLOBAL)) + counter = Counter(hostname_list) + self.num_processes_on_current_node = counter[hostname] + + @staticmethod + def _check_parallel_mode(parallel_mode: ParallelMode): + assert isinstance(parallel_mode, ParallelMode), \ + f'expected the argument parallel_mode to be of enum ParallelMode, but got {type(parallel_mode)}' + + def get_global_rank(self): + """Returns the global rank of the current device. + + Returns: + int: The global rank of the current device + """ + return self._global_ranks[ParallelMode.GLOBAL] + + def add_global_rank(self, parallel_mode: ParallelMode, rank: int): + """Adds the global rank of the current device for `parallel_mode` to the context. + + Args: + parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode for the rank. + rank (int): The rank to be added + + Raises: + AssertionError: Raises an AssertionError if `parallel_mode` is not an instance + of :class:`colossalai.context.ParallelMode`. + """ + self._check_parallel_mode(parallel_mode) + self._global_ranks[parallel_mode] = rank + + def get_local_rank(self, parallel_mode: ParallelMode): + """Returns the local rank of the current device. + + Args: + parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + + Raises: + AssertionError: Raises an AssertionError if `parallel_mode` is not an instance + of :class:`colossalai.context.ParallelMode`. + + Returns: + int: The local rank of the current device for `parallel_mode`. + """ + self._check_parallel_mode(parallel_mode) + return self._local_ranks[parallel_mode] + + def _add_local_rank(self, parallel_mode: ParallelMode, rank: int): + """Adds the local rank of the current device for `parallel_mode` to the context. + + Args: + parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode for the rank. + rank (int): The rank to be added. + + Raises: + AssertionError: Raises an AssertionError if `parallel_mode` is not an instance + of :class:`colossalai.context.ParallelMode`. + """ + self._check_parallel_mode(parallel_mode) + self._local_ranks[parallel_mode] = rank + + def get_next_global_rank(self, parallel_mode: ParallelMode): + """Returns the global rank of the next device. + + Args: + parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + + Raises: + AssertionError: Raises an AssertionError if `parallel_mode` is not an instance + of :class:`colossalai.context.ParallelMode`. + + Returns: + int: The global rank of the next device for `parallel_mode`. + """ + self._check_parallel_mode(parallel_mode) + + # get rank and world size + local_rank = self.get_local_rank(parallel_mode) + world_size = self.get_world_size(parallel_mode) + ranks_in_group = self.get_ranks_in_group(parallel_mode) + + return ranks_in_group[(local_rank + 1) % world_size] + + def get_prev_global_rank(self, parallel_mode: ParallelMode): + """Returns the global rank of the previous device. + + Args: + parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + + Raises: + AssertionError: Raises an AssertionError if `parallel_mode` is not an instance + of :class:`colossalai.context.ParallelMode`. + + Returns: + int: The global rank of the previous device for `parallel_mode`. + """ + self._check_parallel_mode(parallel_mode) + + # get rank and world size + local_rank = self.get_local_rank(parallel_mode) + world_size = self.get_world_size(parallel_mode) + ranks_in_group = self.get_ranks_in_group(parallel_mode) + + return ranks_in_group[(local_rank - 1) % world_size] + + def is_first_rank(self, parallel_mode: ParallelMode): + """Returns a boolean value indicating whether the current device is the first one + among its group for `parallel_mode`. + + Args: + parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + + Raises: + AssertionError: Raises an AssertionError if `parallel_mode` is not an instance + of :class:`colossalai.context.ParallelMode`. + + Returns: + bool: a boolean value indicating whether the current device is the first one + among its group for `parallel_mode`. + """ + rank = self.get_local_rank(parallel_mode) + return rank == 0 + + def is_last_rank(self, parallel_mode: ParallelMode): + """Returns a boolean value indicating whether the current device is the last one + among its group for `parallel_mode`. + + Args: + parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + + Raises: + AssertionError: Raises an AssertionError if `parallel_mode` is not an instance + of :class:`colossalai.context.ParallelMode`. + + Returns: + bool: a boolean value indicating whether the current device is the first one + among its group for `parallel_mode`. + """ + rank = self.get_local_rank(parallel_mode) + world_size = self.get_world_size(parallel_mode) + return rank == world_size - 1 + + def is_pipeline_first_stage(self, ignore_virtual=False): + if not ignore_virtual: + if self.virtual_pipeline_parallel_size is not None and self.virtual_pipeline_parallel_rank != 0: + return False + return self.is_first_rank(ParallelMode.PIPELINE) + + def is_pipeline_last_stage(self, ignore_virtual=False): + if not ignore_virtual: + if self.virtual_pipeline_parallel_size \ + is not None and self.virtual_pipeline_parallel_rank != self.virtual_pipeline_parallel_size - 1: + return False + return self.is_last_rank(ParallelMode.PIPELINE) + + def get_world_size(self, parallel_mode: ParallelMode): + """Returns the world size for `parallel_mode`. + + Args: + parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + + Raises: + AssertionError: Raises an AssertionError if `parallel_mode` is not an instance + of :class:`colossalai.context.ParallelMode`. + + Returns: + int: The world size for `parallel_mode`. + """ + self._check_parallel_mode(parallel_mode) + return self._world_sizes[parallel_mode] + + def _add_world_size(self, parallel_mode: ParallelMode, world_size: int): + """Adds world size for `parallel_mode`. + + Args: + parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode corresponding to the process group + world_size (int): The world size to be added + + Raises: + AssertionError: Raises an AssertionError if `parallel_mode` is not an instance + of :class:`colossalai.context.ParallelMode`. + """ + self._check_parallel_mode(parallel_mode) + self._world_sizes[parallel_mode] = world_size + + def get_group(self, parallel_mode: ParallelMode): + """Returns the group of the current device for `parallel_mode`. + + Args: + parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + + Raises: + AssertionError: Raises an AssertionError if `parallel_mode` is not an instance + of :class:`colossalai.context.ParallelMode`. + + Returns: + torch.distributed.ProcessGroup: The group of the current device for `parallel_mode`. + """ + self._check_parallel_mode(parallel_mode) + return self._groups[parallel_mode] + + def _add_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup): + """Adds the group of the current device for `parallel_mode`. + + Args: + parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + group (torch.distributed.ProcessGroup): The group to be added + + Raises: + AssertionError: Raises an AssertionError if `parallel_mode` is not an instance + of :class:`colossalai.context.ParallelMode`. + """ + self._check_parallel_mode(parallel_mode) + self._groups[parallel_mode] = group + + def get_cpu_group(self, parallel_mode: ParallelMode): + """Returns the Gloo group of the current device for `parallel_mode`. + + :param parallel_mode: The chosen parallel mode + :type parallel_mode: :class:`colossalai.context.ParallelMode` + :raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance + of :class:`colossalai.context.ParallelMode` + :return: The group of the current device for `parallel_mode` + :rtype: torch.distributed.ProcessGroup + """ + self._check_parallel_mode(parallel_mode) + return self._cpu_groups[parallel_mode] + + def _add_cpu_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup): + """Adds the Gloo group of the current device for `parallel_mode`. + + :param parallel_mode: The chosen parallel mode + :type parallel_mode: :class:`colossalai.context.ParallelMode` + :param group: The group to be added + :type group: torch.distributed.ProcessGroup + :raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance + of :class:`colossalai.context.ParallelMode` + """ + self._check_parallel_mode(parallel_mode) + self._cpu_groups[parallel_mode] = group + + def get_ranks_in_group(self, parallel_mode: ParallelMode): + """Returns the rank of the current device for `parallel_mode` in the group. + + Args: + parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + + Raises: + AssertionError: Raises an AssertionError if `parallel_mode` is not an instance + of :class:`colossalai.context.ParallelMode`. + + Returns: + int: The rank of the current device for `parallel_mode` in the group. + """ + self._check_parallel_mode(parallel_mode) + return self._ranks_in_group[parallel_mode] + + def _add_ranks_in_group(self, parallel_mode: ParallelMode, ranks: list): + """Adds the ranks of the current device for `parallel_mode` in the group. + + Args: + parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + ranks (list): List of ranks to be added + + Raises: + AssertionError: Raises an AssertionError if `parallel_mode` is not an instance + of :class:`colossalai.context.ParallelMode`. + """ + self._check_parallel_mode(parallel_mode) + self._ranks_in_group[parallel_mode] = ranks + + def init_global_dist(self, rank: int, world_size: int, backend: str, host: str, port: int): + """Initializes the global distributed environment + + Args: + rank (int): rank for the default process group. + world_size (int): world size of the default process group. + backend (str): backend for ``torch.distributed`` + host (str): the master address for distributed training. + port (str): the master port for distributed training + """ + # initialize the default process group + init_method = f'tcp://[{host}]:{port}' + dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method) + + # None will give the default global process group for pytorch dist operations + ranks = list(range(world_size)) + cpu_group = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else None + self._register_dist(rank, world_size, dist.GroupMember.WORLD, cpu_group, ranks, ParallelMode.GLOBAL) + self.add_global_rank(ParallelMode.GLOBAL, rank) + + def _register_dist(self, local_rank, world_size, process_group, cpu_group, ranks_in_group, mode): + self._add_local_rank(mode, local_rank) + self._add_world_size(mode, world_size) + self._add_group(mode, process_group) + self._add_cpu_group(mode, cpu_group) + self._add_ranks_in_group(mode, ranks_in_group) + + def check_sanity(self): + """Checks sanity of the parallel context. + + Raises: + AssertionError: Raises an AssertionError if the world size does not equal to the product + of data parallel size, pipeline parallel size and tensor parallel size. + """ + dps = self.data_parallel_size + pps = self.pipeline_parallel_size + tps = self.tensor_parallel_size + ws = self.world_size + assert ws == dps * pps * \ + tps, f"Expected the world size {ws} to be equal to data" \ + f" parallel size ({dps}) * pipeline parallel size " \ + f"({pps}) * tensor parallel size ({tps})" + + def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str): + if key in config: + ele = config[key] + if isinstance(ele, int): + setattr(self, attr_name, ele) + elif isinstance(ele, dict): + setattr(self, attr_name, ele['size']) + else: + raise NotImplementedError( + f'{"Parallel configuration does not support this kind of argument, please use int or dict"}') + + def init_parallel_groups(self): + """Initializes the parallel groups. + + Raises: + AssertionError: Raises an AssertionError if the field parallel is not present in the config file. + """ + + # get rank and world size + rank = self.get_global_rank() + world_size = self.get_world_size(ParallelMode.GLOBAL) + self.world_size = world_size + + # set parallel size as attributes for global context + parallel_config = self.config.get('parallel', None) + if parallel_config is not None: + self._set_parallel_size_from_config(parallel_config, 'pipeline', 'pipeline_parallel_size') + self._set_parallel_size_from_config(parallel_config, 'tensor', 'tensor_parallel_size') + + # the user should not set the data parallel size manually + # instead, it should be calculated based on other parallel config + self.data_parallel_size = self.world_size // (self.pipeline_parallel_size * self.tensor_parallel_size) + + # get the tensor parallel mode and check + tensor_parallel_mode = None + if parallel_config is not None and 'tensor' in \ + parallel_config and 'mode' in parallel_config['tensor']: + tensor_parallel_mode = parallel_config['tensor']['mode'] + assert tensor_parallel_mode in ALLOWED_MODES, \ + f"mode in the parallel config must be set to one of {ALLOWED_MODES}" + env.mode = tensor_parallel_mode + + self.check_sanity() + + pg_init = [] + # LSG: init data parallel process group for compatibility with other parallel module such as zero + pg_init.append(dict(type=INITIALIZER_MAPPING['data'])) + + # LSG: init model parallel process group for compatibility with amp and clip grad + pg_init.append(dict(type=INITIALIZER_MAPPING['model'])) + + if self.pipeline_parallel_size > 1: + pg_init.append(dict(type=INITIALIZER_MAPPING['pipeline'])) + pg_init.append(dict(type=INITIALIZER_MAPPING['tensor'])) + + # init specific tensor parallel group + if tensor_parallel_mode is not None: + tensor_parallel_cfg = parallel_config['tensor'].copy() + + # remove duplicate parameters + tensor_parallel_cfg.pop('mode') + tensor_parallel_cfg.pop('size') + + # add this config to initialize later + pg_init.append(dict(type=INITIALIZER_MAPPING[tensor_parallel_mode.lower()], **tensor_parallel_cfg)) + + # run initialization of different process groups + for initializer_cfg in pg_init: + cfg = initializer_cfg.copy() + initializer_type = cfg.pop('type') + initializer = DIST_GROUP_INITIALIZER.get_module(initializer_type)(rank, world_size, self.config, + self.data_parallel_size, + self.pipeline_parallel_size, + self.tensor_parallel_size, **cfg) + parallel_setting = initializer.init_dist_group() + if isinstance(parallel_setting, list): + for args in parallel_setting: + self._register_dist(*args) + else: + self._register_dist(*parallel_setting) + + def is_initialized(self, parallel_mode: ParallelMode): + """Returns a boolean value indicating whether `parallel_mode` is initialized + in the current system. + + Args: + parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + + Returns: + bool: a boolean value indicating whether `parallel_mode` is initialized in the current system. + """ + return parallel_mode in self._groups + + def destroy(self): + """Destroys the current distributed parallel environment. + """ + for mode, group in self._groups.items(): + if mode is not ParallelMode.GLOBAL: + dist.destroy_process_group(group) + # destroy global process group + dist.destroy_process_group() + self._groups.clear() + + def set_device(self, device_ordinal: int = None): + """Sets distributed processes to be bound to devices. + + Args: + device_ordinal (int, optional): the device id to be bound to + """ + global_rank = self.get_global_rank() + if device_ordinal is None: + devices_per_node = torch.cuda.device_count() + device_ordinal = global_rank % devices_per_node + + torch.cuda.set_device(device_ordinal) + if self._verbose: + self._logger.info(f'process rank {global_rank} is bound to device {device_ordinal}') + + def set_seed(self, seed: int): + """Sets seeds for all random libraries. + + Args: + seed (int): seed for random states + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + global_rank = self.get_global_rank() + + if torch.cuda.is_available(): + # create random seed for different parallel modes + # data parallel seed are kept the same + parallel_seed = seed + add_seed(ParallelMode.DATA, parallel_seed) + + # model parallel seeds are different across ranks + pipeline_offset = self._local_ranks.get(ParallelMode.PIPELINE, 0) + + # add seed for data parallel and tensor parallel only + if self.is_initialized(ParallelMode.TENSOR): + tp_rank = self.get_local_rank(ParallelMode.TENSOR) + # 100 is only to increase the diff in seeds between pipeline stages + tp_rank_with_offset = tp_rank + pipeline_offset * 1024 + tp_seed = seed + tp_rank_with_offset + add_seed(ParallelMode.TENSOR, tp_seed) + + set_mode(ParallelMode.DATA) + seeds = get_seeds() + seed_str = ', '.join([f'{k}: {v}' for k, v in seeds.items()]) + + if self._verbose: + self._logger.info(f"initialized seed on rank {global_rank}, " + f"numpy: {seed}, python random: {seed}, {seed_str}," + f"the default parallel seed is {ParallelMode.DATA}.") + else: + if self._verbose: + self._logger.info( + f"initialized seed on rank {global_rank}, " + f"numpy: {seed}, python random: {seed}, pytorch: {seed}", + ranks=[0]) + self._logger.info( + 'WARNING: CUDA is not available, thus CUDA RNG cannot be used to track CUDA random number states', + ranks=[0]) + + def set_virtual_pipeline_parallel_size(self, size): + self.virtual_pipeline_parallel_size = size + + def set_virtual_pipeline_parallel_rank(self, rank): + self.virtual_pipeline_parallel_rank = rank + + +global_context = ParallelContext() diff --git a/colossalai/context/parallel_mode.py b/colossalai/context/parallel_mode.py new file mode 100644 index 0000000000000000000000000000000000000000..1cf6fa53dc1e5c31fbaf1c9140e0915419af704c --- /dev/null +++ b/colossalai/context/parallel_mode.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from enum import Enum + + +# parallel modes +class ParallelMode(Enum): + """This is an enumeration class containing all possible parallel modes. + """ + + GLOBAL = 'global' + + # common parallel + DATA = 'data' + + # model parallel - containing tensor and pipeline parallel groups + # this is added to facilitate amp and grad clipping in hybrid parallel + MODEL = 'model' + + # pipeline parallel + PIPELINE = 'pipe' + + # containing all ranks in tensor parallel + TENSOR = 'tensor' + + # sequence parallel + SEQUENCE = 'sequence' + SEQUENCE_DP = 'sequence_dp' + + # 1D Parallel + PARALLEL_1D = '1d' + + # 2D parallel + PARALLEL_2D_ROW = '2d_row' + PARALLEL_2D_COL = '2d_col' + + # 3D parallel + PARALLEL_3D_INPUT = '3d_input' + PARALLEL_3D_WEIGHT = '3d_weight' + PARALLEL_3D_OUTPUT = '3d_output' + PARALLEL_3D_INPUT_X_WEIGHT = "3d_input_x_weight" + PARALLEL_3D_OUTPUT_X_WEIGHT = "3d_output_x_weight" + + # 2.5D parallel + PARALLEL_2P5D_ROW = '2p5d_row' + PARALLEL_2P5D_COL = '2p5d_col' + PARALLEL_2P5D_DEP = '2p5d_dep' + PARALLEL_2P5D_XZ = '2p5d_xz' diff --git a/colossalai/context/process_group_initializer/__init__.py b/colossalai/context/process_group_initializer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d3937a9474376f0ecb7af612121bb4c3e5f5a497 --- /dev/null +++ b/colossalai/context/process_group_initializer/__init__.py @@ -0,0 +1,15 @@ +from .initializer_1d import Initializer_1D +from .initializer_2d import Initializer_2D +from .initializer_2p5d import Initializer_2p5D +from .initializer_3d import Initializer_3D +from .initializer_data import Initializer_Data +from .initializer_pipeline import Initializer_Pipeline +from .initializer_sequence import Initializer_Sequence +from .initializer_tensor import Initializer_Tensor +from .initializer_model import Initializer_Model +from .process_group_initializer import ProcessGroupInitializer + +__all__ = [ + 'Initializer_Tensor', 'Initializer_Sequence', 'Initializer_Pipeline', 'Initializer_Data', 'Initializer_2p5D', + 'Initializer_2D', 'Initializer_3D', 'Initializer_1D', 'ProcessGroupInitializer', 'Initializer_Model' +] diff --git a/colossalai/context/process_group_initializer/initializer_1d.py b/colossalai/context/process_group_initializer/initializer_1d.py new file mode 100644 index 0000000000000000000000000000000000000000..4c05028041cef2a9ad453b05d17d35b09ec2617d --- /dev/null +++ b/colossalai/context/process_group_initializer/initializer_1d.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch.distributed as dist +from colossalai.global_variables import tensor_parallel_env as env +from colossalai.registry import DIST_GROUP_INITIALIZER + +from ..parallel_mode import ParallelMode +from .process_group_initializer import ProcessGroupInitializer + + +@DIST_GROUP_INITIALIZER.register_module +class Initializer_1D(ProcessGroupInitializer): + """A ProcessGroupInitializer for 1d tensor parallelism. + + Args: + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.num_group = self.world_size // self.tensor_parallel_size + + def init_dist_group(self): + """Initialize 1D tensor parallel groups, and assign local_ranks and groups to each gpu. + + Returns: + Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode): + 1D tensor parallelism's information in a tuple. + """ + local_rank = None + ranks_in_group = None + process_group = None + cpu_group = None + group_world_size = None + mode = ParallelMode.PARALLEL_1D + env.parallel_input_1d = False + + for i in range(self.num_group): + ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)] + group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode diff --git a/colossalai/context/process_group_initializer/initializer_2d.py b/colossalai/context/process_group_initializer/initializer_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..7fbe3be5901f73b8c670c71582771ab861e9fccd --- /dev/null +++ b/colossalai/context/process_group_initializer/initializer_2d.py @@ -0,0 +1,155 @@ +import math + +import torch.distributed as dist + +from colossalai.global_variables import tensor_parallel_env as env +from colossalai.registry import DIST_GROUP_INITIALIZER + +from ..parallel_mode import ParallelMode +from .process_group_initializer import ProcessGroupInitializer + + +def _check_summa_env_var(summa_dim): + # check environment variable for SUMMA + env_summa_dim = env.summa_dim + + if env_summa_dim: + assert int(env_summa_dim) == summa_dim, \ + 'SUMMA_DIM has been set in the current environment and ' \ + 'does not match with the value passed to this initialized' + else: + env.summa_dim = summa_dim + + +class Initializer_2D_Row(ProcessGroupInitializer): + """2d tensor parallel initialization among rows. + + Args: + num_group (int): The number of all tensor groups. + summa_dim (int): The dimension of SUMMA. + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + """ + + def __init__(self, num_group, summa_dim, *args, **kwargs): + super(Initializer_2D_Row, self).__init__(*args, **kwargs) + self.num_group = num_group + self.summa_dim = summa_dim + + def init_dist_group(self): + """Initialize 2D tensor row parallel groups, and assign local_ranks and groups to each gpu. + Returns: + Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode): + 2D tensor row parallelism's information in a tuple. + """ + local_rank = None + ranks_in_group = None + process_group = None + cpu_group = None + group_world_size = None + mode = ParallelMode.PARALLEL_2D_ROW + + for i in range(self.num_group): + for j in range(self.summa_dim): + ranks = [i * self.tensor_parallel_size + j * self.summa_dim + k for k in range(self.summa_dim)] + group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode + + +class Initializer_2D_Col(ProcessGroupInitializer): + """2d tensor parallel initialization among cols. + + Args: + num_group (int): The number of all tensor groups. + summa_dim (int): The dimension of SUMMA. + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + """ + + def __init__(self, num_group, summa_dim, *args, **kwargs): + super(Initializer_2D_Col, self).__init__(*args, **kwargs) + self.num_group = num_group + self.summa_dim = summa_dim + + def init_dist_group(self): + """Initialize 2D tensor row parallel groups, and assign local_ranks and groups to each gpu. + + Returns: + Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode): + 2D tensor col parallelism's information in a tuple. + """ + local_rank = None + ranks_in_group = None + process_group = None + cpu_group = None + group_world_size = None + mode = ParallelMode.PARALLEL_2D_COL + + for i in range(self.num_group): + for j in range(self.summa_dim): + ranks = [i * self.tensor_parallel_size + j + k * self.summa_dim for k in range(self.summa_dim)] + group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode + + +@DIST_GROUP_INITIALIZER.register_module +class Initializer_2D(ProcessGroupInitializer): + """ + Serve as the single entry point to 2D parallel initialization. + + Args: + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.num_group = self.world_size // self.tensor_parallel_size + self.summa_dim = int(math.sqrt(self.tensor_parallel_size)) + + assert self.tensor_parallel_size == self.summa_dim ** 2, \ + "2D summa dim should equal to tensor parallel size ^ 0.5" + _check_summa_env_var(self.summa_dim) + + self.col_initializer = Initializer_2D_Col(self.num_group, self.summa_dim, *args, **kwargs) + self.row_initializer = Initializer_2D_Row(self.num_group, self.summa_dim, *args, **kwargs) + + def init_dist_group(self): + """Initialize 2D tensor row and col parallel groups, and assign local_ranks and groups to each gpu. + + Returns: + List[Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode)]: + 2D tensor parallelism's information in a list of tuples. + """ + parallel_setting = [self.row_initializer.init_dist_group(), self.col_initializer.init_dist_group()] + return parallel_setting diff --git a/colossalai/context/process_group_initializer/initializer_2p5d.py b/colossalai/context/process_group_initializer/initializer_2p5d.py new file mode 100644 index 0000000000000000000000000000000000000000..6b6fdc5d715c30169f04cef54abd946c4c46b904 --- /dev/null +++ b/colossalai/context/process_group_initializer/initializer_2p5d.py @@ -0,0 +1,297 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import math + +import torch.distributed as dist +from colossalai.context import Config +from colossalai.global_variables import tensor_parallel_env as env +from colossalai.registry import DIST_GROUP_INITIALIZER + +from ..parallel_mode import ParallelMode +from .process_group_initializer import ProcessGroupInitializer + + +def _check_tesseract_env_var(tesseract_dim: int, tesseract_dep: int): + # check global variable for TESSERACT + env_tesseract_dim = env.tesseract_dim + env_tesseract_dep = env.tesseract_dep + + if env_tesseract_dim and env_tesseract_dep: + assert int(env_tesseract_dim) == tesseract_dim, \ + 'TESSERACT_DIM has been set in the current environment and ' \ + 'does not match with the value passed to this initialized' + assert int(env_tesseract_dep) == tesseract_dep, \ + 'TESSERACT_DEP has been set in the current environment and ' \ + 'does not match with the value passed to this initialized' + else: + env.tesseract_dim = tesseract_dim + env.tesseract_dep = tesseract_dep + + +# i row j col k dep +class Initializer_2p5D_ROW(ProcessGroupInitializer): + """2.5d tensor parallel initialization among rows. + + Args: + tesseract_dim (int): The dimension of tesseract. + tesseract_dep (int): The dimension of depth. + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + """ + + def __init__(self, tesseract_dim: int, tesseract_dep: int, *args): + super(Initializer_2p5D_ROW, self).__init__(*args) + self.num_group = self.world_size // self.tensor_parallel_size + self.tesseract_dep = tesseract_dep + self.tesseract_dim = tesseract_dim + assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \ + "Tensor parallel size should be depth * dim ** 2 in 2.5D parallel" + + def init_dist_group(self): + """Initialize 2.5D tensor row parallel groups, and assign local_ranks and groups to each gpu. + + Returns: + Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode): + 2.5D tensor row parallelism's information in a tuple. + """ + local_rank = None + ranks_in_group = None + process_group = None + cpu_group = None + group_world_size = None + mode = ParallelMode.PARALLEL_2P5D_ROW + + for h in range(self.num_group): + for j in range(self.tesseract_dim): + for k in range(self.tesseract_dep): + ranks = [ + h * self.tensor_parallel_size + i + self.tesseract_dim * (j + self.tesseract_dim * k) + for i in range(self.tesseract_dim) + ] + group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode + + +class Initializer_2p5D_Col(ProcessGroupInitializer): + """2.5d tensor parallel initialization among cols. + + Args: + tesseract_dim (int): The dimension of tesseract. + tesseract_dep (int): The dimension of depth. + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + """ + + def __init__(self, tesseract_dim: int, tesseract_dep: int, *args): + super(Initializer_2p5D_Col, self).__init__(*args) + self.num_group = self.world_size // self.tensor_parallel_size + self.tesseract_dep = tesseract_dep + self.tesseract_dim = tesseract_dim + + def init_dist_group(self): + """Initialize 2.5D tensor col parallel groups, and assign local_ranks and groups to each gpu. + + Returns: + Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode): + 2.5D tensor col parallelism's information in a tuple. + """ + local_rank = None + ranks_in_group = None + process_group = None + cpu_group = None + group_world_size = None + mode = ParallelMode.PARALLEL_2P5D_COL + + for h in range(self.num_group): + for i in range(self.tesseract_dim): + for k in range(self.tesseract_dep): + ranks = [ + h * self.tensor_parallel_size + i + self.tesseract_dim * (j + self.tesseract_dim * k) + for j in range(self.tesseract_dim) + ] + group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode + + +class Initializer_2p5D_Dep(ProcessGroupInitializer): + """2.5D tensor parallel initialization among depths. + + Args: + tesseract_dim (int): The dimension of tesseract. + tesseract_dep (int): The dimension of depth. + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + """ + + def __init__(self, tesseract_dim: int, tesseract_dep: int, *args): + super(Initializer_2p5D_Dep, self).__init__(*args) + self.num_group = self.world_size // self.tensor_parallel_size + self.tesseract_dep = tesseract_dep + self.tesseract_dim = tesseract_dim + + def init_dist_group(self): + """Initialize 2.5D tensor depth parallel groups, and assign local_ranks and groups to each gpu. + + Returns: + Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode): + 2.5D tensor depth parallelism's information in a tuple. + """ + local_rank = None + ranks_in_group = None + process_group = None + cpu_group = None + group_world_size = None + mode = ParallelMode.PARALLEL_2P5D_DEP + + for h in range(self.num_group): + for i in range(self.tesseract_dim): + for j in range(self.tesseract_dim): + ranks = [ + h * self.tensor_parallel_size + i + self.tesseract_dim * (j + self.tesseract_dim * k) + for k in range(self.tesseract_dep) + ] + group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode + + +# i row j col k dep +class Initializer_2p5D_XZ(ProcessGroupInitializer): + """2.5d tensor parallel initialization among cols times dep. + + Args: + tesseract_dim (int): The dimension of tesseract. + tesseract_dep (int): The dimension of depth. + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + """ + + def __init__(self, tesseract_dim: int, tesseract_dep: int, *args): + super(Initializer_2p5D_XZ, self).__init__(*args) + self.num_group = self.world_size // self.tensor_parallel_size + self.tesseract_dep = tesseract_dep + self.tesseract_dim = tesseract_dim + + def init_dist_group(self): + """Initialize 2.5D tensor colXdepth parallel groups, and assign local_ranks and groups to each gpu. + + Returns: + Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode): + 2.5D tensor colXdepth parallelism's information in a tuple. + """ + local_rank = None + ranks_in_group = None + process_group = None + cpu_group = None + group_world_size = None + mode = ParallelMode.PARALLEL_2P5D_XZ + + for h in range(self.num_group): + for i in range(self.tesseract_dim): + ranks = [ + h * self.tensor_parallel_size + i + self.tesseract_dim * (j + self.tesseract_dim * k) + for k in range(self.tesseract_dep) + for j in range(self.tesseract_dim) + ] + group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode + + +@DIST_GROUP_INITIALIZER.register_module +class Initializer_2p5D(ProcessGroupInitializer): + """ + Serve as the single entry point to Tesseract parallel initialization. + + Args: + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + depth (int): The depth of 2.5d parallel. + """ + + def __init__(self, rank: int, world_size: int, config: Config, data_parallel_size: int, pipeline_parallel_size: int, + tensor_parallel_size: int, depth: int): + args = (rank, world_size, config, data_parallel_size, pipeline_parallel_size, tensor_parallel_size) + super().__init__(*args) + self.num_group = self.world_size // self.tensor_parallel_size + self.tesseract_dim = int(math.sqrt(self.tensor_parallel_size / depth)) + self.tesseract_dep = depth + + assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \ + "2.5D tesseract dim should equal to (tensor parallel size / tesseract dep) ^ 0.5" + _check_tesseract_env_var(self.tesseract_dim, self.tesseract_dep) + + self.col_initializer = Initializer_2p5D_Col(self.tesseract_dim, self.tesseract_dep, *args) + self.row_initializer = Initializer_2p5D_ROW(self.tesseract_dim, self.tesseract_dep, *args) + self.dep_initializer = Initializer_2p5D_Dep(self.tesseract_dim, self.tesseract_dep, *args) + self.xz_initializer = Initializer_2p5D_XZ(self.tesseract_dim, self.tesseract_dep, *args) + + def init_dist_group(self): + """Initialize 2.5D tensor row, col, depth, and colXdepth parallel groups, and assign local_ranks and groups to each gpu. + + Returns: + List[Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode)]: + Whole 2.5D tensor parallelism's information in a list of tuples. + """ + parallel_setting = [ + self.col_initializer.init_dist_group(), + self.row_initializer.init_dist_group(), + self.dep_initializer.init_dist_group(), + self.xz_initializer.init_dist_group() + ] + return parallel_setting diff --git a/colossalai/context/process_group_initializer/initializer_3d.py b/colossalai/context/process_group_initializer/initializer_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..1ed8eec86efc83315ee8b549a9a035bc36dca6da --- /dev/null +++ b/colossalai/context/process_group_initializer/initializer_3d.py @@ -0,0 +1,329 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import math + +import torch.distributed as dist + +from colossalai.global_variables import tensor_parallel_env as env +from colossalai.registry import DIST_GROUP_INITIALIZER + +from ..parallel_mode import ParallelMode +from .process_group_initializer import ProcessGroupInitializer + + +def _check_depth_env_var(depth): + # check global variable + env_depth = env.depth_3d + + if env_depth: + assert int(env_depth) == depth, \ + 'DEPTH_3D has been set in the current environment and ' \ + 'does not match with the value passed to this initialized' + else: + env.depth_3d = depth + + +class Initializer_3D_Input(ProcessGroupInitializer): + """3D tensor parallel initialization among input. + + Args: + num_group (int): The number of all tensor groups. + depth (int): Depth of 3D parallelism. + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + """ + + def __init__(self, num_group: int, depth: int, *args): + super().__init__(*args) + self.num_group = num_group + self.depth = depth + + def init_dist_group(self): + """Initialize 3D tensor parallel groups among input, and assign local_ranks and groups to each gpu. + + Returns: + Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode): + 3D tensor parallelism's information among input in a tuple. + """ + local_rank = None + ranks_in_group = None + process_group = None + cpu_group = None + group_world_size = None + mode = ParallelMode.PARALLEL_3D_INPUT + env.input_group_3d = mode + + for h in range(self.num_group): + for i in range(self.depth): + for k in range(self.depth): + ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for j in range(self.depth)] + group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode + + +class Initializer_3D_Weight(ProcessGroupInitializer): + """3D tensor parallel initialization among weight. + + Args: + num_group (int): The number of all tensor groups. + depth (int): Depth of 3D parallelism. + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + """ + + def __init__(self, num_group: int, depth: int, *args): + super().__init__(*args) + self.num_group = num_group + self.depth = depth + + def init_dist_group(self): + """Initialize 3D tensor parallel groups among weight, and assign local_ranks and groups to each gpu. + + Returns: + Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode): + 3D tensor parallelism's information among weight in a tuple. + """ + local_rank = None + ranks_in_group = None + process_group = None + cpu_group = None + group_world_size = None + mode = ParallelMode.PARALLEL_3D_WEIGHT + env.weight_group_3d = mode + + for h in range(self.num_group): + for k in range(self.depth): + for j in range(self.depth): + ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for i in range(self.depth)] + group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode + + +class Initializer_3D_Output(ProcessGroupInitializer): + """3D tensor parallel initialization among output. + + Args: + num_group (int): The number of all tensor groups. + depth (int): Depth of 3D parallelism. + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + """ + + def __init__(self, num_group: int, depth: int, *args): + super().__init__(*args) + self.num_group = num_group + self.depth = depth + + def init_dist_group(self): + """Initialize 3D tensor parallel groups among output, and assign local_ranks and groups to each gpu. + + Returns: + Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode): + 3D tensor parallelism's information among output in a tuple. + """ + local_rank = None + ranks_in_group = None + process_group = None + cpu_group = None + group_world_size = None + mode = ParallelMode.PARALLEL_3D_OUTPUT + env.output_group_3d = mode + + for h in range(self.num_group): + for i in range(self.depth): + for j in range(self.depth): + ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for k in range(self.depth)] + group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode + + +class Initializer_3D_InputxWeight(ProcessGroupInitializer): + """3D tensor parallel initialization among input. + + Args: + num_group (int): The number of all tensor groups. + depth (int): Depth of 3D parallelism. + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + """ + + def __init__(self, num_group: int, depth: int, *args): + super().__init__(*args) + self.num_group = num_group + self.depth = depth + + def init_dist_group(self): + """Initialize 3D tensor parallel groups among input, and assign local_ranks and groups to each gpu. + + Returns: + Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode): + 3D tensor parallelism's information among input in a tuple. + """ + local_rank = None + ranks_in_group = None + process_group = None + cpu_group = None + group_world_size = None + mode = ParallelMode.PARALLEL_3D_INPUT_X_WEIGHT + env.input_x_weight_group_3d = mode + + for h in range(self.num_group): + for k in range(self.depth): + ranks = [ + h * self.depth**3 + i + self.depth * (j + self.depth * k) + for j in range(self.depth) + for i in range(self.depth) + ] + group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode + + +class Initializer_3D_OutputxWeight(ProcessGroupInitializer): + """3D tensor parallel initialization among input. + + Args: + num_group (int): The number of all tensor groups. + depth (int): Depth of 3D parallelism. + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + """ + + def __init__(self, num_group: int, depth: int, *args): + super().__init__(*args) + self.num_group = num_group + self.depth = depth + + def init_dist_group(self): + """Initialize 3D tensor parallel groups among input, and assign local_ranks and groups to each gpu. + + Returns: + Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode): + 3D tensor parallelism's information among input in a tuple. + """ + local_rank = None + ranks_in_group = None + process_group = None + cpu_group = None + group_world_size = None + mode = ParallelMode.PARALLEL_3D_OUTPUT_X_WEIGHT + env.output_x_weight_group_3d = mode + + for h in range(self.num_group): + for j in range(self.depth): + ranks = [ + h * self.depth**3 + i + self.depth * (j + self.depth * k) + for k in range(self.depth) + for i in range(self.depth) + ] + group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode + + +@DIST_GROUP_INITIALIZER.register_module +class Initializer_3D(ProcessGroupInitializer): + """Serve as the single entry point to 3D parallel initialization. + + Args: + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + """ + + def __init__(self, *args): + super().__init__(*args) + self.num_group = self.world_size // self.tensor_parallel_size + self.depth = round(math.pow(self.tensor_parallel_size, 1 / 3)) + assert self.tensor_parallel_size == self.depth ** 3, \ + f'3D depth ({self.depth}) if not cube root of tensor parallel size ({self.tensor_parallel_size})' + _check_depth_env_var(self.depth) + + self.input_initializer = Initializer_3D_Input(self.num_group, self.depth, *args) + self.weight_initializer = Initializer_3D_Weight(self.num_group, self.depth, *args) + self.output_initializer = Initializer_3D_Output(self.num_group, self.depth, *args) + self.input_x_weight_initializer = Initializer_3D_InputxWeight(self.num_group, self.depth, *args) + self.output_x_weight_initializer = Initializer_3D_OutputxWeight(self.num_group, self.depth, *args) + + def init_dist_group(self): + """Initialize 3D tensor parallel groups, and assign local_ranks and groups to each gpu. + + Returns: + List[Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode)]: + Whole 3D tensor parallelism's information in a list of tuples. + """ + parallel_setting = [ + self.input_initializer.init_dist_group(), + self.weight_initializer.init_dist_group(), + self.output_initializer.init_dist_group(), + self.input_x_weight_initializer.init_dist_group(), + self.output_x_weight_initializer.init_dist_group() + ] + return parallel_setting diff --git a/colossalai/context/process_group_initializer/initializer_data.py b/colossalai/context/process_group_initializer/initializer_data.py new file mode 100644 index 0000000000000000000000000000000000000000..9715ebff7f00f0fc8a3f13a5dfca436c9b0e144b --- /dev/null +++ b/colossalai/context/process_group_initializer/initializer_data.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from torch import distributed as dist + +from colossalai.registry import DIST_GROUP_INITIALIZER + +from ..parallel_mode import ParallelMode +from .process_group_initializer import ProcessGroupInitializer + + +@DIST_GROUP_INITIALIZER.register_module +class Initializer_Data(ProcessGroupInitializer): + """A ProcessGroupInitializer for data parallelism. + + Args: + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.num_data_parallel_group = self.world_size // self.data_parallel_size + + def init_dist_group(self): + """Initialize data parallel groups, and assign local_ranks and groups to each gpu. + + Returns: + Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode): + A Data parallelism's information tuple. + """ + local_rank = None + ranks_in_group = None + process_group = None + cpu_group = None + group_world_size = None + mode = ParallelMode.DATA + + for i in range(self.num_data_parallel_group): + ranks = [i + j * self.num_data_parallel_group for j in range(self.data_parallel_size)] + group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode diff --git a/colossalai/context/process_group_initializer/initializer_model.py b/colossalai/context/process_group_initializer/initializer_model.py new file mode 100644 index 0000000000000000000000000000000000000000..99b9cc0d4edce35915c52c01fa5875545256ba97 --- /dev/null +++ b/colossalai/context/process_group_initializer/initializer_model.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch.distributed as dist +from colossalai.registry import DIST_GROUP_INITIALIZER +from .process_group_initializer import ProcessGroupInitializer +from ..parallel_mode import ParallelMode + + +@DIST_GROUP_INITIALIZER.register_module +class Initializer_Model(ProcessGroupInitializer): + """A ProcessGroupInitializer for model parallelism (model parallel group contains pipeline and tensor parallel + groups). + + Args: + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.model_parallel_size = self.tensor_parallel_size * self.pipeline_parallel_size + self.num_group = self.world_size // self.model_parallel_size + + def init_dist_group(self): + """Initialize model parallel groups, and assign local_ranks and groups to each gpu. + + Returns: + Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode): + A Model parallelism's information tuple. + """ + local_rank = None + ranks_in_group = None + process_group = None + cpu_group = None + group_world_size = None + mode = ParallelMode.MODEL + + for i in range(self.num_group): + ranks = [i * self.model_parallel_size + j for j in range(self.model_parallel_size)] + group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode diff --git a/colossalai/context/process_group_initializer/initializer_pipeline.py b/colossalai/context/process_group_initializer/initializer_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..0ddb52f63e22f29aff9920d5cdd2aba1748e1eb6 --- /dev/null +++ b/colossalai/context/process_group_initializer/initializer_pipeline.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from torch import distributed as dist + +from colossalai.registry import DIST_GROUP_INITIALIZER + +from ..parallel_mode import ParallelMode +from .process_group_initializer import ProcessGroupInitializer + + +@DIST_GROUP_INITIALIZER.register_module +class Initializer_Pipeline(ProcessGroupInitializer): + """A ProcessGroupInitializer for pipeline parallelism. + + Args: + rank (int): The rank of current process + world_size (int): Size of whole communication world + config (Config): Running configuration + data_parallel_size (int): Size of data parallel + pipeline_parallel_size (int): Size of pipeline parallel + tensor_parallel_size (int): Size of tensor parallel + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.data_group_size = self.world_size // self.data_parallel_size + self.pipeline_stage_size = self.data_group_size // self.pipeline_parallel_size + + def init_dist_group(self): + """Initialize pipeline parallel groups, and assign local_ranks and groups to each gpu. + + Returns: + List[Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode)]: + A Pipeline parallelism's information in list of tuples. + """ + dist_settings = list() + for i in range(self.data_parallel_size): + for j in range(self.pipeline_stage_size): + pipe_ranks = list( + range(i * self.data_group_size + j, (i + 1) * self.data_group_size, self.pipeline_stage_size)) + pipe_group_size = len(pipe_ranks) + pipe_group = dist.new_group(pipe_ranks) + group_cpu = dist.new_group(pipe_ranks, backend='gloo') if dist.get_backend() != 'gloo' else pipe_group + + if self.rank in pipe_ranks: + local_rank = pipe_ranks.index(self.rank) + group_world_size = pipe_group_size + process_group = pipe_group + cpu_group = group_cpu + ranks_in_group = pipe_ranks + dist_settings.append( + tuple((local_rank, group_world_size, process_group, cpu_group, ranks_in_group, + ParallelMode.PIPELINE))) + + return dist_settings diff --git a/colossalai/context/process_group_initializer/initializer_sequence.py b/colossalai/context/process_group_initializer/initializer_sequence.py new file mode 100644 index 0000000000000000000000000000000000000000..eaacb14d22825db7913e1c87cfe08063ab5865ee --- /dev/null +++ b/colossalai/context/process_group_initializer/initializer_sequence.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- +import torch.distributed as dist + +from colossalai.registry import DIST_GROUP_INITIALIZER + +from ..parallel_mode import ParallelMode +from .initializer_tensor import Initializer_Tensor +from .process_group_initializer import ProcessGroupInitializer + + +@DIST_GROUP_INITIALIZER.register_module +class Initializer_Sequence_DP(ProcessGroupInitializer): + """A ProcessGroupInitializer for sequence parallelism all-reduce. + + In Sequence Parallelism, each GPU holds the full copy of model weights, + thus, gradient all-reduce occurs across all processes in the same pipeline stage + + Args: + rank (int): The rank of current process + world_size (int): Size of whole communication world + config (Config): Running configuration + data_parallel_size (int): Size of data parallel + pipeline_parallel_size (int): Size of pipeline parallel + tensor_parallel_size (int): Size of tensor parallel + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.dp_size = self.world_size // self.pipeline_parallel_size + self.num_group = self.pipeline_parallel_size + + def init_dist_group(self): + """Initialize Sequence Parallel process groups used for gradient all-reduce. + + Returns: + Tuple: A tuple (local_rank, group_world_size, process_group, ranks_in_group, mode). + """ + local_rank = None + ranks_in_group = None + process_group = None + cpu_group = None + group_world_size = None + mode = ParallelMode.SEQUENCE_DP + + for i in range(self.num_group): + ranks = [i * self.dp_size + j for j in range(self.dp_size)] + group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode + + +@DIST_GROUP_INITIALIZER.register_module +class Initializer_Sequence(ProcessGroupInitializer): + """A ProcessGroupInitializer for sequence parallelism. + + Args: + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # reuse tensor parallel initializer code + self._sequence_initializer = Initializer_Tensor(*args, **kwargs) + self._sequence_dp_initializer = Initializer_Sequence_DP(*args, **kwargs) + + def init_dist_group(self): + """Initialize Sequence parallel process groups and assign local_ranks and groups to each gpu. + + Sequence parallelism requires 2 process groups. The first is for model forward where several processes + exchange partial query, key and value embedding to compute self attention values. The second is for + all-reduce to synchronize the model parameters. + + Returns: + List[Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode)]: + A Sequence parallelism's information in list of tuples. + """ + + parallel_setting = [] + + local_rank, group_world_size, process_group, cpu_grop, ranks_in_group, mode = \ + self._sequence_initializer.init_dist_group() + # change mode to sequence + mode = ParallelMode.SEQUENCE + + parallel_setting.append((local_rank, group_world_size, process_group, cpu_grop, ranks_in_group, mode)) + parallel_setting.append(self._sequence_dp_initializer.init_dist_group()) + return parallel_setting diff --git a/colossalai/context/process_group_initializer/initializer_tensor.py b/colossalai/context/process_group_initializer/initializer_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..d2b5be9cfffbe9eb7234411c6526d4055c078f12 --- /dev/null +++ b/colossalai/context/process_group_initializer/initializer_tensor.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch.distributed as dist + +from colossalai.registry import DIST_GROUP_INITIALIZER +from .process_group_initializer import ProcessGroupInitializer +from ..parallel_mode import ParallelMode + + +@DIST_GROUP_INITIALIZER.register_module +class Initializer_Tensor(ProcessGroupInitializer): + """A ProcessGroupInitializer for tensor parallelism. + + Args: + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.num_tensor_parallel_group = self.world_size // self.tensor_parallel_size + + def init_dist_group(self): + """Initialize tensor parallel groups, and assign local_ranks and groups to each gpu. + + Returns: + Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode): + A Tensor parallelism's information tuple. + """ + local_rank = None + ranks_in_group = None + process_group = None + cpu_group = None + group_world_size = None + mode = ParallelMode.TENSOR + + for i in range(self.num_tensor_parallel_group): + ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)] + group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode diff --git a/colossalai/context/process_group_initializer/process_group_initializer.py b/colossalai/context/process_group_initializer/process_group_initializer.py new file mode 100644 index 0000000000000000000000000000000000000000..98150ce8e428a3b9bf81185719685b38efc2bdfd --- /dev/null +++ b/colossalai/context/process_group_initializer/process_group_initializer.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from abc import ABC, abstractmethod + +from colossalai.context import Config + + +class ProcessGroupInitializer(ABC): + """An object, knowing the parallelism configuration, that initializes parallel groups. + + Args: + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + """ + + def __init__(self, rank: int, world_size: int, config: Config, data_parallel_size: int, pipeline_parallel_size: int, + tensor_parallel_size: int): + self.rank = rank + self.world_size = world_size + self.data_parallel_size = data_parallel_size + self.config = config + self.pipeline_parallel_size = pipeline_parallel_size + self.tensor_parallel_size = tensor_parallel_size + super().__init__() + + @abstractmethod + def init_dist_group(self): + pass diff --git a/colossalai/context/random/__init__.py b/colossalai/context/random/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d64b993257c1574706ee5028224692b4e666fc19 --- /dev/null +++ b/colossalai/context/random/__init__.py @@ -0,0 +1,18 @@ +from ._helper import ( + add_seed, + get_current_mode, + get_seeds, + get_states, + moe_set_seed, + reset_seeds, + seed, + set_mode, + set_seed_states, + sync_states, + with_seed, +) + +__all__ = [ + 'seed', 'set_mode', 'with_seed', 'add_seed', 'get_seeds', 'get_states', 'get_current_mode', 'set_seed_states', + 'sync_states', 'moe_set_seed', 'reset_seeds' +] diff --git a/colossalai/context/random/_helper.py b/colossalai/context/random/_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..973c4d9faa325820aa1dedc5e133551430778057 --- /dev/null +++ b/colossalai/context/random/_helper.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import functools +from contextlib import contextmanager + +import torch.cuda +from torch import Tensor + +from .seed_manager import SeedManager +from ..parallel_mode import ParallelMode + +_SEED_MANAGER = SeedManager() + + +def get_seeds(): + """Returns the seeds of the seed manager. + + Returns: + dict: The seeds of the seed manager. + """ + return _SEED_MANAGER.seeds + + +def get_states(copy=False): + """Returns the seed states of the seed manager. + + Returns: + dict: The seed states of the seed manager. + """ + states = _SEED_MANAGER.seed_states + + if copy: + new_states = dict() + + for parallel_mode, state in states.items(): + new_states[parallel_mode] = state.clone() + return new_states + else: + return _SEED_MANAGER.seed_states + + +def get_current_mode(): + """Returns the current mode of the seed manager. + + Returns: + :class:`torch.ByteTensor`: The current mode of the seed manager. + """ + return _SEED_MANAGER.current_mode + + +def add_seed(parallel_mode: ParallelMode, seed: int, overwrite: bool = False): + """Adds a seed to the seed manager for `parallel_mode`. + + Args: + parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + seed (int): The seed to be added + Raises: + AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of + :class:`colossalai.context.ParallelMode` or the seed for `parallel_mode` has been added. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + _SEED_MANAGER.add_seed(parallel_mode, seed, overwrite) + + +def set_mode(parallel_mode: ParallelMode): + """Sets the current mode of the seed manager. + + Args: + parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + _SEED_MANAGER.set_mode(parallel_mode) + + +def set_seed_states(parallel_mode: ParallelMode, state: Tensor): + """Sets the state of the seed manager for `parallel_mode`. + + Args: + parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + state (:class:`torch.Tensor`): the state to be set. + + Raises: + AssertionError: Raises an AssertionError if `parallel_mode` is not found in the seed manager. + """ + _SEED_MANAGER.set_state(parallel_mode, state) + + +def sync_states(): + current_mode = get_current_mode() + current_states = torch.cuda.get_rng_state() + set_seed_states(current_mode, current_states) + + +@contextmanager +def seed(parallel_mode: ParallelMode): + """ A context for seed switch + + Examples: + + >>> with seed(ParallelMode.DATA): + >>> output = F.dropout(input) + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + try: + # set to new mode + current_mode = _SEED_MANAGER.current_mode + yield _SEED_MANAGER.set_mode(parallel_mode) + finally: + # recover + _SEED_MANAGER.set_mode(current_mode) + + +def with_seed(func, parallel_mode: ParallelMode): + """ + A function wrapper which executes the function with a specified seed. + + Examples: + + >>> # use with decorator + >>> @with_seed(ParallelMode.DATA) + >>> def forward(input): + >>> return F.dropout(input) + >>> out = forward(input) + >>> # OR use it inline + >>> def forward(input): + >>> return F.dropout(input) + >>> wrapper_forward = with_seed(forward, ParallelMode.DATA) + >>> out = wrapped_forward(input) + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + # switch mode + current_mode = _SEED_MANAGER.current_mode + _SEED_MANAGER.set_mode(parallel_mode) + + # exec func + out = func(*args, **kwargs) + + # recover state + _SEED_MANAGER.set_mode(current_mode) + + return out + + return wrapper + + +def moe_set_seed(seed): + if torch.cuda.is_available(): + from colossalai.core import global_context as gpc + global_rank = gpc.get_global_rank() + diff_seed = seed + global_rank + add_seed(ParallelMode.TENSOR, diff_seed, True) + print(f"moe seed condition: {global_rank} with tensor seed {diff_seed}", flush=True) + + +def reset_seeds(): + _SEED_MANAGER.reset() diff --git a/colossalai/context/random/seed_manager.py b/colossalai/context/random/seed_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..956f9001200d8706bbd45e1c9b09a175ff10b82d --- /dev/null +++ b/colossalai/context/random/seed_manager.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch +from torch import Tensor + +from colossalai.context.parallel_mode import ParallelMode + + +class SeedManager: + """This class is a manager of all random seeds involved in the system. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + + def __init__(self): + self._current_mode = None + self._seeds = dict() + self._seed_states = dict() + + @property + def current_mode(self): + return self._current_mode + + @property + def seeds(self): + return self._seeds + + @property + def seed_states(self): + return self._seed_states + + def set_state(self, parallel_mode: ParallelMode, state: Tensor): + """Sets the state of the seed manager for `parallel_mode`. + + Args: + parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + state (:class:`torch.Tensor`): the state to be set. + + Raises: + AssertionError: Raises an AssertionError if `parallel_mode` is not found in the seed manager. + """ + assert parallel_mode in self._seed_states, f'Parallel mode {parallel_mode} is not found in the seed manager' + self._seed_states[parallel_mode] = state + + def set_mode(self, parallel_mode: ParallelMode): + """Sets the current mode of the seed manager. + + Args: + parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + """ + if self.current_mode: + # save the current state for current mode + self._seed_states[self._current_mode] = torch.cuda.get_rng_state() + + # set the new state for new mode + self._current_mode = parallel_mode + torch.cuda.set_rng_state(self._seed_states[parallel_mode]) + + def add_seed(self, parallel_mode: ParallelMode, seed: int, overwrite: bool = False): + """Adds a seed to the seed manager for `parallel_mode`. + + Args: + parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + seed (int): The seed to be added. + overwrite (bool, optional): Whether allows to overwrite the seed that has been set already + + Raises: + AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of :class:`colossalai.context.ParallelMode` + or the seed for `parallel_mode` has been added. + """ + assert isinstance(parallel_mode, ParallelMode), 'A valid ParallelMode must be provided' + if overwrite is False: + assert parallel_mode not in self._seed_states, f'The seed for {parallel_mode} has been added' + elif parallel_mode in self._seed_states: + print(f"Warning: {parallel_mode} seed has been overwritten.", flush=True) + + current_state = torch.cuda.get_rng_state() + torch.cuda.manual_seed(seed) + self._seed_states[parallel_mode] = torch.cuda.get_rng_state() + self._seeds[parallel_mode] = seed + torch.cuda.set_rng_state(current_state) + + def reset(self): + self._current_mode = None + self._seeds = dict() + self._seed_states = dict() diff --git a/colossalai/context/singleton_meta.py b/colossalai/context/singleton_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..8ca335119d52ad2a212b1e0c578202b2fc6bb60f --- /dev/null +++ b/colossalai/context/singleton_meta.py @@ -0,0 +1,21 @@ +class SingletonMeta(type): + """ + The Singleton class can be implemented in different ways in Python. Some + possible methods include: base class, decorator, metaclass. We will use the + metaclass because it is best suited for this purpose. + """ + + _instances = {} + + def __call__(cls, *args, **kwargs): + """ + Possible changes to the value of the `__init__` argument do not affect + the returned instance. + """ + if cls not in cls._instances: + instance = super().__call__(*args, **kwargs) + cls._instances[cls] = instance + else: + assert len(args) == 0 and len( + kwargs) == 0, f'{cls.__name__} is a singleton class and a instance has been created.' + return cls._instances[cls] diff --git a/colossalai/core.py b/colossalai/core.py new file mode 100644 index 0000000000000000000000000000000000000000..153247bbed9c65db0b2255247137fa9a64a693fa --- /dev/null +++ b/colossalai/core.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from colossalai.context.parallel_context import global_context + +__all__ = ['global_context'] \ No newline at end of file diff --git a/colossalai/device/__init__.py b/colossalai/device/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..689189998c3f6490145ba2648c522570c6f40b4c --- /dev/null +++ b/colossalai/device/__init__.py @@ -0,0 +1,4 @@ +from .alpha_beta_profiler import AlphaBetaProfiler +from .calc_pipeline_strategy import alpa_dp + +__all__ = ['AlphaBetaProfiler', 'alpa_dp'] diff --git a/colossalai/device/alpha_beta_profiler.py b/colossalai/device/alpha_beta_profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..af2b10928c6f2f99a429aaa413d527d77d52faf0 --- /dev/null +++ b/colossalai/device/alpha_beta_profiler.py @@ -0,0 +1,388 @@ +import math +import time +from typing import Dict, List, Tuple + +import torch +import torch.distributed as dist + +from colossalai.logging import get_dist_logger + +GB = int((1 << 30)) +BYTE = 4 +FRAMEWORK_LATENCY = 0 + + +class AlphaBetaProfiler: + ''' + Profile alpha and beta value for a given device list. + + Usage: + # Note: the environment of execution is supposed to be + # multi-process with multi-gpu in mpi style. + >>> physical_devices = [0, 1, 4, 5] + >>> ab_profiler = AlphaBetaProfiler(physical_devices) + >>> ab_dict = profiler.alpha_beta_dict + >>> print(ab_dict) + {(0, 1): (1.9641406834125518e-05, 4.74049549614719e-12), (0, 4): (1.9506998360157013e-05, 6.97421973297474e-11), (0, 5): (2.293858677148819e-05, 7.129930361393644e-11), + (1, 4): (1.9010603427886962e-05, 7.077968863788975e-11), (1, 5): (1.9807778298854827e-05, 6.928845708992215e-11), (4, 5): (1.8681809306144713e-05, 4.7522367291330524e-12), + (1, 0): (1.9641406834125518e-05, 4.74049549614719e-12), (4, 0): (1.9506998360157013e-05, 6.97421973297474e-11), (5, 0): (2.293858677148819e-05, 7.129930361393644e-11), + (4, 1): (1.9010603427886962e-05, 7.077968863788975e-11), (5, 1): (1.9807778298854827e-05, 6.928845708992215e-11), (5, 4): (1.8681809306144713e-05, 4.7522367291330524e-12)} + ''' + + def __init__(self, + physical_devices: List[int], + alpha_beta_dict: Dict[Tuple[int, int], Tuple[float, float]] = None, + ctype: str = 'a', + warmup: int = 5, + repeat: int = 25, + latency_iters: int = 5, + homogeneous_tolerance: float = 0.1): + ''' + Args: + physical_devices: A list of device id, each element inside it is the global rank of that device. + alpha_beta_dict: A dict which maps a process group to alpha-beta value pairs. + ctype: 'a' for all-reduce, 'b' for broadcast. + warmup: Number of warmup iterations. + repeat: Number of iterations to measure. + latency_iters: Number of iterations to measure latency. + ''' + self.physical_devices = physical_devices + self.ctype = ctype + self.world_size = len(physical_devices) + self.warmup = warmup + self.repeat = repeat + self.latency_iters = latency_iters + self.homogeneous_tolerance = homogeneous_tolerance + self.process_group_dict = None + self._init_profiling() + if alpha_beta_dict is None: + self.alpha_beta_dict = self.profile_ab() + else: + self.alpha_beta_dict = alpha_beta_dict + + def _init_profiling(self): + # Create process group list based on its global rank + process_group_list = [] + for f_index in range(self.world_size - 1): + for b_index in range(f_index + 1, self.world_size): + process_group_list.append((self.physical_devices[f_index], self.physical_devices[b_index])) + + # Create process group dict which maps process group to its handler + process_group_dict = {} + for process_group in process_group_list: + pg_handler = dist.new_group(process_group) + process_group_dict[process_group] = pg_handler + + self.process_group_dict = process_group_dict + + def _profile(self, process_group, pg_handler, nbytes): + logger = get_dist_logger() + rank = dist.get_rank() + src_device_num = process_group[0] + world_size = len(process_group) + + device = torch.cuda.current_device() + buf = torch.randn(nbytes // 4).to(device) + + torch.cuda.synchronize() + # warmup + for _ in range(self.warmup): + if self.ctype == "a": + dist.all_reduce(buf, op=dist.ReduceOp.SUM, group=pg_handler) + elif self.ctype == "b": + dist.broadcast(buf, src=src_device_num, group=pg_handler) + torch.cuda.synchronize() + + dist.barrier(group=pg_handler) + begin = time.perf_counter() + for _ in range(self.repeat): + if self.ctype == "a": + dist.all_reduce(buf, op=dist.ReduceOp.SUM, group=pg_handler) + elif self.ctype == "b": + dist.broadcast(buf, src=src_device_num, group=pg_handler) + torch.cuda.synchronize() + end = time.perf_counter() + dist.barrier(group=pg_handler) + + if rank == src_device_num: + avg_time_s = (end - begin) / self.repeat - FRAMEWORK_LATENCY + alg_band = nbytes / avg_time_s + if self.ctype == "a": + # convert the bandwidth of all-reduce algorithm to the bandwidth of the hardware. + bus_band = 2 * (world_size - 1) / world_size * alg_band + bus_band = alg_band + elif self.ctype == "b": + bus_band = alg_band + + logger.info( + f"GPU:{rank}, Bytes: {nbytes} B,Time: {round(avg_time_s * 1e6,2)} us, Bus bandwidth: {round(bus_band / GB,2)} GB/s" + ) + return (avg_time_s, alg_band) + else: + # Just a placeholder + return (None, None) + + def profile_latency(self, process_group, pg_handler): + ''' + This function is used to profile the latency of the given process group with a series of bytes. + + Args: + process_group: A tuple of global rank of the process group. + pg_handler: The handler of the process group. + + Returns: + latency: None if the latency is not measured, otherwise the median of the latency_list. + ''' + latency_list = [] + for i in range(self.latency_iters): + nbytes = int(BYTE << i) + (t, _) = self._profile(process_group, pg_handler, nbytes) + latency_list.append(t) + + if latency_list[0] is None: + latency = None + else: + median_index = math.floor(self.latency_iters / 2) + latency = latency_list[median_index] + + return latency + + def profile_bandwidth(self, process_group, pg_handler, maxbytes=(1 * GB)): + ''' + This function is used to profile the bandwidth of the given process group. + + Args: + process_group: A tuple of global rank of the process group. + pg_handler: The handler of the process group. + ''' + (_, bandwidth) = self._profile(process_group, pg_handler, maxbytes) + return bandwidth + + def profile_ab(self): + ''' + This method is used to profiling the alpha and beta value for a given device list. + + Returns: + alpha_beta_dict: A dict which maps process group to its alpha and beta value. + ''' + alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = {} + rank = dist.get_rank() + global_pg_handler = dist.new_group(self.physical_devices) + + def get_max_nbytes(process_group: Tuple[int], pg_handler: dist.ProcessGroup): + assert rank in process_group + device = torch.cuda.current_device() + rank_max_nbytes = torch.cuda.mem_get_info(device)[0] + rank_max_nbytes = torch.tensor(rank_max_nbytes, device=device) + dist.all_reduce(rank_max_nbytes, op=dist.ReduceOp.MIN, group=pg_handler) + max_nbytes = min(int(1 * GB), int(GB << int(math.log2(rank_max_nbytes.item() / GB)))) + return max_nbytes + + for process_group, pg_handler in self.process_group_dict.items(): + if rank not in process_group: + max_nbytes = None + alpha = None + bandwidth = None + else: + max_nbytes = get_max_nbytes(process_group, pg_handler) + alpha = self.profile_latency(process_group, pg_handler) + bandwidth = self.profile_bandwidth(process_group, pg_handler, maxbytes=max_nbytes) + + if bandwidth is None: + beta = None + else: + beta = 1 / bandwidth + + broadcast_list = [alpha, beta] + dist.broadcast_object_list(broadcast_list, src=process_group[0]) + alpha_beta_dict[process_group] = tuple(broadcast_list) + + # add symmetry pair to the apha_beta_dict + symmetry_ab_dict = {} + for process_group, alpha_beta_pair in alpha_beta_dict.items(): + symmetry_process_group = (process_group[1], process_group[0]) + symmetry_ab_dict[symmetry_process_group] = alpha_beta_pair + + alpha_beta_dict.update(symmetry_ab_dict) + + return alpha_beta_dict + + def search_best_logical_mesh(self): + ''' + This method is used to search the best logical mesh for the given device list. + + The best logical mesh is searched in following steps: + 1. detect homogeneous device groups, we assume that the devices in the alpha_beta_dict + are homogeneous if the beta value is close enough. + 2. Find the best homogeneous device group contains all the physical devices. The best homogeneous + device group means the lowest beta value in the groups which contains all the physical devices. + And the reason we require the group contains all the physical devices is that the devices not in + the group will decrease the bandwidth of the group. + 3. If the best homogeneous device group is found, we will construct the largest ring for each device + based on the best homogeneous device group, and the best logical mesh will be the union of all the + rings. Otherwise, the best logical mesh will be the balanced logical mesh, such as shape (2, 2) for + 4 devices. + + Returns: + best_logical_mesh: The best logical mesh for the given device list. + + Usage: + >>> physical_devices = [0, 1, 2, 3] + >>> ab_profiler = AlphaBetaProfiler(physical_devices) + >>> best_logical_mesh = profiler.search_best_logical_mesh() + >>> print(best_logical_mesh) + [[0, 1], [2, 3]] + ''' + + def _power_of_two(integer): + return integer & (integer - 1) == 0 + + def _detect_homogeneous_device(alpha_beta_dict): + ''' + This function is used to detect whether the devices in the alpha_beta_dict are homogeneous. + + Note: we assume that the devices in the alpha_beta_dict are homogeneous if the beta value + of the devices are in range of [(1 - self.homogeneous_tolerance), (1 + self.homogeneous_tolerance)] + * base_beta. + ''' + homogeneous_device_dict: Dict[float, List[Tuple[int]]] = {} + for process_group, (_, beta) in alpha_beta_dict.items(): + if homogeneous_device_dict is None: + homogeneous_device_dict[beta] = [] + homogeneous_device_dict[beta].append(process_group) + + match_beta = None + for beta_value in homogeneous_device_dict.keys(): + if beta <= beta_value * (1 + self.homogeneous_tolerance) and beta >= beta_value * ( + 1 - self.homogeneous_tolerance): + match_beta = beta_value + break + + if match_beta is not None: + homogeneous_device_dict[match_beta].append(process_group) + else: + homogeneous_device_dict[beta] = [] + homogeneous_device_dict[beta].append(process_group) + + return homogeneous_device_dict + + def _check_contain_all_devices(homogeneous_group: List[Tuple[int]]): + ''' + This function is used to check whether the homogeneous_group contains all physical devices. + ''' + flatten_mesh = [] + for process_group in homogeneous_group: + flatten_mesh.extend(process_group) + non_duplicated_flatten_mesh = set(flatten_mesh) + return len(non_duplicated_flatten_mesh) == len(self.physical_devices) + + def _construct_largest_ring(homogeneous_group: List[Tuple[int]]): + ''' + This function is used to construct the largest ring in the homogeneous_group for each rank. + ''' + # Construct the ring + ring = [] + ranks_in_ring = [] + for rank in self.physical_devices: + if rank in ranks_in_ring: + continue + stable_status = False + ring_for_rank = [] + ring_for_rank.append(rank) + check_rank_list = [rank] + rank_to_check_list = [] + + while not stable_status: + stable_status = True + check_rank_list.extend(rank_to_check_list) + rank_to_check_list = [] + for i in range(len(check_rank_list)): + check_rank = check_rank_list.pop() + for process_group in homogeneous_group: + if check_rank in process_group: + rank_to_append = process_group[0] if process_group[1] == check_rank else process_group[1] + if rank_to_append not in ring_for_rank: + stable_status = False + rank_to_check_list.append(rank_to_append) + ring_for_rank.append(rank_to_append) + + ring.append(ring_for_rank) + ranks_in_ring.extend(ring_for_rank) + + return ring + + assert _power_of_two(self.world_size) + power_of_two = int(math.log2(self.world_size)) + median = power_of_two // 2 + balanced_logical_mesh_shape = (2**median, 2**(power_of_two - median)) + row_size, column_size = balanced_logical_mesh_shape[0], balanced_logical_mesh_shape[1] + balanced_logical_mesh = [] + for row_index in range(row_size): + balanced_logical_mesh.append([]) + for column_index in range(column_size): + balanced_logical_mesh[row_index].append(self.physical_devices[row_index * column_size + column_index]) + + homogeneous_device_dict = _detect_homogeneous_device(self.alpha_beta_dict) + beta_list = [b for b in homogeneous_device_dict.keys()] + beta_list.sort() + beta_list.reverse() + homogeneous_types = len(beta_list) + best_logical_mesh = None + if homogeneous_types >= 2: + for _ in range(homogeneous_types - 1): + lowest_beta = beta_list.pop() + best_homogeneous_group = homogeneous_device_dict[lowest_beta] + # if the best homogeneous group contains all physical devices, + # we will build the logical device mesh based on it. Otherwise, + # we will check next level homogeneous group. + if _check_contain_all_devices(best_homogeneous_group): + # We choose the largest ring for each rank to maximum the best bus utilization. + best_logical_mesh = _construct_largest_ring(best_homogeneous_group) + break + + if homogeneous_types == 1 or best_logical_mesh is None: + # in this case, we use balanced logical mesh as the best + # logical mesh. + best_logical_mesh = balanced_logical_mesh + + return best_logical_mesh + + def extract_alpha_beta_for_device_mesh(self): + ''' + Extract the mesh_alpha list and mesh_beta list based on the + best logical mesh, which will be used to initialize the device mesh. + + Usage: + >>> physical_devices = [0, 1, 2, 3] + >>> ab_profiler = AlphaBetaProfiler(physical_devices) + >>> mesh_alpha, mesh_beta = profiler.extract_alpha_beta_for_device_mesh() + >>> print(mesh_alpha) + [2.5917552411556242e-05, 0.00010312341153621673] + >>> print(mesh_beta) + [5.875573704655635e-11, 4.7361584445959614e-12] + ''' + best_logical_mesh = self.search_best_logical_mesh() + + first_axis = [row[0] for row in best_logical_mesh] + second_axis = best_logical_mesh[0] + + # init process group for both axes + first_axis_process_group = dist.new_group(first_axis) + second_axis_process_group = dist.new_group(second_axis) + + # extract alpha and beta for both axes + def _extract_alpha_beta(pg, pg_handler): + latency = self.profile_latency(pg, pg_handler) + bandwidth = self.profile_bandwidth(pg, pg_handler) + broadcast_object = [latency, bandwidth] + dist.broadcast_object_list(broadcast_object, src=pg[0]) + return broadcast_object + + first_latency, first_bandwidth = _extract_alpha_beta(first_axis, first_axis_process_group) + second_latency, second_bandwidth = _extract_alpha_beta(second_axis, second_axis_process_group) + mesh_alpha = [first_latency, second_latency] + # The beta values have been enlarged by 1e10 times temporarilly because the computation cost + # is still estimated in the unit of TFLOPs instead of time. We will remove this factor in future. + mesh_beta = [1e10 / first_bandwidth, 1e10 / second_bandwidth] + + return mesh_alpha, mesh_beta diff --git a/colossalai/device/calc_pipeline_strategy.py b/colossalai/device/calc_pipeline_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..4ab72dfe60f0c73f0e4f5186ed54205b68513bc0 --- /dev/null +++ b/colossalai/device/calc_pipeline_strategy.py @@ -0,0 +1,127 @@ +from math import pow + +import numpy as np + + +def get_submesh_choices(num_hosts, num_devices_per_host, mode="new"): + submesh_choices = [] + i = 1 + p = -1 + while i <= num_devices_per_host: + i *= 2 + p += 1 + assert pow(2, p) == num_devices_per_host, ("Only supports the cases where num_devices_per_host is power of two, " + f"while now num_devices_per_host = {num_devices_per_host}") + if mode == "alpa": + for i in range(p + 1): + submesh_choices.append((1, pow(2, i))) + for i in range(2, num_hosts + 1): + submesh_choices.append((i, num_devices_per_host)) + elif mode == "new": + for i in range(p // 2 + 1): + for j in range(i, p - i + 1): + submesh_choices.append((pow(2, i), pow(2, j))) + return submesh_choices + + +def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, compute_cost, max_stage_cost, + best_configs): + """Implementation of Alpa DP for pipeline strategy + Paper reference: https://www.usenix.org/system/files/osdi22-zheng-lianmin.pdf + + Arguments: + num_layers: K + num_devices: N*M + num_microbatches: B + submesh_choices: List[(n_i,m_i)] + compute_cost: t_intra + """ + # For f, layer ID start from 0 + # f[#pipeline stages, layer id that is currently being considered, number of devices used] + f = np.full((num_layers + 1, num_layers + 1, num_devices + 1), np.inf, dtype=np.float32) + f_stage_max = np.full((num_layers + 1, num_layers + 1, num_devices + 1), 0.0, dtype=np.float32) + f_argmin = np.full((num_layers + 1, num_layers + 1, num_devices + 1, 3), -1, dtype=np.int32) + f[0, num_layers, 0] = 0 + for s in range(1, num_layers + 1): + for k in range(num_layers - 1, -1, -1): + for d in range(1, num_devices + 1): + for m, submesh in enumerate(submesh_choices): + n_submesh_devices = np.prod(np.array(submesh)) + if n_submesh_devices <= d: + # TODO: [luzgh]: Why alpa needs max_n_succ_stages? Delete. + # if s - 1 <= max_n_succ_stages[i, k - 1, m, n_config]: + # ... + for i in range(num_layers, k, -1): + stage_cost = compute_cost[k, i, m] + new_cost = f[s - 1, k, d - n_submesh_devices] + stage_cost + if (stage_cost <= max_stage_cost and new_cost < f[s, k, d]): + f[s, k, d] = new_cost + f_stage_max[s, k, d] = max(stage_cost, f_stage_max[s - 1, i, d - n_submesh_devices]) + f_argmin[s, k, d] = (i, m, best_configs[k, i, m]) + best_s = -1 + best_total_cost = np.inf + for s in range(1, num_layers + 1): + if f[s, 0, num_devices] < best_total_cost: + best_s = s + best_total_cost = f[s, 0, num_devices] + + if np.isinf(best_total_cost): + return np.inf, None + + total_cost = f[best_s, 0, num_devices] + (num_microbatches - 1) * f_stage_max[best_s, 0, num_devices] + current_s = best_s + current_layer = 0 + current_devices = num_devices + + res = [] + while current_s > 0 and current_layer < num_layers and current_devices > 0: + next_start_layer, submesh_choice, autosharding_choice = (f_argmin[current_s, current_layer, current_devices]) + assert next_start_layer != -1 and current_devices != -1 + res.append(((current_layer, next_start_layer), submesh_choice, autosharding_choice)) + current_s -= 1 + current_layer = next_start_layer + current_devices -= np.prod(np.array(submesh_choices[submesh_choice])) + assert (current_s == 0 and current_layer == num_layers and current_devices == 0) + + return total_cost, res + + +def alpa_dp(num_layers, + num_devices, + num_microbatches, + submesh_choices, + num_autosharding_configs, + compute_cost, + gap=1e-6): + """Alpa auto stage dynamic programming. + Code reference: https://github.com/alpa-projects/alpa/blob/main/alpa/pipeline_parallel/stage_construction.py + + Arguments: + submesh_choices: List[(int,int)] + num_autosharding_configs: Max number of t_intra(start_layer, end_layer, LogicalMesh) + compute_cost: np.array(num_layers,num_layers,num_submesh_choices,num_autosharding_configs) + """ + assert np.shape(compute_cost) == (num_layers, num_layers, len(submesh_choices), + num_autosharding_configs), "Cost shape wrong." + all_possible_stage_costs = np.sort(np.unique(compute_cost)) + best_cost = np.inf + best_solution = None + last_max_stage_cost = 0.0 + # TODO: [luzgh]: Why alpa needs the num_autosharding_configs dimension in compute_cost? + # In dp_impl it seems the argmin n_config will be chosen. Just amin here. + best_configs = np.argmin(compute_cost, axis=3) + best_compute_cost = np.amin(compute_cost, axis=3) + assert len(all_possible_stage_costs), "no solution in auto stage construction." + for max_stage_cost in all_possible_stage_costs: + if max_stage_cost * num_microbatches >= best_cost: + break + if max_stage_cost - last_max_stage_cost < gap: + continue + cost, solution = alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, best_compute_cost, + max_stage_cost, best_configs) + if cost < best_cost: + best_cost = cost + best_solution = solution + last_max_stage_cost = max_stage_cost + + return best_cost, best_solution diff --git a/colossalai/device/device_mesh.py b/colossalai/device/device_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..2a5f747fbc238c7799b3076e695030020f491d5b --- /dev/null +++ b/colossalai/device/device_mesh.py @@ -0,0 +1,249 @@ +"""This code is adapted from Alpa + https://github.com/alpa-projects/alpa/ + with some changes. """ + +import operator +from functools import reduce +from typing import List, Tuple + +import torch +import torch.distributed as dist + + +# modified from alpa LogicalDeviceMesh(https://github.com/alpa-projects/alpa/blob/main/alpa/shard_parallel/auto_sharding.py) +class DeviceMesh: + """A logical view of a physical cluster. For example, we could view a physical cluster + with 16 devices as a device mesh with shape (2, 2, 4) or (4, 4). + + Arguments: + physical_mesh_id (torch.Tensor): physical view of the devices in global rank. + logical_mesh_id (torch.Tensor): logical view of the devices in global rank. + mesh_shape (torch.Size, optional): shape of logical view. + mesh_alpha (List[float], optional): coefficients used for computing + communication cost (default: None) + mesh_beta (List[float], optional): coefficients used for computing + communication cost (default: None) + init_process_group (bool, optional): initialize logical process group + during initializing the DeviceMesh instance if the init_process_group set to True. + Otherwise, users need to call create_process_groups_for_logical_mesh manually to init logical process group. + (default: False) + need_flatten(bool, optional): initialize flatten_device_mesh during initializing the DeviceMesh instance if the need_flatten set to True. + """ + + def __init__(self, + physical_mesh_id: torch.Tensor, + mesh_shape: torch.Size = None, + logical_mesh_id: torch.Tensor = None, + mesh_alpha: List[float] = None, + mesh_beta: List[float] = None, + init_process_group: bool = False, + need_flatten: bool = True): + self.physical_mesh_id = physical_mesh_id + if logical_mesh_id is None: + self.mesh_shape = mesh_shape + self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape) + else: + self._logical_mesh_id = logical_mesh_id + self.mesh_shape = self._logical_mesh_id.shape + + # map global rank into logical rank + self.convert_map = {} + self._global_rank_to_logical_rank_map(self._logical_mesh_id, []) + # coefficient for alpha-beta communication model + if mesh_alpha is None: + mesh_alpha = [1] * len(self.mesh_shape) + if mesh_beta is None: + mesh_beta = [1] * len(self.mesh_shape) + self.mesh_alpha = tuple(mesh_alpha) + self.mesh_beta = tuple(mesh_beta) + self.init_process_group = init_process_group + self.need_flatten = need_flatten + if self.init_process_group: + self.process_groups_dict = self.create_process_groups_for_logical_mesh() + if self.need_flatten and self._logical_mesh_id.dim() > 1: + self.flatten_device_mesh = self.flatten() + # Create a new member `flatten_device_meshes` to distinguish from original flatten methods (Because I'm not sure if there are functions that rely on the self.flatten()) + # self.flatten_device_meshes = FlattenDeviceMesh(self.physical_mesh_id, self.mesh_shape, self.mesh_alpha, + # self.mesh_beta) + + @property + def shape(self): + return self.mesh_shape + + @property + def num_devices(self): + return reduce(operator.mul, self.physical_mesh_id.shape, 1) + + @property + def logical_mesh_id(self): + return self._logical_mesh_id + + def __deepcopy__(self, memo): + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + for k, v in self.__dict__.items(): + if k != 'process_groups_dict': + setattr(result, k, __import__("copy").deepcopy(v, memo)) + else: + setattr(result, k, v) + + return result + + def flatten(self): + """ + Flatten the logical mesh into an effective 1d logical mesh, + """ + flatten_mesh_shape_size = len(self.mesh_shape) + flatten_mesh_shape = [self.num_devices] + return DeviceMesh(self.physical_mesh_id, + tuple(flatten_mesh_shape), + mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1), + mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1), + init_process_group=self.init_process_group, + need_flatten=False) + + def _global_rank_to_logical_rank_map(self, tensor, index_list): + ''' + This method is a helper function to build convert_map recursively. + ''' + for index, inner_tensor in enumerate(tensor): + if inner_tensor.numel() == 1: + self.convert_map[int(inner_tensor)] = index_list + [index] + else: + self._global_rank_to_logical_rank_map(inner_tensor, index_list + [index]) + + def create_process_groups_for_logical_mesh(self): + ''' + This method is used to initialize the logical process groups which will be used in communications + among logical device mesh. + Note: if init_process_group set to False, you have to call this method manually. Otherwise, + the communication related function, such as ShapeConsistencyManager.apply will raise errors. + ''' + process_groups_dict = {} + check_duplicate_list = [] + global_rank_flatten_list = self.physical_mesh_id.view(-1).tolist() + for global_rank in global_rank_flatten_list: + process_groups = self.global_rank_to_process_groups_with_global_rank(global_rank) + for axis, process_group in process_groups.items(): + if axis not in process_groups_dict: + process_groups_dict[axis] = [] + if process_group not in check_duplicate_list: + check_duplicate_list.append(process_group) + process_group_handler = dist.new_group(process_group) + process_groups_dict[axis].append((process_group, process_group_handler)) + + return process_groups_dict + + def global_rank_to_logical_rank(self, rank): + return self.convert_map[rank] + + def global_rank_to_process_groups_with_logical_rank(self, rank): + ''' + Give a global rank and return all logical process groups of this rank. + for example: + physical_mesh_id = torch.arange(0, 16).reshape(2, 8) + mesh_shape = (4, 4) + # [[0, 1, 2, 3], + # [4, 5, 6, 7], + # [8, 9, 10,11], + # [12,13,14,15]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + print(device_mesh.global_rank_to_process_groups_with_logical_rank(0)) + output: + # key is axis name + # value is a list of logical ranks in same axis with rank 0 + {0: [[0, 0], [1, 0], [2, 0], [3, 0]], 1: [[0, 0], [0, 1], [0, 2], [0, 3]]} + ''' + process_groups = {} + for d in range(self.logical_mesh_id.dim()): + for replacer in range(self.logical_mesh_id.shape[d]): + if d not in process_groups: + process_groups[d] = [] + process_group_member = self.convert_map[rank].copy() + process_group_member[d] = replacer + process_groups[d].append(process_group_member) + return process_groups + + def global_rank_to_process_groups_with_global_rank(self, rank): + ''' + Give a global rank and return all process groups of this rank. + for example: + physical_mesh_id = torch.arange(0, 16).reshape(2, 8) + mesh_shape = (4, 4) + # [[0, 1, 2, 3], + # [4, 5, 6, 7], + # [8, 9, 10,11], + # [12,13,14,15]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + print(device_mesh.global_rank_to_process_groups_with_global_rank(0)) + output: + # key is axis name + # value is a list of global ranks in same axis with rank 0 + {0: [0, 4, 8, 12], 1: [0, 1, 2, 3]} + ''' + logical_process_groups = self.global_rank_to_process_groups_with_logical_rank(rank) + process_groups = {} + for dim, logical_ranks in logical_process_groups.items(): + process_groups[dim] = [] + for logical_rank in logical_ranks: + for g_rank, l_rank in self.convert_map.items(): + if l_rank == logical_rank: + process_groups[dim].append(g_rank) + return process_groups + + def all_gather_cost(self, num_bytes, mesh_dim): + num_devices = self.logical_mesh_id.shape[mesh_dim] + return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes + + 0.1) + + def all_reduce_cost(self, num_bytes, mesh_dim): + num_devices = self.logical_mesh_id.shape[mesh_dim] + return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * 2 * (num_devices - 1) / num_devices * num_bytes + + 0.01) + + def reduce_scatter_cost(self, num_bytes, mesh_dim): + num_devices = self.logical_mesh_id.shape[mesh_dim] + return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes + + 0.001) + + def all_to_all_cost(self, num_bytes, mesh_dim): + num_devices = self.logical_mesh_id.shape[mesh_dim] + penalty_factor = num_devices / 2.0 + return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * + (num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001) + + +class FlattenDeviceMesh(DeviceMesh): + + def __init__(self, physical_mesh_id, mesh_shape, mesh_alpha=None, mesh_beta=None): + super().__init__(physical_mesh_id, + mesh_shape, + mesh_alpha, + mesh_beta, + init_process_group=False, + need_flatten=False) + # Different from flatten(), mesh_shape leaves unchanged, mesh_alpha and mesh_beta are scalars + self.mesh_alpha = max(self.mesh_alpha) + self.mesh_beta = min(self.mesh_beta) + # Different from original process_groups_dict, rank_list is not stored + self.process_number_dict = self.create_process_numbers_for_logical_mesh() + + def create_process_numbers_for_logical_mesh(self): + ''' + Build 1d DeviceMesh in column-major(0) and row-major(1) + for example: + mesh_shape = (2,4) + # [[0, 1, 2, 3], + # [4, 5, 6, 7]] + # return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]} + ''' + num_devices = reduce(operator.mul, self.mesh_shape, 1) + process_numbers_dict = {} + process_numbers_dict[0] = torch.arange(num_devices).reshape(self.mesh_shape).transpose(1, 0).flatten().tolist() + process_numbers_dict[1] = torch.arange(num_devices).reshape(self.mesh_shape).flatten().tolist() + return process_numbers_dict + + def mix_gather_cost(self, num_bytes): + num_devices = reduce(operator.mul, self.mesh_shape, 1) + return (self.mesh_alpha + self.mesh_beta * (num_devices - 1) / num_devices * num_bytes + 0.1) diff --git a/colossalai/engine/__init__.py b/colossalai/engine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..158796befb312755ed92f77f7828557f55800e4c --- /dev/null +++ b/colossalai/engine/__init__.py @@ -0,0 +1,4 @@ +from ._base_engine import Engine +from .gradient_handler import * + +__all__ = ['Engine'] diff --git a/colossalai/engine/_base_engine.py b/colossalai/engine/_base_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..ff8979d82401931b04649ffadff615932a1e1b37 --- /dev/null +++ b/colossalai/engine/_base_engine.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- +# this code is inspired by the DeepSpeed library and implemented with our own design from scratch + +from typing import Iterable, List, Optional, Type + +from torch import Tensor +from torch.nn import Module +from torch.nn.modules.loss import _Loss + +from colossalai.engine.gradient_handler import BaseGradientHandler +from colossalai.engine.schedule import BaseSchedule, InterleavedPipelineSchedule, NonPipelineSchedule, PipelineSchedule +from colossalai.logging import get_dist_logger +from colossalai.zero.legacy.gemini import BaseOpHook, register_ophooks_recursively + + +class Engine: + """Basic engine class for training and evaluation. It runs a specific process method + :meth:`step` which is based on the given :attr:`schedule` over each batch of a dataset. + It controls a iteration in training. + + Args: + model (``torch.nn.Module``): The neural network model. + optimizer (``colossalai.nn.optimizer.ColossalaiOptimizer``): Optimizer for updating the parameters. + criterion (``torch.nn.modules.loss._Loss``, optional): Loss function for calculating loss. + gradient_handlers (List[``BaseGradientHandler``], optional): A list of gradient handler used in backward. + clip_grad_norm (float, optional): The norm of gradient clipping. + ophook_list (list): List of ophook. + verbose (bool): whether to display log info. + schedule (''BaseSchedule''): Runtime schedule. + + Examples: + >>> # define model, criterion, optimizer, lr_scheduler, train_dataloader for your training + >>> model = ... + >>> criterion = ... + >>> optimizer = ... + >>> train_dataloader = ... + >>> engine, _, _, _ = colossalai.initialize(model, optimizer, criterion) + >>> engine.train() + >>> for inputs, labels in train_dataloader + >>> # set gradients to zero + >>> engine.zero_grad() + >>> # run forward pass + >>> outputs = engine(inputs) + >>> # compute loss value and run backward pass + >>> loss = engine.criterion(outputs, labels) + >>> engine.backward(loss) + >>> # update parameters + >>> engine.step() + + The example of using Engine in training could be find in + `Training with engine and trainer `_. and + `Run resnet cifar10 with engine `_. + """ + + def __init__(self, + model: Module, + optimizer: "ColossalaiOptimizer", + criterion: Optional[_Loss] = None, + gradient_handlers: Optional[List[BaseGradientHandler]] = None, + clip_grad_norm: float = 0.0, + ophook_list: Optional[List[BaseOpHook]] = None, + verbose: bool = True, + schedule: Optional[BaseSchedule] = None): + self._model = model + self._optimizer = optimizer + self._criterion = criterion + self._clip_grad_norm = clip_grad_norm + self._verbose = verbose + self._logger = get_dist_logger() + + # state + self.training = True # default + + # build gradient handler + if gradient_handlers: + self._gradient_handlers = gradient_handlers + else: + self._gradient_handlers = [] + + if ophook_list is None: + self._ophook_list = [] + else: + self._ophook_list = ophook_list + + # build schedule + if schedule: + assert isinstance(schedule, BaseSchedule), \ + f'expected schedule to be of type BaseSchedule, but got {type(schedule)}' + self._schedule = schedule + else: + self._schedule = NonPipelineSchedule() + if self.uses_pipeline: + self._schedule.pre_processing(self) + + # register hook if any + if len(self._ophook_list) > 0: + register_ophooks_recursively(self._model, self._ophook_list) + + @property + def ophooks(self): + """show current activated ophooks""" + return self._ophook_list + + @property + def model(self): + """Model attached to the engine""" + return self._model + + @property + def optimizer(self): + """Optimizer attached to the engine""" + return self._optimizer + + @property + def criterion(self): + """Criterion attached to the engine""" + return self._criterion + + @property + def schedule(self): + """Schedule attached to the engine""" + return self._schedule + + @property + def uses_pipeline(self): + """show the pipeline parallel used or not""" + return isinstance(self._schedule, (PipelineSchedule, InterleavedPipelineSchedule)) + + def add_hook(self, ophook: Type[BaseOpHook]) -> None: + """add necessary hook""" + # whether this hook exist + for h in self._ophook_list: + if type(h) == type(ophook): + logger = get_dist_logger() + logger.warning(f"duplicate hooks, at least two instance of {type(ophook)}") + self._ophook_list.append(ophook) + register_ophooks_recursively(self._model, self._ophook_list) + + def remove_hook(self, ophook: Type[BaseOpHook]) -> None: + """remove hook""" + logger = get_dist_logger() + logger.warning(f"removing hooks is currently not supported") + + def zero_grad(self): + """Set the gradient of parameters to zero + """ + self.optimizer.zero_grad() + + def step(self): + """Execute parameter update + """ + self._all_reduce_gradients() + self.optimizer.clip_grad_norm(self.model, self._clip_grad_norm) + return self.optimizer.step() + + def backward(self, loss: Tensor): + """Start backward propagation given the loss value computed by a loss function. + + Args: + loss (:class:`torch.Tensor`): Loss value computed by a loss function. + """ + ret = self.optimizer.backward(loss) + for ophook in self._ophook_list: + ophook.post_iter() + return ret + + def backward_by_grad(self, tensor, grad): + """Start backward propagation given the gradient of the output tensor. + + Args: + tensor (:class:`torch.Tensor`): Output tensor. + grad (:class:`torch.Tensor`): Gradient passed back to the output. + """ + ret = self.optimizer.backward_by_grad(tensor, grad) + for ophook in self._ophook_list: + ophook.post_iter() + return ret + + def __call__(self, *args, **kwargs): + """Run the forward step for the model. + + Returns: + Tuple[:class:`torch.Tensor`] or :class:`torch.Tensor`: Output of the model. + """ + return self.model(*args, **kwargs) + + def _all_reduce_gradients(self): + """Handles all-reduce operations of gradients across different parallel groups. + """ + for handler in self._gradient_handlers: + handler.handle_gradient() + + def execute_schedule(self, data_iter: Iterable, **kwargs): + """Run the forward, loss computation, and backward for the model. + Returns a tuple of (output, label, loss). + + Returns: + Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss). + """ + output, label, loss = self._schedule.forward_backward_step(self, data_iter, **kwargs) + return output, label, loss + + def train(self): + """Sets the model to training mode. + """ + self.training = True + self._model.train() + + def eval(self): + """Sets the model to evaluation mode. + """ + self.training = False + self._model.eval() diff --git a/colossalai/engine/gradient_accumulation/__init__.py b/colossalai/engine/gradient_accumulation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4cb6f4ad7384dda6136d98a0a73521e37d4027ba --- /dev/null +++ b/colossalai/engine/gradient_accumulation/__init__.py @@ -0,0 +1,57 @@ +from typing import Iterable, List + +import torch.nn as nn +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler + +from colossalai.engine import BaseGradientHandler + +from ._gradient_accumulation import ( + GradAccumDataloader, + GradAccumGradientHandler, + GradAccumLrSchedulerByStep, + GradAccumOptimizer, +) + +__all__ = [ + 'accumulate_gradient', 'GradAccumDataloader', 'GradAccumOptimizer', 'GradAccumLrSchedulerByStep', + 'GradAccumGradientHandler' +] + + +def accumulate_gradient(model: nn.Module, + optimizer: Optimizer, + dataloader: Iterable, + accumulate_size: int, + gradient_handlers: List[BaseGradientHandler] = None, + lr_scheduler: _LRScheduler = None): + r"""Turning model, optimizer, dataloader into corresponding object for gradient accumulation. + + Args: + model (:class:`torch.nn.Module`): your model object for gradient accumulation. + optimizer (:class:`torch.optim.Optimizer`): your optimizer object for gradient accumulation. + dataloader (:class:`torch.utils.data.DataLoader` or iterable objects): + your dataloader object, would be called like iter(dataloader) + accumulate_size (int): the number of steps to accumulate gradients + gradient_handlers (List[:class:`colossalai.engine.BaseGradientHandler`]): + list of gradient handler objects. Default is None. + lr_scheduler (`torch.optim.lr_scheduler` or `colossalai.nn.lr_scheduler`): + your ``lr_scheduler`` object for gradient accumulation. Defaults to None. + + More details about `gradient_handlers` could be found in + `Gradient_handler `_. + + More details about `lr_scheduler` could be found + `lr_scheduler `_. and + `how to adjust learning rate `_. + """ + optimizer = GradAccumOptimizer(optimizer, accumulate_size=accumulate_size, model=model) + dataloader = GradAccumDataloader(dataloader, accumulate_size=accumulate_size) + + if gradient_handlers is not None: + gradient_handlers = [GradAccumGradientHandler(handler, accumulate_size) for handler in gradient_handlers] + + if lr_scheduler is not None: + lr_scheduler = GradAccumLrSchedulerByStep(lr_scheduler, accumulate_size=accumulate_size) + + return optimizer, dataloader, gradient_handlers, lr_scheduler diff --git a/colossalai/engine/gradient_accumulation/_gradient_accumulation.py b/colossalai/engine/gradient_accumulation/_gradient_accumulation.py new file mode 100644 index 0000000000000000000000000000000000000000..cf66be1cd8218e8dc35177f9d69115ec1553c687 --- /dev/null +++ b/colossalai/engine/gradient_accumulation/_gradient_accumulation.py @@ -0,0 +1,291 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from typing import Any, Iterable, Tuple, Union + +import torch.nn as nn +from torch import Tensor +from torch.nn.parallel.distributed import DistributedDataParallel +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler +from torch.utils.data import DataLoader + +from colossalai.engine import BaseGradientHandler +from colossalai.nn.optimizer import ColossalaiOptimizer +from colossalai.utils import conditional_context + + +class GradAccumOptimizer(ColossalaiOptimizer): + """A wrapper for the optimizer to enable gradient accumulation by skipping the steps + before accumulation size is reached. + + Args: + optim (:class:`torch.optim.Optimizer`): Your optimizer object for gradient accumulation. + accumulate_size (int): The number of steps to accumulate gradients. + model (:class:`torch.nn.Module`): + Your model object to check if it is DistributedDataParallel for special handling of no_sync() context. + """ + + def __init__(self, optim: Optimizer, accumulate_size: int, model: nn.Module = None): + super().__init__(optim) + self.accumulate_size = accumulate_size + self.accumulate_step = 0 + + # handle pytorch ddp auto all reduce + self.model = model + self.is_torch_ddp = isinstance(self.model, DistributedDataParallel) + + def zero_grad(self, *args, **kwargs) -> None: + """ + Set all gradients to zero. + + Args: + *args: positional arguments for the optimizer wrapped + **kwargs: keyword arguments for the optimizer wrapped + """ + + if self.accumulate_step == 0: + self.optim.zero_grad(*args, **kwargs) + + def step(self, *args, **kwargs) -> None: + """ + Update the model parameters. + + Args: + *args: positional arguments for the optimizer wrapped + **kwargs: keyword arguments for the optimizer wrapped + """ + + if self.accumulate_step < self.accumulate_size: + return None + else: + self.accumulate_step = 0 + return self.optim.step(*args, **kwargs) + + def clip_grad_norm(self, model: nn.Module, max_norm: float) -> None: + """ + Clip gradients by norm. + + Args: + model (:class:`torch.nn.Module`): a torch module instance + max_norm (float): the max norm for gradient clipping + """ + + if self.accumulate_step < self.accumulate_size: + pass + else: + self.optim.clip_grad_norm(model, max_norm) + + def backward(self, loss: Tensor) -> None: + """Execute backward pass. + + Args: + loss (:class:`torch.Tensor`): the loss value. + """ + + self.accumulate_step += 1 + + if self.is_torch_ddp: + no_sync = self.accumulate_step < self.accumulate_size + with conditional_context(self.model.no_sync(), enable=no_sync): + scaled_loss = loss / self.accumulate_size + self.optim.backward(scaled_loss) + else: + scaled_loss = loss / self.accumulate_size + self.optim.backward(scaled_loss) + + def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None: + """Execute backward pass given the gradients of the output. + + Args: + loss (:class:`torch.Tensor`): the loss value. + grad (:class:`torch.Tensor`): the output gradient. + """ + + self.accumulate_step += 1 + no_sync = self.is_torch_ddp and self.accumulate_step < self.accumulate_size + + if no_sync: + with self.model.no_sync(): + self.optim.backward_by_grad(tensor, grad) + else: + self.optim.backward_by_grad(tensor, grad) + + +class GradAccumDataloader: + """A wrapper for dataloader to enable gradient accumulation by dropping the last incomplete steps. + + Note: + The dataloader would drop the last incomplete steps for gradient accumulation. + For example, if a dataloader has 10 batches of data and accumulate size is 4. The model parameters will + be updated only twice at step 4 and step 8. The last two batches of data do not form a complete 4-step cycle. + Thus, they will be automatically skipped by this class. If the dataloader is not standard PyTorch dataloader, + (e.g. Dali dataloader), this class will automatically consume (load data for nothing) the remaining 2 batches. + + Args: + dataloader (``Iterable``): Your dataloader object for gradient accumulation. + accumulate_size (int): The number of steps to accumulate gradients. + """ + + def __init__(self, dataloader: Iterable, accumulate_size: int) -> None: + self.dataloader = dataloader + self.consume_remain_data = not isinstance(dataloader, DataLoader) + self.steps_per_epoch = len(dataloader) - len(dataloader) % accumulate_size + + def __getattr__(self, __name: str) -> Any: + return getattr(self.dataloader, __name) + + def __len__(self) -> int: + return self.steps_per_epoch + + def __iter__(self) -> Iterable: + self._cur_step = 0 + self._dataiter = iter(self.dataloader) + return self + + def __next__(self) -> Union[Tensor, Tuple[Tensor]]: + if self._cur_step < self.steps_per_epoch: + self._cur_step += 1 + data = next(self._dataiter) + + if self._cur_step == self.steps_per_epoch and self.consume_remain_data: + # this is to handle non standard pytorch dataloader + # such as dali dataloader + while True: + try: + _ = next(self._dataiter) + except StopIteration: + break + return data + else: + raise StopIteration + + +class GradAccumLrSchedulerByStep(_LRScheduler): + """A wrapper for the LR scheduler to enable gradient accumulation by skipping the steps + before accumulation size is reached. + + Args: + lr_scheduler (:class:`torch.optim.lr_scheduler._LRScheduler`): + Your ``lr_scheduler`` object for gradient accumulation. + accumulate_size (int): The number of steps to accumulate gradients. + """ + + def __init__(self, lr_scheduler: _LRScheduler, accumulate_size: int) -> None: + self.lr_scheduler = lr_scheduler + self.accumulate_size = accumulate_size + self.accumulate_step = 0 + + @staticmethod + def compute_effective_steps_per_epoch(dataloader: Iterable, accumulate_size: int) -> int: + """ + Computes the number of effective training iterations. An effective iteration is defined + as the the aggregation of iterations. For examples, if accumulate_size = 4, + then 4 iterations are considered as one effective iteration. + + Args: + dataloader (``Iterable``): Your dataloader object for gradient accumulation. + accumulate_size (int): The number of steps to accumulate gradients. + + """ + return len(dataloader) // accumulate_size + + def __getattr__(self, __name: str) -> Any: + return getattr(self.lr_scheduler, __name) + + def step(self, *args, **kwargs) -> None: + """ + Update the learning rate. + + Args: + *args: positional arguments for the lr scheduler wrapped. + **kwargs: keyword arguments for the lr scheduler wrapped. + """ + self.accumulate_step += 1 + if self.accumulate_step < self.accumulate_size: + pass + else: + self.accumulate_step = 0 + self.lr_scheduler.step(*args, **kwargs) + + def get_lr(self) -> Tensor: + """ + Compute the next learning rate. + + Returns: + Tensor: the upcoming learning rate. + """ + + return self.lr_scheduler.get_lr() + + def get_last_lr(self) -> Tensor: + """ + Returns the current learning rate. + + Returns: + Tensor: the current learning rate. + """ + + return self.lr_scheduler.get_last_lr() + + def print_lr(self, *args, **kwargs) -> None: + """ + Print he learning rate. + + Args: + *args: positional arguments for the lr scheduler wrapped. + **kwargs: keyword arguments for the lr scheduler wrapped. + """ + self.lr_scheduler.print_lr(*args, **kwargs) + + def state_dict(self) -> dict: + """ + Returns the states of the lr scheduler as dictionary. + + Returns: + dict: the states of the lr scheduler. + """ + return self.lr_scheduler.state_dict() + + def load_state_dict(self, state_dict: dict) -> None: + """ + Load the states of the lr scheduler from a dictionary object. + + Returns: + dict: the states of the lr scheduler. + """ + self.lr_scheduler.load_state_dict(state_dict) + + +class GradAccumGradientHandler: + r"""A wrapper for the gradient handler to enable gradient accumulation by skipping the steps + before accumulation size is reached. + + Args: + grad_handler (:class:`colossalai.engine.BaseGradientHandler`): + Your ``gradient_handler`` object for gradient accumulation, would be called when achieving `accumulate_size`. + accumulate_size (int): The number of steps to accumulate gradients. + + More details about ``gradient_handlers`` could be found in + `Gradient_handler `_. + + """ + + def __init__(self, grad_handler: BaseGradientHandler, accumulate_size: int) -> None: + assert isinstance(grad_handler, BaseGradientHandler), \ + f'expected grad_handler to be type BaseGradientHandler, but got {type(grad_handler)}' + self.grad_handler = grad_handler + self.accumulate_size = accumulate_size + self.accumulate_step = 0 + + def handle_gradient(self) -> None: + """ + Handle gradients reduction only in the last gradient accumulation step. + """ + + self.accumulate_step += 1 + if self.accumulate_step < self.accumulate_size: + pass + else: + self.accumulate_step = 0 + self.grad_handler.handle_gradient() diff --git a/colossalai/engine/gradient_handler/__init__.py b/colossalai/engine/gradient_handler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2dea768bad7ecf1feed8bae69f733cda943509b5 --- /dev/null +++ b/colossalai/engine/gradient_handler/__init__.py @@ -0,0 +1,11 @@ +from ._base_gradient_handler import BaseGradientHandler +from ._data_parallel_gradient_handler import DataParallelGradientHandler +from ._moe_gradient_handler import MoeGradientHandler +from ._pipeline_parallel_gradient_handler import PipelineSharedModuleGradientHandler +from ._sequence_parallel_gradient_handler import SequenceParallelGradientHandler +from ._zero_gradient_handler import ZeROGradientHandler + +__all__ = [ + 'BaseGradientHandler', 'DataParallelGradientHandler', 'ZeROGradientHandler', 'PipelineSharedModuleGradientHandler', + 'MoeGradientHandler', 'SequenceParallelGradientHandler' +] diff --git a/colossalai/engine/gradient_handler/_base_gradient_handler.py b/colossalai/engine/gradient_handler/_base_gradient_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..7d96dd8a88a63d9f0c40ceefb99bf2809a37662d --- /dev/null +++ b/colossalai/engine/gradient_handler/_base_gradient_handler.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from abc import ABC, abstractmethod + + +class BaseGradientHandler(ABC): + """A basic helper class to handle all-reduce operations of gradients across different parallel groups + before optimization. + + Args: + model (Module): Model where the gradients accumulate. + optimizer (Optimizer): Optimizer for updating the parameters. + """ + + def __init__(self, model, optimizer): + self._model = model + self._optimizer = optimizer + + @abstractmethod + def handle_gradient(self): + """A method to accumulate gradients across different parallel groups. Users should + write their own functions or just use the functions in pre-defined subclasses. + """ + pass diff --git a/colossalai/engine/gradient_handler/_data_parallel_gradient_handler.py b/colossalai/engine/gradient_handler/_data_parallel_gradient_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..5cc7169c5a9f630dcb9e1b981f33c3fb35548cc0 --- /dev/null +++ b/colossalai/engine/gradient_handler/_data_parallel_gradient_handler.py @@ -0,0 +1,27 @@ +from colossalai.core import global_context as gpc +from colossalai.registry import GRADIENT_HANDLER + +from ...context.parallel_mode import ParallelMode +from ._base_gradient_handler import BaseGradientHandler +from .utils import bucket_allreduce + + +@GRADIENT_HANDLER.register_module +class DataParallelGradientHandler(BaseGradientHandler): + """A helper class to handle all-reduce operations in a data parallel group. + A all-reduce collective communication will be operated in + :func:`handle_gradient` among a data parallel group. + For better performance, it bucketizes the gradients of all parameters that are + the same type to improve the efficiency of communication. + + Args: + model (Module): Model where the gradients accumulate. + optimizer (Optimizer): Optimizer for updating the parameters. + """ + + def handle_gradient(self): + """A method running a all-reduce operation in a data parallel group. + """ + # TODO: add memory buffer + if gpc.data_parallel_size > 1: + bucket_allreduce(param_list=self._model.parameters(), group=gpc.get_group(ParallelMode.DATA)) diff --git a/colossalai/engine/gradient_handler/_moe_gradient_handler.py b/colossalai/engine/gradient_handler/_moe_gradient_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..b499345d4e184662b3a242aa71d24859ad843c7c --- /dev/null +++ b/colossalai/engine/gradient_handler/_moe_gradient_handler.py @@ -0,0 +1,46 @@ +from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.core import global_context as gpc +from colossalai.registry import GRADIENT_HANDLER +from colossalai.utils.moe import get_moe_epsize_param_dict + +from ...context.parallel_mode import ParallelMode +from ._base_gradient_handler import BaseGradientHandler +from .utils import bucket_allreduce + + +@GRADIENT_HANDLER.register_module +class MoeGradientHandler(BaseGradientHandler): + """A helper class to handle all-reduce operations in a data parallel group and + moe model parallel. A all-reduce collective communication will be operated in + :func:`handle_gradient` among a data parallel group. + For better performance, it bucketizes the gradients of all parameters that are + the same type to improve the efficiency of communication. + + Args: + model (Module): Model where the gradients accumulate. + optimizer (Optimizer): Optimizer for updating the parameters. + """ + + def __init__(self, model, optimizer=None): + super().__init__(model, optimizer) + + def handle_gradient(self): + """A method running an all-reduce operation in a data parallel group. + Then running an all-reduce operation for all parameters in experts + across moe model parallel group + """ + global_data = gpc.data_parallel_size + + if global_data > 1: + epsize_param_dict = get_moe_epsize_param_dict(self._model) + + # epsize is 1, indicating the params are replicated among processes in data parallelism + # use the ParallelMode.DATA to get data parallel group + # reduce gradients for all parameters in data parallelism + if 1 in epsize_param_dict: + bucket_allreduce(param_list=epsize_param_dict[1], group=gpc.get_group(ParallelMode.DATA)) + + for ep_size in epsize_param_dict: + if ep_size != 1 and ep_size != MOE_CONTEXT.world_size: + bucket_allreduce(param_list=epsize_param_dict[ep_size], + group=MOE_CONTEXT.parallel_info_dict[ep_size].dp_group) diff --git a/colossalai/engine/gradient_handler/_pipeline_parallel_gradient_handler.py b/colossalai/engine/gradient_handler/_pipeline_parallel_gradient_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..5b49a9c0360dca600b8f1226ce0334f959e2265b --- /dev/null +++ b/colossalai/engine/gradient_handler/_pipeline_parallel_gradient_handler.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python + +from collections import defaultdict + +import torch +import torch.distributed as dist +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + +from colossalai.core import global_context as gpc +from colossalai.registry import GRADIENT_HANDLER + +from ._base_gradient_handler import BaseGradientHandler + + +@GRADIENT_HANDLER.register_module +class PipelineSharedModuleGradientHandler(BaseGradientHandler): + """A helper class to handle all-reduce operations in sub parallel groups. + A all-reduce collective communication will be operated in + :func:`handle_gradient` among all sub pipeline parallel groups. + For better performance, it bucketizes the gradients of all parameters that are + the same type to improve the efficiency of communication. + + Args: + model (Module): Model where the gradients accumulate. + optimizer (Optimizer): Optimizer for updating the parameters. + """ + + def handle_gradient(self): + """A method running a all-reduce operation in sub pipeline parallel groups. + """ + if gpc.pipeline_parallel_size > 1: + # bucketize and all-reduce + buckets = defaultdict(lambda: defaultdict(list)) + # Pack the buckets. + for param in self._model.parameters(): + group = getattr(param, 'pipeline_shared_module_pg', None) + if param.requires_grad and group is not None and ( + (hasattr(param, 'colo_attr') and not param.colo_attr.saved_grad.is_null()) + or param.grad is not None): + tp = param.data.type() + buckets[group][tp].append(param) + + # For each bucket, all-reduce and copy all-reduced grads. + for group, group_buckets in buckets.items(): + for tp, bucket in group_buckets.items(): + grads = [ + param.colo_attr.grad_payload if hasattr(param, 'colo_attr') else param.grad.data + for param in bucket + ] + coalesced = _flatten_dense_tensors(grads).to(torch.cuda.current_device()) + dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=group) + for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): + buf.copy_(synced) diff --git a/colossalai/engine/gradient_handler/_sequence_parallel_gradient_handler.py b/colossalai/engine/gradient_handler/_sequence_parallel_gradient_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..ea4f0fbb1c718965deae37fc0a148aafca3d104a --- /dev/null +++ b/colossalai/engine/gradient_handler/_sequence_parallel_gradient_handler.py @@ -0,0 +1,26 @@ +from colossalai.core import global_context as gpc +from colossalai.registry import GRADIENT_HANDLER + +from ...context.parallel_mode import ParallelMode +from ._base_gradient_handler import BaseGradientHandler +from .utils import bucket_allreduce + + +@GRADIENT_HANDLER.register_module +class SequenceParallelGradientHandler(BaseGradientHandler): + """A helper class to handle all-reduce operations in a data parallel group. + A all-reduce collective communication will be operated in + :func:`handle_gradient` among a data parallel group. + For better performance, it bucketizes the gradients of all parameters that are + the same type to improve the efficiency of communication. + + Args: + model (Module): Model where the gradients accumulate. + optimizer (Optimizer): Optimizer for updating the parameters. + """ + + def handle_gradient(self): + """A method running a all-reduce operation in a data parallel group. + """ + if gpc.get_world_size(ParallelMode.SEQUENCE_DP) > 1: + bucket_allreduce(param_list=self._model.parameters(), group=gpc.get_group(ParallelMode.SEQUENCE_DP)) diff --git a/colossalai/engine/gradient_handler/_zero_gradient_handler.py b/colossalai/engine/gradient_handler/_zero_gradient_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..19fd1e97f86f035826666f766fa1983eb9aae2cc --- /dev/null +++ b/colossalai/engine/gradient_handler/_zero_gradient_handler.py @@ -0,0 +1,21 @@ +from colossalai.registry import GRADIENT_HANDLER + +from ._base_gradient_handler import BaseGradientHandler + + +@GRADIENT_HANDLER.register_module +class ZeROGradientHandler(BaseGradientHandler): + """A helper class to handle all-reduce operations in a data parallel group. + A all-reduce collective communication will be operated in + :func:`handle_gradient` among a data parallel group. + This class is specialized with ZeRO optimization. + + Args: + model (Module): Model where the gradients accumulate. + optimizer (Optimizer): Optimizer for updating the parameters. + """ + + def handle_gradient(self): + """A method running a all-reduce operation in a data parallel group. + """ + self._optimizer.sync_grad() diff --git a/colossalai/engine/gradient_handler/utils.py b/colossalai/engine/gradient_handler/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fca5f2ec9da9e73aa76e2e41ca5b99f6db8ef773 --- /dev/null +++ b/colossalai/engine/gradient_handler/utils.py @@ -0,0 +1,30 @@ +from typing import Iterable + +import torch.distributed as dist +import torch.nn as nn +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + + +def bucket_allreduce(param_list: Iterable[nn.Parameter], group=None): + # get communication world size + comm_size = dist.get_world_size(group) + # bucketize and all-reduce + buckets = {} + # Pack the buckets. + for param in param_list: + if param.requires_grad and param.grad is not None: + tp = param.data.type() + if tp not in buckets: + buckets[tp] = [] + buckets[tp].append(param) + + # For each bucket, all-reduce and copy all-reduced grads. + for tp in buckets: + bucket = buckets[tp] + grads = [param.grad.data for param in bucket] + coalesced = _flatten_dense_tensors(grads) + coalesced /= comm_size + + dist.all_reduce(coalesced, group=group) + for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): + buf.copy_(synced) diff --git a/colossalai/engine/schedule/__init__.py b/colossalai/engine/schedule/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0f2c039d7057324676d30938c6ec112279077b61 --- /dev/null +++ b/colossalai/engine/schedule/__init__.py @@ -0,0 +1,5 @@ +from ._base_schedule import BaseSchedule +from ._non_pipeline_schedule import NonPipelineSchedule +from ._pipeline_schedule import InterleavedPipelineSchedule, PipelineSchedule, get_tensor_shape + +__all__ = ['BaseSchedule', 'NonPipelineSchedule', 'PipelineSchedule', 'InterleavedPipelineSchedule', 'get_tensor_shape'] diff --git a/colossalai/engine/schedule/_base_schedule.py b/colossalai/engine/schedule/_base_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..a2d50041127ace67726f1390fbb58331925e8af5 --- /dev/null +++ b/colossalai/engine/schedule/_base_schedule.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from abc import ABC, abstractmethod +from typing import Callable, Iterable + +import torch + +from colossalai.logging import get_dist_logger +from colossalai.utils import get_current_device + + +class BaseSchedule(ABC): + """A basic helper class to control the process of training or evaluation. + It mainly composes of forward_backward_step for gradient backward and + optimizer_step for parameters update. + For the convenience to enable FP16, we aggregate all codes that contain the + control of FP16 in class schedule. + + Args: + data_process_func (Callable, optional): The preprocessing function which receives a batch of data and arranges them into data and label. + """ + + def __init__(self, data_process_func: Callable = None): + self.logger = get_dist_logger() + self.data_process_func = data_process_func + + @staticmethod + def _move_tensor(element): + if torch.is_tensor(element): + if not element.is_cuda: + return element.to(get_current_device()).detach() + return element + + def _move_to_device(self, data): + if isinstance(data, torch.Tensor): + data = data.to(get_current_device()) + elif isinstance(data, (list, tuple)): + data_to_return = [] + for element in data: + if isinstance(element, dict): + data_to_return.append({k: self._move_tensor(v) for k, v in element.items()}) + else: + data_to_return.append(self._move_tensor(element)) + data = data_to_return + elif isinstance(data, dict): + data = {k: self._move_tensor(v) for k, v in data.items()} + else: + raise TypeError( + f"Expected batch data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}") + return data + + def _get_batch_size(self, data): + if isinstance(data, torch.Tensor): + return data.size(0) + elif isinstance(data, (list, tuple)): + if isinstance(data[0], dict): + return data[0][list(data[0].keys())[0]].size(0) + return data[0].size(0) + elif isinstance(data, dict): + return data[list(data.keys())[0]].size(0) + + def load_batch(self, data_iter, to_gpu=True): + """Loads a batch from data iterator. It returns the data and labels which are + already in the same GPU as where the model's. + + Args: + data_iter (Iterable): Data iterator from which get a batch of data, obtained by calling iter(dataloader). + to_gpu (bool, optional): Whether the data should be moved to GPU + + Returns: + Tuple (:class:`Tensor`, :class:`torch.Tensor`): A tuple of (data, label). + """ + if data_iter is None: + raise RuntimeError('Dataloader is not defined.') + batch_data = next(data_iter) + + if to_gpu: + batch_data = self._move_to_device(batch_data) + self.batch_size = self._get_batch_size(batch_data) + return batch_data + + def pre_processing(self, engine): + """To perform actions before running the schedule. + """ + pass + + @abstractmethod + def forward_backward_step(self, + engine, + data_iter: Iterable, + forward_only: bool, + return_loss: bool = True, + return_output_label: bool = True): + """The process function over a batch of dataset for training or evaluation. + + Args: + engine (colossalai.engine.Engine): Colossalai engine for training and inference. + data_iter (Iterable): Data iterator from which get a batch of data, obtained by calling iter(dataloader). + forward_only (bool): If True, the process won't include backward. + return_loss (bool, optional): If False, the loss won't be returned. + return_output_label (bool, optional): If False, the output and label won't be returned. + """ + pass + + @staticmethod + def _call_engine(engine, inputs): + if isinstance(inputs, torch.Tensor): + return engine(inputs) + elif isinstance(inputs, (list, tuple)): + return engine(*inputs) + elif isinstance(inputs, dict): + return engine(**inputs) + else: + TypeError( + f"Expected engine inputs to be of type torch.Tensor, list, tuple, or dict, but got {type(inputs)}") + + @staticmethod + def _call_engine_criterion(engine, outputs, labels): + assert isinstance(outputs, + (torch.Tensor, list, tuple, + dict)), f'Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}' + if isinstance(outputs, torch.Tensor): + outputs = (outputs,) + if isinstance(labels, torch.Tensor): + labels = (labels,) + + if isinstance(outputs, (tuple, list)) and isinstance(labels, (tuple, list)): + return engine.criterion(*outputs, *labels) + elif isinstance(outputs, (tuple, list)) and isinstance(labels, dict): + return engine.criterion(*outputs, **labels) + elif isinstance(outputs, dict) and isinstance(labels, dict): + return engine.criterion(**outputs, **labels) + elif isinstance(outputs, dict) and isinstance(labels, (list, tuple)): + raise ValueError(f"Expected labels to be a dict when the model outputs are dict, but got {type(labels)}") + else: + raise TypeError(f"Expected model outputs and labels to be of type torch.Tensor ' \ + '(which is auto-converted to tuple), list, tuple, or dict, ' \ + 'but got {type(outputs)} (model outputs) and {type(labels)} (labels)") diff --git a/colossalai/engine/schedule/_non_pipeline_schedule.py b/colossalai/engine/schedule/_non_pipeline_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..b9239d928a7ba4e9471071f1c4e08c8443f5edb1 --- /dev/null +++ b/colossalai/engine/schedule/_non_pipeline_schedule.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import inspect +from typing import Callable, Iterable + +import torch + +from colossalai.utils import conditional_context + +from ._base_schedule import BaseSchedule + + +class NonPipelineSchedule(BaseSchedule): + """A helper schedule class for no pipeline parallelism running environment. + During one process, it loads a batch of dataset and feeds it to the model. + After getting the output and calculating the loss, it will use :meth:`step` + to update the parameters if it is in training mode. + + Args: + data_process_func (Callable, optional): The preprocessing function which receives a batch of data + and returns a tuple in the form of (data, label). + and it will be executed in load_batch. + + Example: + # this shows an example of customized data_process_func + def data_process_func(dataloader_output): + item1, item2, item3 = dataloader_output + data = (item1, item2) + label = item3 + return data, label + """ + + def __init__(self, data_process_func: Callable = None): + # check that non-pipeline schedule data process func only takes in one parameter + # which is the batch data + + if data_process_func: + sig = inspect.signature(data_process_func) + assert len(sig.parameters) == 1, \ + 'The data_process_func only takes in one parameter for NonPipelineSchedule, ' \ + 'which is a tuple of tensors for the current batch, ' \ + 'i.e. data_process_func(dataloader_output).' + + super().__init__(data_process_func) + + def forward_backward_step(self, + engine, + data_iter: Iterable, + forward_only: bool = False, + return_loss: bool = True, + return_output_label: bool = True): + """The process function that loads a batch of dataset and feeds it to the model. + The returned labels and loss will None if :attr:`return_loss` is False. + + Args: + engine (colossalai.engine.Engine): Colossalai engine for training and inference. + data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader). + forward_only (bool, optional): + If True, the model is run for the forward pass, else back propagation will be executed. + return_loss (bool, optional): Loss will be returned if True. + return_output_label (bool, optional): Output and label will be returned if True. + + Returns: + Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None. + """ + assert forward_only or return_loss, \ + "The argument 'return_loss' has to be True when 'forward_only' is False, but got False." + batch_data = self.load_batch(data_iter) + if self.data_process_func: + data, label = self.data_process_func(batch_data) + else: + # if not batch data process func is given, + # then we regard the batch data as a simple tuple of (data, label) + data, label = batch_data + + # forward + with conditional_context(torch.no_grad(), enable=forward_only): + output = self._call_engine(engine, data) + if return_loss: + loss = self._call_engine_criterion(engine, output, label) + + if not forward_only: + engine.backward(loss) + + if return_output_label: + if return_loss: + return output, label, loss + else: + return output, label, None + else: + if return_loss: + return None, None, loss + else: + return None, None, None diff --git a/colossalai/engine/schedule/_pipeline_schedule.py b/colossalai/engine/schedule/_pipeline_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..38175fe0941c1c053bf91fe1df558ee9e763c360 --- /dev/null +++ b/colossalai/engine/schedule/_pipeline_schedule.py @@ -0,0 +1,833 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import inspect +from typing import Callable, List, Tuple, Union + +import torch.cuda + +import colossalai.communication as comm +from colossalai.amp.naive_amp import NaiveAMPModel +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.utils import switch_virtual_pipeline_parallel_rank +from colossalai.utils.cuda import get_current_device + +from ._base_schedule import BaseSchedule + + +def get_tensor_shape(): + if hasattr(gpc.config, 'TENSOR_SHAPE'): + return gpc.config.TENSOR_SHAPE + + if not gpc.is_initialized(ParallelMode.PIPELINE): + return None + + if hasattr(gpc.config, 'SEQ_LENGTH') and hasattr(gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr( + gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr(gpc.config, 'HIDDEN_SIZE'): + if gpc.is_initialized(ParallelMode.DATA): + dp_size = gpc.get_world_size(ParallelMode.DATA) + else: + dp_size = 1 + if gpc.is_initialized(ParallelMode.SEQUENCE): + seq_size = gpc.get_world_size(ParallelMode.SEQUENCE) + else: + seq_size = 1 + + tensor_shape = (gpc.config.SEQ_LENGTH // seq_size, + gpc.config.GLOBAL_BATCH_SIZE // dp_size // gpc.config.NUM_MICRO_BATCHES, gpc.config.HIDDEN_SIZE) + return tensor_shape + else: + return None + + +def pack_return_tensors(return_tensors): + output, label = tuple(zip(*return_tensors)) + if isinstance(output[0], torch.Tensor): + output = torch.cat(output, dim=0) + elif isinstance(output[0], (list, tuple)): + output = tuple(torch.cat(tensors, dim=0) for tensors in zip(*output)) + else: + raise TypeError(f'Output of model must be tensor or list/tuple of tensors') + if isinstance(label[0], torch.Tensor): + label = torch.cat(label, dim=0) + else: + merged_label = {k: [] for k in label[0].keys()} + for d in label: + for k, v in d.items(): + merged_label[k].append(v) + label = {k: torch.cat(v, dim=0) for k, v in merged_label.items()} + return output, label + + +class PipelineSchedule(BaseSchedule): + """A helper schedule class for pipeline parallelism running environment. + It uses non-interleaved 1F1B strategy. Other properties are similar as + :class:`NonPipelineSchedule`. + + Args: + num_microbatches (int): The number of microbatches. + data_process_func (Callable, optional): + The preprocessing function which receives a batch of data, and it will be executed in `load_batch`. + tensor_shape (torch.Size, optional): Specified shape in pipeline communication. + scatter_gather_tensors (bool, optional): + If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization. + + Example: + + # this shows an example of customized data_process_func + def data_process_func(stage_output, dataloader_output): + output1, output2 = stage_output + item1, item2, item3 = dataloader_output + + # assume item2 is not needed + data = (output1, output2, item1) + label = item3 + return data, label + + """ + + def __init__(self, + num_microbatches, + data_process_func: Callable = None, + tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None, + scatter_gather_tensors: bool = False): + + # we need to make sure that the signature of the data_process_func is valid + if data_process_func: + sig = inspect.signature(data_process_func) + assert len(sig.parameters) == 2, \ + 'The data_process_func only takes in two parameters for NonPipelineSchedule, ' \ + 'which is the tensors passed by the previous pipeline stage and the dataloader output from this stage, ' \ + 'i.e. data_process_func(stage_output, dataloader_output).' + + super().__init__(data_process_func=data_process_func) + + assert num_microbatches > 0, f'expected num_microbatches to be larger then 1, but got {num_microbatches}' + + self.num_microbatches = num_microbatches + self.dtype = torch.float + assert not isinstance(tensor_shape, + int), "tensor_shape type should be one of Union[torch.Size, List[int], Tuple[int]]." + if tensor_shape is None: + self.tensor_shape = tensor_shape + elif isinstance(tensor_shape, torch.Size): + self.tensor_shape = tensor_shape + else: + self.tensor_shape = torch.Size(tensor_shape) + self.scatter_gather_tensors = False + if gpc.is_initialized(ParallelMode.PARALLEL_1D) and gpc.get_world_size(ParallelMode.PARALLEL_1D) > 1: + self.scatter_gather_tensors = scatter_gather_tensors + self._logger = get_dist_logger() + + # cache for the batch data + self.batch_data = None + + def load_batch(self, data_iter): + # Pipeline schedule just puts data in memory + batch_data = super().load_batch(data_iter, to_gpu=False) + self.microbatch_offset = 0 + assert self.batch_size % self.num_microbatches == 0, \ + "Batch size should divided by the number of microbatches" + self.microbatch_size = self.batch_size // self.num_microbatches + self.batch_data = batch_data + + def _get_data_slice(self, data, offset): + if isinstance(data, torch.Tensor): + return data[offset:offset + self.microbatch_size] + elif isinstance(data, (list, tuple)): + data_dict = {} + for element in data: + if isinstance(element, dict): + data_dict.update({k: v[offset:offset + self.microbatch_size] for k, v in element.items()}) + elif data_dict: + data_dict['label'] = element[offset:offset + self.microbatch_size] + if data_dict: + return data_dict + return [val[offset:offset + self.microbatch_size] for val in data] + elif isinstance(data, dict): + return {k: v[offset:offset + self.microbatch_size] for k, v in data.items()} + else: + raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}") + + def load_micro_batch(self): + mciro_batch_data = self._get_data_slice(self.batch_data, self.microbatch_offset) + self.microbatch_offset += self.microbatch_size + return self._move_to_device(mciro_batch_data) + + def pre_processing(self, engine): + from colossalai.zero.legacy import ShardedModelV2 + + # TODO: remove this after testing new zero with pipeline parallelism + model = engine.model + if isinstance(model, NaiveAMPModel): + self.dtype = torch.half + model = model.model + if isinstance(model, ShardedModelV2): + self.dtype = torch.half + model = model.module + # sig = inspect.signature(model.forward) + # for p in sig.parameters.values(): + # assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported' + + @staticmethod + def _call_engine(model, data): + if data is not None: + if isinstance(data, torch.Tensor): + return model(data) + elif isinstance(data, (list, tuple)): + return model(*data) + elif isinstance(data, dict): + stage_output = None + if 'stage_output' in data: + stage_output = data.pop('stage_output') + if stage_output is None: + return model(**data) + elif isinstance(stage_output, torch.Tensor): + return model(stage_output, **data) + elif isinstance(stage_output, (tuple, list)): + return model(*stage_output, **data) + else: + raise TypeError( + f"Expected stage_output to be of type torch.Tensor, list, or tuple, but got {type(stage_output)}" + ) + else: + raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}") + + def _get_actual_forward_func(self, module): + if isinstance(module, NaiveAMPModel): + sig = inspect.signature(module.model.forward) + elif hasattr(module, 'colo_attr'): + sig = inspect.signature(module.module.forward) + else: + sig = inspect.signature(module.forward) + return sig + + def _get_data_label_for_current_step(self, stage_output, micro_batch_data, criterion, model): + if self.data_process_func: + # use customized function to get data and label + data, label = self.data_process_func(stage_output, micro_batch_data) + else: + if isinstance(micro_batch_data, (tuple, list)): + if gpc.is_first_rank(ParallelMode.PIPELINE): + # for the first stage, we use the data from the + # dataloader output by default + data, label = micro_batch_data + else: + # for non-first stage, we use the output passed + # by the previous as the model input + data = stage_output + _, label = micro_batch_data + elif isinstance(micro_batch_data, dict): + data = {} + data['stage_output'] = stage_output + if 'label' in micro_batch_data: + label = micro_batch_data.pop('label') + else: + label = None + load_data = micro_batch_data + data.update(load_data) + return data, label + + def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None): + """Forward step for passed-in model. If it is the first stage, the input tensor + is obtained from data_iterator, otherwise the passed-in input_obj is used. + Returns output tensor. This is a helper function and can be ignored by users. + + Args: + engine (colossalai.engine.Engine): Colossalai engine for training and inference. + input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage. + return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return. + return_output_label (bool, optional): Whether returns output labels. + accum_loss (optional): Where accumulated loss stores. + Returns: + Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current pipeline stage. + """ + micro_batch_data = self.load_micro_batch() + + data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data, engine.criterion, engine.model) + + output_obj = self._call_engine(engine.model, data) + + if gpc.is_last_rank(ParallelMode.PIPELINE): + if return_output_label: + return_tensors.append((output_obj, label)) + if accum_loss is not None: + loss_reduced = self._call_engine_criterion(engine, output_obj, label) / self.num_microbatches + accum_loss.add_(loss_reduced.detach()) + return loss_reduced + else: + # forward only, it's useless since backward is not needed + return output_obj + else: + if isinstance(output_obj, torch.Tensor): + self._logger.debug( + f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_obj.shape}, dtype {output_obj.dtype}' + ) + return output_obj + + def _backward_step(self, engine, input_obj, output_obj, output_obj_grad): + """Backward step through the passed-in output tensor. If it is the last stage, the + output_obj_grad is None, otherwise it is the gradients with respect to stage's output tensor. + Returns the gradients with respect to the input tensor (None if first stage). + This is a helper function and can be ignored by users. + + Args: + engine (colossalai.engine.Engine): Colossalai engine for training and inference. + input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): input tensor for this pipeline stage. + output_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): output tensor for this pipeline stage. + output_obj_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): gradient of output tensor for this pipeline stage. + + Returns: + Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: gradient of input tensor. + """ + + # Retain the grad on the input_obj. + if input_obj is not None: + if isinstance(input_obj, torch.Tensor): + input_obj.retain_grad() + else: + for in_tensor in input_obj: + if in_tensor is not None: + in_tensor.retain_grad() + # Backward pass. + if output_obj_grad is None: + engine.backward(output_obj) + else: + engine.backward_by_grad(output_obj, output_obj_grad) + + # Collect the grad of the input_obj. + input_obj_grad = None + if input_obj is not None: + if isinstance(input_obj, torch.Tensor): + input_obj_grad = input_obj.grad + else: + input_obj_grad = [] + for in_tensor in input_obj: + input_obj_grad.append(in_tensor.grad) + + return input_obj_grad + + def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True): + """Runs non-interleaved 1F1B schedule, with communication between pipeline stages. + Returns a tuple with losses if the last stage, an empty tuple otherwise. + + Args: + engine (colossalai.engine.Engine): Colossalai engine for training and inference. + data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader). + forward_only (bool, optional): + Whether run forward step only. Default is false. If true, no backward will be run. + return_loss (bool, optional): Whether returns the loss value. Default is true. + return_output_label (bool, optional): If False, the output and label won't be returned. + + Returns: + Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None. + """ + + assert forward_only or return_loss, \ + 'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.' + self.load_batch(data_iter) + num_warmup_microbatches = \ + (gpc.get_world_size(ParallelMode.PIPELINE) + - gpc.get_local_rank(ParallelMode.PIPELINE) - 1) + num_warmup_microbatches = min(num_warmup_microbatches, self.num_microbatches) + num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches + + # Input, output tensors only need to be saved when doing backward passes + input_objs = None + output_objs = None + if not forward_only: + input_objs = [] + output_objs = [] + return_tensors = [] + if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): + accum_loss = torch.zeros(1, device=get_current_device()) + else: + accum_loss = None + # Used for tensor meta information communication + ft_shapes = self.tensor_shape + bt_shapes = None + fs_checker = self.tensor_shape is None + + # Run warmup forward passes. + for i in range(num_warmup_microbatches): + if not gpc.is_first_rank(ParallelMode.PIPELINE): + ft_shapes = comm.recv_obj_meta(ft_shapes) + input_obj = comm.recv_forward(ft_shapes, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors) + output_obj = self._forward_step(engine, + input_obj, + return_tensors, + return_output_label=return_output_label, + accum_loss=accum_loss) + if not gpc.is_last_rank(ParallelMode.PIPELINE): + if isinstance(output_obj, torch.Tensor): + bt_shapes = output_obj.shape + else: + bt_shapes = [] + for out_tensor in output_obj: + bt_shapes.append(out_tensor.shape) + fs_checker = comm.send_obj_meta(output_obj, fs_checker) + comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors) + + if not forward_only: + input_objs.append(input_obj) + output_objs.append(output_obj) + + # Before running 1F1B, need to receive first forward tensor. + # If all microbatches are run in warmup / cooldown phase, then no need to + # receive this tensor here. + if num_microbatches_remaining > 0: + if not gpc.is_first_rank(ParallelMode.PIPELINE): + ft_shapes = comm.recv_obj_meta(ft_shapes) + input_obj = comm.recv_forward(ft_shapes, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors) + + # Run 1F1B in steady state. + for i in range(num_microbatches_remaining): + last_iteration = (i == (num_microbatches_remaining - 1)) + + output_obj = self._forward_step(engine, + input_obj, + return_tensors, + return_output_label=return_output_label, + accum_loss=accum_loss) + if forward_only: + comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors) + + if not last_iteration: + input_obj = comm.recv_forward(ft_shapes, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors) + + else: + output_obj_grad = comm.send_forward_recv_backward(output_obj, + bt_shapes, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors) + + # Add input_obj and output_obj to end of list. + input_objs.append(input_obj) + output_objs.append(output_obj) + + # Pop output_obj and output_obj from the start of the list for + # the backward pass. + input_obj = input_objs.pop(0) + output_obj = output_objs.pop(0) + + input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad) + + if last_iteration: + input_obj = None + comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors) + else: + input_obj = comm.send_backward_recv_forward(input_obj_grad, + ft_shapes, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors) + + # Run cooldown backward passes. + if not forward_only: + for i in range(num_warmup_microbatches): + input_obj = input_objs.pop(0) + output_obj = output_objs.pop(0) + + output_obj_grad = comm.recv_backward(bt_shapes, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors) + + input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad) + + comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors) + + if len(return_tensors) > 0: + output, label = pack_return_tensors(return_tensors) + return output, label, accum_loss + else: + return None, None, accum_loss + + +class InterleavedPipelineSchedule(PipelineSchedule): + + def __init__(self, + num_microbatches: int, + num_model_chunks: int, + data_process_func: Callable = None, + tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None, + scatter_gather_tensors: bool = False): + """A helper schedule class for pipeline parallelism running environment. + It uses interleaved 1F1B strategy. Other properties are similar as + :class:`NonPipelineSchedule`. + + Args: + num_microbatches (int): The number of microbatches. + num_model_chunks (int): The number of model chunks. + data_process_func (Callable, optional): + The preprocessing function which receives a batch of data, and it will be executed in `load_batch`. + tensor_shape (torch.Size, optional): Specified shape in pipeline communication. + scatter_gather_tensors (bool, optional): + If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization. + """ + assert num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0, \ + 'num_microbatches must be an integer multiple of pipeline parallel world size' + assert isinstance(num_model_chunks, int) and num_model_chunks > 0, \ + f'expected num_model_chunks to be an integer and larger than 0, but got {num_model_chunks}' + super().__init__(num_microbatches, + data_process_func=data_process_func, + tensor_shape=tensor_shape, + scatter_gather_tensors=scatter_gather_tensors) + gpc.set_virtual_pipeline_parallel_size(num_model_chunks) + gpc.set_virtual_pipeline_parallel_rank(0) + self.num_model_chunks = num_model_chunks + + def pre_processing(self, engine): + from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 + if isinstance(engine.model, ShardedModelV2): + self.dtype = torch.half + elif isinstance(engine.model[0], NaiveAMPModel): + self.dtype = torch.half + for model in engine.model: + if isinstance(model, NaiveAMPModel): + model = model.model + sig = inspect.signature(model.forward) + for p in sig.parameters.values(): + assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported' + + def load_batch(self, data_iter): + super().load_batch(data_iter) + # overwrite microbatch_offset, since model chunks load the same microbatch, and should tract the offset + self.microbatch_offset = [0 for _ in range(self.num_model_chunks)] + + def load_micro_batch(self, model_chunk_id): + data = self._get_data_slice(self.batch_data, self.microbatch_offset[model_chunk_id]) + self.microbatch_offset[model_chunk_id] += self.microbatch_size + return self._move_to_device(data) + + def _forward_step(self, + engine, + model_chunk_id, + input_obj, + return_tensors, + return_output_label=True, + accum_loss=None): + """Forward step for passed-in model. If it is the first stage, the input tensor + is obtained from data_iterator, otherwise the passed-in input_obj is used. + Returns output tensor. This is a helper function and can be ignored by users. + + Args: + engine (colossalai.engine.Engine): Colossalai engine for training and inference. + model_chunk_id (int): The id of model chunks. + input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage. + return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return. + return_output_label (bool, optional): Whether returns output labels. + accum_loss (optional): Where accumulated loss stores. + Returns: + Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current pipeline stage. + """ + micro_batch_data = self.load_micro_batch(model_chunk_id) + data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data, engine.criterion, + engine.model[model_chunk_id]) + + output_obj = self._call_engine(engine.model[model_chunk_id], data) + + if gpc.is_pipeline_last_stage(): + if return_output_label: + return_tensors.append((output_obj, label)) + if accum_loss is not None: + loss_reduced = self._call_engine_criterion(engine, output_obj, label) / self.num_microbatches + accum_loss.add_(loss_reduced.detach()) + return loss_reduced + else: + # forward only, it's useless since backward is not needed + return output_obj + else: + if isinstance(output_obj, torch.Tensor): + self._logger.debug( + f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_obj.shape}, dtype {output_obj.dtype}' + ) + return output_obj + + def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True): + """Run interleaved 1F1B schedule (model split into model chunks), with + communication between pipeline stages as needed. + + Args: + engine (colossalai.engine.Engine): Colossalai engine for training and inference. + data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader). + forward_only (bool, optional): + Whether run forward step only. Default is false. If true, no backward will be run. + return_loss (bool, optional): Whether returns the loss value. Default is true. + return_output_label (bool, optional): If False, the output and label won't be returned. + + Returns: + Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None. + The loss would be returned only in the last stage. + """ + assert forward_only or return_loss, \ + 'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.' + self.load_batch(data_iter) + model = engine.model + input_objs = [[] for _ in range(len(model))] + output_objs = [[] for _ in range(len(model))] + return_tensors = [] + if not forward_only: + output_obj_grads = [[] for _ in range(len(model))] + if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): + accum_loss = torch.zeros(1, device=get_current_device()) + else: + accum_loss = None + + # Used for obj meta information communication + input_obj_shapes = [self.tensor_shape for _ in range(len(model))] + output_obj_shapes = [None for _ in range(len(model))] + send_tensor_shape_flags = [self.tensor_shape is None for _ in range(len(model))] + + pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_parallel_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + + # Compute number of warmup and remaining microbatches. + num_model_chunks = len(model) + num_microbatches = self.num_microbatches * num_model_chunks + all_warmup_microbatches = False + if forward_only: + num_warmup_microbatches = num_microbatches + else: + # Run all forward passes and then all backward passes if number of + # microbatches is just the number of pipeline stages. + # Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on + # all workers, followed by more microbatches after depending on + # stage ID (more forward passes for earlier stages, later stages can + # immediately start with 1F1B). + if self.num_microbatches == pipeline_parallel_size: + num_warmup_microbatches = num_microbatches + all_warmup_microbatches = True + else: + num_warmup_microbatches = \ + (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2 + num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size + num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) + num_microbatches_remaining = \ + num_microbatches - num_warmup_microbatches + + def get_model_chunk_id(microbatch_id, forward): + """Helper method to get the model chunk ID given the iteration number.""" + microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks) + model_chunk_id = microbatch_id_in_group // pipeline_parallel_size + if not forward: + model_chunk_id = (num_model_chunks - model_chunk_id - 1) + return model_chunk_id + + def _forward_step_helper(microbatch_id): + """Helper method to run forward step with model split into chunks + (run set_virtual_pipeline_model_parallel_rank() before calling + forward_step()).""" + model_chunk_id = get_model_chunk_id(microbatch_id, forward=True) + gpc.set_virtual_pipeline_parallel_rank(model_chunk_id) + + # forward step + if gpc.is_pipeline_first_stage(): + if len(input_objs[model_chunk_id]) == \ + len(output_objs[model_chunk_id]): + input_objs[model_chunk_id].append(None) + input_obj = input_objs[model_chunk_id][-1] + output_obj = self._forward_step(engine, + model_chunk_id, + input_obj, + return_tensors, + return_output_label=return_output_label, + accum_loss=accum_loss) + output_objs[model_chunk_id].append(output_obj) + + # if forward-only, no need to save tensors for a backward pass + if forward_only: + input_objs[model_chunk_id].pop() + output_objs[model_chunk_id].pop() + + return output_obj + + def _backward_step_helper(microbatch_id): + """Helper method to run backward step with model split into chunks + (run set_virtual_pipeline_model_parallel_rank() before calling + backward_step()).""" + model_chunk_id = get_model_chunk_id(microbatch_id, forward=False) + gpc.set_virtual_pipeline_parallel_rank(model_chunk_id) + + if gpc.is_pipeline_last_stage(): + if len(output_obj_grads[model_chunk_id]) == 0: + output_obj_grads[model_chunk_id].append(None) + input_obj = input_objs[model_chunk_id].pop(0) + output_obj = output_objs[model_chunk_id].pop(0) + output_obj_grad = output_obj_grads[model_chunk_id].pop(0) + input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad) + + return input_obj_grad + + # Run warmup forward passes. + gpc.set_virtual_pipeline_parallel_rank(0) + if not gpc.is_pipeline_first_stage(): + input_obj_shapes[0] = comm.recv_obj_meta(input_obj_shapes[0]) + input_objs[0].append( + comm.recv_forward(input_obj_shapes[0], dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors)) + + for k in range(num_warmup_microbatches): + model_chunk_id = get_model_chunk_id(k, forward=True) + output_obj = _forward_step_helper(k) + if not gpc.is_pipeline_last_stage(): + if isinstance(output_obj, torch.Tensor): + output_obj_shapes[model_chunk_id] = output_obj.shape + else: + output_obj_shapes[model_chunk_id] = [] + for out_tensor in output_obj: + output_obj_shapes[model_chunk_id].append(out_tensor.shape) + send_tensor_shape_flags[model_chunk_id] = comm.send_obj_meta(output_obj, + send_tensor_shape_flags[model_chunk_id]) + # Determine if tensor should be received from previous stage. + next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True) + recv_prev = True + if gpc.is_pipeline_first_stage(ignore_virtual=True): + if next_forward_model_chunk_id == 0: + recv_prev = False + if k == (num_microbatches - 1): + recv_prev = False + + # Don't send tensor downstream if on last stage. + if gpc.is_pipeline_last_stage(): + output_obj = None + + with switch_virtual_pipeline_parallel_rank(next_forward_model_chunk_id): + if not gpc.is_pipeline_first_stage(): + input_obj_shapes[next_forward_model_chunk_id] = comm.recv_obj_meta( + input_obj_shapes[next_forward_model_chunk_id]) + # Send and receive tensors as appropriate (send tensors computed + # in this iteration; receive tensors for next iteration). + input_shape = input_obj_shapes[next_forward_model_chunk_id] if recv_prev else None + if k == (num_warmup_microbatches - 1) and not forward_only and \ + not all_warmup_microbatches: + input_obj_grad = None + recv_next = True + if gpc.is_pipeline_last_stage(ignore_virtual=True): + recv_next = False + output_shape = output_obj_shapes[num_model_chunks - 1] if recv_next else None + input_obj, output_obj_grad = \ + comm.send_forward_backward_recv_forward_backward( + output_obj, input_obj_grad, + input_shape, + output_shape, + recv_prev=recv_prev, recv_next=recv_next, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors) + output_obj_grads[num_model_chunks - 1].append(output_obj_grad) + else: + input_obj = \ + comm.send_forward_recv_forward( + output_obj, + input_shape, + recv_prev=recv_prev, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors) + input_objs[next_forward_model_chunk_id].append(input_obj) + + # Run 1F1B in steady state. + for k in range(num_microbatches_remaining): + # Forward pass. + forward_k = k + num_warmup_microbatches + output_obj = _forward_step_helper(forward_k) + + # Backward pass. + backward_k = k + input_obj_grad = _backward_step_helper(backward_k) + + # Send output_obj and input_obj_grad, receive input_obj + # and output_obj_grad. + + # Determine if current stage has anything to send in either direction, + # otherwise set obj to None. + forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True) + gpc.set_virtual_pipeline_parallel_rank(forward_model_chunk_id) + if gpc.is_pipeline_last_stage(): + output_obj = None + + backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False) + gpc.set_virtual_pipeline_parallel_rank(backward_model_chunk_id) + if gpc.is_pipeline_first_stage(): + input_obj_grad = None + + # Determine if peers are sending, and where in data structure to put + # received tensors. + recv_prev = True + if gpc.is_pipeline_first_stage(ignore_virtual=True): + # First stage is ahead of last stage by (pipeline_parallel_size - 1). + next_forward_model_chunk_id = get_model_chunk_id(forward_k - (pipeline_parallel_size - 1), forward=True) + if next_forward_model_chunk_id == (num_model_chunks - 1): + recv_prev = False + next_forward_model_chunk_id += 1 + else: + next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True) + + recv_next = True + if gpc.is_pipeline_last_stage(ignore_virtual=True): + # Last stage is ahead of first stage by (pipeline_parallel_size - 1). + next_backward_model_chunk_id = get_model_chunk_id(backward_k - (pipeline_parallel_size - 1), + forward=False) + if next_backward_model_chunk_id == 0: + recv_next = False + next_backward_model_chunk_id -= 1 + else: + next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False) + + # If last iteration, don't receive; we already received one extra + # before the start of the for loop. + if k == (num_microbatches_remaining - 1): + recv_prev = False + + input_shape = input_obj_shapes[next_forward_model_chunk_id] if recv_prev else None + output_shape = output_obj_shapes[next_backward_model_chunk_id] if recv_next else None + # Communicate objs. + input_obj, output_obj_grad = \ + comm.send_forward_backward_recv_forward_backward( + output_obj, input_obj_grad, + input_shape, + output_shape, + recv_prev=recv_prev, recv_next=recv_next, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors) + + # Put input_obj and output_obj_grad in data structures in the + # right location. + if recv_prev: + input_objs[next_forward_model_chunk_id].append(input_obj) + if recv_next: + output_obj_grads[next_backward_model_chunk_id].append(output_obj_grad) + + # Run cooldown backward passes (flush out pipeline). + if not forward_only: + if all_warmup_microbatches: + output_obj_grads[num_model_chunks - 1].append( + comm.recv_backward(output_obj_shapes[num_model_chunks - 1], + scatter_gather_tensors=self.scatter_gather_tensors)) + for k in range(num_microbatches_remaining, num_microbatches): + input_obj_grad = _backward_step_helper(k) + next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False) + recv_next = True + if gpc.is_pipeline_last_stage(ignore_virtual=True): + if next_backward_model_chunk_id == (num_model_chunks - 1): + recv_next = False + if k == (num_microbatches - 1): + recv_next = False + output_shape = output_obj_shapes[next_backward_model_chunk_id] if recv_next else None + output_obj_grads[next_backward_model_chunk_id].append( + comm.send_backward_recv_backward(input_obj_grad, + output_shape, + recv_next=recv_next, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors)) + + if len(return_tensors) > 0: + output, label = pack_return_tensors(return_tensors) + return output, label, accum_loss + else: + return None, None, accum_loss diff --git a/colossalai/engine/schedule/_pipeline_schedule_v2.py b/colossalai/engine/schedule/_pipeline_schedule_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..28c58bd82b5c3f6969337c0a718a5698346744d9 --- /dev/null +++ b/colossalai/engine/schedule/_pipeline_schedule_v2.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from typing import Iterable, Tuple + +import torch.cuda + +import colossalai.communication.p2p_v2 as comm +from colossalai import engine +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.utils.cuda import get_current_device + +from ._pipeline_schedule import PipelineSchedule + + +def pack_return_tensors(return_tensors): + output, label = tuple(zip(*return_tensors)) + if isinstance(output[0], torch.Tensor): + output = torch.cat(output, dim=0) + elif isinstance(output[0], (list, tuple)): + output = tuple(torch.cat(tensors, dim=0) for tensors in zip(*output)) + else: + raise TypeError(f'Output of model must be tensor or list/tuple of tensors') + if isinstance(label[0], torch.Tensor): + label = torch.cat(label, dim=0) + else: + merged_label = {k: [] for k in label[0].keys()} + for d in label: + for k, v in d.items(): + merged_label[k].append(v) + label = {k: torch.cat(v, dim=0) for k, v in merged_label.items()} + return output, label + + +class PipelineScheduleV2(PipelineSchedule): + """Derived class of PipelineSchedule, the only difference is that + forward_backward_step is reconstructed with p2p_v2 + + Args: + num_microbatches (int): The number of microbatches. + data_process_func (Callable, optional): + The preprocessing function which receives a batch of data, and it will be executed in `load_batch`. + tensor_shape (torch.Size, optional): Specified shape in pipeline communication. + scatter_gather_tensors (bool, optional): + If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization. + + Example: + + # this shows an example of customized data_process_func + def data_process_func(stage_output, dataloader_output): + output1, output2 = stage_output + item1, item2, item3 = dataloader_output + + # assume item2 is not needed + data = (output1, output2, item1) + label = item3 + return data, label + + """ + + def forward_backward_step(self, + engine: engine.Engine, + data_iter: Iterable, + forward_only=False, + return_loss=True, + return_output_label=True) -> Tuple[torch.Tensor]: + """Runs non-interleaved 1F1B schedule, with communication between pipeline stages. + Returns a tuple with losses if the last stage, an empty tuple otherwise. + + Args: + engine (colossalai.engine.Engine): Colossalai engine for training and inference. + data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader). + forward_only (bool, optional): + Whether run forward step only. Default is false. If true, no backward will be run. + return_loss (bool, optional): Whether returns the loss value. Default is true. + return_output_label (bool, optional): If False, the output and label won't be returned. + + Returns: + Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None. + """ + + assert forward_only or return_loss, \ + 'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.' + self.load_batch(data_iter) + + # num_warmup_microbatches is the step when not all the processers are working + num_warmup_microbatches = \ + (gpc.get_world_size(ParallelMode.PIPELINE) + - gpc.get_local_rank(ParallelMode.PIPELINE) - 1) + num_warmup_microbatches = min(num_warmup_microbatches, self.num_microbatches) + num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches + + # Input, output tensors only need to be saved when doing backward passes + input_objs = None + output_objs = None + # local_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + + if not forward_only: + input_objs = [] + output_objs = [] + return_tensors = [] + if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): + accum_loss = torch.zeros(1, device=get_current_device()) + else: + accum_loss = None + + # Run warmup forward passes. + for i in range(num_warmup_microbatches): + input_obj = comm.recv_forward() + + output_obj = self._forward_step(engine, + input_obj, + return_tensors, + return_output_label=return_output_label, + accum_loss=accum_loss) + + comm.send_forward(output_obj) + + if not forward_only: + input_objs.append(input_obj) + output_objs.append(output_obj) + + # Before running 1F1B, need to receive first forward tensor. + # If all microbatches are run in warmup / cooldown phase, then no need to + # receive this tensor here. + if num_microbatches_remaining > 0: + input_obj = comm.recv_forward() + + # Run 1F1B in steady state. + for i in range(num_microbatches_remaining): + last_iteration = (i == (num_microbatches_remaining - 1)) + + output_obj = self._forward_step(engine, + input_obj, + return_tensors, + return_output_label=return_output_label, + accum_loss=accum_loss) + if forward_only: + comm.send_forward(output_obj) + + if not last_iteration: + input_obj = comm.recv_forward() + + else: + # TODO adjust here + comm.send_forward(output_obj) + output_obj_grad = comm.recv_backward() + + # Add input_obj and output_obj to end of list. + input_objs.append(input_obj) + output_objs.append(output_obj) + + # Pop output_obj and output_obj from the start of the list for + # the backward pass. + input_obj = input_objs.pop(0) + output_obj = output_objs.pop(0) + + input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad) + + if last_iteration: + input_obj = None + comm.send_backward(input_obj_grad) + else: + input_obj = comm.recv_forward() + comm.send_backward(input_obj_grad) + + # Run cooldown backward passes. + if not forward_only: + for i in range(num_warmup_microbatches): + input_obj = input_objs.pop(0) + output_obj = output_objs.pop(0) + + output_obj_grad = comm.recv_backward() + input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad) + comm.send_backward(input_obj_grad) + + if len(return_tensors) > 0: + output, label = pack_return_tensors(return_tensors) + return output, label, accum_loss + else: + return None, None, accum_loss diff --git a/colossalai/fx/__init__.py b/colossalai/fx/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d39fa579943d7a4f3cccd7ea7f6e2a20613373c8 --- /dev/null +++ b/colossalai/fx/__init__.py @@ -0,0 +1,4 @@ +from ._compatibility import compatibility, is_compatible_with_meta +from .graph_module import ColoGraphModule +from .passes import MetaInfoProp, metainfo_trace +from .tracer import ColoTracer, meta_trace, symbolic_trace diff --git a/colossalai/fx/_compatibility.py b/colossalai/fx/_compatibility.py new file mode 100644 index 0000000000000000000000000000000000000000..0444a481627356e202912c491495364134aa65dd --- /dev/null +++ b/colossalai/fx/_compatibility.py @@ -0,0 +1,54 @@ +from typing import Callable + +import torch + +TORCH_MAJOR = int(torch.__version__.split('.')[0]) +TORCH_MINOR = int(torch.__version__.split('.')[1]) + +if TORCH_MAJOR == 1 and TORCH_MINOR < 12: + META_COMPATIBILITY = False +elif TORCH_MAJOR == 1 and TORCH_MINOR == 12: + from . import _meta_regist_12 + META_COMPATIBILITY = True +elif TORCH_MAJOR == 1 and TORCH_MINOR == 13: + from . import _meta_regist_13 + META_COMPATIBILITY = True +elif TORCH_MAJOR == 2: + META_COMPATIBILITY = True + + +def compatibility(is_backward_compatible: bool = False) -> Callable: + """A decorator to make a function compatible with different versions of PyTorch. + + Args: + is_backward_compatible (bool, optional): Whether the function is backward compatible. Defaults to False. + + Returns: + Callable: The decorated function + """ + + def decorator(func): + if META_COMPATIBILITY: + return func + else: + if is_backward_compatible: + return func + else: + + def wrapper(*args, **kwargs): + raise RuntimeError(f'Function `{func.__name__}` is not compatible with PyTorch {torch.__version__}') + + return wrapper + + return decorator + + +def is_compatible_with_meta() -> bool: + """Check the meta compatibility. Normally it should be called before importing some of the `colossalai.fx` + modules. If the meta compatibility is not satisfied, the `colossalai.fx` modules will be replaced by its + experimental counterparts. + + Returns: + bool: The meta compatibility + """ + return META_COMPATIBILITY diff --git a/colossalai/fx/_meta_regist_12.py b/colossalai/fx/_meta_regist_12.py new file mode 100644 index 0000000000000000000000000000000000000000..52e8d63ae54355a0d3e27dfc6a3347a2304dc0d7 --- /dev/null +++ b/colossalai/fx/_meta_regist_12.py @@ -0,0 +1,506 @@ +# meta patch from https://github.com/pytorch/pytorch/blob/master/torch/_meta_registrations.py +# should be activated for PyTorch version 1.12.0 and below +# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml +# for more meta_registrations + +from typing import Callable, List, Optional, Tuple, Union + +import torch +from torch.utils._pytree import tree_map + +aten = torch.ops.aten + +meta_lib = torch.library.Library("aten", "IMPL", "Meta") + +meta_table = {} + + +def register_meta(op, register_dispatcher=True): + + def wrapper(f): + + def add_func(op): + meta_table[op] = f + if register_dispatcher: + name = (op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__) + try: + meta_lib.impl(name, f) + except: + pass + + tree_map(add_func, op) + return f + + return wrapper + + +# ============================== Convolutions ====================================== +# https://github.com/pytorch/pytorch/pull/79834 +@register_meta(aten.convolution.default) +def meta_conv( + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: List[int], + padding: List[int], + dilation: List[int], + is_transposed: bool, + output_padding: List[int], + groups: int, +): + + def _formula(ln: int, p: int, d: int, k: int, s: int) -> int: + """ + Formula to apply to calculate the length of some dimension of the output + See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html + Args: + ln: length of the dimension + p: padding in that dim + d: dilation in that dim + k: kernel size in that dim + s: stride in that dim + Returns: + The output length + """ + return (ln + 2 * p - d * (k - 1) - 1) // s + 1 + + def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int: + """ + Formula to apply to calculate the length of some dimension of the output + if transposed convolution is used. + See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html + Args: + ln: length of the dimension + p: padding in that dim + d: dilation in that dim + k: kernel size in that dim + s: stride in that dim + op: output padding in that dim + Returns: + The output length + """ + return (ln - 1) * s - 2 * p + d * (k - 1) + op + 1 + + def calc_conv_nd_return_shape( + dims: torch.Size, + kernel_size: torch.Size, + stride: Union[List[int], int], + padding: Union[List[int], int], + dilation: Union[List[int], int], + output_padding: Optional[Union[List[int], int]] = None, + ): + ret_shape = [] + if isinstance(stride, int): + stride = [stride] * len(dims) + elif len(stride) == 1: + stride = [stride[0]] * len(dims) + + if isinstance(padding, int): + padding = [padding] * len(dims) + elif len(padding) == 1: + padding = [padding[0]] * len(dims) + + if isinstance(dilation, int): + dilation = [dilation] * len(dims) + elif len(dilation) == 1: + dilation = [dilation[0]] * len(dims) + + output_padding_list: Optional[List[int]] = None + if output_padding: + if isinstance(output_padding, int): + output_padding_list = [output_padding] * len(dims) + elif len(output_padding) == 1: + output_padding_list = [output_padding[0]] * len(dims) + else: + output_padding_list = output_padding + + for i in range(len(dims)): + # If output_padding is present, we are dealing with a transposed convolution + if output_padding_list: + ret_shape.append( + _formula_transposed( + dims[i], + padding[i], + dilation[i], + kernel_size[i], + stride[i], + output_padding_list[i], + )) + else: + ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i])) + return ret_shape + + def pick_memory_format(): + if input_tensor.is_contiguous(memory_format=torch.channels_last): + return torch.channels_last + elif input_tensor.is_contiguous(memory_format=torch.contiguous_format): + return torch.contiguous_format + elif input_tensor.is_contiguous(memory_format=torch.preserve_format): + return torch.preserve_format + + kernel_size = weight.shape[2:] + dims = input_tensor.shape[2:] + if is_transposed: + out_channels = groups * weight.shape[1] + + shape_out = calc_conv_nd_return_shape( + dims, + kernel_size, + stride, + padding, + dilation, + output_padding, + ) + + else: + out_channels = weight.shape[0] + if weight.shape[1] != input_tensor.shape[1] / groups: + raise RuntimeError("Invalid channel dimensions") + shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation) + out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out)) + mem_fmt = pick_memory_format() + out = out.to(memory_format=mem_fmt) # type: ignore[call-overload] + return out + + +@register_meta(aten._convolution.default) +def meta_conv_1(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int], + padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int, + *extra_args): + out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups) + return out + + +@register_meta(aten.convolution_backward.default) +def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride, + padding, dilation, transposed, output_padding, groups, output_mask): + return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device='meta') + + +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp +@register_meta(aten._adaptive_avg_pool2d_backward.default) +def meta_adaptive_avg_pool2d_backward( + grad_output: torch.Tensor, + input: torch.Tensor, +): + grad_input = torch.empty_like(input) + return grad_input + + +# ================================ RNN ============================================= +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp +@register_meta(aten._cudnn_rnn.default) +def meta_cuda_rnn( + input, + weight, + weight_stride0, + weight_buf, + hx, + cx, + mode, + hidden_size, + proj_size, + num_layers, + batch_first, + dropout, + train, + bidirectional, + batch_sizes, + dropout_state, +): + + is_input_packed = len(batch_sizes) != 0 + if is_input_packed: + seq_length = len(batch_sizes) + mini_batch = batch_sizes[0] + batch_sizes_sum = input.shape[0] + else: + seq_length = input.shape[1] if batch_first else input.shape[0] + mini_batch = input.shape[0] if batch_first else input.shape[1] + batch_sizes_sum = -1 + + num_directions = 2 if bidirectional else 1 + out_size = proj_size if proj_size != 0 else hidden_size + if is_input_packed: + out_shape = [batch_sizes_sum, out_size * num_directions] + else: + out_shape = ([mini_batch, seq_length, out_size * + num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions]) + output = input.new_empty(out_shape) + + cell_shape = [num_layers * num_directions, mini_batch, hidden_size] + cy = torch.empty(0) if cx is None else cx.new_empty(cell_shape) + + hy = hx.new_empty([num_layers * num_directions, mini_batch, out_size]) + + # TODO: Query cudnnGetRNNTrainingReserveSize (expose to python) + reserve_shape = 0 if train else 0 + reserve = input.new_empty(reserve_shape, dtype=torch.uint8) + + return output, hy, cy, reserve, weight_buf + + +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp +@register_meta(aten._cudnn_rnn_backward.default) +def meta_cudnn_rnn_backward(input: torch.Tensor, + weight: torch.Tensor, + weight_stride0: int, + hx: torch.Tensor, + cx: Optional[torch.Tensor] = None, + *args, + **kwargs): + print(input, weight, hx, cx) + grad_input = torch.empty_like(input) + grad_weight = torch.empty_like(weight) + grad_hx = torch.empty_like(hx) + grad_cx = torch.empty_like(cx) if cx is not None else torch.empty((), device='meta') + return grad_input, grad_weight, grad_hx, grad_cx + + +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp +# ============================== Activations ======================================= +@register_meta(aten.relu.default) +def meta_relu(input: torch.Tensor): + return torch.empty_like(input) + + +@register_meta(aten.prelu.default) +def meta_prelu(input: torch.Tensor, weight: torch.Tensor): + return torch.empty_like(input) + + +@register_meta(aten.hardswish.default) +def meta_hardswish(input: torch.Tensor): + return torch.empty_like(input) + + +@register_meta(aten.hardtanh.default) +def meta_hardtanh(input: torch.Tensor, min, max): + return torch.empty_like(input) + + +@register_meta(aten.hardswish_backward.default) +def meta_hardswish_backward(grad_out: torch.Tensor, input: torch.Tensor): + grad_in = torch.empty_like(input) + return grad_in + + +@register_meta(aten.hardtanh_backward.default) +def meta_hardtanh_backward(grad_out: torch.Tensor, input: torch.Tensor, min_val: int, max_val: int): + grad_in = torch.empty_like(input) + return grad_in + + +# ============================== Normalization ===================================== +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp +@register_meta(aten.native_batch_norm.default) +def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps): + n_input = input.size(1) + + output = torch.empty_like(input) + running_mean = torch.empty((n_input), device='meta') + running_var = torch.empty((n_input), device='meta') + return output, running_mean, running_var + + +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp +@register_meta(aten.native_batch_norm_backward.default) +def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, save_mean, + save_invstd, train, eps, output_mask): + dX = torch.empty_like(input) + dgamma = torch.empty_like(weight) + dbeta = torch.empty_like(weight) + return dX, dgamma, dbeta + + +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp +@register_meta(aten.cudnn_batch_norm.default) +def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps): + n_input = input.size(1) + + output = torch.empty_like(input) + running_mean = torch.empty((n_input), device='meta') + running_var = torch.empty((n_input), device='meta') + reserve = torch.empty((0), dtype=torch.uint8, device='meta') + return output, running_mean, running_var, reserve + + +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp +# NB: CuDNN only implements the backward algorithm for batchnorm +# in training mode (evaluation mode batchnorm has a different algorithm), +# which is why this doesn't accept a 'training' parameter. +@register_meta(aten.cudnn_batch_norm_backward.default) +def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, + save_mean, save_invstd, eps, reserve): + dX = torch.empty_like(input) + dgamma = torch.empty_like(weight) + dbeta = torch.empty_like(weight) + return dX, dgamma, dbeta + + +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp +@register_meta(aten.native_layer_norm.default) +def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps): + bs = input.size(0) + n_input = input.size(1) + + output = torch.empty_like(input) + running_mean = torch.empty((bs, n_input, 1), device='meta') + running_var = torch.empty((bs, n_input, 1), device='meta') + return output, running_mean, running_var + + +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp +@register_meta(aten.native_layer_norm_backward.default) +def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias, + grad_input_mask): + dX = torch.empty_like(input) + dgamma = torch.empty_like(weight) + dbeta = torch.empty_like(bias) + return dX, dgamma, dbeta + + +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/group_norm.cpp +@register_meta(aten.native_group_norm_backward.default) +def meta_gn_backward(dY: torch.Tensor, input: torch.Tensor, mean, rstd, gamma, N, C, HxW, group, grad_input_mask): + dX = torch.empty_like(input) + dgamma = torch.empty_like(gamma) + dbeta = torch.empty_like(gamma) + return dX, dgamma, dbeta + + +# ================================== Misc ========================================== +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml +@register_meta(aten.roll.default) +def meta_roll(input: torch.Tensor, shifts, dims): + return input + + +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Scalar.cpp +@register_meta(aten._local_scalar_dense.default) +def meta_local_scalar_dense(self: torch.Tensor): + return 0 + + +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp +@register_meta(aten.where.self) +def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor): + result_type = torch.result_type(self, other) + return torch.empty_like(condition + self + other, dtype=result_type) + + +@register_meta(aten.index.Tensor) +def meta_index_Tensor(self, indices): + assert indices, "at least one index must be provided" + # aten::index is the internal advanced indexing implementation + # checkIndexTensorTypes and expandTensors + result: List[Optional[torch.Tensor]] = [] + for i, index in enumerate(indices): + if index is not None: + assert index.dtype in [torch.long, torch.int8, torch.bool],\ + "tensors used as indices must be long, byte or bool tensors" + if index.dtype in [torch.int8, torch.bool]: + nonzero = index.nonzero() + k = len(result) + assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}" + for j in range(index.ndim): + assert index.shape[j] == self.shape[ + k + + j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}" + result.append(nonzero.select(1, j)) + else: + result.append(index) + else: + result.append(index) + indices = result + assert len(indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})" + # expand_outplace + import torch._refs as refs + + indices = list(refs._maybe_broadcast(*indices)) + # add missing null tensors + while len(indices) < self.ndim: + indices.append(None) + + # hasContiguousSubspace + # true if all non-null tensors are adjacent + # See: + # https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing + # https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency + state = 0 + has_contiguous_subspace = False + for index in indices: + if state == 0: + if index is not None: + state = 1 + elif state == 1: + if index is None: + state = 2 + else: + if index is not None: + break + else: + has_contiguous_subspace = True + + # transposeToFront + # This is the logic that causes the newly inserted dimensions to show up + # at the beginning of the tensor, if they're not contiguous + if not has_contiguous_subspace: + dims = [] + transposed_indices = [] + for i, index in enumerate(indices): + if index is not None: + dims.append(i) + transposed_indices.append(index) + for i, index in enumerate(indices): + if index is None: + dims.append(i) + transposed_indices.append(index) + self = self.permute(dims) + indices = transposed_indices + + # AdvancedIndex::AdvancedIndex + # Now we can assume the indices have contiguous subspace + # This is simplified from AdvancedIndex which goes to more effort + # to put the input and indices in a form so that TensorIterator can + # take them. If we write a ref for this, probably that logic should + # get implemented + before_shape: List[int] = [] + after_shape: List[int] = [] + replacement_shape: List[int] = [] + for dim, index in enumerate(indices): + if index is None: + if replacement_shape: + after_shape.append(self.shape[dim]) + else: + before_shape.append(self.shape[dim]) + else: + replacement_shape = list(index.shape) + return self.new_empty(before_shape + replacement_shape + after_shape) + + +# ============================== Embedding ========================================= +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp +@register_meta(aten.embedding_dense_backward.default) +def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, + scale_grad_by_freq): + return torch.empty((num_weights, grad_output.size(-1)), + dtype=grad_output.dtype, + device=grad_output.device, + layout=grad_output.layout) + + +# ============================== Dropout =========================================== +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp +@register_meta(aten.native_dropout.default) +def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False): + # notice that mask is bool + output = torch.empty_like(input) + mask = torch.empty_like(input, dtype=torch.bool) + return output, mask + + +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp +@register_meta(aten.native_dropout_backward.default) +def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float): + return torch.empty_like(grad) diff --git a/colossalai/fx/_meta_regist_13.py b/colossalai/fx/_meta_regist_13.py new file mode 100644 index 0000000000000000000000000000000000000000..6caa87c449abb7cd332384f622ff484bcc97b68f --- /dev/null +++ b/colossalai/fx/_meta_regist_13.py @@ -0,0 +1,57 @@ +import torch +from torch._meta_registrations import register_meta +from torch._prims_common import check + +aten = torch.ops.aten + + +# since we fix the torch version to 1.13.1, we have to add unimplemented meta ops +# all these functions are from here https://github.com/pytorch/pytorch/blob/master/torch/_meta_registrations.py +@register_meta([aten.convolution_backward.default]) +def meta_convolution_backward( + grad_output_, + input_, + weight_, + bias_sizes_opt, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + output_mask, +): + # High level logic taken from slow_conv3d_backward_cpu which should + # be representative of all convolution_backward impls + backend_grad_input = None + backend_grad_weight = None + backend_grad_bias = None + + if output_mask[0]: + backend_grad_input = grad_output_.new_empty(input_.size()) + if output_mask[1]: + backend_grad_weight = grad_output_.new_empty(weight_.size()) + if output_mask[2]: + backend_grad_bias = grad_output_.new_empty(bias_sizes_opt) + + return (backend_grad_input, backend_grad_weight, backend_grad_bias) + + +@register_meta(aten._adaptive_avg_pool2d_backward.default) +def meta__adaptive_avg_pool2d_backward(grad_out, self): + ndim = grad_out.ndim + for i in range(1, ndim): + check( + grad_out.size(i) > 0, + lambda: f"adaptive_avg_pool2d_backward(): Expected grad_output to have non-zero \ + size for non-batch dimensions, {grad_out.shape} with dimension {i} being empty", + ) + check( + ndim == 3 or ndim == 4, + lambda: f"adaptive_avg_pool2d_backward(): Expected 3D or 4D tensor, but got {self.shape}", + ) + check( + self.dtype == grad_out.dtype, + lambda: f"expected dtype {self.dtype} for `grad_output` but got dtype {grad_out.dtype}", + ) + return self.new_empty(self.shape) diff --git a/colossalai/fx/codegen/__init__.py b/colossalai/fx/codegen/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..794692f511537db5a8a54588f15ecafb41242ea2 --- /dev/null +++ b/colossalai/fx/codegen/__init__.py @@ -0,0 +1 @@ +from .activation_checkpoint_codegen import * diff --git a/colossalai/fx/codegen/activation_checkpoint_codegen.py b/colossalai/fx/codegen/activation_checkpoint_codegen.py new file mode 100644 index 0000000000000000000000000000000000000000..5a72cb9ca923bcdf5f9b3daddad9b5e7c339d5c5 --- /dev/null +++ b/colossalai/fx/codegen/activation_checkpoint_codegen.py @@ -0,0 +1,1058 @@ +from typing import Any, Callable, Dict, Iterable, List, Tuple + +import torch + +import colossalai + +try: + from torch.fx.graph import ( + CodeGen, + PythonCode, + _custom_builtins, + _CustomBuiltin, + _format_target, + _is_from_torch, + _Namespace, + _origin_type_map, + inplace_methods, + magic_methods, + ) + from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg + CODEGEN_AVAILABLE = True +except: + from torch.fx.graph import ( + PythonCode, + _custom_builtins, + _CustomBuiltin, + _format_args, + _format_target, + _is_from_torch, + _Namespace, + _origin_type_map, + magic_methods, + ) + from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg + CODEGEN_AVAILABLE = False + +if CODEGEN_AVAILABLE: + __all__ = ['ActivationCheckpointCodeGen'] +else: + __all__ = ['python_code_with_activation_checkpoint'] + + +def _gen_saved_tensors_hooks(): + """ + Generate saved tensors hooks + """ + + pack_hook = """def pack_hook_input(self, x): + if getattr(x, "offload", False): + return (x.device, x.cpu()) + else: + return x + +def pack_hook_no_input(self, x): + if getattr(x, "offload", True): + return (x.device, x.cpu()) + else: + return x +""" + + unpack_hook = """def unpack_hook(self, packed): + if isinstance(packed, tuple): + device, tensor = packed + return tensor.to(device) + else: + return packed +""" + + return pack_hook, unpack_hook + + +def _gen_save_tensors_hooks_context(offload_input=True) -> str: + """Generate customized saved_tensors_hooks + Args: + offload_input (bool, optional): whether we need offload input, if offload_input=False, + we will use self.pack_hook_no_input instead. Defaults to True. + Returns: + str: generated context + """ + + if offload_input: + context = "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):\n" + else: + context = "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):\n" + return context + + +def _gen_save_on_cpu_context(): + """ + Generate save on cpu context + """ + + context = "with torch.autograd.graph.save_on_cpu(pin_memory=True):\n" + return context + + +def _find_input_and_output_nodes(nodes: List[Node]): + """ + Find the input and output node names which are not found in the given list of nodes. + """ + input_nodes = [] + output_nodes = [] + + # if a node has an input node which is not in the node list + # we treat that input node as the input of the checkpoint function + for node in nodes: + for input_node in node._input_nodes.keys(): + node_repr = repr(input_node) + if input_node not in nodes and node_repr not in input_nodes: + input_nodes.append(node_repr) + + # if a node has a user node which is not in the node list + # we treat that user node as the node receiving the current node output + for node in nodes: + for output_node in node.users.keys(): + node_repr = repr(node) + if output_node not in nodes and node_repr not in output_nodes: + output_nodes.append(node_repr) + + return input_nodes, output_nodes + + +def _find_ckpt_regions(nodes: List[Node]): + """ + Find the checkpoint regions given a list of consecutive nodes. The outputs will be list + of tuples, each tuple is in the form of (start_index, end_index). + """ + ckpt_nodes = [] + ckpt_regions = [] + start = -1 + end = -1 + current_region = None + + for idx, node in enumerate(nodes): + if 'activation_checkpoint' in node.meta: + act_ckpt_label = node.meta['activation_checkpoint'] + + # this activation checkpoint label is not set yet + # meaning this is the first node of the activation ckpt region + if current_region is None: + current_region = act_ckpt_label + start = idx + + # if activation checkpoint has changed + # we restart the tracking + # e.g. node ckpt states = [ckpt1, ckpt2, ckpt2, ckpt2] + if act_ckpt_label != current_region: + assert start != -1 + ckpt_regions.append((start, idx - 1)) + current_region = act_ckpt_label + start = idx + end = -1 + elif current_region is not None and not 'activation_checkpoint' in node.meta: + # used to check the case below + # node ckpt states = [ckpt, ckpt, non-ckpt] + end = idx - 1 + assert start != -1 and end != -1 + ckpt_regions.append((start, end)) + start = end = -1 + current_region = None + else: + pass + return ckpt_regions + + +def _find_offload_regions(nodes: List[Node]): + """This function is to find the offload regions + In pofo algorithm, during annotation, we will annotate the offload region with the + list in the form of [idx, offload_input, offload_bar]. idx indicates the offload + region's index, offload_input is a bool type indicates whether we need to offload + the input, offload_bar is a bool type indicates whether we need to offload all the + intermediate x_bars of this region. + """ + offload_regions = [] + offload_labels = [] + start = -1 + end = -1 + current_region = None + + for idx, node in enumerate(nodes): + if 'activation_offload' in node.meta and isinstance(node.meta['activation_offload'], Iterable): + act_offload_label = node.meta['activation_offload'] + + if current_region == None: + current_region = act_offload_label + start = idx + offload_labels.append(act_offload_label) + + if act_offload_label != current_region: + assert start != -1 + offload_regions.append((start, idx - 1)) + offload_labels.append(act_offload_label) + current_region = act_offload_label + start = idx + end = -1 + + else: + if current_region is not None: + end = idx - 1 + assert start != -1 and end != -1 + offload_regions.append((start, end)) + start = end = -1 + current_region = None + + else: + pass + + return offload_regions, offload_labels + + +def _gen_ckpt_fn_def(label, free_vars: List[str]) -> str: + """ + Generate the checkpoint function definition + """ + return f"def checkpoint_{label}({', '.join(['self'] + free_vars)}):" + + +def _gen_ckpt_output(output_vars: List[str]) -> str: + """ + Generate the return statement for checkpoint region + """ + return f"return {', '.join(output_vars)}" + + +def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars, use_reentrant=True): + """ + Generate the checkpoint function call code text + """ + outputs = ', '.join(output_vars) + inputs = ', '.join(input_vars) + return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})' + + +def _end_of_ckpt(node: Node, check_idx: int) -> bool: + """Check if the node could end the ckpt region + Args: + node (Node): torch.fx.Node + check_idx (int): the index of checkpoint level for + nested checkpoint + Returns: + bool + """ + if 'activation_checkpoint' in node.meta: + if isinstance(node.meta['activation_checkpoint'], list): + return node.meta['activation_checkpoint'][check_idx] == None + else: + return False + else: + return True + + +def _find_nested_ckpt_regions(nodes, check_idx=0): + """ + Find the nested checkpoint regions given a list of consecutive nodes. The outputs + will be list of tuples, each tuple is in the form of (start_index, end_index). + """ + ckpt_regions = [] + start = -1 + end = -1 + current_region = None + + for idx, node in enumerate(nodes): + if 'activation_checkpoint' in node.meta: + if isinstance(node.meta['activation_checkpoint'], int): + act_ckpt_label = node.meta['activation_checkpoint'] + else: + act_ckpt_label = node.meta['activation_checkpoint'][check_idx] + + # this activation checkpoint label is not set yet + # meaning this is the first node of the activation ckpt region + if current_region is None: + current_region = act_ckpt_label + start = idx + + # if activation checkpoint has changed + # we restart the tracking + # e.g. node ckpt states = [ckpt1, ckpt2, ckpt2, ckpt2] + if act_ckpt_label != current_region: + assert start != -1 + ckpt_regions.append((start, idx - 1)) + current_region = act_ckpt_label + start = idx + end = -1 + elif current_region is not None and _end_of_ckpt(node, check_idx): + # used to check the case below + # node ckpt states = [ckpt, ckpt, non-ckpt] + end = idx - 1 + assert start != -1 and end != -1 + ckpt_regions.append((start, end)) + start = end = -1 + current_region = None + else: + pass + + if current_region is not None: + end = len(nodes) - 1 + ckpt_regions.append((start, end)) + return ckpt_regions + + +def emit_ckpt_func(body, + ckpt_func, + node_list: List[Node], + emit_node_func, + delete_unused_value_func, + level=0, + in_ckpt=False): + """Emit ckpt function in nested way + Args: + body: forward code, in recursive calls, this part will be checkpoint + functions code + ckpt_func: checkpoint functions code, in recursive calls, this part + will be a buffer + node_list (List[Node]): list of torch.fx.Node + emit_node_func: function to emit a node + delete_unused_value_func: function to delete unused value + level (int, optional): checkpoint level. Defaults to 0. + in_ckpt (bool, optional): indicates wether the func is in recursive + call. Defaults to False. + """ + inputs, outputs = _find_input_and_output_nodes(node_list) + + # if the current checkpoint function use int as label, using old generation method + if isinstance(node_list[0].meta['activation_checkpoint'], int): + label = node_list[0].meta['activation_checkpoint'] + ckpt_fn_def = _gen_ckpt_fn_def(label, inputs) + ckpt_func.append(f'{ckpt_fn_def}\n') + for node in node_list: + emit_node_func(node, ckpt_func) + ckpt_func[-1] = ' ' + ckpt_func[-1] + delete_unused_value_func(node, ckpt_func) + + ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') + activation_offload = node_list[0].meta.get('activation_offload', False) + usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + usage += "\n" + body.append(usage) + + # use nested ckpt function codegen + else: + # label given by each layer, e.g. if you are currently at level [0, 1, 1] + # the label will be '0_1_1' + label = "_".join([str(idx) for idx in node_list[0].meta['activation_checkpoint'][:level + 1]]) + ckpt_fn_def = _gen_ckpt_fn_def(label, inputs) + ckpt_func.append(f'{ckpt_fn_def}\n') + + # if there is more level to fetch + if level + 1 < len(node_list[0].meta['activation_checkpoint']): + ckpt_regions = _find_nested_ckpt_regions(node_list, level + 1) + start_idx = [item[0] for item in ckpt_regions] + end_idx = [item[1] for item in ckpt_regions] + + # use ckpt_func_buffer to store nested checkpoint functions + ckpt_func_buffer = [] + node_idx = 0 + while 1: + if node_idx >= len(node_list): + break + + if node_idx in start_idx: + ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1] + emit_ckpt_func(ckpt_func, ckpt_func_buffer, ckpt_node_list, emit_node_func, + delete_unused_value_func, level + 1, True) + node_idx += len(ckpt_node_list) + + else: + node = node_list[node_idx] + emit_node_func(node, ckpt_func) + ckpt_func[-1] = ' ' + ckpt_func[-1] + delete_unused_value_func(node, ckpt_func) + node_idx += 1 + + ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') + ckpt_func += ckpt_func_buffer + activation_offload = node_list[0].meta.get('activation_offload', False) + usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n' + if in_ckpt: + usage = ' ' + usage + body.append(usage) + + # last level + else: + for node in node_list: + emit_node_func(node, ckpt_func) + ckpt_func[-1] = ' ' + ckpt_func[-1] + delete_unused_value_func(node, ckpt_func) + + ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') + activation_offload = node_list[0].meta.get('activation_offload', False) + usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n' + if in_ckpt: + usage = ' ' + usage + body.append(usage) + + +def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func): + """Emit code with nested activation checkpoint + When we detect some of the node.activation_checkpoint is a List, we will use + this function to emit the activation checkpoint codes. + Args: + body: forward code + ckpt_func: checkpoint functions code + nodes: graph.nodes + emit_node_func: function to emit node + delete_unused_value_func: function to remove the unused value + """ + ckpt_regions = _find_nested_ckpt_regions(nodes, 0) + start_idx = [item[0] for item in ckpt_regions] + end_idx = [item[1] for item in ckpt_regions] + + # find the offload regions + offload_regions, offload_labels = _find_offload_regions(nodes) + offload_starts = [item[0] for item in offload_regions] + offload_ends = [item[1] for item in offload_regions] + offload_inputs = [] + offload_outputs = [] + within_offload_region = False + + node_list = list(nodes) + + # find the input and output var names for each offload region + for idx, (start, end) in enumerate(offload_regions): + offload_node_list = node_list[start:end + 1] + inputs, outputs = _find_input_and_output_nodes(offload_node_list) + offload_inputs.append(inputs) + offload_outputs.append(outputs) + + # this flag is to prevent repeated insert of save tensors + # hooks definition in ckpt_func + is_hook_inserted = False + node_idx = 0 + while 1: + # break if we finish the processing all the nodes + if node_idx >= len(node_list): + break + + # process ckpt_regions + if node_idx in start_idx: + ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1] + emit_ckpt_func(body, ckpt_func, ckpt_node_list, emit_node_func, delete_unused_value_func) + node_idx += len(ckpt_node_list) + + # process node in forward function + else: + node = node_list[node_idx] + + if node_idx in offload_starts: + offload_label = offload_labels[offload_starts.index(node_idx)] + _, offload_input, offload_bar = offload_label + within_offload_region = True + + # insert hook functions if needed + if not is_hook_inserted: + pack_hook, unpack_hook = _gen_saved_tensors_hooks() + ckpt_func.insert(0, "\n".join([pack_hook, unpack_hook]) + "\n") + is_hook_inserted = True + + if offload_input and offload_bar: + body.append(_gen_save_on_cpu_context()) + + elif offload_input: + for par in offload_inputs[offload_label[0]]: + body.append(f"setattr({par}, 'offload', True)\n") + body.append(_gen_save_tensors_hooks_context(offload_input=True)) + + else: + for par in offload_inputs[offload_label[0]]: + body.append(f"setattr({par}, 'offload', False)\n") + body.append(_gen_save_tensors_hooks_context(offload_input=False)) + + if within_offload_region: + emit_node_func(node, body) + body[-1] = ' ' + body[-1] + delete_unused_value_func(node, body) + + else: + emit_node_func(node, body) + delete_unused_value_func(node, body) + + if node_idx in offload_ends: + within_offload_region = False + + node_idx += 1 + + +def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func): + # find the activation checkpoint regions + ckpt_regions = _find_ckpt_regions(nodes) + start_idx = [item[0] for item in ckpt_regions] + end_idx = [item[1] for item in ckpt_regions] + input_vars = [] + output_vars = [] + within_ckpt_region = False + + # find the offload regions + offload_regions, offload_labels = _find_offload_regions(nodes) + offload_starts = [item[0] for item in offload_regions] + offload_ends = [item[1] for item in offload_regions] + offload_inputs = [] + offload_outputs = [] + within_offload_region = False + + node_list = list(nodes) + + # use this variable to avoid inserting hook functions + # to ckpt_func repeatedly + is_hook_inserted = False + + # find the input and output var names for each region + for idx, (start, end) in enumerate(ckpt_regions): + ckpt_node_list = node_list[start:end + 1] + inputs, outputs = _find_input_and_output_nodes(ckpt_node_list) + input_vars.append(inputs) + output_vars.append(outputs) + + # find the input and output var names for each offload region + for idx, (start, end) in enumerate(offload_regions): + offload_node_list = node_list[start:end + 1] + inputs, outputs = _find_input_and_output_nodes(offload_node_list) + offload_inputs.append(inputs) + offload_outputs.append(outputs) + + # append code text to body + for idx, node in enumerate(node_list): + # if this is the first node of the ckpt region + # append the ckpt function defition + if idx in start_idx: + label = start_idx.index(idx) + ckpt_fn_def = _gen_ckpt_fn_def(label, input_vars[label]) + ckpt_func.append(f'{ckpt_fn_def}\n') + within_ckpt_region = True + + if idx in offload_starts: + offload_label = offload_labels[offload_starts.index(idx)] + _, offload_input, offload_bar = offload_label + within_offload_region = True + + # insert hook functions if needed + if not is_hook_inserted: + pack_hook, unpack_hook = _gen_saved_tensors_hooks() + ckpt_func.insert(0, "\n".join([pack_hook, unpack_hook]) + "\n") + is_hook_inserted = True + + if offload_input and offload_bar: + body.append(_gen_save_on_cpu_context()) + + elif offload_input: + for par in offload_inputs[offload_label[0]]: + body.append(f"setattr({par}, 'offload', True)\n") + body.append(_gen_save_tensors_hooks_context(offload_input=True)) + + else: + for par in offload_inputs[offload_label[0]]: + body.append(f"setattr({par}, 'offload', False)\n") + body.append(_gen_save_tensors_hooks_context(offload_input=False)) + + # NOTE: emit_node does not emit a string with newline. It depends + # on delete_unused_values to append one + # NOTE: currently we separate body and ckpt_func definition + if within_ckpt_region: + emit_node_func(node, ckpt_func) + ckpt_func[-1] = ' ' + ckpt_func[-1] + delete_unused_value_func(node, ckpt_func) + + elif within_offload_region: + emit_node_func(node, body) + body[-1] = ' ' + body[-1] + delete_unused_value_func(node, body) + + else: + emit_node_func(node, body) + delete_unused_value_func(node, body) + + if idx in end_idx: + # if this is the last node of the ckpt region + # generate return statement + label = end_idx.index(idx) + return_statement = _gen_ckpt_output(output_vars[label]) + return_statement = f' {return_statement}\n\n' + ckpt_func.append(return_statement) + + # we need to check if the checkpoint need to offload the input + start_node_idx = start_idx[label] + if 'activation_offload' in node_list[start_node_idx].meta: + activation_offload = node_list[start_node_idx].meta['activation_offload'] + else: + activation_offload = False + + # we need to check if the checkpoint need use_reentrant=False + use_reentrant = True + non_leaf_input = 0 + for var in input_vars[label]: + input_node = next(item for item in node_list if item.name == var) + if input_node.op != "placeholder": + non_leaf_input = 1 + for user in input_node.users: + if 'activation_checkpoint' in user.meta: + if user.meta['activation_checkpoint'] == label: + if user.op == "call_module": + if hasattr(user.graph.owning_module.get_submodule(user.target), "inplace"): + use_reentrant = not user.graph.owning_module.get_submodule(user.target).inplace + + elif user.op == "call_function": + if "inplace" in user.kwargs: + use_reentrant = not user.kwargs["inplace"] + + # if all the inputs are leaf nodes, we need to set use_reentrant = False + if not non_leaf_input: + use_reentrant = False + + # generate checkpoint function call in a new line + usage = _gen_ckpt_usage(label, activation_offload, input_vars[label], output_vars[label], use_reentrant) + usage += '\n' + body.append(usage) + within_ckpt_region = False + + if idx in offload_ends: + within_offload_region = False + + +if CODEGEN_AVAILABLE: + + class ActivationCheckpointCodeGen(CodeGen): + + def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode: + free_vars: List[str] = [] + body: List[str] = [] + globals_: Dict[str, Any] = {} + wrapped_fns: Dict[str, None] = {} + + # Wrap string in list to pass by reference + maybe_return_annotation: List[str] = [''] + + def add_global(name_hint: str, obj: Any): + """Add an obj to be tracked as a global. + We call this for names that reference objects external to the + Graph, like functions or types. + Returns: the global name that should be used to reference 'obj' in generated source. + """ + if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device + # HACK: workaround for how torch custom ops are registered. We + # can't import them like normal modules so they must retain their + # fully qualified name. + return _get_qualified_name(obj) + + # normalize the name hint to get a proper identifier + global_name = namespace.create_name(name_hint, obj) + + if global_name in globals_: + assert globals_[global_name] is obj + return global_name + globals_[global_name] = obj + return global_name + + # set _custom_builtins here so that we needn't import colossalai in forward + _custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai) + + # Pre-fill the globals table with registered builtins. + for name, (_, obj) in _custom_builtins.items(): + add_global(name, obj) + + def type_repr(o: Any): + if o == (): + # Empty tuple is used for empty tuple type annotation Tuple[()] + return '()' + + typename = _type_repr(o) + + if hasattr(o, '__origin__'): + # This is a generic type, e.g. typing.List[torch.Tensor] + origin_type = _origin_type_map.get(o.__origin__, o.__origin__) + origin_typename = add_global(_type_repr(origin_type), origin_type) + + if hasattr(o, '__args__'): + # Assign global names for each of the inner type variables. + args = [type_repr(arg) for arg in o.__args__] + + if len(args) == 0: + # Bare type, such as `typing.Tuple` with no subscript + # This code-path used in Python < 3.9 + return origin_typename + + return f'{origin_typename}[{",".join(args)}]' + else: + # Bare type, such as `typing.Tuple` with no subscript + # This code-path used in Python 3.9+ + return origin_typename + + # Common case: this is a regular module name like 'foo.bar.baz' + return add_global(typename, o) + + def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str: + + def _get_repr(arg): + # Handle NamedTuples (if it has `_fields`) via add_global. + if isinstance(arg, tuple) and hasattr(arg, '_fields'): + qualified_name = _get_qualified_name(type(arg)) + global_name = add_global(qualified_name, type(arg)) + return f"{global_name}{repr(tuple(arg))}" + return repr(arg) + + args_s = ', '.join(_get_repr(a) for a in args) + kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items()) + if args_s and kwargs_s: + return f'{args_s}, {kwargs_s}' + return args_s or kwargs_s + + # Run through reverse nodes and record the first instance of a use + # of a given node. This represents the *last* use of the node in the + # execution order of the program, which we will use to free unused + # values + node_to_last_use: Dict[Node, Node] = {} + user_to_last_uses: Dict[Node, List[Node]] = {} + + def register_last_uses(n: Node, user: Node): + if n not in node_to_last_use: + node_to_last_use[n] = user + user_to_last_uses.setdefault(user, []).append(n) + + for node in reversed(nodes): + map_arg(node.args, lambda n: register_last_uses(n, node)) + map_arg(node.kwargs, lambda n: register_last_uses(n, node)) + + # NOTE: we add a variable to distinguish body and ckpt_func + def delete_unused_values(user: Node, body): + """ + Delete values after their last use. This ensures that values that are + not used in the remainder of the code are freed and the memory usage + of the code is optimal. + """ + if user.op == 'placeholder': + return + if user.op == 'output': + body.append('\n') + return + nodes_to_delete = user_to_last_uses.get(user, []) + if len(nodes_to_delete): + to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None']) + body.append(f'; {to_delete_str}\n') + else: + body.append('\n') + + # NOTE: we add a variable to distinguish body and ckpt_func + def emit_node(node: Node, body): + maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}' + if node.op == 'placeholder': + assert isinstance(node.target, str) + maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}' + free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}') + raw_name = node.target.replace('*', '') + if raw_name != repr(node): + body.append(f'{repr(node)} = {raw_name}\n') + return + elif node.op == 'call_method': + assert isinstance(node.target, str) + body.append( + f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}' + f'({_format_args(node.args[1:], node.kwargs)})') + return + elif node.op == 'call_function': + assert callable(node.target) + # pretty print operators + if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods: + assert isinstance(node.args, tuple) + body.append(f'{repr(node)}{maybe_type_annotation} = ' + f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}') + return + + # pretty print inplace operators; required for jit.script to work properly + # not currently supported in normal FX graphs, but generated by torchdynamo + if node.target.__module__ == '_operator' and node.target.__name__ in inplace_methods: + body.append(f'{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; ' + f'{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}') + return + + qualified_name = _get_qualified_name(node.target) + global_name = add_global(qualified_name, node.target) + # special case for getattr: node.args could be 2-argument or 3-argument + # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value + if global_name == 'getattr' and \ + isinstance(node.args, tuple) and \ + isinstance(node.args[1], str) and \ + node.args[1].isidentifier() and \ + len(node.args) == 2: + body.append( + f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}') + return + body.append( + f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})') + if node.meta.get('is_wrapped', False): + wrapped_fns.setdefault(global_name) + return + elif node.op == 'call_module': + assert isinstance(node.target, str) + body.append(f'{repr(node)}{maybe_type_annotation} = ' + f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})') + return + elif node.op == 'get_attr': + assert isinstance(node.target, str) + body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}') + return + elif node.op == 'output': + if node.type is not None: + maybe_return_annotation[0] = f" -> {type_repr(node.type)}" + body.append(self.generate_output(node.args[0])) + return + raise NotImplementedError(f'node: {node.op} {node.target}') + + # Modified for activation checkpointing + ckpt_func = [] + + # if any node has a list of labels for activation_checkpoint, we + # will use nested type of activation checkpoint codegen + if any(isinstance(node.meta.get('activation_checkpoint', None), Iterable) for node in nodes): + emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values) + else: + emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values) + + if len(body) == 0: + # If the Graph has no non-placeholder nodes, no lines for the body + # have been emitted. To continue to have valid Python code, emit a + # single pass statement + body.append('pass\n') + + if len(wrapped_fns) > 0: + wrap_name = add_global('wrap', torch.fx.wrap) + wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns]) + else: + wrap_stmts = '' + + if self._body_transformer: + body = self._body_transformer(body) + + for name, value in self.additional_globals(): + add_global(name, value) + + # as we need colossalai.utils.checkpoint, we need to import colossalai + # in forward function + prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) + prologue = ''.join(ckpt_func) + prologue + prologue = prologue + + code = ''.join(body) + code = '\n'.join(' ' + line for line in code.split('\n')) + fn_code = f""" +{wrap_stmts} +{prologue} +{code}""" + return PythonCode(fn_code, globals_) + +else: + + def python_code_with_activation_checkpoint(self, root_module: str, namespace: _Namespace) -> PythonCode: + """ + This method is copied from the _python_code of torch.fx.graph.Graph. Modifications are made so that it can generate + code for activation checkpoint. + """ + free_vars: List[str] = [] + body: List[str] = [] + globals_: Dict[str, Any] = {} + wrapped_fns: Dict[str, None] = {} + + # Wrap string in list to pass by reference + maybe_return_annotation: List[str] = [''] + + def add_global(name_hint: str, obj: Any): + """Add an obj to be tracked as a global. + We call this for names that reference objects external to the + Graph, like functions or types. + Returns: the global name that should be used to reference 'obj' in generated source. + """ + if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device + # HACK: workaround for how torch custom ops are registered. We + # can't import them like normal modules so they must retain their + # fully qualified name. + return _get_qualified_name(obj) + + # normalize the name hint to get a proper identifier + global_name = namespace.create_name(name_hint, obj) + + if global_name in globals_: + assert globals_[global_name] is obj + return global_name + globals_[global_name] = obj + return global_name + + # set _custom_builtins here so that we needn't import colossalai in forward + _custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai) + + # Pre-fill the globals table with registered builtins. + for name, (_, obj) in _custom_builtins.items(): + add_global(name, obj) + + def type_repr(o: Any): + if o == (): + # Empty tuple is used for empty tuple type annotation Tuple[()] + return '()' + + typename = _type_repr(o) + + # This is a generic type, e.g. typing.List[torch.Tensor] + if hasattr(o, '__origin__'): + origin_type = _origin_type_map.get(o.__origin__, o.__origin__) + origin_typename = add_global(_type_repr(origin_type), origin_type) + + # Assign global names for each of the inner type variables. + args = [type_repr(arg) for arg in o.__args__] + + return f'{origin_typename}[{",".join(args)}]' + + # Common case: this is a regular module name like 'foo.bar.baz' + return add_global(typename, o) + + # Run through reverse nodes and record the first instance of a use + # of a given node. This represents the *last* use of the node in the + # execution order of the program, which we will use to free unused + # values + node_to_last_use: Dict[Node, Node] = {} + user_to_last_uses: Dict[Node, List[Node]] = {} + + def register_last_uses(n: Node, user: Node): + if n not in node_to_last_use: + node_to_last_use[n] = user + user_to_last_uses.setdefault(user, []).append(n) + + for node in reversed(self.nodes): + map_arg(node.args, lambda n: register_last_uses(n, node)) + map_arg(node.kwargs, lambda n: register_last_uses(n, node)) + + # NOTE: we add a variable to distinguish body and ckpt_func + def delete_unused_values(user: Node, body): + """ + Delete values after their last use. This ensures that values that are + not used in the remainder of the code are freed and the memory usage + of the code is optimal. + """ + if user.op == 'placeholder': + return + if user.op == 'output': + body.append('\n') + return + nodes_to_delete = user_to_last_uses.get(user, []) + if len(nodes_to_delete): + to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None']) + body.append(f'; {to_delete_str}\n') + else: + body.append('\n') + + # NOTE: we add a variable to distinguish body and ckpt_func + def emit_node(node: Node, body): + maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}' + if node.op == 'placeholder': + assert isinstance(node.target, str) + maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}' + free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}') + raw_name = node.target.replace('*', '') + if raw_name != repr(node): + body.append(f'{repr(node)} = {raw_name}\n') + return + elif node.op == 'call_method': + assert isinstance(node.target, str) + body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}' + f'({_format_args(node.args[1:], node.kwargs)})') + return + elif node.op == 'call_function': + assert callable(node.target) + # pretty print operators + if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods: + assert isinstance(node.args, tuple) + body.append(f'{repr(node)}{maybe_type_annotation} = ' + f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}') + return + qualified_name = _get_qualified_name(node.target) + global_name = add_global(qualified_name, node.target) + # special case for getattr: node.args could be 2-argument or 3-argument + # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value + if global_name == 'getattr' and \ + isinstance(node.args, tuple) and \ + isinstance(node.args[1], str) and \ + node.args[1].isidentifier() and \ + len(node.args) == 2: + body.append( + f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}') + return + body.append( + f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})') + if node.meta.get('is_wrapped', False): + wrapped_fns.setdefault(global_name) + return + elif node.op == 'call_module': + assert isinstance(node.target, str) + body.append(f'{repr(node)}{maybe_type_annotation} = ' + f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})') + return + elif node.op == 'get_attr': + assert isinstance(node.target, str) + body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}') + return + elif node.op == 'output': + if node.type is not None: + maybe_return_annotation[0] = f" -> {type_repr(node.type)}" + if self._pytree_info is None: + body.append(f'return {repr(node.args[0])}') + else: + body.append(f'return pytree.tree_unflatten({repr(node.args[0])}, self._out_spec)') + return + raise NotImplementedError(f'node: {node.op} {node.target}') + + # Modified for activation checkpointing + ckpt_func = [] + + # if any node has a list of labels for activation_checkpoint, we + # will use nested type of activation checkpoint codegen + if any(isinstance(node.meta.get('activation_checkpoint', None), Iterable) for node in self.nodes): + emit_code_with_nested_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values) + else: + emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values) + + if len(body) == 0: + # If the Graph has no non-placeholder nodes, no lines for the body + # have been emitted. To continue to have valid Python code, emit a + # single pass statement + body.append('pass\n') + if self._pytree_info is not None: + orig_args = self._pytree_info.orig_args + has_orig_self = (orig_args[0] == 'self') + if has_orig_self: + free_vars.insert(0, 'self') + if len(free_vars) > 0: # pytree has placeholders in it + body.insert( + 0, + f"{', '.join(free_vars)}, = fx_pytree.tree_flatten_spec([{', '.join(orig_args)}], self._in_spec)\n") + else: + orig_args = free_vars + + if len(wrapped_fns) > 0: + wrap_name = add_global('wrap', torch.fx.wrap) + wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns]) + else: + wrap_stmts = '' + + ckpt_func = ''.join(ckpt_func) + + # If the original function didn't have self as its first argument, we + # would have added it. + if len(orig_args) == 0 or orig_args[0] != 'self': + orig_args.insert(0, 'self') + code = ''.join(body) + code = '\n'.join(' ' + line for line in code.split('\n')) + + # as we need colossalai.utils.checkpoint, we need to import colossalai + # in forward function + fn_code = f""" +{wrap_stmts} +{ckpt_func} +def forward({', '.join(orig_args)}){maybe_return_annotation[0]}: +{code}""" + return PythonCode(fn_code, globals_) diff --git a/colossalai/fx/graph_module.py b/colossalai/fx/graph_module.py new file mode 100644 index 0000000000000000000000000000000000000000..ebb9975f27dbf312870fdb9d7beea98faf7889b7 --- /dev/null +++ b/colossalai/fx/graph_module.py @@ -0,0 +1,174 @@ +import os +import warnings +from pathlib import Path +from typing import Any, Dict, List, Optional, Set, Type, Union + +import torch +import torch.nn as nn +from torch.nn.modules.module import _addindent + +try: + from torch.fx.graph import Graph, PythonCode, _custom_builtins, _is_from_torch, _PyTreeCodeGen + from torch.fx.graph_module import GraphModule, _EvalCacheLoader, _exec_with_source, _forward_from_src, _WrappedCall + + from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen + COLOGM = True +except: + from torch.fx.graph import Graph + from torch.fx.graph_module import GraphModule + COLOGM = False + +if COLOGM: + + class ColoGraphModule(GraphModule): + + def __init__(self, + root: Union[torch.nn.Module, Dict[str, Any]], + graph: Graph, + class_name: str = 'GraphModule', + ckpt_codegen: bool = True): + if ckpt_codegen: + graph.set_codegen(ActivationCheckpointCodeGen()) + super().__init__(root, graph, class_name) + + def bind(self, ckpt_def, globals): + """Bind function needed for correctly execute gm forward + + We need to bind checkpoint functions and saved_tensor_hooks functions + to gm so that we could correctly execute gm forward + + Args: + ckpt_def (_type_): definition before the forward function + globals (_type_): global variables + """ + + ckpt_code = "\n".join(ckpt_def) + globals_copy = globals.copy() + _exec_with_source(ckpt_code, globals_copy) + func_list = [func for func in globals_copy.keys() if "checkpoint" in func or "pack" in func] + for func in func_list: + tmp_func = globals_copy[func] + setattr(self, func, tmp_func.__get__(self, self.__class__)) + del globals_copy[func] + + def recompile(self) -> PythonCode: + """ + Recompile this GraphModule from its ``graph`` attribute. This should be + called after editing the contained ``graph``, otherwise the generated + code of this ``GraphModule`` will be out of date. + """ + if isinstance(self._graph._codegen, _PyTreeCodeGen): + self._in_spec = self._graph._codegen.pytree_info.in_spec + self._out_spec = self._graph._codegen.pytree_info.out_spec + python_code = self._graph.python_code(root_module='self') + self._code = python_code.src + + # To split ckpt functions code and forward code + _code_list = self._code.split("\n") + _fwd_def = [item for item in _code_list if "def forward" in item][0] + _fwd_idx = _code_list.index(_fwd_def) + ckpt_def = _code_list[:_fwd_idx] + self._code = "\n".join(_code_list[_fwd_idx:]) + + self.bind(ckpt_def, python_code.globals) + + cls = type(self) + cls.forward = _forward_from_src(self._code, python_code.globals) + + # Determine whether this class explicitly defines a __call__ implementation + # to wrap. If it does, save it in order to have wrapped_call invoke it. + # If it does not, wrapped_call can use a dynamic call to super() instead. + # In most cases, super().__call__ should be torch.nn.Module.__call__. + # We do not want to hold a reference to Module.__call__ here; doing so will + # bypass patching of torch.nn.Module.__call__ done while symbolic tracing. + cls_call = cls.__call__ if "__call__" in vars(cls) else None + + if '_wrapped_call' not in vars(cls): + cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined] + + def call_wrapped(self, *args, **kwargs): + return self._wrapped_call(self, *args, **kwargs) + + cls.__call__ = call_wrapped + + # reset self._code to original src, otherwise to_folder will be wrong + self._code = python_code.src + return python_code + + def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModule"): + """Dumps out module to ``folder`` with ``module_name`` so that it can be + imported with ``from import `` + + Args: + + folder (Union[str, os.PathLike]): The folder to write the code out to + + module_name (str): Top-level name to use for the ``Module`` while + writing out the code + """ + folder = Path(folder) + Path(folder).mkdir(exist_ok=True) + torch.save(self.state_dict(), folder / 'state_dict.pt') + tab = " " * 4 + + # we add import colossalai here + model_str = f""" +import torch +from torch.nn import * +import colossalai + + +class {module_name}(torch.nn.Module): + def __init__(self): + super().__init__() +""" + + def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]: + safe_reprs = [ + nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d + ] + if type(module) in safe_reprs: + return f"{module.__repr__()}" + else: + return None + + blobified_modules = [] + for module_name, module in self.named_children(): + module_str = _gen_model_repr(module_name, module) + if module_str is None: + module_file = folder / f'{module_name}.pt' + torch.save(module, module_file) + blobified_modules.append(module_name) + module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ') + module_str = f"torch.load(r'{module_file}') # {module_repr}" + model_str += f"{tab*2}self.{module_name} = {module_str}\n" + + for buffer_name, buffer in self._buffers.items(): + if buffer is None: + continue + model_str += f"{tab*2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n" + + for param_name, param in self._parameters.items(): + if param is None: + continue + model_str += f"{tab*2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n" + + model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n" + model_str += f"{_addindent(self.code, 4)}\n" + + module_file = folder / 'module.py' + module_file.write_text(model_str) + + init_file = folder / '__init__.py' + init_file.write_text('from .module import *') + + if len(blobified_modules) > 0: + warnings.warn("Was not able to save the following children modules as reprs -" + f"saved as pickled files instead: {blobified_modules}") + +else: + + class ColoGraphModule(GraphModule): + + def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'): + super().__init__(root, graph, class_name) diff --git a/colossalai/fx/passes/__init__.py b/colossalai/fx/passes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6f948cb2d3b37ce73e03044a944732a7d8370d6b --- /dev/null +++ b/colossalai/fx/passes/__init__.py @@ -0,0 +1,4 @@ +from .adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass +from .concrete_info_prop import ConcreteInfoProp +from .meta_info_prop import MetaInfoProp, metainfo_trace +from .shard_1d_pass import column_shard_linear_pass, row_shard_linear_pass diff --git a/colossalai/fx/passes/adding_split_node_pass.py b/colossalai/fx/passes/adding_split_node_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..2c7b842b530cc12fb669154456dc51e489ccc85d --- /dev/null +++ b/colossalai/fx/passes/adding_split_node_pass.py @@ -0,0 +1,363 @@ +import numpy as np +import torch +import tqdm +from torch.fx import symbolic_trace +from torch.fx.node import Node + +from colossalai.fx.passes.split_module import split_module + + +def pipe_split(): + pass + + +def block_split(): + pass + + +# Construct blocks with the condition that (block_flops / total_flops) >= limit. +def construct_blocks(gm: torch.fx.GraphModule, limit=0.01): + total_fwd_flop = 0 + total_bwd_flop = 0 + for node in gm.graph.nodes: + total_fwd_flop += node.fwd_flop + total_bwd_flop += node.bwd_flop + + total_flop = total_fwd_flop + total_bwd_flop + per_block_flop = total_flop * limit + accumulate_fwd_flop = 0 + accumulate_bwd_flop = 0 + block_nodes = [] + for node in gm.graph.nodes: + if 'block_split' in node.name: + continue + accumulate_fwd_flop += node.fwd_flop + accumulate_bwd_flop += node.bwd_flop + if accumulate_fwd_flop + accumulate_bwd_flop >= per_block_flop: + with gm.graph.inserting_after(node): + block_node = gm.graph.create_node('call_function', block_split) + setattr(block_node, 'fwd_flop', accumulate_fwd_flop) + setattr(block_node, 'bwd_flop', accumulate_bwd_flop) + accumulate_fwd_flop = 0 + accumulate_bwd_flop = 0 + block_nodes.append(block_node) + + return block_nodes + + +def remove_blocks(gm: torch.fx.GraphModule): + for node in gm.graph.nodes: + if (node.op, node.target) == ('call_function', block_split): + gm.graph.erase_node(node) + + +def get_compute_costs(node_list): + num_nodes = len(node_list) + all_compute_cost = np.full((num_nodes, num_nodes), np.inf, dtype=np.float64) + + for start in tqdm.tqdm(range(num_nodes), desc='start pos', position=0): + for end in tqdm.tqdm(range(start, num_nodes), desc='end pos', position=1, leave=False): + selected_flops = [(node_list[i].fwd_flop + node_list[i].bwd_flop) for i in range(start, end + 1)] + all_compute_cost[start, end] = sum(selected_flops) + + return all_compute_cost + + +def do_dp_split_gpipe_impl(num_nodes, num_stages, num_microbatches, compute_costs, max_compute_cost): + """The core implementation of the DP algorithm.""" + # Adapted from Alpa DP Formulation. + # For f, node ID start from 0 + # f[number of stages, + # node id that is currently being considered] + + # record time cost(assess by fwd+bwd flop now) + f = np.full((num_stages + 1, num_nodes + 1), np.inf, dtype=np.float32) + + # record max stage compute cost among all stages in this partition. + f_stage_max = np.full((num_stages + 1, num_nodes + 1), 0.0, dtype=np.float32) + # record start node index for next stage in this partition + f_argmin = np.full((num_stages + 1, num_nodes + 1), -1, dtype=np.int32) + f[0, num_nodes] = 0 + for s in tqdm.tqdm(range(1, num_stages + 1), desc='stage', position=2, leave=False): # pylint: disable=too-many-nested-blocks + for i in tqdm.tqdm(range(num_nodes - 1, -1, -1), desc='start node', position=3, leave=False): + for k in tqdm.tqdm(range(num_nodes, i, -1), desc='mid node', position=4, leave=False): + stage_cost = compute_costs[i, k - 1] + new_cost = f[s - 1, k] + stage_cost + if (stage_cost <= max_compute_cost and new_cost < f[s, i]): + f[s, i] = new_cost + f_stage_max[s, i] = max(f_stage_max[s - 1, k], stage_cost) + f_argmin[s, i] = k + + best_total_cost = f[num_stages, 0] + if np.isinf(best_total_cost): + return np.inf, None + + total_cost = f[num_stages, 0] + (num_microbatches - 1) * f_stage_max[num_stages, 0] + + current_s = num_stages + current_node = 0 + + res = [] + while current_s > 0 and current_node < num_nodes: + next_start_node = f_argmin[current_s, current_node] + res.append((current_node, next_start_node)) + current_s -= 1 + current_node = next_start_node + + return total_cost, res + + +def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatches: int): + # Ignore the memory cost profiling in Alpa's design for convenience. + max_compute_costs = np.sort(np.unique(compute_costs)) + best_cost = np.inf + best_solution = None + last_max_compute_cost = 0.0 + gap = 1e6 # temporary magic number, unit: flops + + for max_compute_cost in tqdm.tqdm(max_compute_costs): + # Pruning to reduce search space. + if max_compute_cost * num_microbatches >= best_cost: + break + if max_compute_cost - last_max_compute_cost < gap: + continue + + cost, solution = do_dp_split_gpipe_impl(len(node_list), num_stages, num_microbatches, compute_costs, + max_compute_cost) + + if cost < best_cost: + best_cost = cost + best_solution = solution + last_max_compute_cost = max_compute_cost + return best_cost, best_solution + + +# Auto DP partition based on Alpa. +# Adapted to Gpipe Scheduler +# split_mode: +# 'node': fx_node +# 'block': many fx_nodes construct a block +def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches: int, mode='block', block_limit=0.01): + assert mode in ['node', 'block'] + + # nodes or blocks will be used in partition. + node_list = [] + if mode == 'node': + for node in gm.graph.nodes: + node_list.append(node) + elif mode == 'block': + node_list = construct_blocks(gm, limit=block_limit) + else: + pass + + compute_costs = get_compute_costs(node_list) + + best_cost, best_solution = do_dp_split_gpipe(node_list, compute_costs, pp_size, num_microbatches) + + for (_, next_start_node) in best_solution: + if pp_size <= 1: + break + node = node_list[next_start_node] + with gm.graph.inserting_before(node): + split_node = gm.graph.create_node('call_function', pipe_split) + pp_size -= 1 + + # remove block node if possible + if mode == 'block': + remove_blocks(gm) + + gm.recompile() + return gm + + +def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int): + """ + In avgcompute_split_pass, we split module by the fwd flops. + """ + mod_graph = gm.graph + # To use avgcompute_split_pass, we need run meta_info_prop interpreter first. + # If nodes don't have meta info, this pass will fall back to normal balanced split pass. + check_node = list(mod_graph.nodes)[0] + if 'tensor_meta' not in check_node.meta: + return balanced_split_pass(gm, pp_size) + + total_fwd_flop = 0 + for node in mod_graph.nodes: + total_fwd_flop += node.fwd_flop + + partition_flop = total_fwd_flop // pp_size + accumulate_fwd_flop = 0 + for node in mod_graph.nodes: + if pp_size <= 1: + break + if 'pipe_split' in node.name: + continue + accumulate_fwd_flop += node.fwd_flop + if accumulate_fwd_flop >= partition_flop: + total_fwd_flop = total_fwd_flop - accumulate_fwd_flop + accumulate_fwd_flop = 0 + pp_size -= 1 + partition_flop = total_fwd_flop // pp_size + with mod_graph.inserting_after(node): + split_node = mod_graph.create_node('call_function', pipe_split) + gm.recompile() + return gm + + +def avgnode_split_pass(gm: torch.fx.GraphModule, pp_size: int): + """ + In avgnode_split_pass, simpliy split graph by node number. + """ + mod_graph = gm.graph + avg_num_node = len(mod_graph.nodes) // pp_size + accumulate_num_node = 0 + for node in mod_graph.nodes: + if pp_size <= 1: + break + accumulate_num_node += 1 + if accumulate_num_node >= avg_num_node: + accumulate_num_node = 0 + pp_size -= 1 + if node.next.op == 'output': + with mod_graph.inserting_before(node): + split_node = mod_graph.create_node('call_function', pipe_split) + else: + with mod_graph.inserting_after(node): + split_node = mod_graph.create_node('call_function', pipe_split) + gm.recompile() + return gm + + +def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int): + """ + In balanced_split_pass, we split module by the size of parameters(weights+bias). + """ + mod_graph = gm.graph + total_param_amount = 0 + for param in mod_graph.owning_module.parameters(): + total_param_amount += param.numel() + params_per_partition = total_param_amount // pp_size + accumulate_param_amount = 0 + for node in mod_graph.nodes: + if pp_size <= 1: + break + if node.op == "call_module": + target_module = node.graph.owning_module.get_submodule(node.target) + for param in target_module.parameters(): + accumulate_param_amount += param.numel() + if accumulate_param_amount >= params_per_partition: + accumulate_param_amount = 0 + pp_size -= 1 + # If the next node is output node, we will insert split annotation before + # node to make sure there is at least one node in last partition. + if node.next.op == 'output': + with mod_graph.inserting_before(node): + split_node = mod_graph.create_node('call_function', pipe_split) + else: + with mod_graph.inserting_after(node): + split_node = mod_graph.create_node('call_function', pipe_split) + if pp_size > 1: + node_counter = 0 + for node in mod_graph.nodes: + if pp_size <= 1: + break + if node.op == 'placeholder': + continue + elif node_counter == 0: + node_counter += 1 + else: + pp_size -= 1 + node_counter = 0 + with mod_graph.inserting_before(node): + split_node = mod_graph.create_node('call_function', pipe_split) + + gm.recompile() + return gm + + +def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int): + """ + In balanced_split_pass_v12, we split module by the size of nodes(weights+bias+outputs). + """ + mod_graph = gm.graph + # To use balanced_split_pass_v2, we need run meta_info_prop interpreter first. + # If nodes don't have meta info, this pass will fall back to normal balanced split pass. + check_node = list(mod_graph.nodes)[0] + if 'tensor_meta' not in check_node.meta: + return balanced_split_pass(gm, pp_size) + + total_element_size = 0 + for node in mod_graph.nodes: + total_element_size += node.node_size + + partition_size = total_element_size // pp_size + accumulate_node_size = 0 + for node in mod_graph.nodes: + if pp_size <= 1: + break + if 'pipe_split' in node.name: + continue + accumulate_node_size += node.node_size + if accumulate_node_size >= partition_size: + total_element_size = total_element_size - accumulate_node_size + accumulate_node_size = 0 + pp_size -= 1 + partition_size = total_element_size // pp_size + with mod_graph.inserting_after(node): + split_node = mod_graph.create_node('call_function', pipe_split) + gm.recompile() + return gm + + +def uniform_split_pass(gm: torch.fx.GraphModule, pp_size: int): + mod_graph = gm.graph + valid_children_size = 0 + valid_children = [] + for module in mod_graph.owning_module.children(): + valid_children_size += 1 + valid_children.append(module) + + if valid_children_size < pp_size: + # If valid children is not enough to shard, we will use balanced policy instead of uniform policy. + return balanced_split_pass(gm, pp_size) + layers_per_partition = valid_children_size // pp_size + accumulate_layer_amount = 0 + for node in mod_graph.nodes: + if pp_size <= 1: + break + if node.op == "call_module": + target_module = node.graph.owning_module.get_submodule(node.target) + if target_module in valid_children: + accumulate_layer_amount += 1 + if accumulate_layer_amount == layers_per_partition: + accumulate_layer_amount = 0 + pp_size -= 1 + with mod_graph.inserting_after(node): + split_node = mod_graph.create_node('call_function', pipe_split) + gm.recompile() + return gm + + +def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule, merge_output=False): + # TODO(lyl): use partition IR to assign partition ID to each node. + # Currently: analyzing graph -> annotate graph by inserting split node -> use split module pass to split graph + # In future: graph to partitions -> analyzing partition IR -> recombining partitions to get best performance -> assign partition ID to each node + part_idx = 0 + + def split_callback(n: torch.fx.Node): + nonlocal part_idx + if (n.op, n.target) == ('call_function', pipe_split): + part_idx += 1 + return part_idx + + split_mod = split_module(annotated_gm, None, split_callback, merge_output) + split_submodules = [] + for name, submodule in split_mod.named_modules(): + if isinstance(submodule, torch.fx.GraphModule): + for node in submodule.graph.nodes: + if (node.op, node.target) == ('call_function', pipe_split): + submodule.graph.erase_node(node) + submodule.recompile() + split_submodules.append(submodule) + + return split_mod, split_submodules diff --git a/colossalai/fx/passes/concrete_info_prop.py b/colossalai/fx/passes/concrete_info_prop.py new file mode 100644 index 0000000000000000000000000000000000000000..81ac6420552815a5cea2d3fe5aef175efb36ece0 --- /dev/null +++ b/colossalai/fx/passes/concrete_info_prop.py @@ -0,0 +1,290 @@ +from dataclasses import asdict +from typing import Any, Dict, List, NamedTuple, Optional, Tuple + +import torch +import torch.fx +from torch.fx.node import Argument, Node, Target +from torch.utils._pytree import tree_flatten + +from colossalai.fx._compatibility import compatibility +from colossalai.fx.profiler import GraphInfo, profile_function, profile_method, profile_module + + +@compatibility(is_backward_compatible=True) +class ConcreteInfoProp(torch.fx.Interpreter): + """ + Execute an FX graph Node-by-Node with concrete tensor and record the memory + usage, execution time of forward and backward, and type of the result into + the corresponding node. + + Usage: + BATCH_SIZE = 2 + DIM_IN = 4 + DIM_HIDDEN = 16 + DIM_OUT = 16 + model = torch.nn.Sequential( + torch.nn.Linear(DIM_IN, DIM_HIDDEN), + torch.nn.Linear(DIM_HIDDEN, DIM_OUT), + ).cuda() + input_sample = torch.rand(BATCH_SIZE, DIM_IN, device="cuda") + gm = symbolic_trace(model) + interp = ConcreteInfoProp(gm) + interp.run(input_sample) + print(interp.summary(unit='kb')) + + + output of above code is + Op type Op Forward time Backward time SAVE_FWD_IN FWD_OUT FWD_TMP BWD_OUT BWD_TMP + ----------- ------- ----------------------- ------------------------ ------------- --------- --------- --------- --------- + placeholder input_1 0.0 s 0.0 s False 0.00 KB 0.00 KB 0.00 KB 0.00 KB + call_module _0 0.0003993511199951172 s 0.00706791877746582 s False 0.50 KB 0.00 KB 0.03 KB 0.66 KB + call_module _1 6.29425048828125e-05 s 0.00018286705017089844 s False 0.50 KB 0.00 KB 0.12 KB 0.81 KB + output output 0.0 s 0.0 s True 0.00 KB 0.00 KB 0.00 KB 0.00 KB + Args: + module (GraphModule): The module to be executed + + """ + + _is_proped: bool = False + + def run(self, *args, initial_env: Optional[Dict[Node, Any]] = None, enable_io_processing: bool = True) -> Any: + """Customized run for ConcreteInfoProp + We need to store the device in self.device + + Args: + *args: The arguments to the Module to run, in positional order + initial_env (Optional[Dict[Node, Any]]): An optional starting environment for execution. + This is a dict mapping `Node` to any value. This can be used, for example, to + pre-populate results for certain `Nodes` so as to do only partial evaluation within + the interpreter. + enable_io_processing (bool): If true, we process the inputs and outputs with graph's process_inputs and + process_outputs function first before using them. + + Returns: + Any: The value returned from executing the Module + """ + + flatten_args, _ = tree_flatten(args) + self.device = next(item for item in flatten_args if hasattr(item, "device")).device + return super().run(*args, initial_env, enable_io_processing) + + @compatibility(is_backward_compatible=True) + def run_node(self, n: Node) -> Any: + """ + Run a specific node ``n`` and return the result. + Calls into placeholder, get_attr, call_function, + call_method, call_module, or output depending + on ``node.op`` + + Args: + n (Node): The Node to execute + + Returns: + Any: The result of executing ``n`` + """ + self._is_proped = True + result, meta_info = super().run_node(n) + + n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta` + # TODO: the attribute node_size should be removed in the future + setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0)) + n.meta['type'] = type(result) + + # retain the autograd graph + for param in self.module.parameters(): + param.grad = None + + return result + + # Main Node running APIs + @compatibility(is_backward_compatible=True) + def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + """ + Execute a ``placeholder`` node. Note that this is stateful: + ``Interpreter`` maintains an internal iterator over + arguments passed to ``run`` and this method returns + next() on that iterator. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Returns: + result (Any): The argument value that was retrieved + meta_info (MetaInfo): The memory cost and forward & backward time. + """ + return super().placeholder(target, args, kwargs), GraphInfo() + + @compatibility(is_backward_compatible=True) + def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + """ + Execute a ``get_attr`` node. Will retrieve an attribute + value from the ``Module`` hierarchy of ``self.module``. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return: + result (Any): The argument value that was retrieved + meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. + """ + return super().get_attr(target, args, kwargs), GraphInfo() + + @compatibility(is_backward_compatible=True) + def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + """ + Execute a ``call_function`` node with meta tensor and return the result and its meta profile. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return + result (Any): The argument value that was retrieved + meta_info (MetaInfo): The memory cost and forward & backward time. + """ + assert not isinstance(target, str) + return profile_function(target, self.device)(*args, **kwargs) + + @compatibility(is_backward_compatible=True) + def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + """ + Execute a ``call_method`` node with meta tensor and return the result and its meta profile. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return + result (Any): The argument value that was retrieved + meta_info (MetaInfo): The memory cost and forward & backward time. + """ + return profile_method(target, self.device)(*args, **kwargs) + + @compatibility(is_backward_compatible=True) + def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + """ + Execute a ``call_module`` node with meta tensor and return the result and its meta profile. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return + result (Any): The argument value that was retrieved + meta_info (MetaInfo): The memory cost and forward & backward time. + """ + # Retrieve executed args and kwargs values from the environment + # Execute the method and return the result + assert isinstance(target, str) + submod = self.fetch_attr(target) + return profile_module(submod, self.device)(*args, **kwargs) + + @compatibility(is_backward_compatible=True) + def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + """ + Execute an ``output`` node. This really just retrieves + the value referenced by the ``output`` node and returns it. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return: + result (Any): The argument value that was retrieved + meta_info (MetaInfo): The memory cost and forward & backward time. + """ + return args[0], GraphInfo(save_fwd_in=True) + + def propagate(self, *args): + """ + Run `module` via interpretation and return the result and + record the shape and type of each node. + + Args: + *args (Tensor): the sample input. + + Returns: + Any: The value returned from executing the Module + """ + return self.run(*args) + + def summary(self, unit: str = 'MB') -> str: + """ + Summarizes the memory and FLOPs statistics of the `GraphModule` in + tabular format. Note that this API requires the ``tabulate`` module + to be installed. + """ + # https://github.com/pytorch/pytorch/blob/master/torch/fx/graph.py + try: + from tabulate import tabulate + except ImportError: + print("`summary` relies on the library `tabulate`, " + "which could not be found on this machine. Run `pip " + "install tabulate` to install the library.") + + assert self._is_proped, "Please call `interp.run(input)` before calling `interp.summary()`." + + # Build up a list of summary information for each node + node_summaries: List[List[Any]] = [] + + def mem_repr(mem: int) -> str: + unit_divisor_map = { + 'kb': 1024, + 'mb': 1024**2, + 'gb': 1024**3, + 'tb': 1024**4, + } + return f"{mem / unit_divisor_map[unit.lower()]:.2f} {unit.upper()}" + + def time_repr(time: float): + return f"{time:,} s" + + for node in self.module.graph.nodes: + node: Node + node_summaries.append([ + node.op, + str(node), + time_repr(node.meta['fwd_time']), + time_repr(node.meta['bwd_time']), + node.meta['save_fwd_in'], + mem_repr(node.meta['fwd_mem_out']), + mem_repr(node.meta['fwd_mem_tmp']), + mem_repr(node.meta['bwd_mem_out']), + mem_repr(node.meta['bwd_mem_tmp']), + ]) + + # Use the ``tabulate`` library to create a well-formatted table + # presenting our summary information + headers: List[str] = [ + 'Op type', + 'Op', + 'Forward time', + 'Backward time', + 'SAVE_FWD_IN', + 'FWD_OUT', + 'FWD_TMP', + 'BWD_OUT', + 'BWD_TMP', + ] + + return tabulate(node_summaries, headers=headers, stralign='right') diff --git a/colossalai/fx/passes/experimental/adding_shape_consistency_pass.py b/colossalai/fx/passes/experimental/adding_shape_consistency_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..f28d65e2668ac39e7b189c7d181d018468648614 --- /dev/null +++ b/colossalai/fx/passes/experimental/adding_shape_consistency_pass.py @@ -0,0 +1,111 @@ +import torch +from typing import List +from torch.fx import symbolic_trace +from torch.fx.node import Node +from colossalai.fx.passes.split_module import split_module +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec +import builtins +import operator +from copy import deepcopy + + +def apply(*args, **kwargs): + shape_consistency_manager = ShapeConsistencyManager() + return shape_consistency_manager.apply(*args, **kwargs) + + +def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], device_mesh): + mod_graph = gm.graph + nodes = tuple(mod_graph.nodes) + + # the dict to get origin sharding spec of node + origin_node_sharding_spec_dict = {} + for node_index, (node, strategy_index) in enumerate(zip(nodes, solution)): + strategies_vector = node.strategies_vector + setattr(node, 'best_strategy', strategies_vector[strategy_index]) + setattr(node, 'sharding_spec', strategies_vector[strategy_index].output_sharding_spec) + origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].output_sharding_spec + + # apply the sharding spec of parameters + for node in nodes: + if node.op == 'call_module': + target_module = node.graph.owning_module.get_submodule(node.target) + origin_sharding_spec = ShardingSpec(device_mesh, target_module.weight.shape, {}) + setattr(target_module.weight, 'sharding_spec', origin_sharding_spec) + target_weight_sharding_spec = node.best_strategy.input_shardings[1] + target_module.weight.data = target_module.weight.data.permute((1, 0, 2, 3)) + apply(target_module.weight, target_weight_sharding_spec) + target_module.weight.data = target_module.weight.data.permute((1, 0, 2, 3)) + + # the dict to get input sharding specs of user node + sharding_spec_convert_dict = {} + for index, node in enumerate(nodes): + target_sharding_specs = [] + for user_node in node.strategies_vector.successor_nodes: + node_index = user_node.strategies_vector.predecessor_nodes.index(node) + target_sharding_spec = user_node.best_strategy.input_shardings[node_index] + target_sharding_specs.append(target_sharding_spec) + sharding_spec_convert_dict[index] = target_sharding_specs + + # add above dicts into graph + for node in nodes: + if node.op != 'placeholder': + with mod_graph.inserting_before(node): + input_specs_node = mod_graph.create_node('placeholder', target='sharding_spec_convert_dict') + origin_specs_node = mod_graph.create_node('placeholder', target='origin_node_sharding_spec_dict') + break + + return sharding_spec_convert_dict, origin_node_sharding_spec_dict + + +def shape_consistency_pass(gm: torch.fx.GraphModule): + mod_graph = gm.graph + nodes = tuple(mod_graph.nodes) + input_dict_node = None + origin_dict_node = None + + # mapping the node into the origin graph index + node_to_index_dict = {} + index = 0 + for node in nodes: + if node.target == 'sharding_spec_convert_dict': + input_dict_node = node + continue + if node.target == 'origin_node_sharding_spec_dict': + origin_dict_node = node + continue + if not hasattr(node, 'best_strategy'): + continue + node_to_index_dict[node] = index + index += 1 + assert input_dict_node is not None + + # add shape consistency apply function into graph + for node in nodes: + if not hasattr(node, 'best_strategy'): + continue + with mod_graph.inserting_after(node): + origin_spec_node = mod_graph.create_node('call_function', + operator.getitem, + args=(origin_dict_node, node_to_index_dict[node])) + with mod_graph.inserting_after(origin_spec_node): + set_sharding_spec_node = mod_graph.create_node('call_function', + builtins.setattr, + args=(node, 'sharding_spec', origin_spec_node)) + + for user_node in node.strategies_vector.successor_nodes: + node_index = user_node.strategies_vector.predecessor_nodes.index(node) + with mod_graph.inserting_before(user_node): + input_specs_node = mod_graph.create_node('call_function', + operator.getitem, + args=(input_dict_node, node_to_index_dict[node])) + with mod_graph.inserting_before(user_node): + sharding_spec_node = mod_graph.create_node('call_function', + operator.getitem, + args=(input_specs_node, node_index)) + with mod_graph.inserting_before(user_node): + shape_consistency_node = mod_graph.create_node('call_function', apply, args=(node, sharding_spec_node)) + + return gm diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py new file mode 100644 index 0000000000000000000000000000000000000000..2b4a8749cfd776e3ac22d75fc2e47c2475c521d6 --- /dev/null +++ b/colossalai/fx/passes/meta_info_prop.py @@ -0,0 +1,358 @@ +from dataclasses import asdict +from typing import Any, Dict, List, NamedTuple, Tuple + +import torch +import torch.fx +from torch.fx.node import Argument, Node, Target +from torch.utils._pytree import tree_map + +from colossalai.fx._compatibility import compatibility, is_compatible_with_meta +from colossalai.fx.profiler import ( + GraphInfo, + activation_size, + calculate_fwd_in, + calculate_fwd_out, + calculate_fwd_tmp, + profile_function, + profile_method, + profile_module, +) + + +@compatibility(is_backward_compatible=True) +class TensorMetadata(NamedTuple): + # TensorMetadata is a structure containing pertinent information + # about a tensor within a PyTorch program. + + shape: torch.Size + dtype: torch.dtype + requires_grad: bool + stride: Tuple[int] + numel: int + is_tensor: bool + # TODO: we can add a list of sharding spec here, and record the sharding + # behaviour by appending sharding spec into list. + + +def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata: + """ + Extract a TensorMetadata NamedTuple describing `result`. + """ + shape = result.shape + dtype = result.dtype + requires_grad = result.requires_grad + stride = result.stride() + numel = result.numel() + is_tensor = True + + return TensorMetadata(shape, dtype, requires_grad, stride, numel, is_tensor) + + +@compatibility(is_backward_compatible=True) +class MetaInfoProp(torch.fx.Interpreter): + """ + Execute an FX graph Node-by-Node with meta tensor and + record the memory usage, FLOPs, and type of the result + into the corresponding node. + + Usage: + BATCH_SIZE = 2 + DIM_IN = 4 + DIM_HIDDEN = 16 + DIM_OUT = 16 + model = torch.nn.Sequential( + torch.nn.Linear(DIM_IN, DIM_HIDDEN), + torch.nn.Linear(DIM_HIDDEN, DIM_OUT), + ) + input_sample = torch.rand(BATCH_SIZE, DIM_IN) + gm = symbolic_trace(model) + interp = MetaInfoProp(gm) + interp.run(input_sample) + print(interp.summary(format='kb')) # don't panic if some statistics are 0.00 MB + + + # output of above code is + Op type Op Forward FLOPs Backward FLOPs FWD_OUT FWD_TMP BWD_OUT BWD_TMP + ----------- ------- --------------- ---------------- --------- --------- --------- --------- + placeholder input_1 0 FLOPs 0 FLOPs 0.00 KB 0.00 KB 0.00 KB 0.00 KB + call_module _0 128 FLOPs 288 FLOPs 0.12 KB 0.00 KB 0.34 KB 0.00 KB + call_module _1 512 FLOPs 1,056 FLOPs 0.12 KB 0.00 KB 1.19 KB 0.00 KB + output output 0 FLOPs 0 FLOPs 0.00 KB 0.00 KB 0.00 KB 0.00 KB + Args: + module (GraphModule): The module to be executed + + """ + + _is_proped: bool = False + + @compatibility(is_backward_compatible=True) + def run_node(self, n: Node) -> Any: + """ + Run a specific node ``n`` and return the result. + Calls into placeholder, get_attr, call_function, + call_method, call_module, or output depending + on ``node.op`` + + Args: + n (Node): The Node to execute + + Returns: + Any: The result of executing ``n`` + """ + self._is_proped = True + result, meta_info = super().run_node(n) + + def extract_tensor_meta(obj): + if isinstance(obj, torch.Tensor): + return _extract_tensor_metadata(obj) + else: + return TensorMetadata(None, None, False, None, 0, False) + + tensor_meta = tree_map(extract_tensor_meta, result) + n.meta['tensor_meta'] = tensor_meta + n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta` + # TODO: the attribute node_size should be removed in the future + setattr(n, 'node_size', activation_size(n.meta.get('fwd_out', 0)) + activation_size(n.meta.get('fwd_tmp', 0))) + setattr(n, 'fwd_flop', n.meta.get('fwd_flop', 0)) + setattr(n, 'bwd_flop', n.meta.get('bwd_flop', 0)) + n.meta['type'] = type(result) + + # retain the autograd graph + for param in self.module.parameters(): + param.grad = None + + return result + + # Main Node running APIs + @compatibility(is_backward_compatible=True) + def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + """ + Execute a ``placeholder`` node. Note that this is stateful: + ``Interpreter`` maintains an internal iterator over + arguments passed to ``run`` and this method returns + next() on that iterator. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Returns: + result (Any): The argument value that was retrieved + meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. + """ + return super().placeholder(target, args, kwargs), GraphInfo() + + @compatibility(is_backward_compatible=True) + def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + """ + Execute a ``get_attr`` node. Will retrieve an attribute + value from the ``Module`` hierarchy of ``self.module``. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return: + result (Any): The argument value that was retrieved + meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. + """ + return super().get_attr(target, args, kwargs), GraphInfo() + + @compatibility(is_backward_compatible=True) + def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + """ + Execute a ``call_function`` node with meta tensor and return the result and its meta profile. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return + result (Any): The argument value that was retrieved + meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. + """ + assert not isinstance(target, str) + return profile_function(target)(*args, **kwargs) + + @compatibility(is_backward_compatible=True) + def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + """ + Execute a ``call_method`` node with meta tensor and return the result and its meta profile. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return + result (Any): The argument value that was retrieved + meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. + """ + return profile_method(target)(*args, **kwargs) + + @compatibility(is_backward_compatible=True) + def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + """ + Execute a ``call_module`` node with meta tensor and return the result and its meta profile. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return + result (Any): The argument value that was retrieved + meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. + """ + # Retrieve executed args and kwargs values from the environment + # Execute the method and return the result + assert isinstance(target, str) + submod = self.fetch_attr(target) + return profile_module(submod)(*args, **kwargs) + + @compatibility(is_backward_compatible=True) + def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + """ + Execute an ``output`` node. This really just retrieves + the value referenced by the ``output`` node and returns it. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return: + result (Any): The argument value that was retrieved + meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. + """ + if hasattr(args[0], '_tensor'): + return args[0], GraphInfo(fwd_in=[args[0]._tensor]) + return args[0], GraphInfo(save_fwd_in=True) + + def propagate(self, *args): + """ + Run `module` via interpretation and return the result and + record the shape and type of each node. + + Args: + *args (Tensor): the sample input. + + Returns: + Any: The value returned from executing the Module + """ + return super().run(*args) + + def summary(self, unit: str = 'MB') -> str: + """ + Summarizes the memory and FLOPs statistics of the `GraphModule` in + tabular format. Note that this API requires the ``tabulate`` module + to be installed. + """ + # https://github.com/pytorch/pytorch/blob/master/torch/fx/graph.py + try: + from tabulate import tabulate + except ImportError: + print("`summary` relies on the library `tabulate`, " + "which could not be found on this machine. Run `pip " + "install tabulate` to install the library.") + + assert self._is_proped, "Please call `interp.run(input)` before calling `interp.summary()`." + + # Build up a list of summary information for each node + node_summaries: List[List[Any]] = [] + + def mem_repr(mem: int) -> str: + unit_divisor_map = { + 'kb': 1024, + 'mb': 1024**2, + 'gb': 1024**3, + 'tb': 1024**4, + } + return f"{mem / unit_divisor_map[unit.lower()]:.2f} {unit.upper()}" + + def flops_repr(flop: int) -> str: + return f"{flop:,} FLOPs" + + accumulate_size = 0 + for node in self.module.graph.nodes: + node: Node + accumulate_size += calculate_fwd_out(node) + calculate_fwd_tmp(node) + node_summaries.append([ + node.op, + str(node), + flops_repr(node.meta['fwd_flop']), + flops_repr(node.meta['bwd_flop']), + mem_repr(accumulate_size), + mem_repr(calculate_fwd_in(node)), + mem_repr(calculate_fwd_out(node)), + mem_repr(calculate_fwd_tmp(node)), + mem_repr(node.meta['bwd_mem_out']), + mem_repr(node.meta['bwd_mem_tmp']), + ]) + + # Use the ``tabulate`` library to create a well-formatted table + # presenting our summary information + headers: List[str] = [ + 'Op type', + 'Op', + 'Forward FLOPs', + 'Backward FLOPs', + 'Accumulated Memory', + 'FWD_IN', + 'FWD_OUT', + 'FWD_TMP', + 'BWD_OUT', + 'BWD_TMP', + ] + + return tabulate(node_summaries, headers=headers, stralign='right') + + +def metainfo_trace(gm: torch.fx.GraphModule, *args, verbose: bool = False, unit: str = "MB", **kwargs) -> None: + """ + MetaInfo tracing API + + Given a ``GraphModule`` and a sample input, this API will trace the MetaInfo of a single training cycle, + and annotate them on ``gm.graph``. + + Uses: + >>> model = ... + >>> gm = symbolic_trace(model) + >>> args = ... # sample input to the ``GraphModule`` + >>> metainfo_trace(gm, *args) + + Args: + gm (torch.fx.GraphModule): The ``GraphModule`` to be annotated with MetaInfo. + verbose (bool, optional): Whether to show ``MetaInfoProp.summary()`. Defaults to False. + unit (str, optional): The unit of memory. Defaults to "MB". + + Returns: + torch.fx.GraphModule: The ``GraphModule`` annotated with MetaInfo. + """ + device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') + interp = MetaInfoProp(gm.to(device)) + if is_compatible_with_meta(): + from colossalai.fx.profiler import MetaTensor + args = tree_map(lambda x: MetaTensor(x, fake_device=device), args) + kwargs = tree_map(lambda x: MetaTensor(x, fake_device=device), kwargs) + interp.propagate(*args, **kwargs) + if verbose: + interp.summary(unit) + gm.to('cpu') + del interp + return gm diff --git a/colossalai/fx/passes/passes_for_gpt2_test.py b/colossalai/fx/passes/passes_for_gpt2_test.py new file mode 100644 index 0000000000000000000000000000000000000000..abc1a089e9a90ebaff0681879aaa68d488edb624 --- /dev/null +++ b/colossalai/fx/passes/passes_for_gpt2_test.py @@ -0,0 +1,371 @@ +import inspect +from typing import Any, Callable, Dict, List, Optional + +import torch +from packaging import version +from torch.fx._compatibility import compatibility +from torch.fx.graph_module import GraphModule +from torch.fx.node import Node + +from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, pipe_split +from colossalai.fx.passes.meta_info_prop import TensorMetadata +from colossalai.fx.passes.split_module import Partition + + +def customized_split_pass_for_gpt2(gm: torch.fx.GraphModule, pp_size: int, partition_list: List[int]): + ''' + This pass is only used to do the gpt2 performance test, it may move into adding_split_node_pass.py, and will be deprecated in future. + ''' + mod_graph = gm.graph + valid_children_size = 0 + valid_children = [] + for node in mod_graph.nodes: + if node.op == "call_module": + valid_children_size += 1 + valid_children.append(node.target) + if valid_children_size < pp_size: + # If valid children is not enough to shard, we will use balanced policy instead of uniform policy. + return balanced_split_pass(gm, pp_size) + accumulate_layer_amount = 0 + list_of_part = partition_list + part_index = 0 + for node in mod_graph.nodes: + if pp_size <= 1: + break + if node.op == "call_module": + if node.target in valid_children: + accumulate_layer_amount += 1 + if accumulate_layer_amount == list_of_part[part_index]: + part_index += 1 + pp_size -= 1 + with mod_graph.inserting_after(node): + split_node = mod_graph.create_node('call_function', pipe_split) + + gm.recompile() + return gm + + +def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule): + ''' + This pass will be used in gpt2 test, only a part of changes may be added into + split_with_split_nodes_pass, and it will be deprecated in future. + ''' + part_idx = 0 + + def eliminate_unused_placeholders(gm): + for node in gm.graph.nodes: + if node.op == 'placeholder': + if not len(node.users): + gm.graph.erase_node(node) + gm.recompile() + return gm + + def refill_outputs_and_placeholders(gm, next_partition_placeholders): + ''' + This method is used to eliminate the outputs in previous partition which is unused in next partition. + In split module pass, it treats partitions as a DAG, but we need treat them as a single direction linked list in pipeline parallel. + The difference is if a output from partition 0 is an input argument of partition 3, the DAG will not transfer it + to partition 1 and partition 2. However, in single direction linked list, we need to do so. + ''' + output_type = None + output_args = [] + non_output_list = [] + new_placeholder_list = [] + for node in gm.graph.nodes: + if node.op == 'output': + if isinstance(node.args[0], (tuple, list)): + output_type = node.args[0].__class__ + output_args.extend([n.name for n in node.args[0]]) + else: + output_args.append(node.args[0].name) + rm_list = [] + for name in output_args: + if next_partition_placeholders and name not in next_partition_placeholders: + rm_list.append(name) + for name in rm_list: + output_args.remove(name) + gm.graph.erase_node(node) + else: + non_output_list.append(node.name) + + for name in next_partition_placeholders: + if name not in output_args: + output_args.append(name) + + for name in output_args: + if name not in non_output_list: + gm.graph.placeholder(name) + + # convert name to node for output_args + for index, name in enumerate(output_args): + for n in gm.graph.nodes: + if n.name == name: + output_args[index] = n + continue + + # reorder the output args to make sure + # output args has same order as next partition placeholder + reorder_output_args = [] + if next_partition_placeholders: + for name in next_partition_placeholders: + for node in output_args: + if node.name == name: + reorder_output_args.append(node) + continue + + for node in gm.graph.nodes: + if node.op == 'placeholder': + new_placeholder_list.append(node.name) + if output_type is not None: + gm.graph.output(output_type(output_args)) + else: + gm.graph.output(output_args) + gm.recompile() + return gm, new_placeholder_list + + def split_callback(n: torch.fx.Node): + nonlocal part_idx + if (n.op, n.target) == ('call_function', pipe_split): + part_idx += 1 + return part_idx + + split_mod = split_module_for_gpt2_test(annotated_gm, None, split_callback) + split_submodules = [] + for name, submodule in split_mod.named_modules(): + if isinstance(submodule, torch.fx.GraphModule): + for node in submodule.graph.nodes: + if (node.op, node.target) == ('call_function', pipe_split): + submodule.graph.erase_node(node) + submodule.recompile() + split_submodules.append(submodule) + + submodules = list(split_mod.children()) + placeholder_dict = {} + for submodule in submodules: + submodule = eliminate_unused_placeholders(submodule) + placeholder_dict[submodule] = [] + submodules.reverse() + for index, submodule in enumerate(submodules): + if index == 0: + placeholder_list = [] + else: + placeholder_list = placeholder_dict[submodules[index - 1]] + submodule, placeholder_dict[submodule] = refill_outputs_and_placeholders(submodule, placeholder_list) + submodule.recompile() + + split_mod.recompile() + + return split_mod, split_submodules + + +@compatibility(is_backward_compatible=True) +def split_module_for_gpt2_test( + m: GraphModule, + root_m: torch.nn.Module, + split_callback: Callable[[torch.fx.node.Node], int], +): + """ + This pass will be used in gpt2 pp performance test, only a part of changes may be added into + split_module, and it will be deprecated in future. + """ + partitions: Dict[str, Partition] = {} + orig_nodes: Dict[str, torch.fx.node.Node] = {} + + def _node_with_all_tensor_element(node_metadata: Any) -> int: + """ + return whether node contains non-tensor element. + """ + all_tensor_node = True + + if isinstance(node_metadata, TensorMetadata): + all_tensor_node = node_metadata.is_tensor and all_tensor_node + elif isinstance(node_metadata, dict): + value_list = [v for _, v in node_metadata.items()] + all_tensor_node += _node_with_all_tensor_element(value_list) + else: + for element in node_metadata: + all_tensor_node += _node_with_all_tensor_element(element) + + return all_tensor_node + + def _move_all_ancestors_into_partition(node, partition_name): + all_ancestors = set() + + def _gen_all_ancestors_set(node): + all_ancestors.add(node) + for n in node.all_input_nodes: + if n in all_ancestors: + continue + _gen_all_ancestors_set(n) + + _gen_all_ancestors_set(node) + for n in list(all_ancestors): + if n.op != 'placeholder' and n._fx_partition > partition_name: + n._fx_partition = partition_name + + def record_cross_partition_use(def_node: torch.fx.node.Node, + use_node: Optional[torch.fx.node.Node]): # noqa: B950 + def_partition_name = getattr(def_node, '_fx_partition', None) + use_partition_name = getattr(use_node, '_fx_partition', None) + if def_partition_name != use_partition_name: + # if 'tensor_meta' in def_node.meta: + # if not _node_with_all_tensor_element(def_node.meta['tensor_meta']): + # _move_all_ancestors_into_partition(use_node, def_partition_name) + # node_process_list.extend(use_node.all_input_nodes) + # node_process_list.extend(list(use_node.users)) + # node_process_list.append(use_node) + + # return + + if def_partition_name is not None: + def_partition = partitions[def_partition_name] + def_partition.outputs.setdefault(def_node.name) + if use_partition_name is not None: + def_partition.partition_dependents.setdefault(use_partition_name) + + if use_partition_name is not None: + use_partition = partitions[use_partition_name] + use_partition.inputs.setdefault(def_node.name) + if def_partition_name is not None: + use_partition.partitions_dependent_on.setdefault(def_partition_name) + + node_process_list = list(m.graph.nodes) + # split nodes into parititons + while node_process_list: + node = node_process_list.pop(0) + orig_nodes[node.name] = node + + if node.op in ["placeholder"]: + continue + if node.op == 'output': + # partition_name = str(split_callback(node)) + # def _set_output_args_partition(n, partition_name): + # n._fx_partition = partition_name + # torch.fx.graph.map_arg(node.args[0], lambda n: _set_output_args_partition(n, partition_name)) + torch.fx.graph.map_arg(node.args[0], lambda n: record_cross_partition_use(n, None)) + continue + partition_name = str(split_callback(node)) + + # add node to partitions + partition = partitions.get(partition_name) + if partition is None: + partitions[partition_name] = partition = Partition(partition_name) + + partition.node_names.append(node.name) + origin_partition_name = getattr(node, '_fx_partition', None) + if origin_partition_name is None: + node._fx_partition = partition_name + + torch.fx.graph.map_arg(node.args, lambda def_node: record_cross_partition_use(def_node, node)) + torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950 + + # find partitions with no dependencies + root_partitions: List[str] = [] + for partition_name, partition in partitions.items(): + if not len(partition.partitions_dependent_on): + root_partitions.append(partition_name) + + # check partitions for circular dependencies and create topological partition ordering + sorted_partitions: List[str] = [] + while root_partitions: + root_partition = root_partitions.pop() + sorted_partitions.append(root_partition) + for dependent in partitions[root_partition].partition_dependents: + partitions[dependent].partitions_dependent_on.pop(root_partition) + if not partitions[dependent].partitions_dependent_on: + root_partitions.append(dependent) + if len(sorted_partitions) != len(partitions): + raise RuntimeError("cycle exists between partitions!") + + # add placeholders to parititons + for partition_name in sorted_partitions: + partition = partitions[partition_name] + for input in partition.inputs: + placeholder = partition.graph.placeholder(input) + placeholder.meta = orig_nodes[input].meta.copy() + partition.environment[orig_nodes[input]] = placeholder + + # Transform nodes and collect targets for partition's submodule + for node in m.graph.nodes: + if hasattr(node, '_fx_partition'): + partition = partitions[node._fx_partition] + + # swap out old graph nodes in kw/args with references to new nodes in this submodule + environment = partition.environment + gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n]) + gathered_kwargs = torch.fx.graph.map_arg(node.kwargs, lambda n: environment[n]) + + if node.op not in ['call_module', 'get_attr']: + target = node.target + else: + target_atoms = node.target.split('.') + target_attr = m + for atom in target_atoms: + if not hasattr(target_attr, atom): + raise RuntimeError(f'Operator target {node.target} not found!') + target_attr = getattr(target_attr, atom) + # target = target_atoms[-1] + target = '_'.join(target_atoms) + partition.targets[target] = target_attr + + assert isinstance(gathered_args, tuple) + assert isinstance(gathered_kwargs, dict) + new_node = partition.graph.create_node(op=node.op, + target=target, + args=gathered_args, + kwargs=gathered_kwargs, + name=node.name) + new_node.meta = node.meta.copy() + partition.environment[node] = new_node + + # Set up values to construct base module + base_mod_env: Dict[str, torch.fx.node.Node] = {} + base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph() + base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {} + for node in m.graph.nodes: + if node.op == 'placeholder': + if version.parse(torch.__version__) < version.parse('1.11.0'): + base_mod_env[node.name] = base_mod_graph.placeholder(node.name, type_expr=node.type) + else: + default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty + base_mod_env[node.name] = base_mod_graph.placeholder(node.name, + type_expr=node.type, + default_value=default_value) + base_mod_env[node.name].meta = node.meta.copy() + + # Do some things iterating over the partitions in topological order again: + # 1) Finish off submodule Graphs by setting corresponding outputs + # 2) Construct GraphModules for each submodule + # 3) Construct the base graph by emitting calls to those submodules in + # topological order + + for partition_name in sorted_partitions: + partition = partitions[partition_name] + + # Set correct output values + output_vals = tuple(partition.environment[orig_nodes[name]] for name in partition.outputs) + output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment] + partition.graph.output(output_vals) + + # Construct GraphModule for this partition + submod_name = f'submod_{partition_name}' + base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(partition.targets, + partition.graph) # noqa: B950 + + # Emit call in base graph to this submodule + output_val = base_mod_graph.call_module(submod_name, tuple(base_mod_env[name] for name in partition.inputs)) + if len(partition.outputs) > 1: + # Unpack multiple return values from submodule + output_val_proxy = torch.fx.proxy.Proxy(output_val) + for i, output_name in enumerate(partition.outputs): + base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index] + else: + if not partition.outputs: + continue + base_mod_env[list(partition.outputs)[0]] = output_val + + for node in m.graph.nodes: + if node.op == 'output': + base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950 + + return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph) diff --git a/colossalai/fx/passes/shard_1d_pass.py b/colossalai/fx/passes/shard_1d_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..d2bad06bb45a1393543039ef55dea6c8d1e9f50b --- /dev/null +++ b/colossalai/fx/passes/shard_1d_pass.py @@ -0,0 +1,151 @@ +import torch +import torch.nn as nn +import operator +from colossalai.tensor import ProcessGroup +from colossalai.tensor.distspec import ShardSpec +from colossalai.tensor.compute_spec import ComputePattern, ComputeSpec + +ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU] +ELEMENTWISE_FUNC_OP = [ + torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv, + operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout +] + + +def weight_split(weight: torch.nn.parameter.Parameter, dim: int, col_normal: bool) -> torch.nn.parameter.Parameter: + """weight_split + split a nn.Parameter + + Args: + weight (torch.nn.parameter.Parameter): a torch Parameter instance + dim (int): the dimension to be sharded along with + col_normal(bool): col shard with gather or not + Returns: + _type_: _description_ + """ + if col_normal: + setattr(weight, "fx_attr", (dim, "SHARD", "TP", "col_normal")) + else: + setattr(weight, "fx_attr", (dim, "SHARD", "TP", "col_needs_many_outputs")) + return weight + + +def column_shard_linear_pass(gm: torch.fx.GraphModule): + # Split all the linear module with column shard. Currently for testing only. + mod_graph = gm.graph + for node in mod_graph.nodes: + if node.op == "call_module": + target_module = node.graph.owning_module.get_submodule(node.target) + if isinstance(target_module, torch.nn.Linear): + target_module.weight = weight_split(target_module.weight, dim=0, col_normal=False) + if target_module.bias is not None: + target_module.bias.data = weight_split(target_module.bias.data, dim=0, col_normal=False) + + gm.recompile() + return gm + + +def row_shard_linear_pass(gm: torch.fx.GraphModule): + # Split all the linear module with row shard. Currently for testing only. + mod_graph = gm.graph + for node in mod_graph.nodes: + if node.op == "call_module": + target_module = node.graph.owning_module.get_submodule(node.target) + if isinstance(target_module, torch.nn.Linear): + target_module.weight = weight_split(target_module.weight, dim=-1, col_normal=False) + + gm.recompile() + return gm + + +def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: ProcessGroup): + """ + This IR pass checks for transformer MLP like structure and annotate column and row sharding to the linear layers. + """ + #TODO: Needs to handle special cases, like x = linear(x) + linear(x) + graph = graph_module.graph + world_size = process_group.world_size() + + def _traverse_and_annotate(node, start_tracking, annotation_record, world_size): + # traverse the graph to look for consecutive linear layers + is_linear_module = False + + if node.op == 'call_module': + # look for the linear layer + module = node.graph.owning_module.get_submodule(node.target) + if isinstance(module, nn.Linear): + is_linear_module = True + if start_tracking: + # when start_tracking = True + # it means the first linear has been found and the current module + # is the second linear + # set the current linear module to be row-sharded + annotation_record['row'] = module + + for shard_type, module in annotation_record.items(): + # add row sharding spec + if shard_type == 'row': + dist_spec = ShardSpec(dims=[-1], num_partitions=[world_size]) + comp_spec = ComputeSpec(ComputePattern.TP1D) + setattr(module.weight, 'pg', process_group) + setattr(module.weight, 'dist_spec', dist_spec) + setattr(module.weight, 'comp_spec', comp_spec) + elif shard_type == 'col': + weight_dist_spec = ShardSpec(dims=[0], num_partitions=[world_size]) + weight_comp_spec = ComputeSpec(ComputePattern.TP1D) + weight_comp_spec.output_replicate = False + setattr(module.weight, 'pg', process_group) + setattr(module.weight, 'dist_spec', weight_dist_spec) + setattr(module.weight, 'comp_spec', weight_comp_spec) + + if module.bias is not None: + bias_dist_spec = ShardSpec(dims=[0], num_partitions=[world_size]) + bias_comp_spec = ComputeSpec(ComputePattern.TP1D) + bias_comp_spec.output_replicate = False + setattr(module.bias, 'pg', process_group) + setattr(module.bias, 'dist_spec', bias_dist_spec) + setattr(module.bias, 'comp_spec', bias_comp_spec) + start_tracking = False + annotation_record.clear() + else: + # when start tracking = False + # it means the current layer is the first linear + # set the linear layer to be col-sharded + start_tracking = True + annotation_record['col'] = module + + if start_tracking and not is_linear_module: + # check against the white list + # if non-element wise op is found, we reset the tracking + if node.op == 'call_module': + module = node.graph.owning_module.get_submodule(node.target) + if module.__class__ not in ELEMENTWISE_MODULE_OP: + start_tracking = False + elif node.op == 'call_function' or node.op == 'call_method': + if node.target not in ELEMENTWISE_FUNC_OP: + start_tracking = False + elif len(node.users.keys()) > 1: + start_tracking = False + + if not start_tracking: + annotation_record.clear() + + # stop tracking for consecutive linear when branch is found + # e.g. + # out1 = self.linear1(x) + # out2 = self.linear2(x) + # return out1+out2 + next_nodes = list(node.users.keys()) + if len(next_nodes) > 1: + start_tracking = False + annotation_record.clear() + + # traverse + for node in next_nodes: + _traverse_and_annotate(node, start_tracking, annotation_record, world_size) + + placeholder_node = list(graph.nodes)[0] + annotate_record = {} + _traverse_and_annotate(placeholder_node, False, annotate_record, world_size) + + return graph_module diff --git a/colossalai/fx/passes/split_module.py b/colossalai/fx/passes/split_module.py new file mode 100644 index 0000000000000000000000000000000000000000..5ce5b969cbdefc0e347eb9f156f5268658e988fc --- /dev/null +++ b/colossalai/fx/passes/split_module.py @@ -0,0 +1,296 @@ +import inspect +from typing import Any, Callable, Dict, List, Optional + +import torch +from packaging import version +from torch.fx._compatibility import compatibility +from torch.fx.graph_module import GraphModule + + +@compatibility(is_backward_compatible=True) +class Partition: + """ + Adapted from https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/split_module.py + """ + + def __init__(self, name: str): + self.name: str = name + self.node_names: List[str] = [] + self.inputs: Dict[str, None] = {} + self.outputs: Dict[str, None] = {} + self.partitions_dependent_on: Dict[str, None] = {} + self.partition_dependents: Dict[str, None] = {} + self.graph: torch.fx.graph.Graph = torch.fx.graph.Graph() + self.environment: Dict[torch.fx.node.Node, torch.fx.node.Node] = {} + self.targets: Dict[str, Any] = {} + + def __repr__(self) -> str: + return f"name: {self.name},\n" \ + f" nodes: {self.node_names},\n" \ + f" inputs: {self.inputs},\n" \ + f" outputs: {self.outputs},\n" \ + f" partitions depenent on: {self.partitions_dependent_on},\n" \ + f" parition dependents: {self.partition_dependents}" + + +# Creates subgraphs out of main graph +@compatibility(is_backward_compatible=True) +def split_module( + m: GraphModule, + root_m: torch.nn.Module, + split_callback: Callable[[torch.fx.node.Node], int], + merge_output=False, +): + """ + Adapted from https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/split_module.py + Creates subgraphs out of main graph + Args: + m (GraphModule): Graph module to split + root_m (torch.nn.Module): root nn module. Not currently used. Included + because the root nn module is usually transformed via + torch.fx._symbolic_trace.symbolic_trace (see example below) + split_callback (Callable[[torch.fx.node.Node], int]): Callable function + that maps a given Node instance to a numeric partition identifier. + split_module will use this function as the policy for which operations + appear in which partitions in the output Module. + Returns: + GraphModule: the module after split. + Example: + This is a sample setup: + import torch + from torch.fx.symbolic_trace import symbolic_trace + from torch.fx.graph_module import GraphModule + from torch.fx.node import Node + from colossalai.fx.passes.split_module import split_module + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.param = torch.nn.Parameter(torch.rand(3, 4)) + self.linear = torch.nn.Linear(4, 5) + def forward(self, x, y): + z = self.linear(x + self.param).clamp(min=0.0, max=1.0) + w = self.linear(y).clamp(min=0.0, max=1.0) + return z + w + # symbolically trace model + my_module = MyModule() + my_module_traced = symbolic_trace(my_module) + # random mod partitioning + partition_counter = 0 + NPARTITIONS = 3 + def mod_partition(node: Node): + global partition_counter + partition = partition_counter % NPARTITIONS + partition_counter = (partition_counter + 1) % NPARTITIONS + return partition + # split module in module with submodules + module_with_submodules = split_module( + my_module_traced, my_module, mod_partition + ) + Output looks like this. Original graph is broken into partitions + > print(module_with_submodules) + GraphModule( + (submod_0): GraphModule( + (linear): Linear(in_features=4, out_features=5, bias=True) + ) + (submod_1): GraphModule( + (linear): Linear(in_features=4, out_features=5, bias=True) + ) + (submod_2): GraphModule() + ) + def forward(self, x, y): + param = self.param + submod_0 = self.submod_0(x, param, y); x = param = y = None + getitem = submod_0[0] + getitem_1 = submod_0[1]; submod_0 = None + submod_1 = self.submod_1(getitem, getitem_1); getitem = getitem_1 = None + getitem_2 = submod_1[0] + getitem_3 = submod_1[1]; submod_1 = None + submod_2 = self.submod_2(getitem_2, getitem_3); getitem_2 = getitem_3 = None + return submod_2 + Output of split module is the same as output of input traced module. + This is an example within a test setting: + > orig_out = my_module_traced(x, y) + > submodules_out = module_with_submodules(x, y) + > self.assertEqual(orig_out, submodules_out) + True + """ + partitions: Dict[str, Partition] = {} + orig_nodes: Dict[str, torch.fx.node.Node] = {} + + def record_cross_partition_use(def_node: torch.fx.node.Node, + use_node: Optional[torch.fx.node.Node]): # noqa: B950 + def_partition_name = getattr(def_node, '_fx_partition', None) + use_partition_name = getattr(use_node, '_fx_partition', None) + if def_partition_name != use_partition_name: + if def_partition_name is not None: + def_partition = partitions[def_partition_name] + def_partition.outputs.setdefault(def_node.name) + if use_partition_name is not None: + def_partition.partition_dependents.setdefault(use_partition_name) + + if use_partition_name is not None: + use_partition = partitions[use_partition_name] + use_partition.inputs.setdefault(def_node.name) + if def_partition_name is not None: + use_partition.partitions_dependent_on.setdefault(def_partition_name) + + def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950 + def_partition_name = getattr(def_node, "_fx_partition", None) + use_partition_name = getattr(use_node, "_fx_partition", None) + if def_partition_name != use_partition_name: + if def_partition_name is not None: + def_partition = partitions[def_partition_name] + def_partition.outputs.setdefault(def_node.name) + if use_partition_name is not None: + def_partition.partition_dependents.setdefault(use_partition_name) + + if use_partition_name is not None: + use_partition = partitions[use_partition_name] + use_partition.inputs.setdefault(def_node.name) + if def_partition_name is not None: + use_partition.partitions_dependent_on.setdefault(def_partition_name) + use_partition.outputs.setdefault(def_node.name) + else: + if use_partition_name is not None: + use_partition = partitions[use_partition_name] + use_partition.outputs.setdefault(def_node.name) + + # split nodes into partitions + for node in m.graph.nodes: + orig_nodes[node.name] = node + + if node.op in ["placeholder"]: + continue + if node.op == 'output': + if merge_output: + torch.fx.graph.map_arg(node.args[0], lambda n: record_output(n, node.prev)) + else: + torch.fx.graph.map_arg(node.args[0], lambda n: record_cross_partition_use(n, None)) + continue + partition_name = str(split_callback(node)) + + # add node to partitions + partition = partitions.get(partition_name) + if partition is None: + partitions[partition_name] = partition = Partition(partition_name) + + partition.node_names.append(node.name) + node._fx_partition = partition_name + + torch.fx.graph.map_arg(node.args, lambda def_node: record_cross_partition_use(def_node, node)) + torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950 + + # find partitions with no dependencies + root_partitions: List[str] = [] + for partition_name, partition in partitions.items(): + if not len(partition.partitions_dependent_on): + root_partitions.append(partition_name) + + # check partitions for circular dependencies and create topological partition ordering + sorted_partitions: List[str] = [] + while root_partitions: + root_partition = root_partitions.pop() + sorted_partitions.append(root_partition) + for dependent in partitions[root_partition].partition_dependents: + partitions[dependent].partitions_dependent_on.pop(root_partition) + if not partitions[dependent].partitions_dependent_on: + root_partitions.append(dependent) + if len(sorted_partitions) != len(partitions): + raise RuntimeError("cycle exists between partitions!") + + # add placeholders to partitions + for partition_name in sorted_partitions: + partition = partitions[partition_name] + for input in partition.inputs: + placeholder = partition.graph.placeholder(input) + placeholder.meta = orig_nodes[input].meta.copy() + partition.environment[orig_nodes[input]] = placeholder + + # Transform nodes and collect targets for partition's submodule + for node in m.graph.nodes: + if hasattr(node, '_fx_partition'): + partition = partitions[node._fx_partition] + + # swap out old graph nodes in kw/args with references to new nodes in this submodule + environment = partition.environment + gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n]) + gathered_kwargs = torch.fx.graph.map_arg(node.kwargs, lambda n: environment[n]) + + if node.op not in ['call_module', 'get_attr']: + target = node.target + else: + target_atoms = node.target.split('.') + target_attr = m + for atom in target_atoms: + if not hasattr(target_attr, atom): + raise RuntimeError(f'Operator target {node.target} not found!') + target_attr = getattr(target_attr, atom) + # target = target_atoms[-1] + target = '_'.join(target_atoms) + partition.targets[target] = target_attr + + assert isinstance(gathered_args, tuple) + assert isinstance(gathered_kwargs, dict) + new_node = partition.graph.create_node(op=node.op, + target=target, + args=gathered_args, + kwargs=gathered_kwargs) + new_node.meta = node.meta.copy() + partition.environment[node] = new_node + + # Set up values to construct base module + base_mod_env: Dict[str, torch.fx.node.Node] = {} + base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph() + base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {} + for node in m.graph.nodes: + if node.op == 'placeholder': + if version.parse(torch.__version__) < version.parse('1.11.0'): + base_mod_env[node.name] = base_mod_graph.placeholder(node.target, type_expr=node.type) + else: + default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty + base_mod_env[node.name] = base_mod_graph.placeholder(node.target, + type_expr=node.type, + default_value=default_value) + base_mod_env[node.name].meta = node.meta.copy() + + # Do some things iterating over the partitions in topological order again: + # 1) Finish off submodule Graphs by setting corresponding outputs + # 2) Construct GraphModules for each submodule + # 3) Construct the base graph by emitting calls to those submodules in + # topological order + + for partition_name in sorted_partitions: + partition = partitions[partition_name] + + # Set correct output values + output_vals = tuple(partition.environment[orig_nodes[name]] for name in partition.outputs) + output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment] + partition.graph.output(output_vals) + + # Construct GraphModule for this partition + submod_name = f'submod_{partition_name}' + base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(partition.targets, + partition.graph) # noqa: B950 + + # Emit call in base graph to this submodule + output_val = base_mod_graph.call_module(submod_name, tuple(base_mod_env[name] for name in partition.inputs)) + if len(partition.outputs) > 1: + # Unpack multiple return values from submodule + output_val_proxy = torch.fx.proxy.Proxy(output_val) + for i, output_name in enumerate(partition.outputs): + base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index] + else: + if not partition.outputs: + continue + base_mod_env[list(partition.outputs)[0]] = output_val + + for node in m.graph.nodes: + if node.op == 'output': + base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950 + + for partition_name in sorted_partitions: + partition = partitions[partition_name] + + new_gm = torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph) + + return new_gm diff --git a/colossalai/fx/passes/utils.py b/colossalai/fx/passes/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bb4f3cd6a4908177ca13f9d7fb82ff42b5ad1d5e --- /dev/null +++ b/colossalai/fx/passes/utils.py @@ -0,0 +1,172 @@ +import torch +from typing import Dict +from torch.fx.node import Node, map_arg +from torch.fx.graph import Graph + +def get_comm_size(prev_partition, next_partition): + """ + Given two partitions (parent and child), + calculate the communication size between the two. + """ + # Keep tracking the communication size between parent and child + comm_size = 0 + # Keep tracking all the counted node + visited_nodes = set() + # Go through all nodes in the child partition + # If a node has input nodes from the parent partition, + # the output size of those input nodes will be counted + # and added to comm_size + parent_node_names = [n.name for n in prev_partition.graph.nodes] + for node in next_partition.graph.nodes: + input_nodes: Dict[Node, None] = {} + map_arg(node.args, lambda n: input_nodes.setdefault(n)) + map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) + for n in input_nodes: + if n.name in parent_node_names and n not in visited_nodes: + comm_size += n.meta['tensor_meta'].numel + visited_nodes.add(n) + return comm_size + + +def get_leaf(graph: Graph): + """ + Given a graph, return leaf nodes of this graph. + Note: If we remove ``root`` nodes, ``placeholder`` nodes, and ``output`` nodes from fx graph, + we will get a normal DAG. Leaf nodes in this context means leaf nodes in that DAG. + """ + input_nodes: Dict[Node, None] = {} + for node in graph.nodes: + if node.op == 'output': + map_arg(node.args, lambda n: input_nodes.setdefault(n)) + map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) + placeholder_nodes = [] + for node in input_nodes.keys(): + if node.op == 'placeholder': + placeholder_nodes.append(node) + for node in placeholder_nodes: + input_nodes.pop(node) + return list(input_nodes.keys()) + + +def is_leaf(graph: Graph, node: Node): + return node in get_leaf(graph) + + +def get_top(graph: Graph): + """ + Given a graph, return top nodes of this graph. + Note: If we remove ``root`` nodes, ``placeholder`` nodes, and ``output`` nodes from fx graph, + we will get a normal DAG. Top nodes in this context means nodes with BFS level 0 in that DAG. + """ + top_node_list = set() + for node in graph.nodes: + if node.op == 'output': + continue + is_top = False + + def _get_top(node): + nonlocal is_top + if node.op == 'placeholder': + is_top = True + + map_arg(node.args, lambda n: _get_top(n)) + map_arg(node.kwargs, lambda n: _get_top(n)) + if is_top: + top_node_list.add(node) + return list(top_node_list) + + +def is_top(graph: Graph, node: Node): + return node in get_top(graph) + + +def get_all_consumers(graph: Graph, node: Node): + """ + Given a graph and a node of this graph, return all consumers of the node. + + Returns: + List of ``Nodes`` that node appear in these nodes ``args`` and ``kwargs``. + """ + consumer_list = [] + for n in graph.nodes: + if node in n.all_input_nodes: + consumer_list.append(n) + return consumer_list + + +def assign_bfs_level_to_nodes(graph: Graph): + """ + Give a graph, assign bfs level to each node of this graph excluding ``placeholder`` and ``output`` nodes. + Example: + class MLP(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.linear1 = torch.nn.Linear(dim, dim) + self.linear2 = torch.nn.Linear(dim, dim) + self.linear3 = torch.nn.Linear(dim, dim) + self.linear4 = torch.nn.Linear(dim, dim) + self.linear5 = torch.nn.Linear(dim, dim) + def forward(self, x): + l1 = self.linear1(x) + l2 = self.linear2(x) + l3 = self.linear3(l1) + l4 = self.linear4(l2) + l5 = self.linear5(l3) + return l4, l5 + model = MLP(4) + gm = symbolic_trace(model) + print(gm.graph) + assign_bfs_level_to_nodes(gm.graph) + for node in gm.graph.nodes: + if hasattr(node, 'bfs_level'): + print(node.name, node.bfs_level) + + Output: + graph(): + %x : [#users=2] = placeholder[target=x] + %linear1 : [#users=1] = call_module[target=linear1](args = (%x,), kwargs = {}) + %linear2 : [#users=1] = call_module[target=linear2](args = (%x,), kwargs = {}) + %linear3 : [#users=1] = call_module[target=linear3](args = (%linear1,), kwargs = {}) + %linear4 : [#users=1] = call_module[target=linear4](args = (%linear2,), kwargs = {}) + %linear5 : [#users=1] = call_module[target=linear5](args = (%linear3,), kwargs = {}) + return (linear4, linear5) + linear1 0 + linear2 0 + linear3 1 + linear4 1 + linear5 2 + """ + current_level = 0 + nodes_to_process = [] + + top_nodes = get_top(graph) + for node in top_nodes: + node.bfs_level = current_level + nodes_to_process.extend(get_all_consumers(graph, node)) + + current_level += 1 + while nodes_to_process: + new_process_list = [] + for node in nodes_to_process: + if node.op == 'output': + continue + node.bfs_level = current_level + new_process_list.extend(get_all_consumers(graph, node)) + nodes_to_process = new_process_list + current_level += 1 + + +def get_node_module(node) -> torch.nn.Module: + """ + Find the module associated with the given node. + Args: + node (torch.fx.Node): a torch.fx.Node object in the fx computation graph + Returns: + torch.nn.Module: the module associated with the given node + """ + + assert node.graph.owning_module is not None, 'Cannot find the owning_module for node.graph, please make sure the graph is associated with a GraphModule object' + assert node.op == 'call_module', f'Expected node.op to be call_module, but found {node.op}' + module = node.graph.owning_module.get_submodule(node.target) + return module + diff --git a/colossalai/fx/profiler/__init__.py b/colossalai/fx/profiler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bcbde0eb23b806b7e37e407d57a962d8ff71573 --- /dev/null +++ b/colossalai/fx/profiler/__init__.py @@ -0,0 +1,18 @@ +from .._compatibility import is_compatible_with_meta + +if is_compatible_with_meta(): + from .opcount import flop_mapping + from .profiler import profile_function, profile_method, profile_module + from .shard_utils import ( + calculate_bwd_time, + calculate_fwd_in, + calculate_fwd_out, + calculate_fwd_time, + calculate_fwd_tmp, + ) + from .tensor import MetaTensor +else: + from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module, calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out + +from .dataflow import GraphInfo +from .memory_utils import activation_size, is_inplace, parameter_size diff --git a/colossalai/fx/profiler/constants.py b/colossalai/fx/profiler/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..5763a46dc83f19dadefbebb32dbcf9a59578a2b3 --- /dev/null +++ b/colossalai/fx/profiler/constants.py @@ -0,0 +1,44 @@ +import torch + +__all__ = ['ALIAS_ATEN', 'INPLACE_NEW', 'INPLACE_MATH_ATEN', 'CLONE_ATEN', 'RELU_LIKE_OPS', 'RELU_LIKE_MOD'] + +aten = torch.ops.aten + +ALIAS_ATEN = [ + aten.detach.default, + aten.t.default, + aten.transpose.int, + aten.view.default, + aten._unsafe_view.default, + aten._reshape_alias.default, +] + +INPLACE_NEW = [ + aten.empty_like.default, + aten.new_empty_strided.default, +] + +INPLACE_MATH_ATEN = [ + aten.add_.Tensor, + aten.sub_.Tensor, + aten.div_.Tensor, + aten.div_.Scalar, + aten.mul_.Tensor, + aten.bernoulli_.float, +] + +CLONE_ATEN = [ + aten.clone.default, +] + +# See illustrations in +# https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/fx/profiler/constants.py +OUTPUT_SAVED_OPS = [ + torch.nn.functional.relu, + torch.nn.functional.softmax, +] + +OUTPUT_SAVED_MOD = [ + torch.nn.ReLU, + torch.nn.Softmax, +] diff --git a/colossalai/fx/profiler/dataflow.py b/colossalai/fx/profiler/dataflow.py new file mode 100644 index 0000000000000000000000000000000000000000..a5e8880322b84f54e6c4742a821f4de76dfb664a --- /dev/null +++ b/colossalai/fx/profiler/dataflow.py @@ -0,0 +1,141 @@ +from dataclasses import dataclass, field +from enum import Enum +from functools import partial +from typing import Dict, List + +from torch.fx import Graph, Node + +from .._compatibility import compatibility +from .memory_utils import activation_size, is_inplace + + +class Phase(Enum): + FORWARD = 0 + BACKWARD = 1 + PLACEHOLDER = 2 + + +@compatibility(is_backward_compatible=True) +@dataclass +class GraphInfo: + """ + GraphInfo is a dataclass for MetaInfo, which measures + the execution memory cost and FLOPs with `MetaTensor`. + The dataflow analysis is conducted on a single node of the FX graph. + ============================================================================ + ------------------------------- + | Node | + [fwd_in] are ---> | [fwd_in] [bwd_out] | <----- [bwd_out] is marks the memory for `grad_out`. + placeholders saved for | | \__________ | | + backward. | | \ | | + | [fwd_tmp] ------> [bwd_tmp] | <----- + | | \_________ | | [bwd_tmp] marks the peak memory + | / \ \ | | in backward pass. + [x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <----- + in [fwd_tmp] because | | \_____ | | + it is not saved for | | \ | | + backward. | [fwd_out] \ | | <----- [fwd_out] is [fwd_in] for the next node. + ------------------------------- + ============================================================================ + Attributes: + fwd_flop (int): The forward FLOPs of a certain node. + fwd_time (float): The real forward time (s) of a certain node. + bwd_flop (int): The backward FLOPs of a certain node. + bwd_time (float): The real backward time (s) of a certain node. + save_fwd_in (bool): The decision variable of whether to save the fwd_mem_out of parent nodes. + fwd_in (List): See the above illustration. + fwd_tmp (List): See the above illustration. + fwd_out (List): See the above illustration. + fwd_mem_tmp (int): See the above illustration. + fwd_mem_out (int): See the above illustration. + bwd_mem_tmp (int): See the above illustration. + bwd_mem_out (int): See the above illustration. + """ + + # TODO(super-dainiu): removed redundant items, currently all of them are necessary for development + + fwd_flop: int = 0 + fwd_time: float = 0.0 + bwd_flop: int = 0 + bwd_time: float = 0.0 + save_fwd_in: bool = False + fwd_in: List = field(default_factory=list) + fwd_tmp: List = field(default_factory=list) + fwd_out: List = field(default_factory=list) + fwd_mem_tmp: int = 0 + fwd_mem_out: int = 0 + bwd_mem_tmp: int = 0 + bwd_mem_out: int = 0 + + +def is_phase(n: Node, phase: Phase) -> bool: + assert 'phase' in n.meta, f'Node meta of {n} has no key `phase`!' + return n.meta['phase'] == phase + + +@compatibility(is_backward_compatible=False) +def autograd_graph_analysis(graph: Graph) -> GraphInfo: + """Analyze the autograd node dependencies and find out the memory usage. + Basically the input graph should have all nodes marked for keyword `phase`. + Nodes should have attribute `out` indicating the output of each node. + ============================================================================ + Placeholder ----> p o <---- We need to keep track of grad out + |\________ | + ↓ ↘| + f --------> b + |\ \_____ ↑ + | \ ↘ / + f f ----> b <---- Not every forward result needs to be saved for backward + | \____ ↑ + ↘ ↘| + f ----> b <---- Backward can be freed as soon as it is required no more. + ↘ ↗ + l + ============================================================================= + Args: + graph (Graph): The autograd graph with nodes marked for keyword `phase`. + + Returns: + graph_info (GraphInfo): Meta information for the dataflow. + """ + + def _peak_memory(deps: Dict[Node, int]): + peak_mem = 0 + for k, v in deps.items(): + if v > 0 and is_phase(k, Phase.BACKWARD) and not all(map(is_inplace, k.users)) and not is_inplace(k): + peak_mem += activation_size(k.meta['saved_tensor']) + if v <= float('-inf') and is_phase(k, Phase.FORWARD): + peak_mem -= activation_size(k.meta['saved_tensor']) + return peak_mem + + # deps is used to track all the memory dependencies of the graph. + deps = {} + graph_info = GraphInfo() + + for n in graph.nodes: + n: Node + deps[n] = len(n.users) + # A forward tensor who is marked `save` but is also + # an input to `Phase.FORWARD` should be saved during forward. + # If the tensor is a placeholder, then it belongs to `fwd_mem_in`. + # Any `fwd_mem_in` should be kept in memory even this function + # is checkpointed. + # Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint + # the node, `fwd_mem_tmp` can be freed. + if is_phase(n, Phase.PLACEHOLDER): + graph_info.fwd_in += n.meta['saved_tensor'] + if is_phase(n, Phase.FORWARD): + graph_info.fwd_tmp += n.meta['saved_tensor'] + elif is_phase(n, Phase.BACKWARD): + if len(n.users): + graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps)) + else: + # TODO: some of the bwd_mem_out might be model parameters. + # basically a backward node without user is a `grad_out` node + graph_info.bwd_mem_out += activation_size(n.meta['saved_tensor']) + for input_n in n.all_input_nodes: + if input_n in deps: + deps[input_n] -= 1 + if deps[input_n] <= 0: + deps[input_n] = float('-inf') + return graph_info diff --git a/colossalai/fx/profiler/experimental/__init__.py b/colossalai/fx/profiler/experimental/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a5387981e1921497d27c2acee9ebf4e310c3add7 --- /dev/null +++ b/colossalai/fx/profiler/experimental/__init__.py @@ -0,0 +1,5 @@ +from .profiler import profile_function, profile_method, profile_module +from .profiler_function import * +from .profiler_module import * +from .registry import meta_profiler_function, meta_profiler_module +from .shard_utils import calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp diff --git a/colossalai/fx/profiler/experimental/constants.py b/colossalai/fx/profiler/experimental/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..57ff3fd91299b5bb8938125bf2d3243c9a9c4c2b --- /dev/null +++ b/colossalai/fx/profiler/experimental/constants.py @@ -0,0 +1,44 @@ +from operator import add, floordiv, getitem, mul, neg, pos, setitem, sub + +import torch + +__all__ = ['INPLACE_OPS', 'INPLACE_METHOD', 'NON_INPLACE_METHOD'] + +# TODO fill out the inplace ops +INPLACE_OPS = [ + add, + sub, + mul, + floordiv, + neg, + pos, + getitem, + setitem, + getattr, + torch.Tensor.cpu, +] + +# TODO: list all call_methods that are inplace here +INPLACE_METHOD = [ + 'transpose', + 'permute', + # TODO: reshape may return a copy of the data if the data is not contiguous + 'reshape', + 'dim', + 'flatten', + 'size', + 'view', + 'unsqueeze', + 'to', + 'type', + 'flatten', +] + +# TODO: list all call_methods that are not inplace here +NON_INPLACE_METHOD = [ + 'chunk', + 'contiguous', + 'expand', + 'mean', + 'split', +] diff --git a/colossalai/fx/profiler/experimental/profiler.py b/colossalai/fx/profiler/experimental/profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..5c545260e72b723bfa54beacfb20def3e758413f --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler.py @@ -0,0 +1,172 @@ +from dataclasses import dataclass +from typing import Any, Callable, Dict, Tuple + +import torch +from torch.fx.node import Argument, Target + +from ..._compatibility import compatibility +from ..memory_utils import activation_size +from .constants import INPLACE_METHOD, INPLACE_OPS, NON_INPLACE_METHOD +from .registry import meta_profiler_function, meta_profiler_module + +__all__ = ['profile_function', 'profile_module', 'profile_method'] + + +# this is for compatibility use +@compatibility(is_backward_compatible=True) +@dataclass +class GraphInfo: + """ + GraphInfo is a dataclass for MetaInfo, which measures + the execution memory cost and FLOPs with `MetaTensor`. + The dataflow analysis is conducted on a single node of the FX graph. + ============================================================================ + ------------------------------- + | Node | + [fwd_in] are ---> | [fwd_in] [bwd_out] | <----- [bwd_out] is marks the memory for `grad_out` + placeholders saved for | | \__________ | | + backward. | | \ | | + | [fwd_tmp] ------> [bwd_tmp] | <----- + | | \_________ | | [bwd_tmp] marks the peak memory + | / \ \ | | in backward pass. + [x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <----- + in [fwd_tmp] because | | | \_____ | | + it is not saved for | | | \ | | + backward. ------------------------------- + ============================================================================ + Attributes: + fwd_flop (int): The forward FLOPs of a certain node + bwd_flop (int): The backward FLOPs of a certain node. + fwd_mem_in (int): See the above illustration. + fwd_mem_tmp (int): See the above illustration. + bwd_mem_tmp (int): See the above illustration. + bwd_mem_out (int): See the above illustration. + """ + fwd_flop: int = 0 + bwd_flop: int = 0 + fwd_mem_in: int = 0 + fwd_mem_tmp: int = 0 + bwd_mem_tmp: int = 0 + bwd_mem_out: int = 0 + + +CALL_FUNCTION_MSG = \ +""" +Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n +from colossalai.fx.profiler.experimental import meta_profiler_function +@meta_profiler_function.register(YOUR_FUNCTION) +def profile_YOUR_FUNCTION(input: torch.Tensor, *args) -> Tuple[int, int]: + flops = ... + macs = ... + return flops, macs +""" +CALL_METHOD_MSG = 'Please check if {} is an inplace method. If so, add target to INPLACE_METHOD={}. Otherwise, add target to NON_INPLACE_METHOD={}' +CALL_MODULE_MSG = \ +""" +Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n +from colossalai.fx.profiler.experimental import meta_profiler_module +@meta_profiler_module.register(YOUR_MODULE) +def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]: + flops = ... + macs = ... + return flops, macs +""" + + +@compatibility(is_backward_compatible=True) +def profile_function(target: 'Target') -> Callable: + """ + Wrap a `call_function` node or `torch.nn.functional` in order to + record the memory cost and FLOPs of the execution. + Unfortunately, backward memory cost and FLOPs are estimated results. + + Warnings: + You may only use tensors with `device=meta` for this wrapped function. + Only original `torch.nn.functional` are available. + + Examples: + >>> input = torch.rand(100, 100, 100, 100, device='meta') + >>> func = torch.nn.functional.relu + >>> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_function(func)(input, inplace=False) + """ + + def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: + assert meta_profiler_function.has(target) or meta_profiler_function.has( + target.__name__), CALL_FUNCTION_MSG.format(target) + + fwd_tmp = 0 + fwd_out = 0 + out = func(*args, **kwargs) + if target not in INPLACE_OPS and not kwargs.get('inplace', False): + fwd_out = activation_size(out) + if meta_profiler_function.has(target): + profiler = meta_profiler_function.get(target) + else: + profiler = meta_profiler_function.get(target.__name__) + fwd_flop, _ = profiler(*args, **kwargs) + return out, GraphInfo(fwd_flop, fwd_flop * 2, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) + + f.__name__ = target.__name__ + func = target + return f + + +@compatibility(is_backward_compatible=True) +def profile_method(target: 'Target') -> Callable: + """ + Wrap a `call_method` node + record the memory cost and FLOPs of the execution. + + Warnings: + This is not fully implemented and you may follow the error message to debug. + """ + + def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: + # args[0] is the `self` object for this method call + self_obj, *args_tail = args + + # execute the method and return the result + assert isinstance(target, str), f'{target} instance is not str.' + + out = getattr(self_obj, target)(*args_tail, **kwargs) + assert target in INPLACE_METHOD + NON_INPLACE_METHOD, CALL_METHOD_MSG.format( + target, INPLACE_METHOD, NON_INPLACE_METHOD) + # call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs. + fwd_tmp = 0 if target in INPLACE_METHOD else activation_size(out) + fwd_out = 0 if target not in INPLACE_METHOD else activation_size(out) + return out, GraphInfo(0, 0, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) + + return f + + +@compatibility(is_backward_compatible=True) +def profile_module(module: torch.nn.Module) -> Callable: + """ + Wrap a `call_module` node or `torch.nn` in order to + record the memory cost and FLOPs of the execution. + + Warnings: + You may only use tensors with `device=meta` for this wrapped function. + Only original `torch.nn` are available. + + Example: + >>> input = torch.rand(4, 3, 224, 224, device='meta') + >>> mod = torch.nn.Conv2d(3, 128, 3) + >>> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_module(mod)(input) + """ + + def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: + assert meta_profiler_module.has(type(module)), CALL_MODULE_MSG.format(type(module)) + + fwd_tmp = 0 + fwd_out = 0 + out = func(*args, **kwargs) + if getattr(module, 'inplace', False): + fwd_out = activation_size(out) + profiler = meta_profiler_module.get(type(module)) + fwd_flop, _ = profiler(module, *args, **kwargs) + return out, GraphInfo(fwd_flop, fwd_flop * 2, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) + + f.__name__ = module.__class__.__name__ + func = module.forward + return f diff --git a/colossalai/fx/profiler/experimental/profiler_function/__init__.py b/colossalai/fx/profiler/experimental/profiler_function/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bf77edba859ecc568a5010287b8797fc31bb6701 --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler_function/__init__.py @@ -0,0 +1,8 @@ +from .activation_function import * +from .arithmetic import * +from .embedding import * +from .linear import * +from .normalization import * +from .pooling import * +from .python_ops import * +from .torch_ops import * diff --git a/colossalai/fx/profiler/experimental/profiler_function/activation_function.py b/colossalai/fx/profiler/experimental/profiler_function/activation_function.py new file mode 100644 index 0000000000000000000000000000000000000000..a43aef063e197de12c23fc5a81fb13e8183eaae9 --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler_function/activation_function.py @@ -0,0 +1,33 @@ +from typing import Tuple +import torch +from ..registry import meta_profiler_function + +# TODO: different activation has different FLOPs count, currently unused. +_multiplier = { + torch.nn.functional.relu: 1, + torch.nn.functional.prelu: 4, + torch.nn.functional.sigmoid: 4, + torch.nn.functional.tanh: 5, + torch.nn.functional.leaky_relu: 3, + torch.nn.functional.elu: 4, + torch.nn.functional.relu6: 2, + torch.nn.functional.gelu: 9, + torch.nn.functional.hardswish: 5, + torch.nn.functional.hardsigmoid: 4, +} + + +@meta_profiler_function.register(torch.nn.functional.leaky_relu) +@meta_profiler_function.register(torch.nn.functional.elu) +@meta_profiler_function.register(torch.nn.functional.gelu) +@meta_profiler_function.register(torch.nn.functional.relu6) +@meta_profiler_function.register(torch.nn.functional.prelu) +@meta_profiler_function.register(torch.nn.functional.relu) +@meta_profiler_function.register(torch.nn.functional.sigmoid) +@meta_profiler_function.register(torch.nn.functional.tanh) +@meta_profiler_function.register(torch.nn.functional.hardswish) +@meta_profiler_function.register(torch.nn.functional.hardsigmoid) +def torch_nn_func_non_linear_act(input: torch.Tensor, inplace: bool = False) -> Tuple[int, int]: + flops = input.numel() + macs = 0 + return flops, macs diff --git a/colossalai/fx/profiler/experimental/profiler_function/arithmetic.py b/colossalai/fx/profiler/experimental/profiler_function/arithmetic.py new file mode 100644 index 0000000000000000000000000000000000000000..8d1c8a8c6877fc12a7ad47b7b0103c309b8ed597 --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler_function/arithmetic.py @@ -0,0 +1,90 @@ +# Copyright (c) Microsoft Corporation. + +# Licensed under the MIT License. +import operator +from functools import reduce +from typing import Any, Optional, Tuple, Union + +import torch + +from ..registry import meta_profiler_function + + +def _elementwise_flops_compute(input, other): + # copied from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L763 + if not torch.is_tensor(input): + if torch.is_tensor(other): + return reduce(operator.mul, other.shape), 0 + else: + return 1, 0 + elif not torch.is_tensor(other): + return reduce(operator.mul, input.shape), 0 + else: + dim_input = len(input.shape) + dim_other = len(other.shape) + max_dim = max(dim_input, dim_other) + + final_shape = [] + for i in range(max_dim): + in_i = input.shape[i] if i < dim_input else 1 + ot_i = other.shape[i] if i < dim_other else 1 + if in_i > ot_i: + final_shape.append(in_i) + else: + final_shape.append(ot_i) + flops = reduce(operator.mul, final_shape) + return flops, 0 + + +@meta_profiler_function.register(torch.add) +@meta_profiler_function.register(torch.eq) +@meta_profiler_function.register(torch.sub) +@meta_profiler_function.register(torch.mul) +@meta_profiler_function.register(torch.floor_divide) +@meta_profiler_function.register('add') # for built-in op + +@meta_profiler_function.register('iadd') # for built-in op += +@meta_profiler_function.register('eq') # for built-in op = +@meta_profiler_function.register('sub') # for built-in op - +@meta_profiler_function.register('isub') # for built-in op -= +@meta_profiler_function.register('mul') # for built-in op * +@meta_profiler_function.register('imul') # for built-in op *= +@meta_profiler_function.register('floordiv') # for built-in op // +@meta_profiler_function.register('ifloordiv') # for built-in op //= +def torch_add_like_ops(input: Any, other: Any, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]: + return _elementwise_flops_compute(input, other) + + +@meta_profiler_function.register(torch.abs) +def torch_elementwise_op(input: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]: + flops = input.numel() + macs = 0 + return flops, macs + + +@meta_profiler_function.register(torch.matmul) +@meta_profiler_function.register('matmul') # for built-in op @ +@meta_profiler_function.register(torch.Tensor.matmul) +def torch_matmul(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]: + macs = reduce(operator.mul, input.shape) * other.shape[-1] + flops = 2 * macs + return flops, macs + + +@meta_profiler_function.register(torch.bmm) +def torch_bmm(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]: + macs = reduce(operator.mul, input.shape) * other.shape[-1] + flops = 2 * macs + return flops, macs + + +@meta_profiler_function.register(torch.var_mean) +def torch_var_mean(input: torch.Tensor, + dim: Union[int, Tuple[int, ...]], + unbiased: Optional[bool] = True, + keepdim: Optional[bool] = False, + *, + out: Optional[torch.Tensor] = None) -> Tuple[int, int]: + assert out is None, 'saving to out is not supported yet' + flops = input.numel() * 3 + macs = 0 + return flops, macs diff --git a/colossalai/fx/profiler/experimental/profiler_function/embedding.py b/colossalai/fx/profiler/experimental/profiler_function/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..d6e43d781b8b64ab78cf3299daba3df1d17a5420 --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler_function/embedding.py @@ -0,0 +1,19 @@ +import torch +from typing import Optional +from ..registry import meta_profiler_function + + +@meta_profiler_function.register(torch.nn.functional.embedding) +def torch_nn_functional_embedding( + input: torch.Tensor, + weight: torch.Tensor, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, +) -> torch.Tensor: + # F.embedding is a dictionary lookup, so technically it has 0 FLOPs. (https://discuss.pytorch.org/t/correct-way-to-calculate-flops-in-model/67198/6) + flops = 0 + macs = 0 + return flops, macs diff --git a/colossalai/fx/profiler/experimental/profiler_function/linear.py b/colossalai/fx/profiler/experimental/profiler_function/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..01fe4c87137083db2458c560a88cc6faa0af377e --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler_function/linear.py @@ -0,0 +1,13 @@ +from typing import Tuple +import torch +from ..registry import meta_profiler_function + + +@meta_profiler_function.register(torch.nn.functional.linear) +def torch_nn_linear(input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor = None) -> Tuple[int, int]: + out_features = weight.shape[0] + macs = torch.numel(input) * out_features + flops = 2 * macs + if bias is not None: + flops += bias.numel() + return flops, macs diff --git a/colossalai/fx/profiler/experimental/profiler_function/normalization.py b/colossalai/fx/profiler/experimental/profiler_function/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..c4ea508d70f80f33bbc8ae354e9743a4939d5e8c --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler_function/normalization.py @@ -0,0 +1,66 @@ +from typing import List, Optional, Tuple +import torch +from ..registry import meta_profiler_function + + +@meta_profiler_function.register(torch.nn.functional.instance_norm) +def torch_nn_func_instancenorm( + input: torch.Tensor, + running_mean: Optional[torch.Tensor] = None, + running_var: Optional[torch.Tensor] = None, + weight: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + use_input_stats: bool = True, + momentum: float = 0.1, + eps: float = 1e-5, +): + has_affine = weight is not None + flops = input.numel() * (5 if has_affine else 4) + macs = 0 + return flops, macs + + +@meta_profiler_function.register(torch.nn.functional.group_norm) +def torch_nn_func_groupnorm(input: torch.Tensor, + num_groups: int, + weight: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + eps: float = 1e-5) -> Tuple[int, int]: + has_affine = weight is not None + flops = input.numel() * (5 if has_affine else 4) + macs = 0 + return flops, macs + + +@meta_profiler_function.register(torch.nn.functional.layer_norm) +def torch_nn_func_layernorm( + input: torch.Tensor, + normalized_shape: List[int], + weight: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + eps: float = 1e-5, +) -> Tuple[int, int]: + has_affine = weight is not None + flops = input.numel() * (5 if has_affine else 4) + macs = 0 + return flops, macs + + +@meta_profiler_function.register(torch.nn.functional.batch_norm) +def torch_nn_func_batchnorm( + input: torch.Tensor, + running_mean: Optional[torch.Tensor], + running_var: Optional[torch.Tensor], + weight: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + training: bool = False, + momentum: float = 0.1, + eps: float = 1e-5, +) -> Tuple[int, int]: + has_affine = weight is not None + if training: + flops = input.numel() * (2 if has_affine else 1) + else: + flops = input.numel() * (5 if has_affine else 4) + macs = 0 + return flops, macs diff --git a/colossalai/fx/profiler/experimental/profiler_function/pooling.py b/colossalai/fx/profiler/experimental/profiler_function/pooling.py new file mode 100644 index 0000000000000000000000000000000000000000..a639f5ee83c1f4d2b75a3c120ea6ae3884fc422f --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler_function/pooling.py @@ -0,0 +1,22 @@ +from typing import Tuple, Union +import torch +from ..registry import meta_profiler_function + + +@meta_profiler_function.register(torch.nn.functional.avg_pool1d) +@meta_profiler_function.register(torch.nn.functional.avg_pool2d) +@meta_profiler_function.register(torch.nn.functional.avg_pool3d) +@meta_profiler_function.register(torch.nn.functional.max_pool1d) +@meta_profiler_function.register(torch.nn.functional.max_pool2d) +@meta_profiler_function.register(torch.nn.functional.max_pool3d) +@meta_profiler_function.register(torch.nn.functional.adaptive_avg_pool1d) +@meta_profiler_function.register(torch.nn.functional.adaptive_avg_pool2d) +@meta_profiler_function.register(torch.nn.functional.adaptive_avg_pool3d) +@meta_profiler_function.register(torch.nn.functional.adaptive_max_pool1d) +@meta_profiler_function.register(torch.nn.functional.adaptive_max_pool2d) +@meta_profiler_function.register(torch.nn.functional.adaptive_max_pool3d) +def torch_nn_func_pooling(input: torch.Tensor, *args, **kwargs) -> Tuple[int, int]: + # all pooling could be considered as going over each input element only once (https://stackoverflow.com/a/67301217) + flops = input.numel() + macs = 0 + return flops, macs diff --git a/colossalai/fx/profiler/experimental/profiler_function/python_ops.py b/colossalai/fx/profiler/experimental/profiler_function/python_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..1e8561206ba0e7a874b202a31dd13a040533d1db --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler_function/python_ops.py @@ -0,0 +1,18 @@ +import operator +from typing import Any, Tuple +import torch +from ..registry import meta_profiler_function + + +@meta_profiler_function.register(operator.getitem) +def operator_getitem(a: Any, b: Any) -> Tuple[int, int]: + flops = 0 + macs = 0 + return flops, macs + + +@meta_profiler_function.register(getattr) +def python_getattr(a: Any, b: Any) -> Tuple[int, int]: + flops = 0 + macs = 0 + return flops, macs diff --git a/colossalai/fx/profiler/experimental/profiler_function/torch_ops.py b/colossalai/fx/profiler/experimental/profiler_function/torch_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..abdd7ad565ba237d7d6eab9e3c9b77d7afb10abf --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler_function/torch_ops.py @@ -0,0 +1,60 @@ +from functools import reduce +import operator +from typing import Any, Optional, Tuple +import torch +from ..registry import meta_profiler_function + + +@meta_profiler_function.register(torch.arange) +@meta_profiler_function.register(torch.finfo) +@meta_profiler_function.register(torch.permute) +@meta_profiler_function.register(torch.Tensor.permute) +@meta_profiler_function.register(torch.Tensor.repeat) +@meta_profiler_function.register(torch.index_select) +@meta_profiler_function.register(torch.Tensor.index_select) +@meta_profiler_function.register(torch.squeeze) +@meta_profiler_function.register(torch.Tensor.squeeze) +@meta_profiler_function.register(torch.unsqueeze) +@meta_profiler_function.register(torch.Tensor.unsqueeze) +@meta_profiler_function.register(torch.cat) +@meta_profiler_function.register(torch.concat) +@meta_profiler_function.register(torch.repeat_interleave) +@meta_profiler_function.register(torch.Tensor.repeat_interleave) +@meta_profiler_function.register(torch.flatten) +@meta_profiler_function.register(torch.Tensor.flatten) +@meta_profiler_function.register(torch.roll) +@meta_profiler_function.register(torch.full) +@meta_profiler_function.register(torch.Tensor.cpu) +@meta_profiler_function.register(torch.Tensor.cuda) +@meta_profiler_function.register(torch._assert) +def torch_zero_flops_op(*args, **kwargs) -> Tuple[int, int]: + flops = 0 + macs = 0 + return flops, macs + + +@meta_profiler_function.register(torch.where) +def torch_where(condition: torch.Tensor, x: Any, y: Any) -> Tuple[int, int]: + # torch.where returns the broadcasted tensor of condition, x, and y, + # so hack it by using addition + flops = condition.numel() + macs = 0 + return flops, macs + + +@meta_profiler_function.register(torch.max) +def torch_max(input: torch.Tensor, + dim: int = None, + keepdim: bool = False, + *, + out: Optional[torch.Tensor] = None) -> Tuple[int, int]: + macs = 0 + assert out is None, 'assigning value to out is not supported yet' + if dim is not None: + shape = list(input.shape) + shape.pop(int(dim)) + flops = reduce(operator.mul, shape), macs + return flops, macs + else: + flops = input.numel() + return flops, macs diff --git a/colossalai/fx/profiler/experimental/profiler_module/__init__.py b/colossalai/fx/profiler/experimental/profiler_module/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e4fe646f3695e63d285b9a32530cd70f10187f34 --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler_module/__init__.py @@ -0,0 +1,10 @@ +from .activation_function import * +from .attention import * +from .convolution import * +from .dropout import * +from .embedding import * +from .linear import * +from .normalization import * +from .pooling import * +from .rnn import * +from .torch_op import * diff --git a/colossalai/fx/profiler/experimental/profiler_module/activation_function.py b/colossalai/fx/profiler/experimental/profiler_module/activation_function.py new file mode 100644 index 0000000000000000000000000000000000000000..2ebf514ad2699cc4e71741b9c3e143cedcb63041 --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler_module/activation_function.py @@ -0,0 +1,33 @@ +from typing import Tuple +import torch +from ..registry import meta_profiler_module + +# TODO: different activation has different FLOPs count, currently unused. +_multiplier = { + torch.nn.ReLU: 1, + torch.nn.PReLU: 4, + torch.nn.Sigmoid: 4, + torch.nn.Tanh: 5, + torch.nn.LeakyReLU: 3, + torch.nn.ELU: 4, + torch.nn.ReLU6: 2, + torch.nn.GELU: 9, + torch.nn.Hardswish: 5, + torch.nn.Hardsigmoid: 4, +} + + +@meta_profiler_module.register(torch.nn.ELU) +@meta_profiler_module.register(torch.nn.LeakyReLU) +@meta_profiler_module.register(torch.nn.ReLU) +@meta_profiler_module.register(torch.nn.GELU) +@meta_profiler_module.register(torch.nn.Sigmoid) +@meta_profiler_module.register(torch.nn.Tanh) +@meta_profiler_module.register(torch.nn.ReLU6) +@meta_profiler_module.register(torch.nn.PReLU) +@meta_profiler_module.register(torch.nn.Hardswish) +@meta_profiler_module.register(torch.nn.Hardsigmoid) +def torch_nn_non_linear_act(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]: + flops = input.numel() + macs = 0 + return flops, macs diff --git a/colossalai/fx/profiler/experimental/profiler_module/attention.py b/colossalai/fx/profiler/experimental/profiler_module/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..8daf74b232bf91d41933a2184e7b0c30d516d51a --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler_module/attention.py @@ -0,0 +1,81 @@ +from typing import Optional, Tuple +import torch +from ..registry import meta_profiler_module + + +# TODO: This is hard to compute memory cost +@meta_profiler_module.register(torch.nn.MultiheadAttention) +def torch_nn_msa(self: torch.nn.MultiheadAttention, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_padding_mask: Optional[torch.Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[torch.Tensor] = None, + average_attn_weights: bool = True) -> Tuple[int, int]: + if getattr(self, 'batch_first', False): + batch_size = query.shape[0] + len_idx = 1 + else: + batch_size = query.shape[1] + len_idx = 0 + dim_idx = 2 + + qdim = query.shape[dim_idx] + kdim = key.shape[dim_idx] + vdim = value.shape[dim_idx] + + qlen = query.shape[len_idx] + klen = key.shape[len_idx] + vlen = value.shape[len_idx] + + num_heads = self.num_heads + assert qdim == self.embed_dim + + if self.kdim is None: + assert kdim == qdim + if self.vdim is None: + assert vdim == qdim + + flops = 0 + macs = 0 + + # Q scaling + flops += qlen * qdim + + # Initial projections + flops += 2 * ((qlen * qdim * qdim) # QW + + (klen * kdim * kdim) # KW + + (vlen * vdim * vdim) # VW + ) + + macs += ((qlen * qdim * qdim) # QW + + (klen * kdim * kdim) # KW + + (vlen * vdim * vdim) # VW + ) + + if self.in_proj_bias is not None: + flops += (qlen + klen + vlen) * qdim + + # attention heads: scale, matmul, softmax, matmul + qk_head_dim = qdim // num_heads + v_head_dim = vdim // num_heads + + head_flops = ( + 2 * (qlen * klen * qk_head_dim) # QK^T + + (qlen * klen) # softmax + + 2 * (qlen * klen * v_head_dim) # AV + ) + head_macs = ((qlen * klen * qk_head_dim) # QK^T + + 2 * (qlen * klen * v_head_dim) # AV + ) + + flops += num_heads * head_flops + macs += num_heads * head_flops + + # final projection, bias is always enabled + flops += qlen * vdim * (vdim + 1) + + flops *= batch_size + macs *= batch_size + return flops, macs diff --git a/colossalai/fx/profiler/experimental/profiler_module/convolution.py b/colossalai/fx/profiler/experimental/profiler_module/convolution.py new file mode 100644 index 0000000000000000000000000000000000000000..a4c15b91e611d5d1398eeb38ec5c106c2652749b --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler_module/convolution.py @@ -0,0 +1,157 @@ +# Copyright (c) Microsoft Corporation. + +# Licensed under the MIT License. +import math +import operator +from functools import reduce +from typing import Tuple + +import torch + +from ..registry import meta_profiler_module + + +@meta_profiler_module.register(torch.nn.Conv1d) +def torch_nn_conv1d(self: torch.nn.Conv1d, input: torch.Tensor) -> Tuple[int, int]: + # the output shape is calculated using the formula stated + # at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + c_in, l_in = input.shape[-2:] + c_out = self.out_channels + l_out = math.floor((l_in + 2 * self.padding[0] - self.dilation[0] * + (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) + result_shape = input.shape[:-2] + ( + c_out, + l_out, + ) + macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups + num_elem = reduce(operator.mul, result_shape) + macs = macs_per_elem * num_elem + flops = 2 * macs + if self.bias is not None: + flops += num_elem + return flops, macs + + +@meta_profiler_module.register(torch.nn.Conv2d) +def torch_nn_conv2d(self: torch.nn.Conv2d, input: torch.Tensor) -> Tuple[int, int]: + # the output shape is calculated using the formula stated + # at https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html + c_in, h_in, w_in = input.shape[-3:] + c_out = self.out_channels + h_out = math.floor((h_in + 2 * self.padding[0] - self.dilation[0] * + (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) + w_out = math.floor((w_in + 2 * self.padding[1] - self.dilation[1] * + (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1) + result_shape = input.shape[:-3] + ( + c_out, + h_out, + w_out, + ) + macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups + num_elem = reduce(operator.mul, result_shape) + macs = macs_per_elem * num_elem + flops = 2 * macs + if self.bias is not None: + flops += num_elem + return flops, macs + + +@meta_profiler_module.register(torch.nn.Conv3d) +def torch_nn_conv3d(self: torch.nn.Conv3d, input: torch.Tensor) -> Tuple[int, int]: + # the output shape is calculated using the formula stated + # at https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html + c_in, d_in, h_in, w_in = input.shape[-4:] + c_out = self.out_channels + d_out = math.floor((d_in + 2 * self.padding[0] - self.dilation[0] * + (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) + h_out = math.floor((h_in + 2 * self.padding[1] - self.dilation[1] * + (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1) + w_out = math.floor((w_in + 2 * self.padding[2] - self.dilation[2] * + (self.kernel_size[2] - 1) - 1) / self.stride[2] + 1) + result_shape = input.shape[:-4] + ( + c_out, + d_out, + h_out, + w_out, + ) + macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups + num_elem = reduce(operator.mul, result_shape) + macs = macs_per_elem * num_elem + flops = 2 * macs + if self.bias is not None: + flops += num_elem + return flops, macs + + +@meta_profiler_module.register(torch.nn.ConvTranspose1d) +def torch_nn_convtranspose1d(self: torch.nn.ConvTranspose1d, input: torch.Tensor) -> Tuple[int, int]: + # the output shape is calculated using the formula stated + # at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html + c_in, l_in = input.shape[-2:] + c_out = self.out_channels + l_out = math.floor((l_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] * + (self.kernel_size[0] - 1) + self.output_padding[0] + 1) + result_shape = input.shape[:-2] + ( + c_out, + l_out, + ) + macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups + num_elem = reduce( + operator.mul, input.shape + ) # see https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L604 + macs = macs_per_elem * num_elem + flops = 2 * macs + if self.bias is not None: + flops += reduce(operator.mul, result_shape) + return flops, macs + + +@meta_profiler_module.register(torch.nn.ConvTranspose2d) +def torch_nn_convtranspose2d(self: torch.nn.ConvTranspose2d, input: torch.Tensor) -> Tuple[int, int]: + # the output shape is calculated using the formula stated + # at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html + c_in, h_in, w_in = input.shape[-3:] + c_out = self.out_channels + h_out = math.floor((h_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] * + (self.kernel_size[0] - 1) + self.output_padding[0] + 1) + w_out = math.floor((w_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] * + (self.kernel_size[1] - 1) + self.output_padding[1] + 1) + result_shape = input.shape[:-3] + ( + c_out, + h_out, + w_out, + ) + macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups + num_elem = reduce(operator.mul, input.shape) + macs = macs_per_elem * num_elem + flops = 2 * macs + if self.bias is not None: + flops += reduce(operator.mul, result_shape) + return flops, macs + + +@meta_profiler_module.register(torch.nn.ConvTranspose3d) +def torch_nn_convtranspose3d(self: torch.nn.ConvTranspose3d, input: torch.Tensor) -> Tuple[int, int]: + # the output shape is calculated using the formula stated + # at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html + c_in, d_in, h_in, w_in = input.shape[-4:] + c_out = self.out_channels + d_out = math.floor((d_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] * + (self.kernel_size[0] - 1) + self.output_padding[0] + 1) + h_out = math.floor((h_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] * + (self.kernel_size[1] - 1) + self.output_padding[1] + 1) + w_out = math.floor((w_in - 1) * self.stride[2] - 2 * self.padding[2] + self.dilation[2] * + (self.kernel_size[2] - 1) + self.output_padding[2] + 1) + result_shape = input.shape[:-4] + ( + c_out, + d_out, + h_out, + w_out, + ) + macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups + num_elem = reduce(operator.mul, input.shape) + macs = macs_per_elem * num_elem + flops = 2 * macs + if self.bias is not None: + flops += reduce(operator.mul, result_shape) + return flops, macs diff --git a/colossalai/fx/profiler/experimental/profiler_module/dropout.py b/colossalai/fx/profiler/experimental/profiler_module/dropout.py new file mode 100644 index 0000000000000000000000000000000000000000..417e0ed468637a5ce049ffa8137a73e5b266c971 --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler_module/dropout.py @@ -0,0 +1,11 @@ +from typing import Tuple +import torch +from ..registry import meta_profiler_module + + +@meta_profiler_module.register(torch.nn.Dropout) +def torch_nn_dropout(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]: + # nn.Embedding is a dictionary lookup, so technically it has 0 FLOPs. (https://discuss.pytorch.org/t/correct-way-to-calculate-flops-in-model/67198/6) + flops = 0 + macs = 0 + return flops, macs diff --git a/colossalai/fx/profiler/experimental/profiler_module/embedding.py b/colossalai/fx/profiler/experimental/profiler_module/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..a1ade5d3ad93845eec84ca543f0a9afaf5b7b7fd --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler_module/embedding.py @@ -0,0 +1,13 @@ +from typing import Tuple + +import torch + +from ..registry import meta_profiler_module + + +@meta_profiler_module.register(torch.nn.Embedding) +def torch_nn_embedding(self: torch.nn.Embedding, input: torch.Tensor) -> Tuple[int, int]: + # nn.Embedding is a dictionary lookup, so technically it has 0 FLOPs. (https://discuss.pytorch.org/t/correct-way-to-calculate-flops-in-model/67198/6) + flops = 0 + macs = 0 + return flops, macs diff --git a/colossalai/fx/profiler/experimental/profiler_module/linear.py b/colossalai/fx/profiler/experimental/profiler_module/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..e1ffb6f244d2ed7d5764339d61fdb46f71ae59a2 --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler_module/linear.py @@ -0,0 +1,14 @@ +from typing import Tuple +import torch +from ..registry import meta_profiler_module + + +@meta_profiler_module.register(torch.nn.Linear) +@meta_profiler_module.register(torch.nn.modules.linear.NonDynamicallyQuantizableLinear) +def torch_nn_linear(self: torch.nn.Linear, input: torch.Tensor) -> Tuple[int, int]: + out_features = self.weight.shape[0] + macs = input.numel() * out_features + flops = 2 * macs + if self.bias is not None: + flops += self.bias.numel() + return flops, macs diff --git a/colossalai/fx/profiler/experimental/profiler_module/normalization.py b/colossalai/fx/profiler/experimental/profiler_module/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..49e5e6fa5384b07412abe7ecc947d7963e88bd1a --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler_module/normalization.py @@ -0,0 +1,38 @@ +# Copyright (c) Microsoft Corporation. + +# Licensed under the MIT License. +from typing import Tuple, Union + +import torch + +from ..registry import meta_profiler_module + + +@meta_profiler_module.register(torch.nn.InstanceNorm1d) +@meta_profiler_module.register(torch.nn.InstanceNorm2d) +@meta_profiler_module.register(torch.nn.InstanceNorm3d) +@meta_profiler_module.register(torch.nn.LayerNorm) +@meta_profiler_module.register(torch.nn.GroupNorm) +@meta_profiler_module.register(torch.nn.BatchNorm1d) +@meta_profiler_module.register(torch.nn.BatchNorm2d) +@meta_profiler_module.register(torch.nn.BatchNorm3d) +def torch_nn_normalize(self: Union[torch.nn.LayerNorm, torch.nn.GroupNorm, torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d], input: torch.Tensor) -> Tuple[int, int]: + # adopted from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L615 + has_affine = self.weight is not None + if self.training: + flops = input.numel() * (2 if has_affine else 1) + else: + flops = input.numel() * (5 if has_affine else 4) + macs = 0 + return flops, macs + + +try: + import apex + meta_profiler_module.register(apex.normalization.FusedLayerNorm)(torch_nn_normalize) + meta_profiler_module.register(apex.normalization.FusedRMSNorm)(torch_nn_normalize) + meta_profiler_module.register(apex.normalization.MixedFusedLayerNorm)(torch_nn_normalize) + meta_profiler_module.register(apex.normalization.MixedFusedRMSNorm)(torch_nn_normalize) +except (ImportError, AttributeError): + pass diff --git a/colossalai/fx/profiler/experimental/profiler_module/pooling.py b/colossalai/fx/profiler/experimental/profiler_module/pooling.py new file mode 100644 index 0000000000000000000000000000000000000000..e429ac3eea28055f42af2ea8f84663a5a6fd2a83 --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler_module/pooling.py @@ -0,0 +1,22 @@ +from typing import Tuple +import torch +from ..registry import meta_profiler_module + + +@meta_profiler_module.register(torch.nn.AvgPool1d) +@meta_profiler_module.register(torch.nn.AvgPool2d) +@meta_profiler_module.register(torch.nn.AvgPool3d) +@meta_profiler_module.register(torch.nn.MaxPool1d) +@meta_profiler_module.register(torch.nn.MaxPool2d) +@meta_profiler_module.register(torch.nn.MaxPool3d) +@meta_profiler_module.register(torch.nn.AdaptiveAvgPool1d) +@meta_profiler_module.register(torch.nn.AdaptiveMaxPool1d) +@meta_profiler_module.register(torch.nn.AdaptiveAvgPool2d) +@meta_profiler_module.register(torch.nn.AdaptiveMaxPool2d) +@meta_profiler_module.register(torch.nn.AdaptiveAvgPool3d) +@meta_profiler_module.register(torch.nn.AdaptiveMaxPool3d) +def torch_nn_pooling(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]: + # all pooling could be considered as going over each input element only once (https://stackoverflow.com/a/67301217) + flops = input.numel() + macs = 0 + return flops, macs diff --git a/colossalai/fx/profiler/experimental/profiler_module/rnn.py b/colossalai/fx/profiler/experimental/profiler_module/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..6e733d6da9156db13b2bac35af63a52ad89ad5a3 --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler_module/rnn.py @@ -0,0 +1,75 @@ +from functools import reduce +import operator +import torch +from ..registry import meta_profiler_module +from typing import Optional, Tuple, Union + + +def _rnn_flops(flops: int, macs: int, module: torch.nn.RNNBase, w_ih: torch.Tensor, + w_hh: torch.Tensor) -> Tuple[int, int]: + # copied from https://github.com/sovrasov/flops-counter.pytorch/blob/master/ptflops/pytorch_ops.py + + # matrix matrix mult ih state and internal state + macs += reduce(operator.mul, w_ih.shape) + flops += 2 * reduce(operator.mul, w_ih.shape) + # matrix matrix mult hh state and internal state + macs += reduce(operator.mul, w_hh.shape) + flops += 2 * reduce(operator.mul, w_hh.shape) + if isinstance(module, (torch.nn.RNN, torch.nn.RNNCell)): + # add both operations + flops += module.hidden_size + elif isinstance(module, (torch.nn.GRU, torch.nn.GRUCell)): + # hadamard of r + flops += module.hidden_size + # adding operations from both states + flops += module.hidden_size * 3 + # last two hadamard product and add + flops += module.hidden_size * 3 + elif isinstance(module, (torch.nn.LSTM, torch.nn.LSTMCell)): + # adding operations from both states + flops += module.hidden_size * 4 + # two hadamard product and add for C state + flops += module.hidden_size * 3 + # final hadamard + flops += module.hidden_size * 3 + return flops, macs + + +@meta_profiler_module.register(torch.nn.LSTM) +@meta_profiler_module.register(torch.nn.GRU) +@meta_profiler_module.register(torch.nn.RNN) +def torch_nn_rnn(self: torch.nn.RNNBase, input: torch.Tensor, hx: Optional[torch.Tensor] = None) -> Tuple[int, int]: + flops = 0 + macs = 0 + for i in range(self.num_layers): + w_ih = self.__getattr__('weight_ih_l' + str(i)) + w_hh = self.__getattr__('weight_hh_l' + str(i)) + flops, macs = _rnn_flops(flops, macs, self, w_ih, w_hh) + if self.bias: + b_ih = self.__getattr__('bias_ih_l' + str(i)) + b_hh = self.__getattr__('bias_hh_l' + str(i)) + flops += reduce(operator.mul, b_ih) + reduce(operator.mul, b_hh) + flops *= reduce(operator.mul, input.shape[:2]) + macs *= reduce(operator.mul, input.shape[:2]) + if self.bidirectional: + flops *= 2 + macs *= 2 + return flops, macs + + +@meta_profiler_module.register(torch.nn.LSTMCell) +@meta_profiler_module.register(torch.nn.GRUCell) +@meta_profiler_module.register(torch.nn.RNNCell) +def torch_nn_rnn(self: torch.nn.RNNCellBase, input: torch.Tensor, hx: Optional[torch.Tensor] = None) -> Tuple[int, int]: + flops = 0 + macs = 0 + w_ih = self.__getattr__('weight_ih_l') + w_hh = self.__getattr__('weight_hh_l') + flops, macs = _rnn_flops(flops, macs, self, w_ih, w_hh) + if self.bias: + b_ih = self.__getattr__('bias_ih_l') + b_hh = self.__getattr__('bias_hh_l') + flops += reduce(operator.mul, b_ih) + reduce(operator.mul, b_hh) + flops *= input.shape[0] + macs *= input.shape[0] + return flops, macs diff --git a/colossalai/fx/profiler/experimental/profiler_module/torch_op.py b/colossalai/fx/profiler/experimental/profiler_module/torch_op.py new file mode 100644 index 0000000000000000000000000000000000000000..d3aed874eb10af76dc94e21b23c566178afe6264 --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler_module/torch_op.py @@ -0,0 +1,11 @@ +import operator +import torch +from ..registry import meta_profiler_module +from typing import Optional, Tuple, Union + + +@meta_profiler_module.register(torch.nn.Flatten) +def torch_nn_flatten(self: torch.nn.Flatten, input: torch.Tensor) -> Tuple[int, int]: + flops = 0 + macs = 0 + return flops, macs diff --git a/colossalai/fx/profiler/experimental/registry.py b/colossalai/fx/profiler/experimental/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..7d73bce321e43d7c1284bf4d78dbac2bc7c4abfc --- /dev/null +++ b/colossalai/fx/profiler/experimental/registry.py @@ -0,0 +1,25 @@ +class ProfilerRegistry: + + def __init__(self, name): + self.name = name + self.store = {} + + def register(self, source): + + def wrapper(func): + self.store[source] = func + return func + + return wrapper + + def get(self, source): + assert source in self.store + target = self.store[source] + return target + + def has(self, source): + return source in self.store + + +meta_profiler_function = ProfilerRegistry(name='patched_functions_for_meta_profile') +meta_profiler_module = ProfilerRegistry(name='patched_modules_for_meta_profile') diff --git a/colossalai/fx/profiler/experimental/shard_utils.py b/colossalai/fx/profiler/experimental/shard_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1e53ed0bf8ec657d8916d49c8c4b97f1d996010a --- /dev/null +++ b/colossalai/fx/profiler/experimental/shard_utils.py @@ -0,0 +1,48 @@ +# for PyTorch 1.11 compatibility uses +from typing import Dict, List, Tuple, Union + +import torch +from torch.fx import GraphModule, Node + +from ..._compatibility import compatibility + +__all__ = ["calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"] + + +@compatibility(is_backward_compatible=True) +def calculate_fwd_in(n: Node) -> bool: + """A helper function to calculate `fwd_in` + + Args: + n (Node): a node from the graph + + Returns: + save_fwd_in (bool): the result of `save_fwd_in` + """ + return n.meta['save_fwd_in'] + + +@compatibility(is_backward_compatible=True) +def calculate_fwd_tmp(n: Node) -> int: + """A helper function to calculate `fwd_tmp` + + Args: + n (Node): a node from the graph + + Returns: + fwd_tmp (int): the result of `fwd_tmp` + """ + return n.meta["fwd_mem_tmp"] + + +@compatibility(is_backward_compatible=True) +def calculate_fwd_out(n: Node) -> int: + """A helper function to calculate `fwd_out` + + Args: + n (Node): a node from the graph + + Returns: + fwd_out (int): the result of `fwd_out` + """ + return n.meta['fwd_mem_out'] diff --git a/colossalai/fx/profiler/memory_utils.py b/colossalai/fx/profiler/memory_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6ccbcb01cdc14045fbdd4906f1fc6f2a5ad728db --- /dev/null +++ b/colossalai/fx/profiler/memory_utils.py @@ -0,0 +1,71 @@ +from typing import Dict, List, Tuple, Union + +import torch +from torch.fx import GraphModule, Node + +from .._compatibility import compatibility, is_compatible_with_meta + +__all__ = ['activation_size', 'parameter_size', 'is_inplace'] + + +@compatibility(is_backward_compatible=True) +def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int: + """Calculate activation size of a node. + + Args: + activation (Union[torch.Tensor, Dict, List, Tuple, int]): The activation of a `torch.nn.Module` or `torch.nn.functional`. + + Returns: + int: The activation size, unit is byte. + """ + act_size = 0 + if isinstance(out, torch.Tensor): + if out.is_quantized: + act_size += out.numel() * torch._empty_affine_quantized([], dtype=out.dtype).element_size() + else: + act_size += out.numel() * torch.tensor([], dtype=out.dtype).element_size() + elif isinstance(out, dict): + value_list = [v for _, v in out.items()] + act_size += activation_size(value_list) + elif isinstance(out, tuple) or isinstance(out, list) or isinstance(out, set): + for element in out: + act_size += activation_size(element) + return act_size + + +@compatibility(is_backward_compatible=True) +def parameter_size(mod: torch.nn.Module) -> int: + """Calculate parameter size of a node. + + Args: + mod (torch.nn.Module): The target `torch.nn.Module`. + + Returns: + int: The parameter size, unit is byte. + """ + param_size = 0 + for param in mod.parameters(): + param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size() + return param_size + + +def is_inplace(n: Node): + """Get the inplace argument from torch.fx.Node + + Args: + node (Node): torch.fx.Node + + Returns: + bool: indicates whether this op is inplace + """ + inplace = False + if n.op == "call_function": + inplace = n.kwargs.get("inplace", False) + if is_compatible_with_meta(): + from .constants import ALIAS_ATEN + if n.target in ALIAS_ATEN: + inplace = True + elif n.op == "call_module": + inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False) + + return inplace diff --git a/colossalai/fx/profiler/opcount.py b/colossalai/fx/profiler/opcount.py new file mode 100644 index 0000000000000000000000000000000000000000..ba090a2ec51bd7d1d83a6dd5d75c877c0708577f --- /dev/null +++ b/colossalai/fx/profiler/opcount.py @@ -0,0 +1,374 @@ +# adopted from https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/jit_handles.py +# ideas from https://pastebin.com/AkvAyJBw + +import operator +from functools import partial, reduce +from numbers import Number +from typing import Any, Callable, List + +import torch +from packaging import version + +aten = torch.ops.aten + + +def matmul_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: + """ + Count flops for matmul. + """ + # Inputs should be a list of length 2. + # Inputs contains the shapes of two matrices. + input_shapes = [v.shape for v in inputs] + assert len(input_shapes) == 2, input_shapes + + # There are three cases: 1) gemm, 2) gemv, 3) dot + if all(len(shape) == 2 for shape in input_shapes): + # gemm + assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes + elif all(len(shape) == 1 for shape in input_shapes): + # dot + assert input_shapes[0][0] == input_shapes[1][0], input_shapes + + # expand shape + input_shapes[0] = torch.Size([1, input_shapes[0][0]]) + input_shapes[1] = torch.Size([input_shapes[1][0], 1]) + else: + # gemv + if len(input_shapes[0]) == 1: + assert input_shapes[0][0] == input_shapes[1][-2], input_shapes + input_shapes.reverse() + else: + assert input_shapes[1][0] == input_shapes[0][-1], input_shapes + + # expand the shape of the vector to [batch size, 1] + input_shapes[-1] = torch.Size([input_shapes[-1][-1], 1]) + flops = reduce(operator.mul, input_shapes[0]) * input_shapes[-1][-1] + return flops + + +def addmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: + """ + Count flops for fully connected layers. + """ + # Count flop for nn.Linear + # inputs is a list of length 3. + input_shapes = [v.shape for v in inputs[1:3]] + # input_shapes[0]: [batch size, input feature dimension] + # input_shapes[1]: [input feature dimension, output feature dimension] + assert len(input_shapes[0]) == 2, input_shapes[0] + assert len(input_shapes[1]) == 2, input_shapes[1] + batch_size, input_dim = input_shapes[0] + output_dim = input_shapes[1][1] + flops = batch_size * input_dim * output_dim + return flops + + +def linear_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: + """ + Count flops for the aten::linear operator. + """ + # Inputs is a list of length 3; unlike aten::addmm, it is the first + # two elements that are relevant. + input_shapes = [v.shape for v in inputs[0:2]] + # input_shapes[0]: [dim0, dim1, ..., input_feature_dim] + # input_shapes[1]: [output_feature_dim, input_feature_dim] + assert input_shapes[0][-1] == input_shapes[1][-1] + flops = reduce(operator.mul, input_shapes[0]) * input_shapes[1][0] + return flops + + +def bmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: + """ + Count flops for the bmm operation. + """ + # Inputs should be a list of length 2. + # Inputs contains the shapes of two tensor. + assert len(inputs) == 2, len(inputs) + input_shapes = [v.shape for v in inputs] + n, c, t = input_shapes[0] + d = input_shapes[-1][-1] + flops = n * c * t * d + return flops + + +def baddbmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: + """ + Count flops for the baddbmm(batch add and batch matmul) operation. + """ + # Inputs = [input, batch1, batch2] + # out = input + batch1 x batch2 + assert len(inputs) == 3, len(inputs) + n, c, t = inputs[1].shape + d = inputs[2].shape[-1] + flops = n * c * t * d + return flops + + +def conv_flop_count( + x_shape: List[int], + w_shape: List[int], + out_shape: List[int], + transposed: bool = False, +) -> Number: + """ + Count flops for convolution. Note only multiplication is + counted. Computation for addition and bias is ignored. + Flops for a transposed convolution are calculated as + flops = (x_shape[2:] * prod(w_shape) * batch_size). + Args: + x_shape (list(int)): The input shape before convolution. + w_shape (list(int)): The filter shape. + out_shape (list(int)): The output shape after convolution. + transposed (bool): is the convolution transposed + Returns: + int: the number of flops + """ + batch_size = x_shape[0] + conv_shape = (x_shape if transposed else out_shape)[2:] + flops = batch_size * reduce(operator.mul, w_shape) * reduce(operator.mul, conv_shape) + return flops + + +def conv_flop_jit(inputs: List[Any], outputs: List[Any]): + """ + Count flops for convolution. + """ + x, w = inputs[:2] + x_shape, w_shape, out_shape = (x.shape, w.shape, outputs[0].shape) + transposed = inputs[6] + + return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed) + + +def transpose_shape(shape): + return [shape[1], shape[0]] + list(shape[2:]) + + +def conv_backward_flop_jit(inputs: List[Any], outputs: List[Any]): + grad_out_shape, x_shape, w_shape = [i.shape for i in inputs[:3]] + output_mask = inputs[-1] + fwd_transposed = inputs[7] + flop_count = 0 + + if output_mask[0]: + grad_input_shape = outputs[0].shape + flop_count += conv_flop_count(grad_out_shape, w_shape, grad_input_shape, not fwd_transposed) + if output_mask[1]: + grad_weight_shape = outputs[1].shape + flop_count += conv_flop_count(transpose_shape(x_shape), grad_out_shape, grad_weight_shape, fwd_transposed) + + return flop_count + + +def norm_flop_counter(affine_arg_index: int, input_arg_index: int) -> Callable: + """ + Args: + affine_arg_index: index of the affine argument in inputs + """ + + def norm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: + """ + Count flops for norm layers. + """ + # Inputs[0] contains the shape of the input. + input_shape = inputs[input_arg_index].shape + + has_affine = inputs[affine_arg_index].shape is not None if hasattr(inputs[affine_arg_index], + 'shape') else inputs[affine_arg_index] + assert 2 <= len(input_shape) <= 5, input_shape + # 5 is just a rough estimate + flop = reduce(operator.mul, input_shape) * (5 if has_affine else 4) + return flop + + return norm_flop_jit + + +def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any], training: bool = None) -> Number: + if training is None: + training = inputs[-3] + assert isinstance(training, bool), "Signature of aten::batch_norm has changed!" + if training: + return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore + has_affine = inputs[1].shape is not None + input_shape = reduce(operator.mul, inputs[0].shape) + return input_shape * (2 if has_affine else 1) + + +def elementwise_flop_counter(input_scale: float = 1, output_scale: float = 0) -> Callable: + """ + Count flops by + input_tensor.numel() * input_scale + output_tensor.numel() * output_scale + Args: + input_scale: scale of the input tensor (first argument) + output_scale: scale of the output tensor (first element in outputs) + """ + + def elementwise_flop(inputs: List[Any], outputs: List[Any]) -> Number: + ret = 0 + if input_scale != 0: + shape = inputs[0].shape + ret += input_scale * reduce(operator.mul, shape) if shape else 0 + if output_scale != 0: + shape = outputs[0].shape + ret += output_scale * reduce(operator.mul, shape) if shape else 0 + return ret + + return elementwise_flop + + +def zero_flop_jit(*args): + """ + Count flops for zero flop layers. + """ + return 0 + + +if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse( + torch.__version__) < version.parse('2.0.0'): + flop_mapping = { + # gemm, gemv and dot + aten.mm.default: matmul_flop_jit, + aten.mv.default: matmul_flop_jit, + aten.dot.default: matmul_flop_jit, + aten.matmul.default: matmul_flop_jit, + aten.addmm.default: addmm_flop_jit, + aten.bmm.default: bmm_flop_jit, + aten.baddbmm.default: baddbmm_flop_jit, + + # convolution + aten.convolution.default: conv_flop_jit, + aten._convolution.default: conv_flop_jit, + aten.convolution_backward.default: conv_backward_flop_jit, + + # normalization + aten.native_batch_norm.default: batchnorm_flop_jit, + aten.native_batch_norm_backward.default: batchnorm_flop_jit, + aten.cudnn_batch_norm.default: batchnorm_flop_jit, + aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True), + aten.native_layer_norm.default: norm_flop_counter(2, 0), + aten.native_layer_norm_backward.default: norm_flop_counter(2, 0), + aten.native_group_norm.default: norm_flop_counter(2, 0), + aten.native_group_norm_backward.default: norm_flop_counter(2, 0), + + # pooling + aten.avg_pool1d.default: elementwise_flop_counter(1, 0), + aten.avg_pool2d.default: elementwise_flop_counter(1, 0), + aten.avg_pool2d_backward.default: elementwise_flop_counter(0, 1), + aten.avg_pool3d.default: elementwise_flop_counter(1, 0), + aten.avg_pool3d_backward.default: elementwise_flop_counter(0, 1), + aten.max_pool1d.default: elementwise_flop_counter(1, 0), + aten.max_pool2d.default: elementwise_flop_counter(1, 0), + aten.max_pool3d.default: elementwise_flop_counter(1, 0), + aten.max_pool1d_with_indices.default: elementwise_flop_counter(1, 0), + aten.max_pool2d_with_indices.default: elementwise_flop_counter(1, 0), + aten.max_pool2d_with_indices_backward.default: elementwise_flop_counter(0, 1), + aten.max_pool3d_with_indices.default: elementwise_flop_counter(1, 0), + aten.max_pool3d_with_indices_backward.default: elementwise_flop_counter(0, 1), + aten._adaptive_avg_pool2d.default: elementwise_flop_counter(1, 0), + aten._adaptive_avg_pool2d_backward.default: elementwise_flop_counter(0, 1), + aten._adaptive_avg_pool3d.default: elementwise_flop_counter(1, 0), + aten._adaptive_avg_pool3d_backward.default: elementwise_flop_counter(0, 1), + aten.embedding_dense_backward.default: elementwise_flop_counter(0, 1), + aten.embedding.default: elementwise_flop_counter(1, 0), + aten.upsample_nearest2d.vec: elementwise_flop_counter(0, 1), + aten.upsample_nearest2d_backward.vec: elementwise_flop_counter(0, 1), + } + + elementwise_flop_aten = [ + # basic op + aten.add.Tensor, + aten.add_.Tensor, + aten.div.Tensor, + aten.div_.Tensor, + aten.div.Scalar, + aten.div_.Scalar, + aten.mul.Tensor, + aten.mul.Scalar, + aten.mul_.Tensor, + aten.neg.default, + aten.pow.Tensor_Scalar, + aten.rsub.Scalar, + aten.sum.default, + aten.sum.dim_IntList, + aten.mean.dim, + aten.sub.Tensor, + aten.sub_.Tensor, + aten.exp.default, + aten.sin.default, + aten.cos.default, + + # activation op + aten.hardswish.default, + aten.hardswish_.default, + aten.hardswish_backward.default, + aten.hardtanh.default, + aten.hardtanh_.default, + aten.hardtanh_backward.default, + aten.hardsigmoid_backward.default, + aten.hardsigmoid.default, + aten.gelu.default, + aten.gelu_backward.default, + aten.silu.default, + aten.silu_.default, + aten.silu_backward.default, + aten.sigmoid.default, + aten.sigmoid_backward.default, + aten._softmax.default, + aten._softmax_backward_data.default, + aten.relu_.default, + aten.relu.default, + aten.tanh.default, + aten.tanh_backward.default, + aten.threshold_backward.default, + + # dropout + aten.native_dropout.default, + aten.native_dropout_backward.default, + ] + for op in elementwise_flop_aten: + flop_mapping[op] = elementwise_flop_counter(1, 0) + + # TODO: this will be removed in future + zero_flop_aten = [ + aten.as_strided.default, + aten.as_strided_.default, + aten.bernoulli_.float, + aten.cat.default, + aten.clone.default, + aten.copy_.default, + aten.detach.default, + aten.expand.default, + aten.empty_like.default, + aten.new_empty.default, + aten.new_empty_strided.default, + aten.ones_like.default, + aten._reshape_alias.default, + aten.select.int, + aten.select_backward.default, + aten.squeeze.dim, + aten.slice.Tensor, + aten.slice_backward.default, + aten.stack.default, + aten.split.Tensor, + aten.permute.default, + aten.t.default, + aten.transpose.int, + aten._to_copy.default, + aten.unsqueeze.default, + aten.unbind.int, + aten._unsafe_view.default, + aten.view.default, + aten.where.self, + aten.zero_.default, + aten.zeros_like.default, + aten.fill_.Scalar, + aten.stack.default + ] # yapf: disable + + for op in zero_flop_aten: + flop_mapping[op] = zero_flop_jit + +else: + flop_mapping = {} + elementwise_flop_aten = {} + zero_flop_aten = {} diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..c87cd4321d31c59ebe369a2903b679a439fb4f96 --- /dev/null +++ b/colossalai/fx/profiler/profiler.py @@ -0,0 +1,409 @@ +import time +from functools import partial +from typing import Any, Callable, Dict, Tuple + +import torch +from torch.fx import Graph, Node +from torch.fx.node import Argument, Target +from torch.nn.parameter import Parameter +from torch.utils._pytree import tree_map + +from .._compatibility import compatibility +from .constants import ALIAS_ATEN, OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS +from .dataflow import GraphInfo, Phase, autograd_graph_analysis, is_phase +from .memory_utils import activation_size, parameter_size +from .opcount import flop_mapping +from .tensor import MetaTensor + +__all__ = ['profile_function', 'profile_module', 'profile_method'] + +# super-dainiu: this cache should be global, otherwise it cannot +# track duplicated tensors between nodes +cache = set() + +# a global identifier for inplace ops +do_not_cache = False + + +def normalize_tuple(x): + if not isinstance(x, tuple): + return (x,) + return x + + +def is_autogradable(x): + return isinstance(x, torch.Tensor) and x.is_floating_point() + + +def detach_variables(x): + if isinstance(x, torch.Tensor): + requires_grad = x.requires_grad + x = x.detach() + x.requires_grad = requires_grad + + return x + + +@compatibility(is_backward_compatible=True) +def _profile_concrete(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]: + """Profile a Callable function with args and kwargs on concrete devices by https://github.com/Cypher30 + To profile the actual forward memory, we first run target in the context torch.no_grad() to get + the fwd_mem_out, then we run target with grad enable to found the extra memory stored in the memory + by memory allocated minus the fwd_mem_out. + To profile the actual backward memory, we first make dummy gradient for torch.autograd.backward, then + find the bwd_mem_tmp with memory peak during the process minus bwd_mem_out(it is actually equal to size + of args and kwargs). + We also add time stamps to profile the real forward and backward time. + + Args: + target (Callable): A Callable function + args (Any): Arguments + kwargs (Any): Arguments + + Returns: + Tuple[Tuple[Any, ...], GraphInfo]: Output for next node & memory cost and real forward and backward + time. + """ + + graphinfo = GraphInfo() + + # detach input from the graph + args = tree_map(detach_variables, args) + kwargs = tree_map(detach_variables, kwargs) + if isinstance(target, str): + # args[0] is the `self` object for this method call + self_obj, *args_tail = args + + # calculate fwd_mem_out + mem_stamp0 = torch.cuda.memory_allocated() + with torch.no_grad(): + out = getattr(self_obj, target)(*args_tail, **kwargs) + mem_stamp1 = torch.cuda.memory_allocated() + graphinfo.fwd_mem_out = mem_stamp1 - mem_stamp0 + del out + + # calculate fwd_mem_tmp & fwd_time + mem_stamp0 = torch.cuda.memory_allocated() + fwd_time0 = time.time() + out = getattr(self_obj, target)(*args_tail, **kwargs) + fwd_time1 = time.time() + graphinfo.fwd_time = fwd_time1 - fwd_time0 + mem_stamp1 = torch.cuda.memory_allocated() + graphinfo.fwd_mem_tmp = mem_stamp1 - mem_stamp0 - graphinfo.fwd_mem_out + + # calculate bwd_mem_tmp & bwd_time + grad_tensors = tree_map(lambda x: torch.ones_like(x) if isinstance(x, torch.Tensor) else None, out) + torch.cuda.reset_peak_memory_stats() + mem_stamp0 = torch.cuda.memory_allocated() + bwd_time0 = time.time() + torch.autograd.backward(out, grad_tensors=grad_tensors) + bwd_time1 = time.time() + graphinfo.bwd_time = bwd_time1 - bwd_time0 + mem_stamp1 = torch.cuda.max_memory_allocated() + + # calculate bwd memory stats + # NOTE: the module should add param to bwd_mem_out for bwd_mem_tmp calculation + graphinfo.bwd_mem_out = activation_size(args) + activation_size(kwargs) + graphinfo.bwd_mem_out += parameter_size(target.__self__) if hasattr(target.__self__, "parameters") else 0 + graphinfo.bwd_mem_tmp = mem_stamp1 - mem_stamp0 - graphinfo.bwd_mem_out + + else: + # calculate fwd_mem_out + mem_stamp0 = torch.cuda.memory_allocated() + with torch.no_grad(): + out = target(*args, **kwargs) + mem_stamp1 = torch.cuda.memory_allocated() + graphinfo.fwd_mem_out = mem_stamp1 - mem_stamp0 + del out + + # calculate fwd_mem_tmp & fwd_time + mem_stamp0 = torch.cuda.memory_allocated() + fwd_time0 = time.time() + out = target(*args, **kwargs) + fwd_time1 = time.time() + graphinfo.fwd_time = fwd_time1 - fwd_time0 + mem_stamp1 = torch.cuda.memory_allocated() + graphinfo.fwd_mem_tmp = mem_stamp1 - mem_stamp0 - graphinfo.fwd_mem_out + + # calculate bwd_mem_tmp & bwd_time + grad_tensors = tree_map(lambda x: torch.ones_like(x) if isinstance(x, torch.Tensor) else None, out) + torch.cuda.reset_peak_memory_stats() + mem_stamp0 = torch.cuda.memory_allocated() + bwd_time0 = time.time() + torch.autograd.backward(out, grad_tensors=grad_tensors) + bwd_time1 = time.time() + graphinfo.bwd_time = bwd_time1 - bwd_time0 + mem_stamp1 = torch.cuda.max_memory_allocated() + + # calculate bwd memory stats + # NOTE: the module should add param to bwd_mem_out for bwd_mem_tmp calculation + graphinfo.bwd_mem_out = activation_size(args) + activation_size(kwargs) + graphinfo.bwd_mem_out += parameter_size(target.__self__) if hasattr(target.__self__, "parameters") else 0 + graphinfo.bwd_mem_tmp = mem_stamp1 - mem_stamp0 - graphinfo.bwd_mem_out + + return tree_map(detach_variables, out), graphinfo + + +@compatibility(is_backward_compatible=False) +def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]: + """ + Profile a Callable function with args and kwargs on meta devices. + + Args: + target (Callable): A Callable function + args (Any): Argument + kwargs (Any): Argument + + Returns: + out (Tuple[Any, ...]): The argument value that was retrieved. + meta_info (GraphInfo): The memory cost and FLOPs estimated with `MetaTensor`. + """ + # This subgraph traces aten level ops inside one node. + subgraph = Graph() + + # `flop_count`` serves as a global dictionary to store results. + flop_count = { + Phase.FORWARD: 0, + Phase.BACKWARD: 0, + } + + # FlopTensor not only get the flop statistics of a single node, + # it also build a full autograd graph for this node. + # This makes sure we can analyze the dependencies of memory, and + # decide which forward intermediate results should be kept until + # backward is executed. + # Hopefully, this attempt will provide a better estimation of memory. + class FlopTensor(MetaTensor): + + _node: Node = None + + def __repr__(self): + if self.grad_fn: + return f"FlopTensor({self._tensor}, fake_device='{self.device}', size={tuple(self.shape)}, grad_fn={self.grad_fn})" + return f"FlopTensor({self._tensor}, fake_device='{self.device}', size={tuple(self.shape)}, requires_grad={self.requires_grad})" + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + args_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, args) + kwargs_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, kwargs) + node = subgraph.create_node('call_function', func, args_node, kwargs_node) + + out = super().__torch_dispatch__(func, types, args, kwargs) + + flop_count[phase] += flop_mapping[func](args, normalize_tuple(out)) + node.meta['phase'] = phase + + # super-dainiu: in `nn.MultiheadAttention` this weird thing occurs, + # i.e. `Phase.PLACEHOLDER` tensors are aliased and saved during + # `Phase.FORWARD` + if phase == Phase.FORWARD: + if all(map(partial(is_phase, phase=Phase.PLACEHOLDER), node.all_input_nodes)) and func in ALIAS_ATEN: + node.meta['phase'] = Phase.PLACEHOLDER + + # TODO(yby): specify `saved_tensors` for backward memory estimation + node.meta['saved_tensor'] = [] + if phase == Phase.BACKWARD: + node.meta['saved_tensor'] = normalize_tuple(out) + + def wrap(x): + if isinstance(x, MetaTensor): + x = FlopTensor(x) + x._node = node + return x + + out = tree_map(wrap, out) + return out + + def wrap(x): + if isinstance(x, torch.Tensor): + x = FlopTensor(x) + if is_autogradable(x): + x.requires_grad_(True) + x._node = subgraph.create_node('placeholder', + 'placeholder', (subgraph._root,), + name=subgraph._graph_namespace.create_name('input', x._tensor)) + x._node.meta['phase'] = Phase.PLACEHOLDER + x._node.meta['saved_tensor'] = [] + return x + + # Basically, we need to detach the args and kwargs from the outer graph. + args = tree_map(wrap, args) + kwargs = tree_map(wrap, kwargs) + + def pack(x): + global cache, do_not_cache + if isinstance(x, FlopTensor) and not x._tensor.data_ptr() in cache: + tensor = x._tensor.detach() + tensor.data_ptr = x._tensor.data_ptr + x._node.meta['saved_tensor'] += [tensor] + if not do_not_cache: + cache.add(x._tensor.data_ptr()) + return x + + def unpack(x): + return x + + # `phase` will mark the phase of autograd from outside scope. + phase = Phase.FORWARD + # mark saved tensors with saved_tensors_hooks + with torch.autograd.graph.saved_tensors_hooks(pack, unpack): + if isinstance(target, str): + # args[0] is the `self` object for this method call + self_obj, *args_tail = args + out = getattr(self_obj, target)(*args_tail, **kwargs) + else: + out = target(*args, **kwargs) + + # If the output is not a floating point `torch.Tensor` or it does not + # requires grad, then we should not run backward for this node. + if all(map(lambda x: is_autogradable(x) and x.requires_grad, normalize_tuple(out))): + grad_out = [torch.zeros_like(t) for t in normalize_tuple(out)] + phase = Phase.BACKWARD + torch.autograd.backward( + out, + grad_out, + ) + + graph_info = autograd_graph_analysis(subgraph) + graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Phase.FORWARD], flop_count[Phase.BACKWARD] + + def extract_tensor(x: Any): + if isinstance(x, MetaTensor): + tensor = x._tensor.detach() + tensor.data_ptr = x._tensor.data_ptr + return tensor + if not isinstance(x, torch.finfo): + return x + + graph_info.fwd_out = list(map(extract_tensor, normalize_tuple(out))) + + def unwrap(x): + return MetaTensor(x) if isinstance(x, torch.Tensor) else x + + return tree_map(unwrap, out), graph_info + + +@compatibility(is_backward_compatible=True) +def profile_function(target: 'Target', device: str = 'meta') -> Callable: + """ + Wrap a `call_function` node or `torch.nn.functional` in order to + record the memory cost and FLOPs of the execution. + + Warnings: + You may only use tensors with `device=meta` for this wrapped function. + Only original `torch.nn.functional` are available. + + Examples: + >>> input = torch.rand(100, 100, 100, 100, device='meta') + >>> func = torch.nn.functional.relu + >>> output, meta_info = profile_function(func)(input) + """ + + def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: + + # find the grad for parameter in args and kwargs + param_size = 0 + + def get_param_size(x): + nonlocal param_size + if isinstance(x, Parameter): + param_size += activation_size(x) + + tree_map(get_param_size, args) + tree_map(get_param_size, kwargs) + + # If there is an argument that this `call_function` is inplace, we should + # still run the profiling but discard some results regarding `target` + global do_not_cache + + inplace = kwargs.get('inplace', False) + if target in OUTPUT_SAVED_OPS: + do_not_cache = True + if inplace: + do_not_cache = True + kwargs['inplace'] = False + if device == 'meta': + out, meta = _profile_meta(func, *args, **kwargs) + else: + out, meta = _profile_concrete(func, *args, **kwargs) + if inplace: + kwargs['inplace'] = True + meta.bwd_mem_tmp = 0 + meta.bwd_mem_out = 0 + do_not_cache = False + + meta.bwd_mem_out -= param_size + return out, meta + + f.__name__ = target.__name__ + func = target + return f + + +@compatibility(is_backward_compatible=True) +def profile_method(target: 'Target', device: str = 'meta') -> Callable: + """ + Wrap a `call_method` node + record the memory cost and FLOPs of the execution. + """ + + def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: + # execute the method and return the result + assert isinstance(target, str), f'{target} instance is not str.' + if device == 'meta': + out, meta = _profile_meta(target, *args, **kwargs) + else: + out, meta = _profile_concrete(target, *args, **kwargs) + return out, meta + + return f + + +@compatibility(is_backward_compatible=True) +def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable: + """ + Wrap a `call_module` node or `torch.nn` in order to + record the memory cost and FLOPs of the execution. + + Warnings: + You may only use tensors with `device=meta` for this wrapped function. + Only original `torch.nn` are available. + + Example: + >>> input = torch.rand(4, 3, 224, 224, device='meta') + >>> mod = torch.nn.Conv2d(3, 128, 3) + >>> output, meta_info = profile_module(mod)(input) + """ + + def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: + + # calculate parameter size + param_size = parameter_size(module) + + # If there is an argument that this `call_module` is inplace, we should + # still run the profiling but discard some results regarding `module`. + global do_not_cache + + inplace = getattr(module, 'inplace', False) + if type(module) in OUTPUT_SAVED_MOD: + do_not_cache = True + if inplace: + do_not_cache = True + module.inplace = False + if device == 'meta': + out, meta = _profile_meta(func, *args, **kwargs) + else: + out, meta = _profile_concrete(func, *args, **kwargs) + if inplace: + module.inplace = True + meta.bwd_mem_tmp = 0 + meta.bwd_mem_out = 0 + do_not_cache = False + + # grad for param will not be counted + meta.bwd_mem_out -= param_size + return out, meta + + f.__name__ = module.__class__.__name__ + func = module.forward + return f diff --git a/colossalai/fx/profiler/shard_utils.py b/colossalai/fx/profiler/shard_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..34feefb4336ab4a7924f7023b6f887ba8610b25c --- /dev/null +++ b/colossalai/fx/profiler/shard_utils.py @@ -0,0 +1,114 @@ +import torch +from torch.fx import Node + +from .._compatibility import compatibility, is_compatible_with_meta +from .memory_utils import activation_size + +if is_compatible_with_meta(): + from .constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS + +__all__ = ["calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"] + + +@compatibility(is_backward_compatible=False) +def calculate_fwd_in(n: Node) -> int: + """A helper function to calculate `fwd_in` (with sharding spec) + + Args: + n (Node): a node from the graph + + Returns: + fwd_in (int): the result of `fwd_in` + """ + # TODO(super-dainiu): should divide the memory by sharding spec + return activation_size(n.meta["fwd_in"]) + + +@compatibility(is_backward_compatible=False) +def calculate_fwd_tmp(n: Node) -> int: + """A helper function to calculate `fwd_tmp` (with sharding spec) + Currently, `torch.nn.ReLU` behaves weirdly, so we have to patch it for accuracy. + + Args: + n (Node): a node from the graph + + Returns: + fwd_tmp (int): the result of `fwd_tmp` + """ + + # TODO(super-dainiu): should divide the memory by sharding spec + def is_relu_like_node(n: Node) -> bool: + """Check if a node is a ReLU-like node. + ReLU-like nodes have the following properties: + - They are either `call_function` or `call_module` + - Their output tensors are directly saved for backward + - Their input tensors are not saved for backward + + An example is `torch.nn.functional.softmax` which has (forward + backward): + def forward(self, input_2): + _softmax_default = torch.ops.aten._softmax.default(input_2, None, None); input_2 = None + zeros_like_default = torch.ops.aten.zeros_like.default(_softmax_default, dtype = None, layout = None, device = None, pin_memory = None) + detach_default = torch.ops.aten.detach.default(_softmax_default); _softmax_default = None + _softmax_backward_data_default = torch.ops.aten._softmax_backward_data.default(zeros_like_default, detach_default, None, None); zeros_like_default = detach_default = None + detach_default_1 = torch.ops.aten.detach.default(_softmax_backward_data_default); _softmax_backward_data_default = None + detach_default_2 = torch.ops.aten.detach.default(detach_default_1); detach_default_1 = None + + Args: + n (Node): A node from the graph + + Returns: + bool: Whether the node is a ReLU-like node + """ + if n.op == 'call_function': + return n.target in OUTPUT_SAVED_OPS + elif n.op == 'call_module': + return type(n.graph.owning_module.get_submodule(n.target)) in OUTPUT_SAVED_MOD + return False + + if not is_relu_like_node(n): + return activation_size(n.meta["fwd_tmp"]) + return 0 + + +@compatibility(is_backward_compatible=False) +def calculate_fwd_out(n: Node) -> int: + """A helper function to calculate `fwd_out` (with sharding spec) + + Args: + n (Node): a node from the graph + + Returns: + fwd_out (int): the result of `fwd_out` + """ + + # TODO(super-dainiu): should divide the memory by sharding spec + def intersect(a, b): + return {k: a[k] for k in a if k in b} + + fwd_in = dict() + for u in n.users: + fwd_in.update({x.data_ptr(): x for x in u.meta["fwd_in"] if isinstance(x, torch.Tensor)}) + fwd_out = {x.data_ptr(): x for x in n.meta["fwd_out"] if isinstance(x, torch.Tensor)} + return activation_size(intersect(fwd_in, fwd_out)) + + +def calculate_fwd_time(n: Node) -> float: + """A helper function to calculate `fwd_time` (with sharding spec) + Args: + n (Node): a node from the graph + Returns: + fwd_time (float): the result of `fwd_time` + """ + # TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs + return n.meta["fwd_time"] + + +def calculate_bwd_time(n: Node) -> float: + """A helper function to calculate `bwd_time` (with sharding spec) + Args: + n (Node): a node from the graph + Returns: + bwd_time (float): the result of `bwd_time` + """ + # TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs + return n.meta["bwd_time"] diff --git a/colossalai/fx/profiler/tensor.py b/colossalai/fx/profiler/tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..2ee5e5c47750c4bbc714987ce4bcae9f29e6ac71 --- /dev/null +++ b/colossalai/fx/profiler/tensor.py @@ -0,0 +1,137 @@ +import uuid + +import torch +from torch.types import _bool, _device, _dtype +from torch.utils._pytree import tree_flatten, tree_map + +from .._compatibility import compatibility +from .constants import ALIAS_ATEN + +__all__ = ['MetaTensor'] + + +def set_data_ptr(x): + if isinstance(x, torch.Tensor): + if not x.data_ptr(): + data_ptr = uuid.uuid4() + x.data_ptr = lambda: data_ptr + + +@compatibility(is_backward_compatible=False) +class MetaTensor(torch.Tensor): + """ + A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops. + `fake_device` is the device that `MetaTensor` is supposed to run on. + """ + + _tensor: torch.Tensor + + @staticmethod + def __new__(cls, elem, fake_device=None): + # Avoid multiple wrapping + if isinstance(elem, MetaTensor): + fake_device = elem.device if fake_device is None else fake_device + elem = elem._tensor + + # The wrapping tensor (MetaTensor) shouldn't hold any + # memory for the class in question, but it should still + # advertise the same device as before + r = torch.Tensor._make_wrapper_subclass( + cls, + elem.size(), + strides=elem.stride(), + storage_offset=elem.storage_offset(), + dtype=elem.dtype, + layout=elem.layout, + device=fake_device or (elem.device if elem.device.type != 'meta' else torch.device('cpu')), + requires_grad=elem.requires_grad) # deceive the frontend for aten selections + r._tensor = elem + # ...the real tensor is held as an element on the tensor. + if not r._tensor.is_meta: + r._tensor = r._tensor.to(torch.device('meta')) + # only tensor not on `meta` should be copied to `meta` + set_data_ptr(r._tensor) + return r + + def __repr__(self): + if self.grad_fn: + return f"MetaTensor(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})" + return f"MetaTensor(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})" + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + fake_device = None + + def unwrap(x): + nonlocal fake_device + if isinstance(x, MetaTensor): + fake_device = x.device + x = x._tensor + elif isinstance(x, torch.Tensor): + fake_device = x.device + x = x.to(torch.device('meta')) + return x + + args = tree_map(unwrap, args) + kwargs = tree_map(unwrap, kwargs) + + if 'device' in kwargs: + fake_device = kwargs['device'] + kwargs['device'] = torch.device('meta') + + # run aten for backend=CPU but actually on backend=Meta + out = func(*args, **kwargs) + + # here we keep the uuid of input because ALIAS_ATEN do not generate a physical copy + # of the input + if func in ALIAS_ATEN: + out.data_ptr = args[0].data_ptr + + # Now, we want to continue propagating this tensor, so we rewrap Tensors in + # our custom tensor subclass + def wrap(x): + if isinstance(x, torch.Tensor): + nonlocal fake_device + if not x.is_meta: + x = x.to(torch.device('meta')) + return MetaTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x + + return tree_map(wrap, out) + + def to(self, *args, **kwargs) -> torch.Tensor: + """An extension of `torch.Tensor.to()` to MetaTensor + + Returns: + result (MetaTensor): MetaTensor + + Usage: + >>> tensor = MetaTensor(torch.rand(10), fake_device='cuda:100') + >>> tensor.to(torch.uint8) + MetaTensor(tensor(..., device='meta', size=(10,), dtype=torch.uint8), fake_device='cuda:100') + >>> tensor.to(torch.device('cuda:42')) + MetaTensor(tensor(..., device='meta', size=(10,)), fake_device='cuda:42') + >>> tensor.to('vulkan') + MetaTensor(tensor(..., device='meta', size=(10,)), fake_device='vulkan') + """ + # this imitates c++ function in the way of @overload + fake_device = None + + def replace(x): + nonlocal fake_device + if isinstance(x, str) or isinstance(x, _device): + fake_device = x + return 'meta' + return x + + elem = self._tensor.to(*tree_map(replace, args), **tree_map(replace, kwargs)) + return MetaTensor(elem, fake_device=fake_device) + + def cpu(self, *args, **kwargs): + if self.device.type == 'cpu': + return self.to(*args, **kwargs) + return self.to(*args, device='cpu', **kwargs) + + def cuda(self, device=None, non_blocking=False): + if device is not None: + return self.to(device=device, non_blocking=non_blocking) + return self.to(device='cuda:0', non_blocking=non_blocking) diff --git a/colossalai/fx/proxy.py b/colossalai/fx/proxy.py new file mode 100644 index 0000000000000000000000000000000000000000..7317072c6298b66280810c48536eb22b7edca7f0 --- /dev/null +++ b/colossalai/fx/proxy.py @@ -0,0 +1,129 @@ +import operator +from typing import Any, List, Union + +import torch +from torch.fx.proxy import Attribute, Proxy + +from colossalai.fx.tracer.meta_patch import meta_patched_function + +__all__ = ['ColoProxy'] + + +class ColoProxy(Proxy): + """ + ColoProxy is a proxy class which uses meta tensor to handle data-dependent control flow. The original torch.fx proxy + cannot be used to infer the condition statement, with this proxy, torch.fx can still run even with if statements. + + Example:: + + proxy = tracer.create_proxy(...) + proxy.meta_data = torch.empty(4, 2, device='meta') + print(len(proxy)) # expect output 4 + + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.node._meta_data = None + + @property + def meta_data(self): + return self.node._meta_data + + @meta_data.setter + def meta_data(self, data: Any): + self.node._meta_data = data + + @property + def has_meta_data(self): + return self._meta_data is not None + + def _assert_meta_data_is_tensor(self): + assert torch.is_tensor( + self._meta_data) and self._meta_data.is_meta, f'Meta data is not a meta tensor for {self.node.name}' + + def _assert_has_meta_data(self): + assert self._meta_data is not None, f'Meta data is not set for {self.node.name}' + + def __len__(self): + self._assert_has_meta_data() + return len(self.meta_data) + + def __int__(self): + self._assert_has_meta_data() + return int(self.meta_data) + + def __float__(self): + self._assert_has_meta_data() + return float(self.meta_data) + + def __bool__(self): + self._assert_has_meta_data() + return self.meta_data + + def __getattr__(self, k): + + return ColoAttribute(self, k) + + def __contains__(self, key): + if self.node.op == "placeholder": + # this is used to handle like + # if x in kwargs + # we don't handle this case for now + return False + return super().__contains__(key) + + +def extract_meta(*args, **kwargs): + """ + This function is copied from _tracer_utils.py to avoid circular import issue. + """ + + def _convert(val): + if isinstance(val, ColoProxy): + return val.meta_data + elif isinstance(val, (list, tuple)): + return type(val)([_convert(ele) for ele in val]) + return val + + new_args = [_convert(val) for val in args] + new_kwargs = {k: _convert(v) for k, v in kwargs.items()} + return new_args, new_kwargs + + +class ColoAttribute(ColoProxy): + + def __init__(self, root, attr: str): + self.root = root + self.attr = attr + self.tracer = root.tracer + self._node = None + + @property + def node(self): + if self._node is None: + proxy = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}) + if not isinstance(proxy, ColoProxy): + meta_args, meta_kwargs = extract_meta(*(self.root, self.attr)) + meta_out = getattr(*meta_args, **meta_kwargs) + proxy = ColoProxy(proxy.node) + proxy.meta_data = meta_out + self._node = proxy.node + + return self._node + + def __call__(self, *args, **kwargs): + proxy = self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs) + if not isinstance(proxy, ColoProxy): + meta_args, meta_kwargs = extract_meta(*((self.root,) + args), **kwargs) + method = getattr(meta_args[0].__class__, self.attr) + if meta_patched_function.has(method): + meta_target = meta_patched_function.get(method) + elif meta_patched_function.has(method.__name__): + meta_target = meta_patched_function.get(method.__name__) + else: + meta_target = method + meta_out = meta_target(*meta_args, **meta_kwargs) + proxy = ColoProxy(proxy.node) + proxy.meta_data = meta_out + return proxy diff --git a/colossalai/fx/tracer/__init__.py b/colossalai/fx/tracer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..590555ce38bf30fab0c575bb69f05af74c8d3802 --- /dev/null +++ b/colossalai/fx/tracer/__init__.py @@ -0,0 +1,5 @@ +from colossalai.fx.tracer.meta_patch.patched_function.python_ops import operator_getitem + +from ._meta_trace import meta_trace +from ._symbolic_trace import symbolic_trace +from .tracer import ColoTracer diff --git a/colossalai/fx/tracer/_meta_trace.py b/colossalai/fx/tracer/_meta_trace.py new file mode 100644 index 0000000000000000000000000000000000000000..1c5abb81d271144ab666049d3ae2868cd9568497 --- /dev/null +++ b/colossalai/fx/tracer/_meta_trace.py @@ -0,0 +1,133 @@ +import torch +from torch.fx import Graph, Node +from torch.utils._pytree import tree_map + + +def normalize_tuple(x): + if not isinstance(x, tuple): + return (x,) + return x + + +def is_autogradable(x): + return isinstance(x, torch.Tensor) and x.is_floating_point() + + +def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Graph: + """Trace forward and backward graph with MetaTensor + + Args: + module (torch.nn.Module): The target module for tracing. + + Returns: + graph (torch.fx.Graph): The computation graph. + + Usage: + >>> import torchvision.models as tm + >>> model = tm.alexnet() + >>> graph = meta_trace(model, torch.rand(1000, 3, 224, 224)) + >>> graph.print_tabular() + """ + graph = Graph() + namespace = graph._graph_namespace + + class MetaProxy(torch.Tensor): + """ + A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops. + """ + + _tensor: torch.Tensor + _node: Node + + __slots__ = ['_tensor', '_node'] + + @staticmethod + def __new__(cls, tensor, fake_device=None, placeholder=False, name=None): + r = torch.Tensor._make_wrapper_subclass( + cls, + tensor.size(), + strides=tensor.stride(), + storage_offset=tensor.storage_offset(), + dtype=tensor.dtype, + layout=tensor.layout, + device=fake_device if fake_device is not None else tensor.device, + requires_grad=tensor.requires_grad) # deceive the frontend for aten selections + r._tensor = tensor + if placeholder: + if name is None: + name = 'input' + r._node = graph.create_node('placeholder', + 'placeholder', (graph._root,), + name=namespace.create_name(name, tensor)) + # ...the real tensor is held as an element on the tensor. + if not r._tensor.is_meta: + r._tensor = r._tensor.to(torch.device('meta')) + return r + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + + def unwrap(x): + nonlocal fake_device + if isinstance(x, MetaProxy): + fake_device = x.device + x = x._tensor + # assert not isinstance(x, MetaProxy) + elif isinstance(x, torch.Tensor): + fake_device = x.device + x = x.to(torch.device('meta')) + return x + + def get_node(x): + if isinstance(x, torch.Tensor) and not hasattr(x, '_node'): + x = MetaProxy(x, placeholder=True, name='weight') + return x if not hasattr(x, '_node') else x._node + + args_node = tree_map(get_node, args) + kwargs_node = tree_map(get_node, kwargs) + node = graph.create_node('call_function', func, args_node, kwargs_node) + + if 'device' in kwargs: + fake_device = kwargs['device'] + kwargs['device'] = torch.device('meta') + + args = tree_map(unwrap, args) + kwargs = tree_map(unwrap, kwargs) + + # run aten for backend=CPU but actually on backend=Meta + out = func(*args, **kwargs) + + # Now, we want to continue propagating this tensor, so we rewrap Tensors in + # our custom tensor subclass + def wrap(x): + if isinstance(x, torch.Tensor): + nonlocal fake_device + if not x.is_meta: + x = x.to(torch.device('meta')) + return MetaProxy( + x, fake_device=fake_device) if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor') else x + + def set_node(x): + x._node = node + + out = tree_map(wrap, out) + tree_map(set_node, out) + + return out + + def wrap(x): + return MetaProxy(x, fake_device=fake_device, placeholder=True) if isinstance(x, torch.Tensor) else x + + args = tree_map(wrap, args) + kwargs = tree_map(wrap, kwargs) + + out = module(*args, **kwargs) + + for tensor in normalize_tuple(out): + if is_autogradable(tensor) and tensor.requires_grad: + grad = torch.empty_like(tensor._tensor, device=torch.device('meta')) if isinstance( + tensor, MetaProxy) else torch.empty_like(tensor, device=torch.device('meta')) + torch.autograd.backward(tensor, + MetaProxy(grad, fake_device=tensor.device, placeholder=True), + retain_graph=True) + return graph diff --git a/colossalai/fx/tracer/_symbolic_trace.py b/colossalai/fx/tracer/_symbolic_trace.py new file mode 100644 index 0000000000000000000000000000000000000000..5c04eeace0ad97d1f6b9d0f962f363ef2ecd449b --- /dev/null +++ b/colossalai/fx/tracer/_symbolic_trace.py @@ -0,0 +1,55 @@ +from typing import Any, Callable, Dict, Optional, Union + +import torch + +from colossalai.fx import ColoGraphModule +from colossalai.fx._compatibility import compatibility + +from .tracer import ColoTracer + + +@compatibility(is_backward_compatible=True) +def symbolic_trace( + root: Union[torch.nn.Module, Callable[..., Any]], + concrete_args: Optional[Dict[str, Any]] = None, + meta_args: Optional[Dict[str, Any]] = None, + trace_act_ckpt=False, +) -> ColoGraphModule: + """ + Symbolic tracing API + + Given an ``nn.Module`` or function instance ``root``, this function will return a ``ColoGraphModule`` + constructed by recording operations seen while tracing through ``root``. + + With ``meta_args``, we can trace the model that are untraceable subject to control flow. If specified using + ``meta_args`` only, the tracing can be done ahead of time. + + Note that ``meta_args`` are kwargs, which contains the key of the argument's names and the value of the + argument's values. + + Uses: + >>> model = ... + + # if this works + >>> gm = symbolic_trace(model, concrete_args=concrete_args) + + # else try this + >>> gm = symbolic_trace(model, concrete_args=concrete_args, meta_args={'x': torch.rand(1, 3, 224, 224, device='meta')}) + + Args: + root (Union[torch.nn.Module, Callable[..., Any]]): Module or function to be traced and converted + into a Graph representation. + concrete_args (Optional[Dict[str, Any]], optional): Concrete arguments to be used for tracing. + meta_args (Optional[Dict[str, Any]], optional): Inputs to be partially specialized, special for ``ColoTracer``. + Defaults to None. + + Returns: + ColoGraphModule: A ``ColoGraphModule`` created from the recorded operations from ``root``. + + Warnings: + This API is still under development and can incur some bugs. Feel free to report any bugs to the Colossal-AI team. + + """ + graph = ColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root, concrete_args=concrete_args, meta_args=meta_args) + name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ + return ColoGraphModule(root, graph, name) diff --git a/colossalai/fx/tracer/_tracer_utils.py b/colossalai/fx/tracer/_tracer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e160497a74444fd34eea50943eb0430875fe1252 --- /dev/null +++ b/colossalai/fx/tracer/_tracer_utils.py @@ -0,0 +1,52 @@ +from typing import Any, List, Union + +import torch + +from ..proxy import ColoAttribute, ColoProxy +from .meta_patch import meta_patched_function, meta_patched_module + +__all__ = ['is_element_in_list', 'extract_meta'] + + +def is_element_in_list(elements: Union[List[Any], Any], list_: List[Any]): + if isinstance(elements, (tuple, list, set)): + for ele in elements: + if ele not in list_: + return False, ele + else: + if elements not in list_: + return False, elements + + return True, None + + +def extract_meta(*args, **kwargs): + + def _convert(val): + if isinstance(val, ColoProxy): + return val.meta_data + elif isinstance(val, (list, tuple)): + return type(val)([_convert(ele) for ele in val]) + + return val + + new_args = [_convert(val) for val in args] + new_kwargs = {k: _convert(v) for k, v in kwargs.items()} + return new_args, new_kwargs + + +def compute_meta_data_for_functions_proxy(target, args, kwargs): + args_metas, kwargs_metas = extract_meta(*args, **kwargs) + + # fetch patched function + if meta_patched_function.has(target): + meta_target = meta_patched_function.get(target) + elif meta_patched_function.has(target.__name__): + meta_target = meta_patched_function.get(target.__name__) + else: + meta_target = target + meta_out = meta_target(*args_metas, **kwargs_metas) + if isinstance(meta_out, torch.Tensor): + meta_out = meta_out.to(device="meta") + + return meta_out diff --git a/colossalai/fx/tracer/bias_addition_patch/__init__.py b/colossalai/fx/tracer/bias_addition_patch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e724d6a22fa84ecd954a59ebc6eb9b8daa00a035 --- /dev/null +++ b/colossalai/fx/tracer/bias_addition_patch/__init__.py @@ -0,0 +1,2 @@ +from .patched_bias_addition_function import * +from .patched_bias_addition_module import * diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/__init__.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..071bde4a5293e391d618973f803360deb4cd1b4c --- /dev/null +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/__init__.py @@ -0,0 +1,4 @@ +from .addbmm import Addbmm +from .addmm import Addmm +from .bias_addition_function import BiasAdditionFunc, LinearBasedBiasFunc, func_to_func_dict, method_to_func_dict +from .linear import Linear diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addbmm.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addbmm.py new file mode 100644 index 0000000000000000000000000000000000000000..859a19bf6241bbbf4061e0f7564975682527b8c2 --- /dev/null +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addbmm.py @@ -0,0 +1,75 @@ +import operator + +import torch +import torch.nn.functional as F + +from ...registry import bias_addition_function, bias_addition_method +from .bias_addition_function import LinearBasedBiasFunc + + +@bias_addition_method.register(torch.Tensor.addbmm) +@bias_addition_function.register(torch.addbmm) +class Addbmm(LinearBasedBiasFunc): + + def extract_kwargs_from_origin_func(self): + kwargs = {} + if 'beta' in self.kwargs: + kwargs['beta'] = self.kwargs['beta'] + if 'alpha' in self.kwargs: + kwargs['alpha'] = self.kwargs['alpha'] + return kwargs + + def create_non_bias_func_proxy(self, input_proxy, other_proxy): + """ + This method is used to create the non_bias_func proxy, the node created by this proxy will + compute the main computation, such as convolution, with bias option banned. + """ + assert self.substitute_func == torch.bmm + node_kind = 'call_function' + node_target = self.substitute_func + + node_args = (input_proxy, other_proxy) + # torch.bmm does not have any kwargs + node_kwargs = {} + non_bias_func_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs) + return non_bias_func_proxy + + def insert_sum_node(self, input_proxy, sum_dims=0): + ''' + This method is used to sum the input_proxy through the sum_dims. + ''' + node_kind = 'call_function' + node_target = torch.sum + node_args = (input_proxy, sum_dims) + node_kwargs = {} + sum_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs) + return sum_proxy + + def generate(self): + # The formula for addbmm is output = beta * input + alpha * (torch.bmm(b1, b2)) + + # doing the non-bias computation(temp_0 = torch.bmm(b1, b2)) + non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[1], self.args[2]) + + # doing sum on the batch dimension(temp_1 = torch.sum(temp_0, 0)) + sum_proxy = self.insert_sum_node(non_bias_linear_func_proxy) + kwargs = self.extract_kwargs_from_origin_func() + + if 'beta' in kwargs: + beta = kwargs['beta'] + # doing the multiplication with beta if it exists(temp_2 = beta * input) + beta_proxy = self.create_mul_node(self.args[0], beta) + else: + beta_proxy = self.args[0] + + if 'alpha' in kwargs: + alpha = kwargs['alpha'] + # doing the multiplication with alpha if it exists(temp_3 = alpha * temp_1) + alpha_proxy = self.create_mul_node(alpha, sum_proxy) + else: + alpha_proxy = sum_proxy + + # doing the addition(temp_4 = temp_2 + temp_3) + bias_addition_proxy = self.create_bias_addition_proxy(alpha_proxy, beta_proxy) + + return bias_addition_proxy diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py new file mode 100644 index 0000000000000000000000000000000000000000..fe7d8d07aac941028d7c682043b5af2bcdf2537a --- /dev/null +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py @@ -0,0 +1,60 @@ +import operator + +import torch +import torch.nn.functional as F + +from ...registry import bias_addition_function, bias_addition_method +from .bias_addition_function import LinearBasedBiasFunc + + +@bias_addition_method.register(torch.Tensor.addmm) +@bias_addition_function.register(torch.addmm) +class Addmm(LinearBasedBiasFunc): + + def extract_kwargs_from_origin_func(self): + kwargs = {} + if 'beta' in self.kwargs: + kwargs['beta'] = self.kwargs['beta'] + if 'alpha' in self.kwargs: + kwargs['alpha'] = self.kwargs['alpha'] + return kwargs + + def transpose_other_operand_for_linear(self, other_proxy): + ''' + This method is used to transpose the other operand for linear function. + For example: + input = torch.rand(3, 4) + m1 = torch.rand(3, 5) + m2 = torch.rand(5, 4) + original_output = torch.addmm(input, m1, m2) + # To keep the computation graph consistent with the origin computation graph, we need to transpose the m2 + # before we call the linear function. + new_output = torch.linear(m1, m2.transpose(0, 1)) + input + ''' + node_kind = 'call_function' + node_target = torch.transpose + node_args = (other_proxy, 0, 1) + node_kwargs = {} + transpose_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs) + return transpose_proxy + + def generate(self): + transpose_proxy = self.transpose_other_operand_for_linear(self.args[2]) + non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[1], transpose_proxy) + kwargs = self.extract_kwargs_from_origin_func() + + if 'beta' in kwargs: + beta = kwargs['beta'] + beta_proxy = self.create_mul_node(self.args[0], beta) + else: + beta_proxy = self.args[0] + + if 'alpha' in kwargs: + alpha = kwargs['alpha'] + alpha_proxy = self.create_mul_node(alpha, non_bias_linear_func_proxy) + else: + alpha_proxy = non_bias_linear_func_proxy + + bias_addition_proxy = self.create_bias_addition_proxy(alpha_proxy, beta_proxy) + + return bias_addition_proxy diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py new file mode 100644 index 0000000000000000000000000000000000000000..8a3786332c08d9a3320ce4c7bee8221dd3d10abd --- /dev/null +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py @@ -0,0 +1,115 @@ +import operator +from abc import ABC, abstractmethod + +import torch +import torch.nn.functional as F + + +class BiasAdditionFunc(ABC): + """ + This class is used to construct the restructure computation graph for + call_func node with bias addition inside. + """ + + def __init__(self, tracer, target, args, kwargs, substitute_func): + self.tracer = tracer + self.target = target + self.args = args + self.kwargs = kwargs + self.substitute_func = substitute_func + + @abstractmethod + def extract_kwargs_from_origin_func(self): + """ + This method is used to extract the kwargs for further graph transform. + + For example: + The formula for torch.addmm is out = beta * input + alpha * (m1 @ m2) + The kwargs for addmm function is {beta=1, alpha=1, output=None}, then we need + to insert two more operator.mul nodes for the computation graph to compute the + final result. + """ + pass + + @abstractmethod + def generate(self): + """ + This method is used to construct the whole restructure computation graph for call_func node with bias + addition inside. + + A whole restructure computation graph will contain a weight node, a bias node, a non-bias addition computation node, + a bias reshape node if needed and a bias addition node. + + Use torch.addmm as an example: + The origin node is: + %addmm: call_func[target=torch.addmm](args = (%input_1, m1, m2), kwargs = {beta=1, alpha=1}) + Restructured graph is: + %transpose : [#users=1] = call_function[target=torch.transpose](args = (%m2, 0, 1), kwargs = {}) + %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%m1, %transpose), kwargs = {}) + %mul : [#users=1] = call_function[target=operator.mul](args = (%input_1, 3), kwargs = {}) + %mul_1 : [#users=1] = call_function[target=operator.mul](args = (2, %linear), kwargs = {}) + %add : [#users=1] = call_function[target=operator.add](args = (%mul_1, %mul), kwargs = {}) + """ + pass + + def create_mul_node(self, input_proxy, coefficent): + """ + This method is used to create a coefficent node for the numerical correctness. + The formula for torch.addmm is out = beta * input + alpha * (m1 @ m2) + Therefore, we need to use this method insert two more operator.mul nodes for + the computation graph to compute the final result. + """ + node_kind = 'call_function' + node_target = operator.mul + node_args = ( + input_proxy, + coefficent, + ) + node_kwargs = {} + mul_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs) + return mul_proxy + + +class LinearBasedBiasFunc(BiasAdditionFunc): + """ + This class is used to construct the restructure computation graph for + call_func node based on F.linear. + """ + + def create_non_bias_func_proxy(self, input_proxy, other_proxy): + """ + This method is used to create the non_bias_func proxy, the node created by this proxy will + compute the main computation, such as convolution, with bias option banned. + """ + assert self.substitute_func == torch.nn.functional.linear + node_kind = 'call_function' + node_target = self.substitute_func + + node_args = (input_proxy, other_proxy) + # non-bias linear does not have any kwargs + node_kwargs = {} + non_bias_func_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs) + return non_bias_func_proxy + + def create_bias_addition_proxy(self, non_bias_func_proxy, bias_proxy): + """ + This method is used to create the bias_addition_proxy, the node created by this proxy will + compute the sum of non_bias_func result and bias with some reshape operation if needed. + """ + bias_add_node_kind = 'call_function' + bias_add_node_target = operator.add + bias_add_args = (non_bias_func_proxy, bias_proxy) + bias_add_proxy = self.tracer.create_proxy(bias_add_node_kind, bias_add_node_target, tuple(bias_add_args), {}) + return bias_add_proxy + + +func_to_func_dict = { + torch.addmm: F.linear, + torch.addbmm: torch.bmm, + F.linear: F.linear, +} + +method_to_func_dict = { + torch.Tensor.addmm: F.linear, + torch.Tensor.addbmm: torch.bmm, +} diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/linear.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..e11ec0a364f1e5ee1445c97ae0c9b054d02bcfa2 --- /dev/null +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/linear.py @@ -0,0 +1,25 @@ +import operator + +import torch +import torch.nn.functional as F + +from ...registry import bias_addition_function +from .bias_addition_function import LinearBasedBiasFunc + + +@bias_addition_function.register(F.linear) +class Linear(LinearBasedBiasFunc): + + def extract_kwargs_from_origin_func(self): + assert 'bias' in self.kwargs + kwargs = {} + if 'bias' in self.kwargs: + kwargs['bias'] = self.kwargs['bias'] + return kwargs + + def generate(self): + non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[0], self.args[1]) + kwargs = self.extract_kwargs_from_origin_func() + bias_addition_proxy = self.create_bias_addition_proxy(non_bias_linear_func_proxy, kwargs['bias']) + + return bias_addition_proxy diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/__init__.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f3823bb3e2a20e3963cc451459ec963a36eb1139 --- /dev/null +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/__init__.py @@ -0,0 +1,3 @@ +from .bias_addition_module import * +from .conv import * +from .linear import * diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/bias_addition_module.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/bias_addition_module.py new file mode 100644 index 0000000000000000000000000000000000000000..85f1553e304c9c45b2b8f1373e76e7023d452d47 --- /dev/null +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/bias_addition_module.py @@ -0,0 +1,111 @@ +import operator +from abc import ABC, abstractmethod + +import torch +import torch.nn.functional as F + + +class BiasAdditionModule(ABC): + """ + This class is used to construct the restructure computation graph for + call_module node with bias addition inside. + """ + + def __init__(self, tracer, target, args, kwargs, substitute_func): + self.tracer = tracer + self.target = target + self.args = args + self.kwargs = kwargs + self.substitute_func = substitute_func + self.weight_proxy = self._create_weight_proxy() + self.bias_proxy = self._create_bias_proxy() + + def _create_weight_proxy(self): + """ + Create weight proxy, the node created by this proxy contains module weight. + + Note: this function will be invoked during module initializing, + you should never call this function. + """ + weight_node_kind = 'get_attr' + weight_node_target = self.target + '.weight' + weight_proxy = self.tracer.create_proxy(weight_node_kind, weight_node_target, (), {}) + return weight_proxy + + def _create_bias_proxy(self): + """ + Create bias proxy, the node created by this proxy contains module bias. + + Note: this function will be invoked during module initializing, + you should never call this function. + """ + bias_node_kind = 'get_attr' + bias_node_target = self.target + '.bias' + bias_proxy = self.tracer.create_proxy(bias_node_kind, bias_node_target, (), {}) + return bias_proxy + + @abstractmethod + def extract_kwargs_from_mod(self): + """ + This method is used to extract the kwargs for non-bias computation. + + For example: + The kwargs for conv2d module is {} because the attributes like 'padding' or 'groups' are + considered during module initilizing. However, we need to consider those attributes as kwargs + in F.conv2d. + """ + pass + + def create_non_bias_func_proxy(self, input_proxy=None): + """ + This method is used to create the non_bias_func proxy, the node created by this proxy will + compute the main computation, such as convolution, with bias option banned. + """ + node_kind = 'call_function' + node_target = self.substitute_func + if input_proxy is None: + input_proxy = self.args[0] + node_args = (input_proxy, self.weight_proxy) + node_kwargs = self.extract_kwargs_from_mod() + non_bias_func_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs) + return non_bias_func_proxy + + def create_bias_addition_proxy(self, non_bias_func_proxy, bias_proxy): + """ + This method is used to create the bias_addition_proxy, the node created by this proxy will + compute the sum of non_bias_func result and bias with some reshape operation if needed. + """ + bias_add_node_kind = 'call_function' + bias_add_node_target = operator.add + bias_add_args = (non_bias_func_proxy, bias_proxy) + bias_add_proxy = self.tracer.create_proxy(bias_add_node_kind, bias_add_node_target, tuple(bias_add_args), {}) + return bias_add_proxy + + @abstractmethod + def generate(self): + """ + This method is used to construct the whole restructure computation graph for call_module node with bias + addition inside. + + A whole restructure computation graph will contain a weight node, a bias node, a non-bias addition computation node, + a bias reshape node if needed and a bias addition node. + + Use Conv2d module as an example: + The origin node is: + %conv: call_module[target=conv](args = (%x,), kwargs = {}) + Restructured graph is: + %conv_weight : [#users=1] = get_attr[target=conv.weight] + %conv_bias : [#users=1] = get_attr[target=conv.bias] + %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%x, %conv_weight), kwargs = {}) + %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {}) + %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {}) + """ + pass + + +module_to_func_dict = { + torch.nn.Linear: F.linear, + torch.nn.Conv1d: F.conv1d, + torch.nn.Conv2d: F.conv2d, + torch.nn.Conv3d: F.conv3d, +} diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..4b6c82a74f57d213ba3d8b68863053ddd1aabf9d --- /dev/null +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py @@ -0,0 +1,56 @@ +import torch +import torch.nn.functional as F +from torch.nn.modules.utils import _pair, _reverse_repeat_tuple, _single, _triple + +from ...registry import bias_addition_module +from .bias_addition_module import BiasAdditionModule + + +@bias_addition_module.register(torch.nn.Conv1d) +@bias_addition_module.register(torch.nn.Conv2d) +@bias_addition_module.register(torch.nn.Conv3d) +class BiasAdditionConv(BiasAdditionModule): + + def extract_kwargs_from_mod(self): + root = self.tracer.root + conv_module = root.get_submodule(self.target) + kwarg_attributes = ['groups', 'dilation', 'stride'] + non_bias_kwargs = {} + for attr_name in kwarg_attributes: + if hasattr(conv_module, attr_name): + non_bias_kwargs[attr_name] = getattr(conv_module, attr_name) + if conv_module.padding_mode != "zeros": + #TODO: non zeros mode requires some extra processing for input + conv_type = type(conv_module) + if conv_type == "torch.nn.Conv1d": + padding_element = _single(0) + elif conv_type == "torch.nn.Conv2d": + padding_element = _pair(0) + elif conv_type == "torch.nn.Conv3d": + padding_element = _triple(0) + non_bias_kwargs['padding'] = padding_element + else: + non_bias_kwargs['padding'] = getattr(conv_module, 'padding') + + return non_bias_kwargs + + def create_bias_reshape_proxy(self, dimensions): + """ + This method is used to reshape the bias node in order to make bias and + output of non-bias convolution broadcastable. + """ + bias_shape = [1] * (dimensions - 1) + bias_shape[0] = -1 + bias_reshape_node_kind = 'call_method' + bias_reshape_node_target = 'view' + bias_reshape_node_args = (self.bias_proxy, torch.Size(bias_shape)) + bias_reshape_proxy = self.tracer.create_proxy(bias_reshape_node_kind, bias_reshape_node_target, + bias_reshape_node_args, {}) + return bias_reshape_proxy + + def generate(self): + non_bias_conv_func_proxy = self.create_non_bias_func_proxy() + output_dims = non_bias_conv_func_proxy.meta_data.dim() + bias_reshape_proxy = self.create_bias_reshape_proxy(output_dims) + bias_addition_proxy = self.create_bias_addition_proxy(non_bias_conv_func_proxy, bias_reshape_proxy) + return bias_addition_proxy diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/linear.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..f6f7b6ddab401a637aa2b43b4dd8d2ce9193266e --- /dev/null +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/linear.py @@ -0,0 +1,17 @@ +import torch +import torch.nn.functional as F + +from ...registry import bias_addition_module +from .bias_addition_module import BiasAdditionModule + + +@bias_addition_module.register(torch.nn.Linear) +class BiasAdditionLinear(BiasAdditionModule): + + def extract_kwargs_from_mod(self): + return {} + + def generate(self): + non_bias_linear_func_proxy = self.create_non_bias_func_proxy() + bias_addition_proxy = self.create_bias_addition_proxy(non_bias_linear_func_proxy, self.bias_proxy) + return bias_addition_proxy diff --git a/colossalai/fx/tracer/experimental.py b/colossalai/fx/tracer/experimental.py new file mode 100644 index 0000000000000000000000000000000000000000..88b65b6188fa67be33f7a15b55e7fd5d32d7c4cc --- /dev/null +++ b/colossalai/fx/tracer/experimental.py @@ -0,0 +1,650 @@ +import enum +import functools +import inspect +import operator +from contextlib import contextmanager +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union + +import torch +from torch.fx import Graph, Node, Proxy, Tracer +from torch.utils._pytree import tree_map + +from colossalai.fx import ColoGraphModule, compatibility, is_compatible_with_meta +from colossalai.fx.tracer._tracer_utils import extract_meta, is_element_in_list +from colossalai.fx.tracer.bias_addition_patch import func_to_func_dict, method_to_func_dict, module_to_func_dict +from colossalai.fx.tracer.registry import ( + bias_addition_function, + bias_addition_method, + bias_addition_module, + meta_patched_function, + meta_patched_module, +) + +if is_compatible_with_meta(): + from colossalai.fx.profiler import MetaTensor + +Target = Union[Callable[..., Any], str] +Argument = Optional[Union[Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types + List[Any], # actually Argument + Dict[str, Any], # actually Argument + slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing + 'Node',]] +_CScriptMethod = ['add', 'mul', 'sub', 'div'] +_TorchNewMethod = [ + "arange", "zeros", "zeros_like", "ones", "ones_like", "full", "full_like", "empty", "empty_like", "eye", "tensor", + "finfo" +] +_TensorPropertyMethod = ["dtype", "shape", "device", "requires_grad", "grad", "grad_fn", "data"] + + +def _truncate_suffix(s: str): + import re + return re.sub(r'_\d+$', '', s) + + +def default_device(): + return torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') + + +@compatibility(is_backward_compatible=False) +class ColoProxy(Proxy): + + def __init__(self, *args, data=None, **kwargs): + super().__init__(*args, **kwargs) + self._meta_data = data + + @property + def meta_data(self): + return self._meta_data + + @meta_data.setter + def meta_data(self, args): + wrap_fn = lambda x: MetaTensor(x) if isinstance(x, torch.Tensor) else x + self._meta_data = tree_map(wrap_fn, args) + + @classmethod + def __torch_function__(cls, orig_method, types, args=(), kwargs=None): + proxy = cls.from_torch_proxy(super().__torch_function__(orig_method, types, args, kwargs)) + unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p + kwargs = {} if kwargs is None else kwargs + if proxy.meta_data is None: + proxy.meta_data = orig_method(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) + return proxy + + @classmethod + def from_torch_proxy(cls, proxy: Proxy): + return cls(proxy.node, proxy.tracer) + + def __repr__(self): + return f"ColoProxy({self.node.name}, meta_data={self.meta_data})" + + def __len__(self): + return len(self.meta_data) + + def __int__(self): + return int(self.meta_data) + + def __index__(self): + try: + return int(self.meta_data) + except: + return torch.zeros(self.meta_data.shape, dtype=torch.bool).numpy().__index__() + + def __float__(self): + return float(self.meta_data) + + def __bool__(self): + return self.meta_data + + def __getattr__(self, k): + return ColoAttribute(self, k, getattr(self._meta_data, k, None)) + + def __setitem__(self, key, value): + proxy = self.tracer.create_proxy('call_function', operator.setitem, (self, key, value), {}) + proxy.meta_data = self._meta_data + return proxy + + def __contains__(self, key): + if self.node.op == "placeholder": + # this is used to handle like + # if x in kwargs + # we don't handle this case for now + return False + return super().__contains__(key) + + def __isinstancecheck__(self, type): + return isinstance(self.meta_data, type) + + @property + def shape(self): + return self.meta_data.shape + + @property + def ndim(self): + return self.meta_data.ndim + + @property + def device(self): + proxy = self.tracer.create_proxy('call_function', getattr, (self, 'device'), {}) + proxy.meta_data = self.meta_data.device + return proxy + + @property + def dtype(self): + proxy = self.tracer.create_proxy('call_function', getattr, (self, 'dtype'), {}) + proxy.meta_data = self.meta_data.dtype + return proxy + + def to(self, *args, **kwargs): + return self.tracer.create_proxy('call_method', 'to', (self, *args), {**kwargs}) + + def cpu(self, *args, **kwargs): + return self.tracer.create_proxy('call_method', 'cpu', (self, *args), {**kwargs}) + + def cuda(self, *args, **kwargs): + return self.tracer.create_proxy('call_method', 'cuda', (self, *args), {**kwargs}) + + +@compatibility(is_backward_compatible=False) +class ColoAttribute(ColoProxy): + + def __init__(self, root, attr: str, data=None): + self.root = root + self.attr = attr + self.tracer = root.tracer + self._meta_data = data + self._node: Optional[Node] = None + + @property + def node(self): + # the node for attributes is added lazily, since most will just be method calls + # which do not rely on the getitem call + if self._node is None: + self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node + return self._node + + def __call__(self, *args, **kwargs): + return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs) + + def __repr__(self): + return f"ColoAttribute({self.node.name}, attr={self.attr})" + + +@compatibility(is_backward_compatible=False) +class ColoTracer(Tracer): + + def __init__(self, trace_act_ckpt: bool = False, *args, **kwargs): + super().__init__(*args, **kwargs) + self._disable_module_getattr = False + self.proxy_buffer_attributes = True + + # whether the tracer will record the usage of torch.utils.checkpoint + self.trace_act_ckpt = trace_act_ckpt + # whether the current tracing occurs within the activation checkpoint functions + self.inside_torch_checkpoint_func = False + self.act_ckpt_region_count = 0 + + def proxy(self, node: Node) -> 'ColoProxy': + return ColoProxy(node, self) + + def create_proxy(self, + kind: str, + target: Target, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + name: Optional[str] = None, + type_expr: Optional[Any] = None, + proxy_factory_fn: Callable[[Node], 'Proxy'] = None): + + proxy: ColoProxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn) + unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p + if kind == 'placeholder': + proxy.meta_data = self.meta_args[target] if target in self.meta_args else self.concrete_args.get( + _truncate_suffix(target), None) + elif kind == 'get_attr': + self._disable_module_getattr = True + try: + attr_itr = self.root + atoms = target.split(".") + for atom in atoms: + attr_itr = getattr(attr_itr, atom) + proxy.meta_data = attr_itr + finally: + self._disable_module_getattr = False + elif kind == 'call_function': + proxy.meta_data = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) + elif kind == 'call_method': + self._disable_module_getattr = True + try: + if target == '__call__': + proxy.meta_data = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs)) + else: + if target not in _TensorPropertyMethod: + proxy._meta_data = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]), + **tree_map(unwrap_fn, kwargs)) + finally: + self._disable_module_getattr = False + elif kind == 'call_module': + mod = self.root.get_submodule(target) + self._disable_module_getattr = True + try: + proxy.meta_data = mod.forward(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) + finally: + self._disable_module_getattr = False + return proxy + + def create_node(self, *args, **kwargs) -> Node: + node = super().create_node(*args, **kwargs) + + if self.inside_torch_checkpoint_func: + # annotate the activation checkpoint module + node.meta['activation_checkpoint'] = self.act_ckpt_region_count + return node + + def trace(self, + root: torch.nn.Module, + concrete_args: Optional[Dict[str, torch.Tensor]] = None, + meta_args: Optional[Dict[str, torch.Tensor]] = None) -> Graph: + + if meta_args is None: + meta_args = {} + + if concrete_args is None: + concrete_args = {} + + # check concrete and meta args have valid names + sig = inspect.signature(root.forward) + sig_names = set(sig.parameters.keys()) + meta_arg_names = set(meta_args.keys()) + + # update concrete args with default values + non_meta_arg_names = sig_names - meta_arg_names + for k, v in sig.parameters.items(): + if k in non_meta_arg_names and \ + k not in concrete_args and \ + v.default is not inspect.Parameter.empty: + concrete_args[k] = v.default + + # get non concrete arg names + concrete_arg_names = set(concrete_args.keys()) + non_concrete_arg_names = sig_names - concrete_arg_names + + def _check_arg_name_valid(names): + success, element = is_element_in_list(names, sig_names) + if not success: + raise KeyError( + f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function") + + _check_arg_name_valid(meta_arg_names) + _check_arg_name_valid(concrete_arg_names) + + self.concrete_args = concrete_args + self.meta_args = meta_args + + with _TorchTensorOverride(self), self.trace_activation_checkpoint(enabled=self.trace_act_ckpt): + self.graph = super().trace(root, concrete_args=concrete_args) + self.graph.lint() + return self.graph + + @contextmanager + def trace_activation_checkpoint(self, enabled: bool): + if enabled: + orig_ckpt_func = torch.utils.checkpoint.CheckpointFunction + + class PatchedCheckpointFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, run_function, preserve_rng_state, *args): + # signal that the current tracing occurs within activaton checkpoint part + self.inside_torch_checkpoint_func = True + out = run_function(*args) + self.inside_torch_checkpoint_func = False + self.act_ckpt_region_count += 1 + return out + + @staticmethod + def backward(ctx: Any, *grad_outputs: Any) -> Any: + raise NotImplementedError( + "We do not implement the backward pass as we only trace the forward pass.") + + # override the checkpoint function + torch.utils.checkpoint.CheckpointFunction = PatchedCheckpointFunction + yield + + if enabled: + # recover the checkpoint function upon exit + torch.utils.checkpoint.CheckpointFunction = orig_ckpt_func + + def _post_check(self, non_concrete_arg_names: Set[str]): + # This is necessary because concrete args are added as input to the traced module since + # https://github.com/pytorch/pytorch/pull/55888. + for node in self.graph.nodes: + if node.op == "placeholder": + # Removing default values for inputs as the forward pass will fail with them. + if node.target in non_concrete_arg_names: + node.args = () + # Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor]. + # It cannot infer on the attributes and methods the input should have, and fails. + node.type = torch.Tensor + # It is a concrete arg so it is not used and should be removed. + else: + if hasattr(torch.fx._symbolic_trace, "_assert_is_none"): + # Newer versions of torch.fx emit an assert statement + # for concrete arguments; delete those before we delete + # the concrete arg. + to_delete = [] + for user in node.users: + if user.target == torch.fx._symbolic_trace._assert_is_none: + to_delete.append(user) + for user in to_delete: + self.graph.erase_node(user) + + self.graph.erase_node(node) + + # TODO: solves GraphModule creation. + # Without this, return type annotation "Tuple" is causing code execution failure. + if node.op == "output": + node.type = None + self.graph.lint() + + def _module_getattr(self, attr, attr_val, parameter_proxy_cache): + if getattr(self, "_disable_module_getattr", False): + return attr_val + + def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache): + for n, p in collection_to_search: + if attr_val is p: + if n not in parameter_proxy_cache: + kwargs = {} + if 'proxy_factory_fn' in inspect.signature(self.create_proxy).parameters: + kwargs['proxy_factory_fn'] = (None if not self.param_shapes_constant else + lambda node: ColoProxy(self, node, n, attr_val)) + val_proxy = self.create_proxy('get_attr', n, (), {}, **kwargs) # type: ignore[arg-type] + parameter_proxy_cache[n] = val_proxy + return parameter_proxy_cache[n] + return None + + if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor): + maybe_buffer_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_buffers(), parameter_proxy_cache) + if maybe_buffer_proxy is not None: + return maybe_buffer_proxy + + if isinstance(attr_val, torch.nn.Parameter): + maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(), + parameter_proxy_cache) + if maybe_parameter_proxy is not None: + return maybe_parameter_proxy + + return attr_val + + +@compatibility(is_backward_compatible=True) +def symbolic_trace( + root: Union[torch.nn.Module, Callable[..., Any]], + concrete_args: Optional[Dict[str, Any]] = None, + meta_args: Optional[Dict[str, Any]] = None, + trace_act_ckpt=False, +) -> ColoGraphModule: + if is_compatible_with_meta(): + if meta_args is not None: + root.to(default_device()) + wrap_fn = lambda x: MetaTensor(x, fake_device=default_device()) if isinstance(x, torch.Tensor) else x + graph = ColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root, + concrete_args=concrete_args, + meta_args=tree_map(wrap_fn, meta_args)) + root.cpu() + else: + graph = Tracer().trace(root, concrete_args=concrete_args) + else: + from .tracer import ColoTracer as OrigColoTracer + graph = OrigColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root, + concrete_args=concrete_args, + meta_args=meta_args) + name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ + return ColoGraphModule(root, graph, name) + + +@compatibility(is_backward_compatible=False) +class _TorchTensorOverride(object): + + def __init__(self, tracer: Tracer): + self.overrides = {} + self.tracer = tracer + + def __enter__(self): + + def wrap_tensor_method(target): + + @functools.wraps(target) + def wrapper(*args, **kwargs): + is_proxy = any(isinstance(p, ColoProxy) for p in args) | any( + isinstance(p, ColoProxy) for p in kwargs.values()) + if is_proxy: + # if the arg is a proxy, then need to record this function called on this proxy + # e.g. torch.ones(size) where size is an input proxy + self.tracer._disable_module_getattr = True + try: + proxy = self.tracer.create_proxy('call_function', target, args, kwargs) + finally: + self.tracer._disable_module_getattr = False + return proxy + else: + return target(*args, **kwargs) + + return wrapper, target + + self.overrides = { + target: wrap_tensor_method(getattr(torch, target)) + for target in _TorchNewMethod + if callable(getattr(torch, target)) + } + for name, (wrapper, orig) in self.overrides.items(): + setattr(torch, name, wrapper) + + def __exit__(self, exc_type, exc_val, exc_tb): + for name, (wrapper, orig) in self.overrides.items(): + setattr(torch, name, orig) + + +def meta_prop_pass(gm: ColoGraphModule, + root: torch.nn.Module, + meta_args: Optional[Dict[str, Any]] = None, + concrete_args: Optional[Dict[str, torch.Tensor]] = None): + + if meta_args is None: + meta_args = {} + + if concrete_args is None: + concrete_args = {} + + # check concrete and meta args have valid names + sig = inspect.signature(root.forward) + sig_names = set(sig.parameters.keys()) + meta_arg_names = set(meta_args.keys()) + + # update concrete args with default values + non_meta_arg_names = sig_names - meta_arg_names + for k, v in sig.parameters.items(): + if k in non_meta_arg_names and \ + k not in concrete_args and \ + v.default is not inspect.Parameter.empty: + concrete_args[k] = v.default + + for node in gm.graph.nodes: + node._meta_data = _meta_data_computing(meta_args, concrete_args, root, node.op, node.target, node.args, + node.kwargs) + + +def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwargs): + unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n + if kind == 'placeholder': + meta_out = meta_args[target] if target in meta_args else concrete_args.get(_truncate_suffix(target), None) + elif kind == 'get_attr': + attr_itr = root + atoms = target.split(".") + for atom in atoms: + attr_itr = getattr(attr_itr, atom) + meta_out = attr_itr + elif kind == 'call_function': + meta_out = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) + elif kind == 'call_method': + if target == '__call__': + meta_out = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs)) + else: + if target not in _TensorPropertyMethod: + meta_out = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]), + **tree_map(unwrap_fn, kwargs)) + elif kind == 'call_module': + mod = root.get_submodule(target) + meta_out = mod.forward(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) + else: + meta_out = None + return meta_out + + +def _meta_data_computing_v0(meta_args, root, kind, target, args, kwargs): + if kind == "placeholder" and target in meta_args and meta_args[target].is_meta: + meta_out = meta_args[target] + return meta_out + + if target in [getattr(torch, torch_func) for torch_func in _TorchNewMethod]: + # NOTE: tensor constructors in PyTorch define the `device` argument as + # *kwargs-only*. That is why this works. If you add methods to + # _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only, + # this will break and you will likely see issues where we cannot infer + # the size of the output. + if "device" in kwargs: + kwargs["device"] = "meta" + + try: + unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n + args_metas = tree_map(unwrap_fn, args) + kwargs_metas = tree_map(unwrap_fn, kwargs) + + if kind == "call_function": + # fetch patched function + if meta_patched_function.has(target): + meta_target = meta_patched_function.get(target) + elif meta_patched_function.has(target.__name__): + # use name for some builtin op like @ (matmul) + meta_target = meta_patched_function.get(target.__name__) + else: + meta_target = target + + meta_out = meta_target(*args_metas, **kwargs_metas) + + if isinstance(meta_out, torch.Tensor): + meta_out = meta_out.to(device="meta") + elif kind == "call_method": + method = getattr(args_metas[0].__class__, target) + + # fetch patched method + if meta_patched_function.has(method): + meta_target = meta_patched_function.get(method) + else: + meta_target = method + + meta_out = meta_target(*args_metas, **kwargs_metas) + elif kind == "call_module": + mod = root.get_submodule(target) + mod_type = type(mod) + if meta_patched_module.has(mod_type): + meta_out = meta_patched_module.get(mod_type)(mod, *args_metas, **kwargs_metas) + else: + meta_out = mod(*args_metas, **kwargs_metas) + elif kind == "get_attr": + attr_itr = root + atoms = target.split(".") + for atom in atoms: + attr_itr = getattr(attr_itr, atom) + if isinstance(attr_itr, torch.nn.parameter.Parameter): + meta_out = torch.nn.Parameter(attr_itr.to(device="meta")) + elif isinstance(attr_itr, torch.Tensor): + meta_out = attr_itr.to(device="meta") + else: + meta_out = attr_itr + else: + return None + + except Exception as e: + raise RuntimeError(f"Could not compute metadata for {kind} target {target}: {e}") + + return meta_out + + +def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_args: Optional[Dict[str, Any]] = None): + result_graph = Graph() + value_remap = {} + unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n + + for orig_node in gm.graph.nodes: + assert hasattr(orig_node, "_meta_data") + kind = orig_node.op + target = orig_node.target + args = orig_node.args + kwargs = orig_node.kwargs + + args_metas = tree_map(unwrap_fn, args) + tracer = ColoTracer() + tracer.graph = Graph(tracer_cls=ColoTracer) + tracer.root = root_model + + def wrap_fn(n): + if isinstance(n, Node): + proxy = ColoProxy(n, tracer) + proxy.meta_data = n._meta_data + return proxy + return n + + args_proxy = tree_map(wrap_fn, args) + kwargs_proxy = tree_map(wrap_fn, kwargs) + + handle = None + if kind == "call_function": + if bias_addition_function.has(target): + if target == torch.nn.functional.linear: + if 'bias' in kwargs and kwargs['bias'] is not None: + function_to_substitute = func_to_func_dict[target] + handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy, + function_to_substitute) + else: + function_to_substitute = func_to_func_dict[target] + handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy, + function_to_substitute) + elif bias_addition_function.has(target.__name__): + # use name for some builtin op like @ (matmul) + function_to_substitute = func_to_func_dict[target] + handle = bias_addition_function.get(target.__name__)(tracer, target, args_proxy, kwargs_proxy, + function_to_substitute) + + elif kind == "call_method": + method = getattr(args_metas[0].__class__, target) + if bias_addition_method.has(method): + function_to_substitute = method_to_func_dict[method] + handle = bias_addition_method.get(method)(tracer, target, args_proxy, kwargs_proxy, + function_to_substitute) + + elif kind == "call_module": + # if not hasattr(self, "orig_forward"): + # raise AttributeError(f"{self} does not have an attribute called orig_forward") + mod = gm.get_submodule(target) + mod_type = type(mod) + if bias_addition_module.has(mod_type) and mod.bias is not None: + function_to_substitute = module_to_func_dict[mod_type] + handle = bias_addition_module.get(mod_type)(tracer, target, args_proxy, kwargs_proxy, + function_to_substitute) + + if handle is not None: + handle.generate() + for node_inserted in tracer.graph.nodes: + value_remap[node_inserted] = result_graph.node_copy(node_inserted, lambda n: value_remap[n]) + last_node = value_remap[node_inserted] + value_remap[orig_node] = last_node + else: + value_remap[orig_node] = result_graph.node_copy(orig_node, lambda n: value_remap[n]) + + del tracer + + gm.graph = result_graph + gm.recompile() + meta_prop_pass(gm, root_model, meta_args) diff --git a/colossalai/fx/tracer/meta_patch/__init__.py b/colossalai/fx/tracer/meta_patch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..192aef7a4ba0388a817c8146b55c41311faa577a --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/__init__.py @@ -0,0 +1,2 @@ +from .patched_function import * +from .patched_module import * diff --git a/colossalai/fx/tracer/meta_patch/patched_function/__init__.py b/colossalai/fx/tracer/meta_patch/patched_function/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00fdf6f5c328e4e92a3911094325ada2182c5f9 --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_function/__init__.py @@ -0,0 +1,6 @@ +from .activation_function import * +from .arithmetic import * +from .convolution import * +from .embedding import * +from .normalization import * +from .torch_ops import * diff --git a/colossalai/fx/tracer/meta_patch/patched_function/activation_function.py b/colossalai/fx/tracer/meta_patch/patched_function/activation_function.py new file mode 100644 index 0000000000000000000000000000000000000000..12c42514895e61777f0dbb1c348206af7d0ac5ec --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_function/activation_function.py @@ -0,0 +1,8 @@ +import torch + +from ...registry import meta_patched_function + + +@meta_patched_function.register(torch.nn.functional.relu) +def torch_nn_func_relu(input, inplace=False): + return torch.empty(input.shape, device='meta') diff --git a/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py b/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py new file mode 100644 index 0000000000000000000000000000000000000000..042b92c5847a4ed0d78d4acf068a0a5030fa7089 --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py @@ -0,0 +1,95 @@ +import torch + +from ...registry import meta_patched_function + + +@meta_patched_function.register(torch.matmul) +@meta_patched_function.register('matmul') # for built-in op @ +def torch_matmul(input, other, *, out=None): + # copied from huggingface.utils.fx + d1 = input.dim() + d2 = other.dim() + shape = None + if d1 == 1 and d2 == 1: + shape = None + elif d1 == 2 and d2 == 2: + shape = (input.size(0), other.size(1)) + elif d1 == 1 and d2 == 2: + shape = (other.size(1),) + elif d1 == 2 and d2 == 1: + shape = (input.size(0),) + else: + max_length = max(input.dim(), other.dim()) + shape1 = list(input.shape) + shape2 = list(other.shape) + if d1 == 1: + shape1 = [1] + shape1 + if d2 == 1: + shape2.append(1) + shape1 = [-1] * (max_length - d1) + list(input.shape) + shape2 = [-1] * (max_length - d2) + list(other.shape) + shape = [] + for i in range(max_length): + shape.append(max(shape1[i], shape2[i])) + shape[-2] = shape1[-2] + shape[-1] = shape2[-1] + if d1 == 1: + shape.pop(-2) + if d2 == 1: + shape.pop(-1) + if shape is None: + return torch.tensor(0.0, device="meta") + return torch.empty(*shape, device="meta") + + +@meta_patched_function.register(torch.abs) +def torch_abs(input, *, out=None): + assert out is None, 'out is not supported yet' + return torch.empty(input.shape, device='meta') + + +@meta_patched_function.register(torch.bmm) +def torch_bmm(input, mat2, *, out=None): + if out is not None: + raise ValueError("Don't support in-place abs for MetaTensor analysis") + batch_size, n, m = input.shape + _, _, p = mat2.shape + return torch.empty(batch_size, n, p, device="meta") + + +@meta_patched_function.register(torch.nn.functional.linear) +def torch_linear(input, mat2, bias=None, *, out=None): + if out is not None: + raise ValueError("Don't support in-place abs for MetaTensor analysis") + output_shape = list(input.shape) + output_feature = list(mat2.shape)[0] + output_shape[-1] = output_feature + return torch.empty(*output_shape, device="meta") + + +@meta_patched_function.register(torch.addbmm) +@meta_patched_function.register(torch.Tensor.addbmm) +def torch_addbmm(input, mat1, mat2, *, beta=1, alpha=1, out=None): + if out is not None: + raise ValueError("Don't support in-place abs for MetaTensor analysis") + _, n, _ = mat1.shape + _, _, p = mat2.shape + return torch.empty(n, p, device="meta") + + +@meta_patched_function.register(torch.addmm) +@meta_patched_function.register(torch.Tensor.addmm) +def torch_addmm(input, mat1, mat2, *, beta=1, alpha=1, out=None): + if out is not None: + raise ValueError("Don't support in-place abs for MetaTensor analysis") + n, _ = mat1.shape + _, p = mat2.shape + return torch.empty(n, p, device="meta") + + +@meta_patched_function.register(torch.var_mean) +def torch_var_mean(input, dim, unbiased=True, keepdim=False, *, out=None): + assert out is None, 'saving to out is not supported yet' + var = torch.empty(1).squeeze(0).to('meta') + mean = torch.empty(1).squeeze(0).to('meta') + return var, mean diff --git a/colossalai/fx/tracer/meta_patch/patched_function/convolution.py b/colossalai/fx/tracer/meta_patch/patched_function/convolution.py new file mode 100644 index 0000000000000000000000000000000000000000..8500e5c82508195ca3ec8ebc99a33ad2a1b946ad --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_function/convolution.py @@ -0,0 +1,180 @@ +import collections +import math +from itertools import repeat + +import torch + +from ...registry import meta_patched_function + + +def _ntuple(n, name="parse"): + + def parse(x): + if isinstance(x, collections.abc.Iterable): + return tuple(x) + return tuple(repeat(x, n)) + + parse.__name__ = name + return parse + + +_single = _ntuple(1, "_single") +_pair = _ntuple(2, "_pair") +_triple = _ntuple(3, "_triple") + + +def _extract_kwargs(kwargs): + if 'stride' in kwargs: + stride = kwargs['stride'] + else: + stride = 1 + # TODO: process str type padding + if 'padding' in kwargs: + padding = kwargs['padding'] + else: + padding = 0 + if 'dilation' in kwargs: + dilation = kwargs['dilation'] + else: + dilation = 1 + if 'output_padding' in kwargs: + output_padding = kwargs['output_padding'] + else: + output_padding = 0 + + return stride, padding, dilation, output_padding + + +@meta_patched_function.register(torch.nn.functional.conv1d) +def torch_nn_functional_conv1d(input, weight, **kwargs): + stride, padding, dilation, _ = _extract_kwargs(kwargs) + + stride = _single(stride) + padding = _single(padding) + dilation = _single(dilation) + + kernel_size = weight.shape[2:] + l_in = input.shape[-1] + c_out = weight.shape[0] + l_out = math.floor((l_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1) + result_shape = input.shape[:-2] + ( + c_out, + l_out, + ) + return torch.empty(result_shape, device='meta') + + +@meta_patched_function.register(torch.nn.functional.conv2d) +def torch_nn_functional_conv2d(input, weight, **kwargs): + stride, padding, dilation, _ = _extract_kwargs(kwargs) + + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + + kernel_size = weight.shape[2:] + h_in, w_in = input.shape[-2:] + c_out = weight.shape[0] + h_out = math.floor((h_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1) + w_out = math.floor((w_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1) + result_shape = input.shape[:-3] + ( + c_out, + h_out, + w_out, + ) + return torch.empty(result_shape, device='meta') + + +@meta_patched_function.register(torch.nn.functional.conv3d) +def torch_nn_functional_conv3d(input, weight, **kwargs): + stride, padding, dilation, _ = _extract_kwargs(kwargs) + + stride = _triple(stride) + padding = _triple(padding) + dilation = _triple(dilation) + + kernel_size = weight.shape[2:] + d_in, h_in, w_in = input.shape[-3:] + c_out = weight.shape[0] + d_out = math.floor((d_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1) + h_out = math.floor((h_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1) + w_out = math.floor((w_in + 2 * padding[2] - dilation[2] * (kernel_size[2] - 1) - 1) / stride[2] + 1) + result_shape = input.shape[:-4] + ( + c_out, + d_out, + h_out, + w_out, + ) + return torch.empty(result_shape, device='meta') + + +@meta_patched_function.register(torch.nn.functional.conv_transpose1d) +def torch_nn_functional_convtranspose1d(input, weight, **kwargs): + stride, padding, dilation, output_padding = _extract_kwargs(kwargs) + + stride = _single(stride) + padding = _single(padding) + dilation = _single(dilation) + output_padding = _single(output_padding) + + kernel_size = weight.shape[2:] + l_in = input.shape[-1] + c_out = weight.shape[1] + l_out = math.floor((l_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + + output_padding[0] + 1) + result_shape = input.shape[:-2] + ( + c_out, + l_out, + ) + return torch.empty(result_shape, device='meta') + + +@meta_patched_function.register(torch.nn.functional.conv_transpose2d) +def torch_nn_functional_convtranspose2d(input, weight, **kwargs): + stride, padding, dilation, output_padding = _extract_kwargs(kwargs) + + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + output_padding = _pair(output_padding) + + kernel_size = weight.shape[2:] + h_in, w_in = input.shape[-2:] + c_out = weight.shape[1] + h_out = math.floor((h_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + + output_padding[0] + 1) + w_out = math.floor((w_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) + + output_padding[1] + 1) + result_shape = input.shape[:-3] + ( + c_out, + h_out, + w_out, + ) + return torch.empty(result_shape, device='meta') + + +@meta_patched_function.register(torch.nn.functional.conv_transpose3d) +def torch_nn_functional_convtranspose3d(input, weight, **kwargs): + stride, padding, dilation, output_padding = _extract_kwargs(kwargs) + + stride = _triple(stride) + padding = _triple(padding) + dilation = _triple(dilation) + output_padding = _triple(output_padding) + + kernel_size = weight.shape[2:] + d_in, h_in, w_in = input.shape[-3:] + c_out = weight.shape[1] + d_out = math.floor((d_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + + output_padding[0] + 1) + h_out = math.floor((h_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) + + output_padding[1] + 1) + w_out = math.floor((w_in - 1) * stride[2] - 2 * padding[2] + dilation[2] * (kernel_size[2] - 1) + + output_padding[2] + 1) + result_shape = input.shape[:-4] + ( + c_out, + d_out, + h_out, + w_out, + ) + return torch.empty(result_shape, device='meta') diff --git a/colossalai/fx/tracer/meta_patch/patched_function/embedding.py b/colossalai/fx/tracer/meta_patch/patched_function/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..6d8d864ea29acd648f7f7097821c248655b0191e --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_function/embedding.py @@ -0,0 +1,14 @@ +import torch + +from ...registry import meta_patched_function + + +@meta_patched_function.register(torch.nn.functional.embedding) +def torch_nn_functional_embedding(input, + weight, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False): + return torch.empty(*input.shape, weight.shape[-1], device="meta") diff --git a/colossalai/fx/tracer/meta_patch/patched_function/normalization.py b/colossalai/fx/tracer/meta_patch/patched_function/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..e9e7eda6159c88ca3a82320c841360df87540c82 --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_function/normalization.py @@ -0,0 +1,20 @@ +import torch + +from ...registry import meta_patched_function + + +@meta_patched_function.register(torch.nn.functional.layer_norm) +def torch_nn_func_layernorm(input, normalized_shape, weight=None, bias=None, eps=1e-05): + return torch.empty(input.shape, device='meta') + + +@meta_patched_function.register(torch.nn.functional.batch_norm) +def torch_nn_func_batchnorm(input, + running_mean, + running_var, + weight=None, + bias=None, + training=False, + momentum=0.1, + eps=1e-05): + return torch.empty(input.shape, device='meta') diff --git a/colossalai/fx/tracer/meta_patch/patched_function/python_ops.py b/colossalai/fx/tracer/meta_patch/patched_function/python_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..4c171cb1099119de54b70b03097e1781880f2624 --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_function/python_ops.py @@ -0,0 +1,60 @@ +import operator + +import torch + +from colossalai.fx.proxy import ColoProxy + +from ...registry import meta_patched_function + + +@meta_patched_function.register(operator.getitem) +def operator_getitem(a, b): + # copied from huggingface.utils.fx + def to_concrete(t): + if isinstance(t, torch.Tensor): + concrete = torch.ones_like(t, device="cpu") + if concrete.dtype in [torch.float16, torch.float32, torch.float64, torch.int32]: + concrete = concrete.to(torch.int64) + return concrete + return t + + def _slice_convert(slice_obj): + attrs = {'start': slice_obj.start, 'stop': slice_obj.stop, 'step': slice_obj.step} + new_attrs = _slice_attr_convert(attrs) + attr_dict_to_tuple = (new_attrs['start'], new_attrs['stop'], new_attrs['step']) + return slice(*attr_dict_to_tuple) + + def _slice_attr_convert(attrs): + new_attrs = {} + for key, value in attrs.items(): + if isinstance(value, ColoProxy): + new_attrs[key] = value.meta_data + else: + new_attrs[key] = value + return new_attrs + + if isinstance(b, tuple): + b = list(b) + for index, element in enumerate(b): + if isinstance(element, slice): + b[index] = _slice_convert(element) + b = tuple(b) + elif isinstance(b, slice): + b = _slice_convert(b) + + if isinstance(a, torch.Tensor): + # TODO: infer shape without performing the computation. + if isinstance(b, tuple): + b = tuple(map(to_concrete, b)) + else: + b = to_concrete(b) + return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta") + + if isinstance(a, ColoProxy): + # TODO: infer shape without performing the computation. + if isinstance(b, tuple): + b = tuple(map(to_concrete, b)) + else: + b = to_concrete(b) + return operator.getitem(torch.empty_like(a.meta_data, device="cpu"), b).to("meta") + return operator.getitem(a, b) diff --git a/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py b/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..b14ff10ce1377055ed7f3ab3025ee7c05c6a1657 --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py @@ -0,0 +1,174 @@ +import torch + +from ...registry import meta_patched_function + + +@meta_patched_function.register(torch.arange) +def torch_arange(*args, **kwargs): + n = len(args) + step = 1 + if n == 1: + start = 0 + end = args[0] + elif n == 2: + start, end = args + else: + start, end, step = args + if isinstance(start, float): + start = int(start) + if isinstance(end, float): + start = int(end) + if isinstance(step, float): + step = int(step) + step = kwargs.get("step", step) + dtype = kwargs.get("dtype") + return torch.empty((end - start) // step, dtype=dtype, device="meta") + + +@meta_patched_function.register(torch.finfo) +def torch_finfo(*args): + return torch.finfo(*args) + + +@meta_patched_function.register(torch.where) +def torch_where(condition, x, y): + # torch.where returns the broadcasted tensor of condition, x, and y, + # so hack it by using addition + return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta") + + +@meta_patched_function.register(torch.Tensor.repeat) +def torch_tensor_repeat(self, *sizes): + shape = list(self.shape) + for i, x in enumerate(sizes): + shape[i] *= x + return torch.empty(shape, device="meta") + + +@meta_patched_function.register(torch.index_select) +def torch_index_select(input, dim, index, *, out=None): + shape = list(input.shape) + shape[dim] = len(index) + return torch.empty(*shape, device="meta") + + +@meta_patched_function.register(torch.Tensor.index_select) +def torch_tensor_index_select(self, dim, index): + return torch_index_select(self, dim, index) + + +@meta_patched_function.register(torch.squeeze) +def torch_squeeze(input, dim=None): + shape = list(input.shape) + if dim is not None: + if dim < 0: + dim = input.dim() + dim + if shape[dim] == 1: + shape.pop(dim) + else: + new_shape = [] + for dim_value in shape: + if dim_value == 1: + continue + new_shape.append(dim_value) + shape = new_shape + return torch.empty(shape, device="meta") + + +@meta_patched_function.register(torch.Tensor.squeeze) +def torch_tensor_squeeze(self, dim=None): + return torch_squeeze(self, dim) + + +@meta_patched_function.register(torch.unsqueeze) +def torch_unsqueeze(input, dim): + shape = list(input.shape) + if dim < 0: + dim = input.dim() + 1 + dim + shape.insert(dim, 1) + return torch.empty(shape, device="meta") + + +@meta_patched_function.register(torch.Tensor.unsqueeze) +def torch_tensor_unsqueeze(self, dim): + return torch_unsqueeze(self, dim) + + +@meta_patched_function.register(torch.cat) +def torch_cat(tensors, dim=None, axis=None, *, out=None): + if dim is None and axis is None: + dim = 0 + if dim is None and axis is not None: + dim = axis + if dim < 0: + dim = tensors[0].dim() + dim + shapes = [t.shape for t in tensors] + shape = list(shapes[0]) + concatenated_dim = sum(shape[dim] for shape in shapes) + final_shape = shape[:dim] + [concatenated_dim] + shape[dim + 1:] + return torch.empty(final_shape, device="meta") + + +@meta_patched_function.register(torch.repeat_interleave) +def torch_repeat_interleave(input, repeats, dim=None, output_size=None): + assert isinstance(repeats, int) or isinstance(repeats, torch.Tensor), \ + "Argument 'repeats' should be of type 'torch.Tensor' or 'int'" + + shape = list(input.shape) if dim is not None else [input.numel()] + dim = dim if dim is not None else 0 + dim = input.dim() + dim if dim < 0 else dim + + if isinstance(repeats, int): + shape[dim] = shape[dim] * repeats + elif isinstance(repeats, torch.Tensor): + shape[dim] = repeats.sum() + return torch.empty(shape, device="meta") + + +@meta_patched_function.register(torch.Tensor.repeat_interleave) +def torch_tensor_repeat_interleave(self, repeats, dim=None, *, output_size=None): + return torch_repeat_interleave(self, repeats, dim, output_size) + + +@meta_patched_function.register(torch.roll) +def torch_roll(input, shifts, dims=None): + return torch.empty(input.shape, device='meta') + + +@meta_patched_function.register(torch.full) +def torch_full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False): + assert out is None, 'assigning result to out is not supported yet' + return torch.empty(size, device='meta', dtype=dtype, layout=layout, requires_grad=requires_grad) + + +@meta_patched_function.register(torch.max) +def torch_max(input, dim=None, keepdim=False, *, out=None): + assert out is None, 'assigning value to out is not supported yet' + if dim is not None: + if isinstance(dim, int): + shape = list(input.shape) + shape.pop(dim) + if keepdim: + shape.insert(dim, 1) + return torch.empty(shape, device='meta', dtype=input.dtype), torch.empty(shape, + device='meta', + dtype=input.dtype) + elif isinstance(dim, torch.Tensor): + # when dim is a 0D or 1D tensor, it will maintain the same shape + num_dims = dim.dim() + if num_dims in [0, 1]: + return torch.empty_like(input, device='meta') + else: + raise ValueError(f"Expected dim to a 0D or 1D tensor but got {num_dims} dimensions") + else: + return torch.empty([], device='meta', dtype=input.dtype) + + +@meta_patched_function.register(torch.Tensor.cpu) +def torch_tensor_cpu(input): + return input.clone() + + +@meta_patched_function.register(torch.Tensor.cuda) +def torch_tensor_cuda(input, *args, **kwargs): + return input.clone() diff --git a/colossalai/fx/tracer/meta_patch/patched_module/__init__.py b/colossalai/fx/tracer/meta_patch/patched_module/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e28e52585fffc193473a7c8270c103919cc63e0d --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_module/__init__.py @@ -0,0 +1,7 @@ +from .activation_function import * +from .convolution import * +from .embedding import * +from .linear import * +from .normalization import * +from .pooling import * +from .rnn import * \ No newline at end of file diff --git a/colossalai/fx/tracer/meta_patch/patched_module/activation_function.py b/colossalai/fx/tracer/meta_patch/patched_module/activation_function.py new file mode 100644 index 0000000000000000000000000000000000000000..d03da6588c1cbf56403dccc5989f4a4987b7e2a9 --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_module/activation_function.py @@ -0,0 +1,13 @@ +import torch + +from ...registry import meta_patched_module + + +@meta_patched_module.register(torch.nn.ReLU) +@meta_patched_module.register(torch.nn.Sigmoid) +@meta_patched_module.register(torch.nn.GELU) +@meta_patched_module.register(torch.nn.Tanh) +@meta_patched_module.register(torch.nn.ReLU6) +@meta_patched_module.register(torch.nn.PReLU) +def torch_nn_non_linear_act(self, input): + return torch.empty(input.shape, device='meta') diff --git a/colossalai/fx/tracer/meta_patch/patched_module/convolution.py b/colossalai/fx/tracer/meta_patch/patched_module/convolution.py new file mode 100644 index 0000000000000000000000000000000000000000..cf9f3487aac9f31ff799e3245132c40d643de64f --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_module/convolution.py @@ -0,0 +1,113 @@ +import math + +import torch + +from ...registry import meta_patched_module + + +@meta_patched_module.register(torch.nn.Conv1d) +def torch_nn_conv1d(self, input): + # the output shape is calculated using the formula stated + # at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv1d + l_in = input.shape[-1] + c_out = self.out_channels + l_out = math.floor((l_in + 2 * self.padding[0] - self.dilation[0] * + (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) + result_shape = input.shape[:-2] + ( + c_out, + l_out, + ) + return torch.empty(result_shape, device='meta') + + +@meta_patched_module.register(torch.nn.Conv2d) +def torch_nn_conv2d(self, input): + # the output shape is calculated using the formula stated + # at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv2d + h_in, w_in = input.shape[-2:] + c_out = self.out_channels + h_out = math.floor((h_in + 2 * self.padding[0] - self.dilation[0] * + (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) + w_out = math.floor((w_in + 2 * self.padding[1] - self.dilation[1] * + (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1) + result_shape = input.shape[:-3] + ( + c_out, + h_out, + w_out, + ) + return torch.empty(result_shape, device='meta') + + +@meta_patched_module.register(torch.nn.Conv3d) +def torch_nn_conv3d(self, input): + # the output shape is calculated using the formula stated + # at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv3d + d_in, h_in, w_in = input.shape[-3:] + c_out = self.out_channels + d_out = math.floor((d_in + 2 * self.padding[0] - self.dilation[0] * + (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) + h_out = math.floor((h_in + 2 * self.padding[1] - self.dilation[1] * + (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1) + w_out = math.floor((w_in + 2 * self.padding[2] - self.dilation[2] * + (self.kernel_size[2] - 1) - 1) / self.stride[2] + 1) + result_shape = input.shape[:-4] + ( + c_out, + d_out, + h_out, + w_out, + ) + return torch.empty(result_shape, device='meta') + + +@meta_patched_module.register(torch.nn.ConvTranspose1d) +def torch_nn_convtranspose1d(self, input): + # the output shape is calculated using the formula stated + # at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html + l_in = input.shape[-1] + c_out = self.out_channels + l_out = math.floor((l_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] * + (self.kernel_size[0] - 1) + self.output_padding[0] + 1) + result_shape = input.shape[:-2] + ( + c_out, + l_out, + ) + return torch.empty(result_shape, device='meta') + + +@meta_patched_module.register(torch.nn.ConvTranspose2d) +def torch_nn_convtranspose2d(self, input): + # the output shape is calculated using the formula stated + # at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html + h_in, w_in = input.shape[-2:] + c_out = self.out_channels + h_out = math.floor((h_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] * + (self.kernel_size[0] - 1) + self.output_padding[0] + 1) + w_out = math.floor((w_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] * + (self.kernel_size[1] - 1) + self.output_padding[1] + 1) + result_shape = input.shape[:-3] + ( + c_out, + h_out, + w_out, + ) + return torch.empty(result_shape, device='meta') + + +@meta_patched_module.register(torch.nn.ConvTranspose3d) +def torch_nn_convtranspose3d(self, input): + # the output shape is calculated using the formula stated + # at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html + d_in, h_in, w_in = input.shape[-3:] + c_out = self.out_channels + d_out = math.floor((d_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] * + (self.kernel_size[0] - 1) + self.output_padding[0] + 1) + h_out = math.floor((h_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] * + (self.kernel_size[1] - 1) + self.output_padding[1] + 1) + w_out = math.floor((w_in - 1) * self.stride[2] - 2 * self.padding[2] + self.dilation[2] * + (self.kernel_size[2] - 1) + self.output_padding[2] + 1) + result_shape = input.shape[:-4] + ( + c_out, + d_out, + h_out, + w_out, + ) + return torch.empty(result_shape, device='meta') diff --git a/colossalai/fx/tracer/meta_patch/patched_module/embedding.py b/colossalai/fx/tracer/meta_patch/patched_module/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..999e33b17c1c7b442d2a6db73f957be4413f1fa1 --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_module/embedding.py @@ -0,0 +1,9 @@ +import torch + +from ...registry import meta_patched_module + + +@meta_patched_module.register(torch.nn.Embedding) +def torch_nn_embedding(self, input): + result_shape = input.shape + (self.embedding_dim,) + return torch.empty(result_shape, device='meta') diff --git a/colossalai/fx/tracer/meta_patch/patched_module/linear.py b/colossalai/fx/tracer/meta_patch/patched_module/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..56f13bf97532e26770a0be7a4226ee69a2124ee5 --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_module/linear.py @@ -0,0 +1,10 @@ +import torch + +from ...registry import meta_patched_module + + +@meta_patched_module.register(torch.nn.Linear) +def torch_nn_linear(self, input): + last_dim = input.shape[-1] + assert last_dim == self.in_features, f'Expected hidden size {self.in_features} but got {last_dim} for the torch.nn.Linear patch' + return torch.empty(input.shape[:-1] + (self.out_features,), device="meta") diff --git a/colossalai/fx/tracer/meta_patch/patched_module/normalization.py b/colossalai/fx/tracer/meta_patch/patched_module/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..c21ff64cf3dec9baf357771fa0d15b341b413ac1 --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_module/normalization.py @@ -0,0 +1,31 @@ +import torch + +from ...registry import meta_patched_module + + +@meta_patched_module.register(torch.nn.LayerNorm) +@meta_patched_module.register(torch.nn.GroupNorm) +@meta_patched_module.register(torch.nn.BatchNorm1d) +@meta_patched_module.register(torch.nn.BatchNorm2d) +@meta_patched_module.register(torch.nn.BatchNorm3d) +def torch_nn_normalize(self, input): + # check shape + if isinstance(self, torch.nn.BatchNorm1d): + assert input.dim() in [2, 3] + elif isinstance(self, torch.nn.BatchNorm2d): + assert input.dim() == 4 + elif isinstance(self, torch.nn.BatchNorm3d): + assert input.dim() == 5 + + # normalization maintain the same shape as the input + return input.clone() + + +try: + import apex + meta_patched_module.register(apex.normalization.FusedLayerNorm)(torch_nn_normalize) + meta_patched_module.register(apex.normalization.FusedRMSNorm)(torch_nn_normalize) + meta_patched_module.register(apex.normalization.MixedFusedLayerNorm)(torch_nn_normalize) + meta_patched_module.register(apex.normalization.MixedFusedRMSNorm)(torch_nn_normalize) +except (ImportError, AttributeError): + pass diff --git a/colossalai/fx/tracer/meta_patch/patched_module/pooling.py b/colossalai/fx/tracer/meta_patch/patched_module/pooling.py new file mode 100644 index 0000000000000000000000000000000000000000..7ce23fbf7ac9368f4ec8496b252494e779fb5015 --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_module/pooling.py @@ -0,0 +1,202 @@ +import math + +import torch + +from ...registry import meta_patched_module + + +@meta_patched_module.register(torch.nn.AvgPool1d) +def torch_nn_avgpool1d(self, input): + num_dim = input.dim() + assert num_dim in [2, 3], f'expected the input to have 2 or 3 dimensions, but got {num_dim} dimensions' + + l_in = input.shape[-1] + + def _convert_int_to_list(item): + if isinstance(item, int): + return [item] * 1 + else: + return item + + padding = _convert_int_to_list(self.padding) + kernel_size = _convert_int_to_list(self.kernel_size) + stride = _convert_int_to_list(self.stride) + + l_out = math.floor((l_in + 2 * padding[0] - kernel_size[0]) / stride[0] + 1) + + result_shape = tuple(input.shape[:-1]) + (l_out,) + return torch.empty(result_shape, device='meta') + + +@meta_patched_module.register(torch.nn.AvgPool2d) +def torch_nn_avgpool2d(self, input): + num_dim = input.dim() + assert num_dim in [3, 4], f'expected the input to have 3 or 4 dimensions, but got {num_dim} dimensions' + + h_in, w_in = input.shape[-2:] + + def _convert_int_to_list(item): + if isinstance(item, int): + return [item] * 2 + else: + return item + + padding = _convert_int_to_list(self.padding) + kernel_size = _convert_int_to_list(self.kernel_size) + stride = _convert_int_to_list(self.stride) + + h_out = math.floor((h_in + 2 * padding[0] - kernel_size[0]) / stride[0] + 1) + w_out = math.floor((w_in + 2 * padding[1] - kernel_size[1]) / stride[1] + 1) + + result_shape = tuple(input.shape[:-2]) + ( + h_out, + w_out, + ) + return torch.empty(result_shape, device='meta') + + +@meta_patched_module.register(torch.nn.AvgPool3d) +def torch_nn_avgpool3d(self, input): + num_dim = input.dim() + assert num_dim in [4, 5], f'expected the input to have 4 or 5 dimensions, but got {num_dim} dimensions' + + d_in, h_in, w_in = input.shape[-3:] + + def _convert_int_to_list(item): + if isinstance(item, int): + return [item] * 3 + else: + return item + + padding = _convert_int_to_list(self.padding) + kernel_size = _convert_int_to_list(self.kernel_size) + stride = _convert_int_to_list(self.stride) + + d_out = math.floor((d_in + 2 * padding[0] - kernel_size[0]) / stride[0] + 1) + h_out = math.floor((h_in + 2 * padding[1] - kernel_size[1]) / stride[1] + 1) + w_out = math.floor((w_in + 2 * padding[2] - kernel_size[2]) / stride[2] + 1) + + result_shape = tuple(input.shape[:-3]) + ( + d_out, + h_out, + w_out, + ) + return torch.empty(result_shape, device='meta') + + +@meta_patched_module.register(torch.nn.MaxPool1d) +def torch_nn_maxpool1d(self, input): + num_dim = input.dim() + assert num_dim in [2, 3], f'expected the input to have 2 or 3 dimensions, but got {num_dim} dimensions' + + l_in = input.shape[-1] + + def _convert_int_to_list(item): + if isinstance(item, int): + return [item] * 1 + else: + return item + + padding = _convert_int_to_list(self.padding) + dilation = _convert_int_to_list(self.dilation) + kernel_size = _convert_int_to_list(self.kernel_size) + stride = _convert_int_to_list(self.stride) + + l_out = math.floor((l_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1) + + result_shape = tuple(input.shape[:-1]) + (l_out,) + return torch.empty(result_shape, device='meta') + + +@meta_patched_module.register(torch.nn.MaxPool2d) +def torch_nn_maxpool2d(self, input): + num_dim = input.dim() + assert num_dim in [3, 4], f'expected the input to have 3 or 4 dimensions, but got {num_dim} dimensions' + + h_in, w_in = input.shape[-2:] + + def _convert_int_to_list(item): + if isinstance(item, int): + return [item] * 2 + else: + return item + + padding = _convert_int_to_list(self.padding) + dilation = _convert_int_to_list(self.dilation) + kernel_size = _convert_int_to_list(self.kernel_size) + stride = _convert_int_to_list(self.stride) + + h_out = math.floor((h_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1) + w_out = math.floor((w_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1) + + result_shape = tuple(input.shape[:-2]) + ( + h_out, + w_out, + ) + return torch.empty(result_shape, device='meta') + + +@meta_patched_module.register(torch.nn.MaxPool3d) +def torch_nn_maxpool3d(self, input): + num_dim = input.dim() + assert num_dim in [4, 5], f'expected the input to have 4 or 5 dimensions, but got {num_dim} dimensions' + + d_in, h_in, w_in = input.shape[-3:] + + def _convert_int_to_list(item): + if isinstance(item, int): + return [item] * 3 + else: + return item + + padding = _convert_int_to_list(self.padding) + dilation = _convert_int_to_list(self.dilation) + kernel_size = _convert_int_to_list(self.kernel_size) + stride = _convert_int_to_list(self.stride) + + d_out = math.floor((d_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1) + h_out = math.floor((h_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1) + w_out = math.floor((w_in + 2 * padding[2] - dilation[2] * (kernel_size[2] - 1) - 1) / stride[2] + 1) + + result_shape = tuple(input.shape[:-3]) + ( + d_out, + h_out, + w_out, + ) + return torch.empty(result_shape, device='meta') + + +@meta_patched_module.register(torch.nn.AdaptiveAvgPool1d) +@meta_patched_module.register(torch.nn.AdaptiveMaxPool1d) +def torch_nn_adapative_pooling_1d(self, input): + assert input.dim() in [2, 3] + if isinstance(self.output_size, int): + output_size = (self.output_size,) + else: + output_size = self.output_size + result_shape = tuple(input.shape[:-1]) + output_size + return torch.empty(result_shape, device='meta') + + +@meta_patched_module.register(torch.nn.AdaptiveAvgPool2d) +@meta_patched_module.register(torch.nn.AdaptiveMaxPool2d) +def torch_nn_adapative_pooling_2d(self, input): + assert input.dim() in [3, 4] + if isinstance(self.output_size, int): + output_size = (self.output_size,) * 2 + else: + output_size = self.output_size + result_shape = tuple(input.shape[:-2]) + output_size + return torch.empty(result_shape, device='meta') + + +@meta_patched_module.register(torch.nn.AdaptiveAvgPool3d) +@meta_patched_module.register(torch.nn.AdaptiveMaxPool3d) +def torch_nn_adapative_pooling_3d(self, input): + assert input.dim() in [4, 5] + if isinstance(self.output_size, int): + output_size = (self.output_size,) * 3 + else: + output_size = self.output_size + result_shape = tuple(input.shape[:-3]) + output_size + return torch.empty(result_shape, device='meta') diff --git a/colossalai/fx/tracer/meta_patch/patched_module/rnn.py b/colossalai/fx/tracer/meta_patch/patched_module/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..ee15ca34162e83612eb179e0cff066d9f06faf36 --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_module/rnn.py @@ -0,0 +1,16 @@ +from typing import Optional + +import torch + +from ...registry import meta_patched_module + + +@meta_patched_module.register(torch.nn.GRU) +@meta_patched_module.register(torch.nn.RNN) +def torch_nn_rnn(self, input, hx): + assert input.shape[ + -1] == self.input_size, f'Expected input to have input size {self.input_size} but got {input.shape[-1]} for the torch.nn.RNN patch' + assert hx.shape[ + -1] == self.hidden_size, f'Expected hx to have hidden size {self.hidden_size} but got {hx.shape[-1]} for the torch.nn.RNN patch' + d = 2 if self.bidirectional else 1 + return torch.empty(input.shape[:-1] + (self.hidden_size * d,), device="meta"), hx diff --git a/colossalai/fx/tracer/registry.py b/colossalai/fx/tracer/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..12fc6de73d4435dea8ec58fa50b93a6070fd6254 --- /dev/null +++ b/colossalai/fx/tracer/registry.py @@ -0,0 +1,28 @@ +class PatchRegistry: + + def __init__(self, name): + self.name = name + self.store = {} + + def register(self, source): + + def wrapper(func): + self.store[source] = func + return func + + return wrapper + + def get(self, source): + assert source in self.store + target = self.store[source] + return target + + def has(self, source): + return source in self.store + + +meta_patched_function = PatchRegistry(name='patched_functions_for_meta_execution') +meta_patched_module = PatchRegistry(name='patched_modules_for_meta_execution') +bias_addition_function = PatchRegistry(name='patched_function_for_bias_addition') +bias_addition_module = PatchRegistry(name='patched_module_for_bias_addition') +bias_addition_method = PatchRegistry(name='patched_method_for_bias_addition') diff --git a/colossalai/fx/tracer/tracer.py b/colossalai/fx/tracer/tracer.py new file mode 100644 index 0000000000000000000000000000000000000000..1ae31f9589756d4552f6c2247e0b71679625832b --- /dev/null +++ b/colossalai/fx/tracer/tracer.py @@ -0,0 +1,558 @@ +#!/usr/bin/env python +""" +tracer.py: + Implemented a tracer which supports control flow and user-defined meta arguments. + The implementation is partly inspired HuggingFace's fx tracer +""" +import enum +import functools +import inspect +import operator +from contextlib import contextmanager +from typing import Any, Dict, Optional + +import torch +import torch.nn as nn +from torch import Tensor +from torch.fx import Node, Tracer +from torch.fx.graph import Graph, magic_methods, reflectable_magic_methods +from torch.fx.proxy import ParameterProxy, Proxy + +from ..proxy import ColoProxy +from ._tracer_utils import compute_meta_data_for_functions_proxy, extract_meta, is_element_in_list +from .bias_addition_patch import func_to_func_dict, method_to_func_dict, module_to_func_dict +from .registry import ( + bias_addition_function, + bias_addition_method, + bias_addition_module, + meta_patched_function, + meta_patched_module, +) + +__all__ = ['ColoTracer'] + + +class TracerType(enum.Enum): + DEFAULT = 1 + META = 2 + + +class ColoTracer(Tracer): + """ + ColoTracer is a symbolic tracer designed to support dynamic control flow by using meta tensors for the `colossalai.fx` module. + This tracer is initialized in the same way as the original torch.fx.Tracer. + + Usage:: + + class Model(nn.Module): + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(10, 10) + self.linear2 = nn.Linear(10, 10) + + def forward(self, x, y): + x1 = self.linear1(x) + y1 = self.linear2(y) + + if x1.dim() == 2: + return x1 + y1 + else: + return x1 - y1 + + model = Model() + tracer = ColoTracer() + graph = tracer.trace(model, concrete_args={'y': torch.rand(4, 10)}, meta_args={'x': torch.rand(4, 10, device='meta')}) + """ + + def __init__(self, trace_act_ckpt: bool = False, *args, **kwargs): + super().__init__(*args, **kwargs) + self.tracer_type = TracerType.META + self.proxy_cls = ColoProxy + + # whether the tracer will record the usage of torch.utils.checkpoint + self.trace_act_ckpt = trace_act_ckpt + # whether the current tracing occurs within the activation checkpoint functions + self.inside_torch_checkpoint_func = False + self.act_ckpt_region_count = 0 + + # Feature flag for proxying accesses to buffer values + proxy_buffer_attributes: bool = True + + _TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty", "tensor", "finfo"] + + def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None) -> ColoProxy: + """ + Create a proxy for different kinds of operations. + """ + + if self.tracer_type == TracerType.DEFAULT: + # since meta_args is not given + # we just fall back to the original torch.fx.Tracer + proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn) + return proxy + + # if graph is traced for auto parallelism module, some extra node will be added during + # graph construction to deal with the compatability between bias addition and all reduce. + + # if no extra manipulation is applied, we just pass the origin arguments to create_proxy function + # to create node on computation graph + origin_arguments = (kind, target, args, kwargs, name, type_expr, proxy_factory_fn) + # dispatch the arguments generator depending on the kind and target in origin arguments. + args_metas, _ = extract_meta(*args, **kwargs) + handle = None + if kind == "call_function": + if bias_addition_function.has(target): + if target == torch.nn.functional.linear: + if 'bias' in kwargs and kwargs['bias'] is not None: + function_to_substitute = func_to_func_dict[target] + handle = bias_addition_function.get(target)(self, target, args, kwargs, function_to_substitute) + else: + function_to_substitute = func_to_func_dict[target] + handle = bias_addition_function.get(target)(self, target, args, kwargs, function_to_substitute) + elif bias_addition_function.has(target.__name__): + # use name for some builtin op like @ (matmul) + function_to_substitute = func_to_func_dict[target] + handle = bias_addition_function.get(target.__name__)(self, target, args, kwargs, function_to_substitute) + + elif kind == "call_method": + method = getattr(args_metas[0].__class__, target) + if bias_addition_method.has(method): + function_to_substitute = method_to_func_dict[method] + handle = bias_addition_method.get(method)(self, target, args, kwargs, function_to_substitute) + + elif kind == "call_module": + if not hasattr(self, "orig_forward"): + raise AttributeError(f"{self} does not have an attribute called orig_forward") + self._disable_module_getattr = True + try: + mod = self.root.get_submodule(target) + mod_type = type(mod) + if bias_addition_module.has(mod_type) and mod.bias is not None: + function_to_substitute = module_to_func_dict[mod_type] + handle = bias_addition_module.get(mod_type)(self, target, args, kwargs, function_to_substitute) + finally: + self._disable_module_getattr = False + + if handle is not None: + return handle.generate() + + # create nodes using patched arguments + proxy = super().create_proxy(*origin_arguments) + proxy: ColoProxy + meta_out = self._meta_data_computing( + kind, + target, + args, + kwargs, + ) + proxy.meta_data = meta_out + + return proxy + + def _module_getattr(self, attr, attr_val, parameter_proxy_cache): + if getattr(self, "_disable_module_getattr", False): + return attr_val + else: + # return super()._module_getattr(attr, attr_val, parameter_proxy_cache) + def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache): + for n, p in collection_to_search: + if attr_val is p: + if n not in parameter_proxy_cache: + kwargs = {} + if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters: + kwargs["proxy_factory_fn"] = (None if not self.param_shapes_constant else + lambda node: ParameterProxy(self, node, n, attr_val)) + val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type] + parameter_proxy_cache[n] = val_proxy + return parameter_proxy_cache[n] + return None + + if isinstance(attr_val, torch.nn.Parameter): + maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(), + parameter_proxy_cache) + if maybe_parameter_proxy is not None: + return maybe_parameter_proxy + + if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor): + maybe_buffer_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_buffers(), + parameter_proxy_cache) + if maybe_buffer_proxy is not None: + return maybe_buffer_proxy + + return attr_val + + def call_module(self, m, forward, args, kwargs): + self.orig_forward = forward + module_qualified_name = self.path_of_module(m) + + # a leaf module is the torch.nn.Module subclasses starting with `torch.nn` + # which means customized modules are not leaf module by default + # if a customized or third-party module like apex.normalization.FusedRMSNorm is patched, + # we should treat it as leaf module as well + if meta_patched_module.has(m.__class__) or self.is_leaf_module(m, module_qualified_name): + return self.create_proxy('call_module', module_qualified_name, args, kwargs) + else: + return forward(*args, **kwargs) + + def proxy(self, node) -> Proxy: + """ + Returns a ColoProxy object. + """ + return self.proxy_cls(node, self) + + def _configure_tracer_type(self, tracer_type: TracerType): + if tracer_type == TracerType.DEFAULT: + self.proxy_cls = Proxy + self.tracer_type = TracerType.DEFAULT + elif tracer_type == TracerType.META: + self.proxy_cls = ColoProxy + self.tracer_type = TracerType.META + else: + raise ValueError(f"Unrecognised tracer type {tracer_type}") + + def _meta_data_computing(self, kind, target, args, kwargs): + + if kind == "placeholder" and target in self.meta_args and self.meta_args[target].is_meta: + meta_out = self.meta_args[target] + return meta_out + + if target in self.orig_torch_tensor_methods: + # NOTE: tensor constructors in PyTorch define the `device` argument as + # *kwargs-only*. That is why this works. If you add methods to + # _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only, + # this will break and you will likely see issues where we cannot infer + # the size of the output. + if "device" in kwargs: + kwargs["device"] = "meta" + + try: + args_metas, kwargs_metas = extract_meta(*args, **kwargs) + + if kind == "call_function": + # Our meta data will not record the nn.parameter.Parameter attribute。 + # It works fine in most of the case, but it may cause some problems after + # the bias addition manipulation. + # Therefore, I need to record the nn.parameter.Parameter attribute for the operation + # added by the bias addition manipulation following the get_attr node. + convert_to_parameter = False + if target in (torch.transpose, torch.reshape) and isinstance(args_metas[0], + torch.nn.parameter.Parameter): + convert_to_parameter = True + # fetch patched function + if meta_patched_function.has(target): + meta_target = meta_patched_function.get(target) + elif meta_patched_function.has(target.__name__): + # use name for some builtin op like @ (matmul) + meta_target = meta_patched_function.get(target.__name__) + else: + meta_target = target + + meta_out = meta_target(*args_metas, **kwargs_metas) + if isinstance(meta_out, torch.Tensor): + meta_out = meta_out.to(device="meta") + if convert_to_parameter: + meta_out = torch.nn.Parameter(meta_out) + + elif kind == "call_method": + # Our meta data will not record the nn.parameter.Parameter attribute。 + # It works fine in most of the case, but it may cause some problems after + # the bias addition manipulation. + # Therefore, I need to record the nn.parameter.Parameter attribute for the operation + # added by the bias addition manipulation following the get_attr node. + convert_to_parameter = False + if target in (torch.Tensor.view,) and isinstance(args_metas[0], torch.nn.parameter.Parameter): + convert_to_parameter = True + method = getattr(args_metas[0].__class__, target) + + # fetch patched method + if meta_patched_function.has(method): + meta_target = meta_patched_function.get(method) + else: + meta_target = method + + meta_out = meta_target(*args_metas, **kwargs_metas) + if convert_to_parameter: + meta_out = torch.nn.Parameter(meta_out) + elif kind == "call_module": + if not hasattr(self, "orig_forward"): + raise AttributeError(f"{self} does not have an attribute called orig_forward") + self._disable_module_getattr = True + try: + mod = self.root.get_submodule(target) + mod_type = type(mod) + if meta_patched_module.has(mod_type): + meta_out = meta_patched_module.get(mod_type)(mod, *args_metas, **kwargs_metas) + else: + meta_out = self.orig_forward(*args_metas, **kwargs_metas) + finally: + self._disable_module_getattr = False + elif kind == "get_attr": + self._disable_module_getattr = True + try: + attr_itr = self.root + atoms = target.split(".") + for atom in atoms: + attr_itr = getattr(attr_itr, atom) + if isinstance(attr_itr, torch.nn.parameter.Parameter): + meta_out = torch.nn.Parameter(attr_itr.to(device="meta")) + elif isinstance(attr_itr, torch.Tensor): + meta_out = attr_itr.to(device="meta") + else: + meta_out = attr_itr + finally: + self._disable_module_getattr = False + else: + return None + + except Exception as e: + raise RuntimeError(f"Could not compute metadata for {kind} target {target}: {e}") + + return meta_out + + def trace(self, + root: nn.Module, + concrete_args: Optional[Dict[str, Tensor]] = None, + meta_args: Optional[Dict[str, Tensor]] = None) -> Graph: + """ + Trace the forward computation graph using `torch.fx.Tracer`. This tracer enables data-dependent control flow. + + Args: + root (nn.Module): a `nn.Module` object to trace the computation graph + meta_args (Optional[Dict[str, Tensor]]): the meta tensor arguments used to trace the computation graph. + These arguments are the sample data fed to the model during actual computation, but just converted to meta tensors. + concrete_args (Optional[Dict[str, Tensor]]): the concrete arguments that should not be treated as Proxies. + """ + if meta_args is None: + meta_args = {} + + if concrete_args is None: + concrete_args = {} + + if len(meta_args) == 0: + self._configure_tracer_type(TracerType.DEFAULT) + else: + self._configure_tracer_type(TracerType.META) + + # check concrete and meta args have valid names + sig = inspect.signature(root.forward) + sig_names = set(sig.parameters.keys()) + meta_arg_names = set(meta_args.keys()) + + # update concrete args with default values + non_meta_arg_names = sig_names - meta_arg_names + for k, v in sig.parameters.items(): + if k in non_meta_arg_names and \ + k not in concrete_args and \ + v.default is not inspect.Parameter.empty: + concrete_args[k] = v.default + + # get non concrete arg names + concrete_arg_names = set(concrete_args.keys()) + non_concrete_arg_names = sig_names - concrete_arg_names + + def _check_arg_name_valid(names): + success, element = is_element_in_list(names, sig_names) + if not success: + raise KeyError( + f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function") + + _check_arg_name_valid(meta_arg_names) + _check_arg_name_valid(concrete_arg_names) + + # assign as attributed for late reference + def _check_kwargs(kwargs, should_be_meta: bool): + for k, v in kwargs.items(): + if not should_be_meta: + assert not torch.is_tensor(v) or not v.is_meta, \ + f'Expected the {k} not to be a meta tensor, please check the args passed to the tracer' + else: + assert v.is_meta == should_be_meta, \ + f'Expected the is_meta attribute of {k} to be {should_be_meta}, but got {v.is_meta}, please check the args passed to the tracer' + + _check_kwargs(concrete_args, should_be_meta=False) + _check_kwargs(meta_args, should_be_meta=True) + + self.concrete_args = concrete_args + self.meta_args = meta_args + + self.patched_torch_tensor_methods = {} + if self.tracer_type == TracerType.META: + # wrap the torch tensor constructing methods so that they are captured in the graph + self.patched_torch_tensor_methods = { + target: wrap_tensor_constructor_method(getattr(torch, target)) + for target in self._TORCH_METHODS_TO_PATCH + } + + # patch these methods to replace their original use + for name, (wrapper, orig) in self.patched_torch_tensor_methods.items(): + setattr(torch, name, wrapper) + + # cache these methods so that we can detect whether a method call + # should be patched during tracing + self.orig_torch_tensor_methods = [val[1] for val in self.patched_torch_tensor_methods.values()] + + try: + # to track the usage of torch.utils.checkpoint + with self.trace_activation_checkpoint(enabled=self.trace_act_ckpt): + self.graph = super().trace(root, concrete_args=concrete_args) + + finally: + # recover the patched methods + for name, (_, orig) in self.patched_torch_tensor_methods.items(): + setattr(torch, name, orig) + + if self.tracer_type == TracerType.DEFAULT: + return self.graph + + # This is necessary because concrete args are added as input to the traced module since + # https://github.com/pytorch/pytorch/pull/55888. + for node in self.graph.nodes: + if node.op == "placeholder": + # Removing default values for inputs as the forward pass will fail with them. + if node.target in non_concrete_arg_names: + node.args = () + # Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor]. + # It cannot infer on the attributes and methods the input should have, and fails. + node.type = torch.Tensor + # It is a concrete arg so it is not used and should be removed. + else: + if hasattr(torch.fx._symbolic_trace, "_assert_is_none"): + # Newer versions of torch.fx emit an assert statement + # for concrete arguments; delete those before we delete + # the concrete arg. + to_delete = [] + for user in node.users: + if user.target == torch.fx._symbolic_trace._assert_is_none: + to_delete.append(user) + for user in to_delete: + self.graph.erase_node(user) + + self.graph.erase_node(node) + + # TODO: solves GraphModule creation. + # Without this, return type annotation "Tuple" is causing code execution failure. + if node.op == "output": + node.type = None + + return self.graph + + @contextmanager + def trace_activation_checkpoint(self, enabled: bool): + if enabled: + orig_ckpt_func = torch.utils.checkpoint.CheckpointFunction + + class PatchedCheckpointFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, run_function, preserve_rng_state, *args): + # signal that the current tracing occurs within activaton checkpoint part + self.inside_torch_checkpoint_func = True + out = run_function(*args) + self.inside_torch_checkpoint_func = False + self.act_ckpt_region_count += 1 + return out + + @staticmethod + def backward(ctx: Any, *grad_outputs: Any) -> Any: + raise NotImplementedError( + "We do not implement the backward pass as we only trace the forward pass.") + + # override the checkpoint function + torch.utils.checkpoint.CheckpointFunction = PatchedCheckpointFunction + yield + + if enabled: + # recover the checkpoint function upon exit + torch.utils.checkpoint.CheckpointFunction = orig_ckpt_func + + def create_node(self, *args, **kwargs) -> Node: + node = super().create_node(*args, **kwargs) + + if self.inside_torch_checkpoint_func: + # annotate the activation checkpoint module + node.meta['activation_checkpoint'] = self.act_ckpt_region_count + return node + + +def wrap_tensor_constructor_method(target): + + def look_for_proxy(*args, **kwargs): + # find in pos vars + for arg in args: + if isinstance(arg, Proxy): + return arg + if isinstance(arg, (tuple, list)): + return look_for_proxy(*arg) + + # find in keyword vars + for k, v in kwargs.items(): + if isinstance(v, Proxy): + return v + if isinstance(v, (tuple, list)): + return look_for_proxy(*v) + return None + + @functools.wraps(target) + def wrapper(*args, **kwargs): + proxy = look_for_proxy(*args, **kwargs) + + if proxy is not None: + # if the arg is a proxy, then need to record this function called on this proxy + # e.g. torch.ones(size) where size is an input proxy + colo_proxy = proxy.tracer.create_proxy("call_function", target, args, kwargs) + if not isinstance(colo_proxy, ColoProxy): + meta_out = compute_meta_data_for_functions_proxy(target, args, kwargs) + colo_proxy = ColoProxy(proxy.node) + colo_proxy.meta_data = meta_out + return colo_proxy + else: + # this is called directly when the inputs do not contain proxy + # e.g. torch.ones(4) where the input is static + return target(*args, **kwargs) + + return wrapper, target + + +# Patched magic methods for ColoProxy, then tracer could record the magic_method like __sub__, +# and add meta_data attribute to the created proxy. +for method in magic_methods: + + def _scope(method): + + def impl(*args, **kwargs): + + tracer = args[0].tracer + target = getattr(operator, method) + proxy = tracer.create_proxy('call_function', target, args, kwargs) + if not isinstance(proxy, ColoProxy): + meta_out = compute_meta_data_for_functions_proxy(target, args, kwargs) + proxy = ColoProxy(proxy.node) + proxy.meta_data = meta_out + return proxy + + impl.__name__ = method + as_magic = f'__{method.strip("_")}__' + setattr(ColoProxy, as_magic, impl) + + _scope(method) + + +def _define_reflectable(orig_method_name): + method_name = f'__r{orig_method_name.strip("_")}__' + + def impl(self, rhs): + target = getattr(operator, orig_method_name) + proxy = self.tracer.create_proxy('call_function', target, (rhs, self), {}) + if not isinstance(proxy, ColoProxy): + meta_out = compute_meta_data_for_functions_proxy(target, *(rhs, self), {}) + proxy = ColoProxy(proxy.node) + proxy.meta_data = meta_out + return proxy + + impl.__name__ = method_name + impl.__qualname__ = method_name + setattr(ColoProxy, method_name, impl) + + +for orig_method_name in reflectable_magic_methods: + _define_reflectable(orig_method_name) diff --git a/colossalai/global_variables.py b/colossalai/global_variables.py new file mode 100644 index 0000000000000000000000000000000000000000..61b31965e2e63d2119bfadba0d49478537c31fa7 --- /dev/null +++ b/colossalai/global_variables.py @@ -0,0 +1,56 @@ +from typing import Optional + + +class TensorParallelEnv(object): + _instance = None + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + cls._instance = object.__new__(cls, *args, **kwargs) + return cls._instance + + def __init__(self, *args, **kwargs): + self.load(*args, **kwargs) + + def load(self, + mode: Optional[str] = None, + vocab_parallel: bool = False, + parallel_input_1d: bool = False, + summa_dim: int = None, + tesseract_dim: int = None, + tesseract_dep: int = None, + depth_3d: int = None, + input_group_3d=None, + weight_group_3d=None, + output_group_3d=None, + input_x_weight_group_3d=None, + output_x_weight_group_3d=None): + self.mode = mode + self.vocab_parallel = vocab_parallel + self.parallel_input_1d = parallel_input_1d + self.summa_dim = summa_dim + self.tesseract_dim = tesseract_dim + self.tesseract_dep = tesseract_dep + self.depth_3d = depth_3d + self.input_group_3d = input_group_3d + self.weight_group_3d = weight_group_3d + self.output_group_3d = output_group_3d + self.input_x_weight_group_3d = input_x_weight_group_3d + self.output_x_weight_group_3d = output_x_weight_group_3d + + def save(self): + return dict(mode=self.mode, + vocab_parallel=self.vocab_parallel, + parallel_input_1d=self.parallel_input_1d, + summa_dim=self.summa_dim, + tesseract_dim=self.tesseract_dim, + tesseract_dep=self.tesseract_dep, + depth_3d=self.depth_3d, + input_group_3d=self.input_group_3d, + weight_group_3d=self.weight_group_3d, + output_group_3d=self.output_group_3d, + input_x_weight_group_3d=self.input_x_weight_group_3d, + output_x_weight_group_3d=self.output_x_weight_group_3d) + + +tensor_parallel_env = TensorParallelEnv() diff --git a/colossalai/initialize.py b/colossalai/initialize.py new file mode 100644 index 0000000000000000000000000000000000000000..5d3f3e5530cbb71a6e88cd3726615c1b6b61a164 --- /dev/null +++ b/colossalai/initialize.py @@ -0,0 +1,470 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import argparse +import os +import pprint +from pathlib import Path +from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch.nn.modules.loss import _Loss +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim.lr_scheduler import _LRScheduler +from torch.optim.optimizer import Optimizer +from torch.utils.data import DataLoader + +from colossalai.amp import AMP_TYPE, convert_to_amp +from colossalai.amp.naive_amp import NaiveAMPModel +from colossalai.builder.builder import build_gradient_handler +from colossalai.context import Config, ConfigException, ParallelMode +from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.core import global_context as gpc +from colossalai.engine import Engine +from colossalai.engine.gradient_accumulation import accumulate_gradient +from colossalai.engine.schedule import ( + InterleavedPipelineSchedule, + NonPipelineSchedule, + PipelineSchedule, + get_tensor_shape, +) +from colossalai.logging import get_dist_logger +from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer +from colossalai.utils import get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param +from colossalai.utils.moe import sync_moe_model_param +from colossalai.zero.legacy import ShardedOptimizerV2, convert_to_zero_v2 +from colossalai.zero.legacy.gemini.ophooks import BaseOpHook + + +def get_default_parser(): + """Reads user command line and uses an argument parser to parse the input arguments. + Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed. + + Returns: + Namespace: Returns the parser with the default arguments, the user may add customized arguments into this parser. + """ + parser = argparse.ArgumentParser() + parser.add_argument('--config', type=str, help='path to the config file') + parser.add_argument('--host', type=str, help='the master address for distributed training') + parser.add_argument('--port', type=int, help='the master port for distributed training') + parser.add_argument('--world_size', type=int, help='world size for distributed training') + parser.add_argument('--rank', type=int, help='rank for the default process group') + parser.add_argument('--local_rank', type=int, help='local rank on the node') + parser.add_argument('--backend', type=str, default='nccl', help='backend for distributed communication') + return parser + + +def launch(config: Union[str, Path, Config, Dict], + rank: int, + world_size: int, + host: str, + port: int, + backend: str = 'nccl', + local_rank: int = None, + seed: int = 1024, + verbose: bool = True): + """This function first parses the configuration arguments, using :func:`parse_args()` in case one of the input + arguments are not given. Then initialize and set distributed environment by calling global_context's functions. + + Args: + config (Union[str, dict, Config]): Config file or config file path are both acceptable + rank (int): Rank for the default process group + world_size (int): World size of the default process group + host (str): The master address for distributed training + port (str): The master port for distributed training + backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl`` + local_rank (int, optional): + Rank for the process on the node and is used to set the default CUDA device, + defaults to None. If local_rank = None, the default device ordinal will be calculated automatically. + seed (int, optional): Specified random seed for every process. Defaults to 1024. + verbose (bool, optional): Whether to print logs. Defaults to True. + + Raises: + Exception: Raise exception when config type is wrong + """ + gpc.verbose = verbose + + # set config + assert isinstance(config, (Config, str, Path, dict)), \ + f'expected argument config to be Config, str or Path, but got {type(config)}' + if not isinstance(config, Config) and isinstance(config, dict): + config = Config(config) + if isinstance(config, (str, Path)): + config = Config.from_file(config) + gpc.load_config(config) + + # init default process group + gpc.init_global_dist(rank, world_size, backend, host, port) + + # init process groups for different parallel modes from config + gpc.init_parallel_groups() + + # set cuda device + if torch.cuda.is_available(): + # if local rank is not given, calculate automatically + gpc.set_device(local_rank) + + # set the number of processes running on the same node + gpc.detect_num_processes_on_current_node() + + gpc.set_seed(seed) + + if verbose: + logger = get_dist_logger() + logger.info( + f'Distributed environment is initialized, ' + f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, ' + f'tensor parallel size: {gpc.tensor_parallel_size}', + ranks=[0]) + + +def launch_from_slurm(config: Union[str, Path, Config, Dict], + host: str, + port: int, + backend: str = 'nccl', + seed: int = 1024, + verbose: bool = True): + """A wrapper for colossalai.launch for SLURM launcher by reading rank and world size from the environment variables + set by SLURM + + Args: + config (Union[str, dict, Config]): Config file or config file path are both acceptable + host (str): The master address for distributed training + port (str): The master port for distributed training + backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl`` + seed (int, optional): Specified random seed for every process. Defaults to 1024. + verbose (bool, optional): Whether to print logs. Defaults to True. + """ + try: + rank = int(os.environ['SLURM_PROCID']) + world_size = int(os.environ['SLURM_NPROCS']) + except KeyError as e: + raise RuntimeError( + f"Could not find {e} in the SLURM environment, visit https://www.colossalai.org/ for more information on launching with SLURM" + ) + + launch(config=config, + rank=rank, + world_size=world_size, + host=host, + port=port, + backend=backend, + seed=seed, + verbose=verbose) + + +def launch_from_openmpi(config: Union[str, Path, Config, Dict], + host: str, + port: int, + backend: str = 'nccl', + seed: int = 1024, + verbose: bool = True): + """A wrapper for colossalai.launch for OpenMPI launcher by reading rank and world size from the environment variables + set by OpenMPI + + Args: + config (Union[str, dict, Config]): Config file or config file path are both acceptable + host (str): The master address for distributed training + port (str): The master port for distributed training + backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl`` + seed (int, optional): Specified random seed for every process. Defaults to 1024. + verbose (bool, optional): Whether to print logs. Defaults to True. + """ + try: + rank = int(os.environ['OMPI_COMM_WORLD_RANK']) + local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) + world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) + except KeyError as e: + raise RuntimeError( + f"Could not find {e} in the OpenMPI environment, visit https://www.colossalai.org/ for more information on launching with OpenMPI" + ) + + launch(config=config, + local_rank=local_rank, + rank=rank, + world_size=world_size, + host=host, + port=port, + backend=backend, + seed=seed, + verbose=verbose) + + +def launch_from_torch(config: Union[str, Path, Config, Dict], + backend: str = 'nccl', + seed: int = 1024, + verbose: bool = True): + """A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size + from the environment variables set by PyTorch + + Args: + config (Union[str, dict, Config]): Config file or config file path are both acceptable + backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl`` + seed (int, optional): Specified random seed for every process. Defaults to 1024. + verbose (bool, optional): Whether to print logs. Defaults to True. + """ + try: + rank = int(os.environ['RANK']) + local_rank = int(os.environ['LOCAL_RANK']) + world_size = int(os.environ['WORLD_SIZE']) + host = os.environ['MASTER_ADDR'] + port = int(os.environ['MASTER_PORT']) + except KeyError as e: + raise RuntimeError( + f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch" + ) + + launch(config=config, + local_rank=local_rank, + rank=rank, + world_size=world_size, + host=host, + port=port, + backend=backend, + seed=seed, + verbose=verbose) + + +def initialize(model: nn.Module, + optimizer: Optimizer, + criterion: Optional[_Loss] = None, + train_dataloader: Optional[Iterable] = None, + test_dataloader: Optional[Iterable] = None, + lr_scheduler: Optional[_LRScheduler] = None, + ophooks: Optional[List[BaseOpHook]] = None, + verbose: bool = True) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]: + """Core function to wrap the essential training components with our functionality based on the config which is + loaded into gpc.config. + + Args: + model (:class:`torch.nn.Module` or Callbale): Your model instance or a function to build the model. + optimizer (:class:`torch.optim.optimizer.Optimizer` or :class:`Type[torch.optim.optimizer]`): + Your optimizer instance. + criterion (:class:`torch.nn.modules.loss._Loss`, optional): Your criterion instance. + train_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for training. + test_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for testing. + lr_scheduler (:class:`torch.nn.lr_scheduler._LRScheduler`, optional): Your lr scheduler instance, optional. + verbose (bool, optional): Whether to print logs. + + Returns: + Tuple (engine, train_dataloader, test_dataloader, lr_scheduler): + A tuple of ``(engine, train_dataloader, test_dataloader, lr_scheduler)`` + where only ``engine`` could not be None. + """ + # get logger + logger = get_dist_logger() + gpc.verbose = verbose + + # get config from gpc + config = gpc.config + + # print config + if verbose: + logger.info( + f"\n========== Your Config ========\n" + f"{pprint.pformat(gpc.config)}\n" + f"================================\n", + ranks=[0]) + + # cudnn + cudnn_benchmark = config.get('cudnn_benchmark', False) + cudnn_deterministic = config.get('cudnn_deterministic', False) + torch.backends.cudnn.benchmark = cudnn_benchmark + torch.backends.cudnn.deterministic = cudnn_deterministic + if verbose: + logger.info(f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0]) + + # zero + use_zero = hasattr(gpc.config, 'zero') + if use_zero: + zero_cfg = gpc.config.get('zero', None) + if zero_cfg is not None: + cfg_ = zero_cfg.copy() + else: + cfg_ = {} + optimizer_config = zero_cfg.get('optimizer_config', None) + model_config = zero_cfg.get('model_config', None) + model, optimizer = convert_to_zero_v2(model, + optimizer, + model_config=model_config, + optimizer_config=optimizer_config) + + logger.info("Initializing ZeRO model and optimizer finished!", ranks=[0]) + else: + if isinstance(model, nn.Module): + # first sync model across dp ranks + model.to(get_current_device()) + elif isinstance(model, Callable): + model = model().to(get_current_device()) + + # optimizer maybe a optimizer_cls + if isinstance(optimizer, Callable): + optimizer = optimizer(model.parameters()) + logger.warning("Initializing an non ZeRO model with optimizer class") + + if not use_zero: + if is_using_sequence(): + sync_model_param(model, ParallelMode.SEQUENCE_DP) + elif MOE_CONTEXT.is_initialized: + sync_moe_model_param(model) + elif is_using_ddp(): + sync_model_param(model, ParallelMode.DATA) + else: + logger.warning( + "The parameters of models is not automatically synchronized.\n" + "Please make sure that all parameters are the same in data parallel group.", + ranks=[0]) + + # check amp and zero + fp16_cfg = gpc.config.get('fp16', None) + + if fp16_cfg is not None and fp16_cfg.mode is not None and use_zero: + raise ConfigException( + "It is not allowed to set fp16 and zero configuration in your config file at the same time") + + # clip grad norm + clip_grad_norm = gpc.config.get('clip_grad_norm', 0.0) + + # initialize amp + amp_mode = None + if fp16_cfg is not None and fp16_cfg.mode is not None: + cfg_ = fp16_cfg.copy() + amp_mode = cfg_.pop('mode') + if is_using_pp(): + assert amp_mode == AMP_TYPE.NAIVE, 'Pipeline only support NaiveAMP currently' + if amp_mode == AMP_TYPE.NAIVE: + cfg_['clip_grad_norm'] = clip_grad_norm + model, optimizer, criterion = convert_to_amp(model=model, + optimizer=optimizer, + criterion=criterion, + mode=amp_mode, + amp_config=cfg_) + + # get torch ddp config + torch_ddp_cfg = gpc.config.get('torch_ddp', dict()) + + # gradient handler + gradient_handler_cfg = gpc.config.get('gradient_handler', None) + if gradient_handler_cfg is None: + # if gradient handler is not specified in the configuration file, + # check in the following order + # 1. if optimizer is ZERO, then use zero grad handler + # 2. if dp size is larger than 1 and pipeline is not used, use pytorch ddp + # 3. if using pipeline and dp size larger than 1, use data parallel grad handler + if isinstance(optimizer, ShardedOptimizerV2): + gradient_handler_cfg = [dict(type='ZeROGradientHandler')] + if verbose: + logger.info( + "Training with zero is detected, ZeROGradientHandler is automatically " + "added even though not specified in the configuration", + ranks=[0]) + elif is_using_ddp() and MOE_CONTEXT.is_initialized: + gradient_handler_cfg = [dict(type='MoeGradientHandler')] + if verbose: + logger.info( + "Data parallel training is detected with moe parallel, MoeGradientHandler is automatically " + "added even though not specified in the configuration", + ranks=[0]) + elif is_using_sequence(): + model = DDP(model, + process_group=gpc.get_group(ParallelMode.SEQUENCE_DP), + device_ids=[torch.cuda.current_device()], + **torch_ddp_cfg) + if verbose: + logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism', + ranks=[0]) + elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE: + model = DDP(model, + process_group=gpc.get_group(ParallelMode.DATA), + device_ids=[torch.cuda.current_device()], + **torch_ddp_cfg) + if verbose: + logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0]) + elif is_using_ddp(): + gradient_handler_cfg = [dict(type='DataParallelGradientHandler')] + if verbose: + logger.info( + "Data parallel training is detected when using pipeline parallel, " + "DataParallelGradientHandler is automatically " + "added even though not specified in the configuration", + ranks=[0]) + # add pipeline parallel gradient handler, if pipeline shared module is detected + for param in model.parameters(): + if getattr(param, 'pipeline_shared_module_pg', None) is not None: + if gradient_handler_cfg is None: + gradient_handler_cfg = [dict(type='PipelineSharedModuleGradientHandler')] + else: + gradient_handler_cfg.append(dict(type='PipelineSharedModuleGradientHandler')) + if verbose: + logger.info( + "pipeline_shared_module is detected, PipelineSharedModuleGradientHandler is automatically " + "added even though not specified in the configuration", + ranks=[0]) + break + else: + if not isinstance(gradient_handler_cfg, list): + raise ConfigException( + f"expected gradient_handler in the configuration file to be a list but got {type(gradient_handler_cfg)}" + ) + + # turn off sync buffer for NaiveAMPModel if using torch DDP and NaiveAMPModel at the same time + # to avoid duplicated buffer synchronization + if isinstance(model, DDP) and isinstance(model.module, NaiveAMPModel): + model.module.sync_buffer = False + + # initialize schedule for engine + if is_using_pp(): + tensor_shape = get_tensor_shape() + use_interleaved = hasattr(gpc.config, 'model') and hasattr(gpc.config.model, 'num_chunks') + if gpc.is_initialized(ParallelMode.PARALLEL_1D): + scatter_gather = True + else: + scatter_gather = False + if use_interleaved: + if isinstance(model, nn.Sequential): + model = nn.ModuleList([model]) + schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES, + gpc.config.model.num_chunks, + tensor_shape=tensor_shape, + scatter_gather_tensors=scatter_gather) + else: + schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES, + tensor_shape=tensor_shape, + scatter_gather_tensors=scatter_gather) + else: + schedule = NonPipelineSchedule() + + if gradient_handler_cfg is None: + gradient_handlers = None + if verbose and not isinstance(model, DDP): + logger.warning( + "No PyTorch DDP or gradient handler is set up, please make sure you do not need " + "to all-reduce the gradients after a training step.", + ranks=[0]) + else: + gradient_handlers = [build_gradient_handler(cfg, model, optimizer) for cfg in gradient_handler_cfg] + + # check if optimizer is ColossalaiOptimizer + if not isinstance(optimizer, (ColossalaiOptimizer, ShardedOptimizerV2)): + optimizer = ColossalaiOptimizer(optim=optimizer) + + # gradient accumulation + grad_accum_size = gpc.config.get('gradient_accumulation', None) + if grad_accum_size is not None: + optimizer, train_dataloader, gradient_handlers, lr_scheduler = accumulate_gradient( + model=model, + optimizer=optimizer, + dataloader=train_dataloader, + accumulate_size=grad_accum_size, + gradient_handlers=gradient_handlers, + lr_scheduler=lr_scheduler) + engine = Engine(model=model, + optimizer=optimizer, + criterion=criterion, + gradient_handlers=gradient_handlers, + clip_grad_norm=clip_grad_norm, + ophook_list=ophooks, + schedule=schedule) + + return engine, train_dataloader, test_dataloader, lr_scheduler diff --git a/colossalai/interface/__init__.py b/colossalai/interface/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8c658e375146ae4f6f31f62a9e913ed16fbb0714 --- /dev/null +++ b/colossalai/interface/__init__.py @@ -0,0 +1,4 @@ +from .model import ModelWrapper +from .optimizer import OptimizerWrapper + +__all__ = ['OptimizerWrapper', 'ModelWrapper'] diff --git a/colossalai/interface/model.py b/colossalai/interface/model.py new file mode 100644 index 0000000000000000000000000000000000000000..a067d7671ce7eaaa174aced664ec16461cd03034 --- /dev/null +++ b/colossalai/interface/model.py @@ -0,0 +1,25 @@ +import torch.nn as nn + + +class ModelWrapper(nn.Module): + """ + A wrapper class to define the common interface used by booster. + + Args: + module (nn.Module): The model to be wrapped. + """ + + def __init__(self, module: nn.Module) -> None: + super().__init__() + self.module = module + + def unwrap(self): + """ + Unwrap the model to return the original model for checkpoint saving/loading. + """ + if isinstance(self.module, ModelWrapper): + return self.module.unwrap() + return self.module + + def forward(self, *args, **kwargs): + return self.module(*args, **kwargs) diff --git a/colossalai/interface/optimizer.py b/colossalai/interface/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..dd9acab17584a452ad430094ab7fed4b9e272efb --- /dev/null +++ b/colossalai/interface/optimizer.py @@ -0,0 +1,121 @@ +from typing import Union + +import torch.nn as nn +from torch import Tensor +from torch.optim import Optimizer + + +class OptimizerWrapper: + """ + A standard interface for optimizers wrapped by the Booster. + + Args: + optim (Optimizer): The optimizer to be wrapped. + """ + + def __init__(self, optim: Optimizer): + self.optim = optim + + @property + def parameters(self): + params = [] + + for group in self.param_groups: + params += group['params'] + return params + + @property + def param_groups(self): + return self.optim.param_groups + + @property + def defaults(self): + return self.optim.defaults + + def add_param_group(self, *args, **kwargs): + return self.optim.add_param_group(*args, **kwargs) + + def step(self, *args, **kwargs): + """ + Performs a single optimization step. + """ + return self.optim.step(*args, **kwargs) + + def zero_grad(self, *args, **kwargs): + """ + Clears the gradients of all optimized `torch.Tensor`. + """ + self.optim.zero_grad(*args, **kwargs) + + def backward(self, loss: Tensor, *args, **kwargs): + """ + Performs a backward pass on the loss. + """ + loss.backward(*args, **kwargs) + + def state_dict(self): + """ + Returns the optimizer state. + """ + return self.optim.state_dict() + + def load_state_dict(self, *args, **kwargs): + """ + Loads the optimizer state. + """ + self.optim.load_state_dict(*args, **kwargs) + + def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None: + """ + Clips gradient of an iterable of parameters at specified min and max values. + + Args: + clip_value (float or int): maximum allowed value of the gradients. Gradients are clipped in the range + + Note: + In PyTorch Torch 2.0 and above, you can pass in foreach=True as kwargs to clip_grad_value_ to use the + faster implementation. Please refer to the PyTorch documentation for more details. + """ + nn.utils.clip_grad_value_(self.parameters, clip_value, *args, **kwargs) + + def clip_grad_by_norm(self, + max_norm: Union[float, int], + norm_type: Union[float, int] = 2.0, + error_if_nonfinite: bool = False, + *args, + **kwargs) -> Tensor: + """ + Clips gradient norm of an iterable of parameters. + + Args: + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. + error_if_nonfinite (bool): if True, an error is raised if the total norm is non-finite. Default: False + + Note: + In PyTorch Torch 2.0 and above, you can pass in foreach=True as kwargs to clip_grad_norm_ to use the + faster implementation. Please refer to the PyTorch documentation for more details. + """ + norm = nn.utils.clip_grad_norm_(self.parameters, max_norm, norm_type, error_if_nonfinite, *args, **kwargs) + return norm + + def scale_loss(self, loss: Tensor): + """ + Scales the loss for mixed precision training. + + Note: Only available for optimizers with mixed precision training. + + Args: + loss (Tensor): The loss to be scaled. + """ + raise NotImplementedError( + "The method scale_loss is only available for optimizers with mixed precision training") + + def unscale_grad(self): + """ + Unscale the gradients for mixed precision training. + + Note: Only available for optimizers with mixed precision training. + """ + raise NotImplementedError( + "The method unscale_grad is only available for optimizers with mixed precision training") diff --git a/colossalai/kernel/__init__.py b/colossalai/kernel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8933fc0a3c2f122a37656b6facfd904e039c8cdd --- /dev/null +++ b/colossalai/kernel/__init__.py @@ -0,0 +1,7 @@ +from .cuda_native import FusedScaleMaskSoftmax, LayerNorm, MultiHeadAttention + +__all__ = [ + "LayerNorm", + "FusedScaleMaskSoftmax", + "MultiHeadAttention", +] diff --git a/colossalai/kernel/cuda_native/__init__.py b/colossalai/kernel/cuda_native/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1d5a6ce495bdec0fe02475ed0bb3b3b67bb86b3c --- /dev/null +++ b/colossalai/kernel/cuda_native/__init__.py @@ -0,0 +1,5 @@ +from .layer_norm import MixedFusedLayerNorm as LayerNorm +from .multihead_attention import MultiHeadAttention +from .scaled_softmax import FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax + +__all__ = ['LayerNorm', 'MultiHeadAttention', 'FusedScaleMaskSoftmax', 'ScaledUpperTriangMaskedSoftmax'] diff --git a/colossalai/kernel/cuda_native/csrc/colossal_C_frontend.cpp b/colossalai/kernel/cuda_native/csrc/colossal_C_frontend.cpp new file mode 100644 index 0000000000000000000000000000000000000000..94f132521771bb18638d3a7edf03f7e4e14dcc27 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/colossal_C_frontend.cpp @@ -0,0 +1,49 @@ +// modified from +// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_adam.cu +#include + +void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + float scale); + +void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + float wd, float momentum, float dampening, float lr, + bool nesterov, bool first_run, + bool wd_after_momentum, float scale); + +void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + const float lr, const float beta1, + const float beta2, const float epsilon, + const int step, const int mode, + const int bias_correction, const float weight_decay, + const float div_scale); + +void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + const float lr, const float beta1, + const float beta2, const float epsilon, + const int step, const int bias_correction, + const float weight_decay, const int grad_averaging, + const int mode, at::Tensor global_grad_norm, + const float max_grad_norm, + at::optional use_nvlamb_python); + +std::tuple multi_tensor_l2norm_cuda( + int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::optional per_tensor_python); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("multi_tensor_scale", &multi_tensor_scale_cuda, + "Fused overflow check + scale for a list of contiguous tensors"); + m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda, + "Fused SGD optimizer for list of contiguous tensors"); + m.def("multi_tensor_adam", &multi_tensor_adam_cuda, + "Compute and apply gradient update to parameters for Adam optimizer"); + m.def("multi_tensor_lamb", &multi_tensor_lamb_cuda, + "Computes and apply update for LAMB optimizer"); + m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda, + "Computes L2 norm for a list of contiguous tensors"); +} diff --git a/colossalai/kernel/cuda_native/csrc/compat.h b/colossalai/kernel/cuda_native/csrc/compat.h new file mode 100644 index 0000000000000000000000000000000000000000..00066dc95475296168c799904dc595ed435d2b0a --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/compat.h @@ -0,0 +1,10 @@ +// modified from https://github.com/NVIDIA/apex/blob/master/csrc/compat.h +#ifndef TORCH_CHECK +#define TORCH_CHECK AT_CHECK +#endif + +#ifdef VERSION_GE_1_3 +#define DATA_PTR data_ptr +#else +#define DATA_PTR data +#endif \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp b/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0ab250218da38f9ded00766d6546d73b918699fa --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp @@ -0,0 +1,459 @@ +/* +Copyright (c) Microsoft Corporation. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE +*/ +#include "cpu_adam.h" + +#include +#include +#include + +#include +#include +#include +#include + +// C++ interface + +void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, + float *_exp_avg_sq, size_t _param_size, + bool param_half_precision, bool grad_half_precision, + float loss_scale) { + size_t rounded_size = 0; + + float betta1_minus1 = 1 - _betta1; + float betta2_minus1 = 1 - _betta2; + float step_size = -1 * _alpha / _bias_correction1; + float w_decay = -1 * _alpha * _weight_decay; + + __half *params_cast_h = NULL; + __half *grads_cast_h = NULL; + + if (param_half_precision) { + params_cast_h = reinterpret_cast<__half *>(_params); + } + if (grad_half_precision) { + grads_cast_h = reinterpret_cast<__half *>(grads); + } + +#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) + AVX_Data betta1_4; + betta1_4.data = SIMD_SET(_betta1); + AVX_Data betta2_4; + betta2_4.data = SIMD_SET(_betta2); + + AVX_Data betta1_minus1_4; + betta1_minus1_4.data = SIMD_SET(betta1_minus1); + AVX_Data betta2_minus1_4; + betta2_minus1_4.data = SIMD_SET(betta2_minus1); + + AVX_Data bias2_sqrt; + bias2_sqrt.data = SIMD_SET(_bias_correction2); + + AVX_Data eps_4; + eps_4.data = SIMD_SET(_eps); + + AVX_Data step_size_4; + step_size_4.data = SIMD_SET(step_size); + + AVX_Data weight_decay_4; + if (_weight_decay > 0) + weight_decay_4.data = + (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); + rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH); + + for (size_t t = 0; t < rounded_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > rounded_size) copy_size = rounded_size - t; + size_t offset = copy_size + t; + +#pragma omp parallel for + for (size_t i = t; i < offset; i += SIMD_WIDTH) { + AVX_Data grad_4; + if (grad_half_precision) { + grad_4.data = SIMD_LOAD_HALF(grads_cast_h + i); + } else { + grad_4.data = SIMD_LOAD(grads + i); + } + if (loss_scale > 0) { + AVX_Data loss_scale_vec; + loss_scale_vec.data = SIMD_SET(loss_scale); + grad_4.data = SIMD_DIV(grad_4.data, loss_scale_vec.data); + } + AVX_Data momentum_4; + momentum_4.data = SIMD_LOAD(_exp_avg + i); + + AVX_Data variance_4; + variance_4.data = SIMD_LOAD(_exp_avg_sq + i); + + AVX_Data param_4; + if (param_half_precision) { + param_4.data = SIMD_LOAD_HALF(params_cast_h + i); + } else { + param_4.data = SIMD_LOAD(_params + i); + } + + if (_weight_decay > 0 && !_adamw_mode) { + grad_4.data = SIMD_FMA(param_4.data, weight_decay_4.data, grad_4.data); + } + momentum_4.data = SIMD_MUL(momentum_4.data, betta1_4.data); + momentum_4.data = + SIMD_FMA(grad_4.data, betta1_minus1_4.data, momentum_4.data); + variance_4.data = SIMD_MUL(variance_4.data, betta2_4.data); + grad_4.data = SIMD_MUL(grad_4.data, grad_4.data); + variance_4.data = + SIMD_FMA(grad_4.data, betta2_minus1_4.data, variance_4.data); + grad_4.data = SIMD_SQRT(variance_4.data); + grad_4.data = SIMD_FMA(grad_4.data, bias2_sqrt.data, eps_4.data); + grad_4.data = SIMD_DIV(momentum_4.data, grad_4.data); + + if (_weight_decay > 0 && _adamw_mode) { + param_4.data = + SIMD_FMA(param_4.data, weight_decay_4.data, param_4.data); + } + param_4.data = SIMD_FMA(grad_4.data, step_size_4.data, param_4.data); + + if (param_half_precision) { + SIMD_STORE_HALF((float *)(params_cast_h + i), param_4.data); + } else { + SIMD_STORE(_params + i, param_4.data); + } + SIMD_STORE(_exp_avg + i, momentum_4.data); + SIMD_STORE(_exp_avg_sq + i, variance_4.data); + } + } +#endif + if (_param_size > rounded_size) { + for (size_t t = rounded_size; t < _param_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > _param_size) copy_size = _param_size - t; + size_t offset = copy_size + t; + +#pragma omp parallel for + for (size_t k = t; k < offset; k++) { + float grad = grad_half_precision ? (float)grads_cast_h[k] : grads[k]; + if (loss_scale > 0) { + grad /= loss_scale; + } + float param = + param_half_precision ? (float)params_cast_h[k] : _params[k]; + float momentum = _exp_avg[k]; + float variance = _exp_avg_sq[k]; + if (_weight_decay > 0 && !_adamw_mode) { + grad = param * _weight_decay + grad; + } + momentum = momentum * _betta1; + momentum = grad * betta1_minus1 + momentum; + + variance = variance * _betta2; + grad = grad * grad; + variance = grad * betta2_minus1 + variance; + + grad = sqrt(variance); + grad = grad * _bias_correction2 + _eps; + grad = momentum / grad; + if (_weight_decay > 0 && _adamw_mode) { + param += w_decay * param; + } + param = grad * step_size + param; + + if (param_half_precision) + params_cast_h[k] = (__half)param; + else + _params[k] = param; + _exp_avg[k] = momentum; + _exp_avg_sq[k] = variance; + } + } + } +} + +void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg, + float *_exp_avg_sq, size_t _param_size, + bool param_half_precision, bool grad_half_precision, + float loss_scale) { + size_t rounded_size = 0; + + __half *params_cast_h = NULL; + __half *grads_cast_h = NULL; + if (param_half_precision) { + params_cast_h = reinterpret_cast<__half *>(_params); + } + if (grad_half_precision) { + grads_cast_h = reinterpret_cast<__half *>(grads); + } + +#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) + AVX_Data betta1_4; + betta1_4.data = SIMD_SET(_betta1); + AVX_Data betta2_4; + betta2_4.data = SIMD_SET(_betta2); + + float betta1_minus1 = 1 - _betta1; + AVX_Data betta1_minus1_4; + betta1_minus1_4.data = SIMD_SET(betta1_minus1); + float betta2_minus1 = 1 - _betta2; + AVX_Data betta2_minus1_4; + betta2_minus1_4.data = SIMD_SET(betta2_minus1); + + AVX_Data bias2_sqrt; + bias2_sqrt.data = SIMD_SET(_bias_correction2); + + AVX_Data eps_4; + eps_4.data = SIMD_SET(_eps); + + float step_size = -1 * _alpha / _bias_correction1; + AVX_Data step_size_4; + step_size_4.data = SIMD_SET(step_size); + + float w_decay = -1 * _alpha * _weight_decay; + AVX_Data weight_decay_4; + if (_weight_decay > 0) + weight_decay_4.data = + (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); + rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 4); + + for (size_t t = 0; t < rounded_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > rounded_size) copy_size = rounded_size - t; + size_t offset = copy_size + t; + +#pragma omp parallel for + for (size_t i = t; i < offset; i += SIMD_WIDTH * 4) { + AVX_Data grad_4[4]; + AVX_Data momentum_4[4]; + AVX_Data variance_4[4]; + AVX_Data param_4[4]; +#pragma unroll 4 + for (int j = 0; j < 4; j++) { + if (grad_half_precision) { + grad_4[j].data = SIMD_LOAD_HALF(grads_cast_h + i + SIMD_WIDTH * j); + } else { + grad_4[j].data = SIMD_LOAD(grads + i + SIMD_WIDTH * j); + } + + if (loss_scale > 0) { + AVX_Data loss_scale_vec; + loss_scale_vec.data = SIMD_SET(loss_scale); + grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data); + } + + momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j); + variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j); + + if (param_half_precision) { + param_4[j].data = SIMD_LOAD_HALF(params_cast_h + i + SIMD_WIDTH * j); + } else { + param_4[j].data = SIMD_LOAD(_params + i + SIMD_WIDTH * j); + } + + if (_weight_decay > 0 && !_adamw_mode) { + grad_4[j].data = + SIMD_FMA(param_4[j].data, weight_decay_4.data, grad_4[j].data); + } + momentum_4[j].data = SIMD_MUL(momentum_4[j].data, betta1_4.data); + momentum_4[j].data = + SIMD_FMA(grad_4[j].data, betta1_minus1_4.data, momentum_4[j].data); + variance_4[j].data = SIMD_MUL(variance_4[j].data, betta2_4.data); + grad_4[j].data = SIMD_MUL(grad_4[j].data, grad_4[j].data); + variance_4[j].data = + SIMD_FMA(grad_4[j].data, betta2_minus1_4.data, variance_4[j].data); + grad_4[j].data = SIMD_SQRT(variance_4[j].data); + grad_4[j].data = SIMD_FMA(grad_4[j].data, bias2_sqrt.data, eps_4.data); + grad_4[j].data = SIMD_DIV(momentum_4[j].data, grad_4[j].data); + + if (_weight_decay > 0 && _adamw_mode) { + param_4[j].data = + SIMD_FMA(param_4[j].data, weight_decay_4.data, param_4[j].data); + } + param_4[j].data = + SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data); + if (param_half_precision) { + SIMD_STORE_HALF((float *)(params_cast_h + i + SIMD_WIDTH * j), + param_4[j].data); + } else { + SIMD_STORE(_params + i + SIMD_WIDTH * j, param_4[j].data); + } + SIMD_STORE(_exp_avg + i + SIMD_WIDTH * j, momentum_4[j].data); + SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * j, variance_4[j].data); + } + } + } +#endif + if (_param_size > rounded_size) + Step_1((param_half_precision ? (float *)(params_cast_h + rounded_size) + : _params + rounded_size), + (grad_half_precision ? (float *)(grads_cast_h + rounded_size) + : grads + rounded_size), + (_exp_avg + rounded_size), (_exp_avg_sq + rounded_size), + (_param_size - rounded_size), param_half_precision, + grad_half_precision, loss_scale); +} + +void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg, + float *_exp_avg_sq, size_t _param_size, + bool param_half_precision, bool grad_half_precision, + float loss_scale) { + size_t rounded_size = 0; + __half *params_cast_h = NULL; + __half *grads_cast_h = NULL; + if (param_half_precision) { + params_cast_h = reinterpret_cast<__half *>(_params); + } + if (grad_half_precision) { + grads_cast_h = reinterpret_cast<__half *>(grads); + } +#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) + AVX_Data betta1_4; + betta1_4.data = SIMD_SET(_betta1); + AVX_Data betta2_4; + betta2_4.data = SIMD_SET(_betta2); + + float betta1_minus1 = 1 - _betta1; + AVX_Data betta1_minus1_4; + betta1_minus1_4.data = SIMD_SET(betta1_minus1); + float betta2_minus1 = 1 - _betta2; + AVX_Data betta2_minus1_4; + betta2_minus1_4.data = SIMD_SET(betta2_minus1); + + AVX_Data bias2_sqrt; + bias2_sqrt.data = SIMD_SET(_bias_correction2); + + AVX_Data eps_4; + eps_4.data = SIMD_SET(_eps); + + float step_size = -1 * _alpha / _bias_correction1; + AVX_Data step_size_4; + step_size_4.data = SIMD_SET(step_size); + + float w_decay = -1 * _alpha * _weight_decay; + AVX_Data weight_decay_4; + if (_weight_decay > 0) + weight_decay_4.data = + (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); + rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 8); + + for (size_t t = 0; t < rounded_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > rounded_size) copy_size = rounded_size - t; + size_t offset = copy_size + t; + +#pragma omp parallel for + for (size_t i = t; i < offset; i += SIMD_WIDTH * 8) { + AVX_Data grad_4[8]; + AVX_Data momentum_4[8]; + AVX_Data variance_4[8]; + AVX_Data param_4[8]; +#pragma unroll 8 + for (int j = 0; j < 8; j++) { + if (grad_half_precision) { + grad_4[j].data = SIMD_LOAD_HALF(grads_cast_h + i + SIMD_WIDTH * j); + } else { + grad_4[j].data = SIMD_LOAD(grads + i + SIMD_WIDTH * j); + } + + if (loss_scale > 0) { + AVX_Data loss_scale_vec; + loss_scale_vec.data = SIMD_SET(loss_scale); + grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data); + } + + momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j); + variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j); + + if (param_half_precision) { + param_4[j].data = SIMD_LOAD_HALF(params_cast_h + i + SIMD_WIDTH * j); + } else { + param_4[j].data = SIMD_LOAD(_params + i + SIMD_WIDTH * j); + } + + if (_weight_decay > 0 && !_adamw_mode) { + grad_4[j].data = + SIMD_FMA(param_4[j].data, weight_decay_4.data, grad_4[j].data); + } + momentum_4[j].data = SIMD_MUL(momentum_4[j].data, betta1_4.data); + momentum_4[j].data = + SIMD_FMA(grad_4[j].data, betta1_minus1_4.data, momentum_4[j].data); + variance_4[j].data = SIMD_MUL(variance_4[j].data, betta2_4.data); + grad_4[j].data = SIMD_MUL(grad_4[j].data, grad_4[j].data); + variance_4[j].data = + SIMD_FMA(grad_4[j].data, betta2_minus1_4.data, variance_4[j].data); + grad_4[j].data = SIMD_SQRT(variance_4[j].data); + grad_4[j].data = SIMD_FMA(grad_4[j].data, bias2_sqrt.data, eps_4.data); + grad_4[j].data = SIMD_DIV(momentum_4[j].data, grad_4[j].data); + if (_weight_decay > 0 && _adamw_mode) { + param_4[j].data = + SIMD_FMA(param_4[j].data, weight_decay_4.data, param_4[j].data); + } + param_4[j].data = + SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data); + + if (param_half_precision) { + SIMD_STORE_HALF((float *)(params_cast_h + i + SIMD_WIDTH * j), + param_4[j].data); + } else { + SIMD_STORE(_params + i + SIMD_WIDTH * j, param_4[j].data); + } + + SIMD_STORE(_exp_avg + i + (SIMD_WIDTH * j), momentum_4[j].data); + SIMD_STORE(_exp_avg_sq + i + (SIMD_WIDTH * j), variance_4[j].data); + } + } + } +#endif + if (_param_size > rounded_size) + Step_4((param_half_precision ? (float *)(params_cast_h + rounded_size) + : _params + rounded_size), + (grad_half_precision ? (float *)(grads_cast_h + rounded_size) + : grads + rounded_size), + (_exp_avg + rounded_size), (_exp_avg_sq + rounded_size), + (_param_size - rounded_size), param_half_precision, + grad_half_precision, loss_scale); +} + +void Adam_Optimizer::step(size_t step, float lr, float beta1, float beta2, + float epsilon, float weight_decay, + bool bias_correction, torch::Tensor ¶ms, + torch::Tensor &grads, torch::Tensor &exp_avg, + torch::Tensor &exp_avg_sq, float loss_scale) { + auto params_c = params.contiguous(); + auto grads_c = grads.contiguous(); + auto exp_avg_c = exp_avg.contiguous(); + auto exp_avg_sq_c = exp_avg_sq.contiguous(); + + float *params_ptr = (float *)params_c.data_ptr(); + float *grads_ptr = (float *)grads_c.data_ptr(); + float *exp_avg_ptr = (float *)exp_avg_c.data_ptr(); + float *exp_avg_sq_ptr = (float *)exp_avg_sq_c.data_ptr(); + + this->IncrementStep(step, beta1, beta2); + this->update_state(lr, epsilon, weight_decay, bias_correction); + this->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, + params_c.numel(), (params.options().dtype() == at::kHalf), + (grads.options().dtype() == at::kHalf), loss_scale); +} + +namespace py = pybind11; + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + py::class_(m, "CPUAdamOptimizer") + .def(py::init()) + .def("step", &Adam_Optimizer::step); +} diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam.h b/colossalai/kernel/cuda_native/csrc/cpu_adam.h new file mode 100644 index 0000000000000000000000000000000000000000..4247da94277518e438cbaf638de8eeaffbf26da5 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/cpu_adam.h @@ -0,0 +1,164 @@ +/* +Copyright (c) Microsoft Corporation. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE +*/ +#pragma once + +#include +#include +#include +#include +#include +#include +#if (__x86_64__ || __i386__) +#include +#include +#endif + +#define ROUND_DOWN(size, step) ((size) & ~((step)-1)) +#define TILE (128 * 1024 * 1024) + +#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) + +#if defined(__AVX512__) +#define SIMD_WIDTH 16 +#define INTV __m256i +#define SIMD_STORE(a, d) _mm512_storeu_ps(a, d) +#define SIMD_LOAD(x) _mm512_loadu_ps(x) +#define SIMD_SET(x) _mm512_set1_ps(x) +#define SIMD_ADD(x, y) _mm512_add_ps(x, y) +#define SIMD_MUL(x, y) _mm512_mul_ps(x, y) +#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c) +#define SIMD_SQRT(x) _mm512_sqrt_ps(x) +#define SIMD_DIV(x, y) _mm512_div_ps(x, y) +#define SIMD_LOAD_HALF(x) \ + _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x))) +#define SIMD_STORE_HALF(x, d) \ + _mm256_store_ps( \ + x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) + +#elif defined(__AVX256__) or defined(__AVX2__) +#define SIMD_WIDTH 8 +#define INTV __m128i +#define SIMD_STORE(a, d) _mm256_storeu_ps(a, d) +#define SIMD_LOAD(x) _mm256_loadu_ps(x) +#define SIMD_SET(x) _mm256_set1_ps(x) +#define SIMD_ADD(x, y) _mm256_add_ps(x, y) +#define SIMD_MUL(x, y) _mm256_mul_ps(x, y) +#define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c) +#define SIMD_SQRT(x) _mm256_sqrt_ps(x) +#define SIMD_DIV(x, y) _mm256_div_ps(x, y) +#define SIMD_LOAD_HALF(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x))) +#define SIMD_STORE_HALF(x, d) \ + _mm_store_ps( \ + x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) + +#endif + +union AVX_Data { +#if defined(__AVX512__) + __m512 data; +#elif defined(__AVX256__) or defined(__AVX2__) + __m256 data; +#endif + // float data_f[16]; +}; + +#endif + +#define STEP(SPAN) \ + void Step_##SPAN(float *_params, float *grads, float *_exp_avg, \ + float *_exp_avg_sq, size_t _param_size, \ + bool param_half_precision = false, \ + bool grad_half_precision = false, float loss_scale = -1); + +class Adam_Optimizer { + public: + Adam_Optimizer(float alpha = 1e-3, float betta1 = 0.9, float betta2 = 0.999, + float eps = 1e-8, float weight_decay = 0, + bool adamw_mode = true) + : _alpha(alpha), + _betta1(betta1), + _betta2(betta2), + _eps(eps), + _weight_decay(weight_decay), + _betta1_t(1.0), + _betta2_t(1.0), + _step(0), + _adamw_mode(adamw_mode) {} + ~Adam_Optimizer() {} + + STEP(1) + STEP(4) + STEP(8) + inline void IncrementStep(size_t step, float beta1, float beta2) { + if (beta1 != _betta1 || beta2 != _betta2) { + _step = step; + _betta1 = beta1; + _betta2 = beta2; + _betta1_t = std::pow(_betta1, step); + _betta2_t = std::pow(_betta2, step); + } else { + _step++; + if (_step != step) { + _betta1_t = std::pow(_betta1, step); + _betta2_t = std::pow(_betta2, step); + _step = step; + } else { + _betta1_t *= _betta1; + _betta2_t *= _betta2; + } + } + } + inline void update_state(float lr, float epsilon, float weight_decay, + bool bias_correction) { + _alpha = lr; + _eps = epsilon; + _weight_decay = weight_decay; + + _bias_correction1 = 1.0f; + _bias_correction2 = 1.0f; + if (bias_correction == 1) { + _bias_correction1 = 1 - _betta1_t; + _bias_correction2 = 1 / sqrt(1 - _betta2_t); + } + } + + void step(size_t step, float lr, float beta1, float beta2, float epsilon, + float weight_decay, bool bias_correction, torch::Tensor ¶ms, + torch::Tensor &grads, torch::Tensor &exp_avg, + torch::Tensor &exp_avg_sq, float loss_scale); + + private: + float _alpha; + float _betta1; + float _betta2; + float _eps; + float _weight_decay; + + float _betta1_t; + float _betta2_t; + size_t _step; + + float _bias_correction1; + float _bias_correction2; + + bool _adamw_mode; +}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu b/colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu new file mode 100644 index 0000000000000000000000000000000000000000..58d26235a9cc6954e9822119f215b9745b0a1684 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu @@ -0,0 +1,191 @@ +#include "block_reduce.h" +#include "cuda_util.h" +#include "kernels.h" +#include "ls_cub.cuh" + +ls::cub::CachingDeviceAllocator g_allocator(true); + +template +__global__ void ls_cross_entropy_fw_kernel( + const T *__restrict__ inputs, const int *__restrict__ targets, + float *__restrict__ outputs, float *__restrict__ nll_loss_outputs, + const int padding_idx, const float epsilon, const int vocab_size) { + /* step1: compute each thread's max_logit and sum_exp_logit, store in + * max_input, sum_exp_logit */ + const int block_start = blockIdx.x * vocab_size; + const int left_idx = block_start + threadIdx.x; + const int right_idx = (blockIdx.x + 1) * vocab_size; + float max_input[1] = {REDUCE_FLOAT_INF_NEG}; + float sum_logits[2] = {0.f, 0.f}; // logit and logit exp + int target_tid = targets[blockIdx.x]; + + if (target_tid == padding_idx) { + if (threadIdx.x == 0) { + nll_loss_outputs[blockIdx.x] = 0.f; + outputs[blockIdx.x] = 0.f; + } + return; + } + + for (int i = left_idx; i < right_idx; i += blockDim.x) { + max_input[0] = fmaxf(max_input[0], static_cast(inputs[i])); + } + blockReduce(max_input); + __shared__ float s_max_input; + if (threadIdx.x == 0) { + s_max_input = max_input[0]; + } + __syncthreads(); + + for (int i = left_idx; i < right_idx; i += blockDim.x) { + float logit = static_cast(inputs[i]) - s_max_input; + sum_logits[0] += logit; + sum_logits[1] += expf(logit); + } + + blockReduce(sum_logits); + __shared__ float s_sum_logit; + __shared__ float s_sum_exp; + if (threadIdx.x == 0) { + s_sum_logit = sum_logits[0]; + s_sum_exp = sum_logits[1]; + } + __syncthreads(); + + float eps_i = epsilon / (vocab_size - 1); + if (threadIdx.x == 0) { + // neg_log_prob = log(sum(exp(x - x_max))) - (x - x_max) + float nll_loss = logf(s_sum_exp) - + static_cast(inputs[block_start + target_tid]) + + s_max_input; + nll_loss_outputs[blockIdx.x] = nll_loss; + float sum_nll_loss = vocab_size * logf(s_sum_exp) - s_sum_logit; + outputs[blockIdx.x] = + (1.f - epsilon - eps_i) * nll_loss + eps_i * sum_nll_loss; + } +} + +template +__global__ void ls_cross_entropy_bw_kernel( + const float *__restrict__ grad_outputs, const T *__restrict__ inputs, + const int *__restrict__ targets, T *__restrict__ grad_inputs, + const int padding_idx, const float epsilon, const int vocab_size) { + /* step1: compute each thread's max_logit and sum_exp_logit, store in + * max_input, sum_exp_logit */ + const int block_start = blockIdx.x * vocab_size; + const int left_idx = block_start + threadIdx.x; + const int right_idx = (blockIdx.x + 1) * vocab_size; + float max_input[1] = {REDUCE_FLOAT_INF_NEG}; + float sum_logits[1] = {0.f}; + const float grad_out = static_cast(grad_outputs[0]); + int target_tid = targets[blockIdx.x]; + + if (target_tid == padding_idx) { + for (int i = left_idx; i < right_idx; i += blockDim.x) { + grad_inputs[i] = 0.f; + } + return; + } + + for (int i = left_idx; i < right_idx; i += blockDim.x) { + max_input[0] = fmaxf(max_input[0], static_cast(inputs[i])); + } + blockReduce(max_input); + __shared__ float s_max_input; + if (threadIdx.x == 0) { + s_max_input = max_input[0]; + } + __syncthreads(); + + for (int i = left_idx; i < right_idx; i += blockDim.x) { + float logit = static_cast(inputs[i]) - s_max_input; + sum_logits[0] += expf(logit); + } + + blockReduce(sum_logits); + __shared__ float s_sum_exp; + if (threadIdx.x == 0) { + s_sum_exp = sum_logits[0]; + } + __syncthreads(); + + float eps_i = epsilon / (vocab_size - 1); + float nll_weight = 1.0 - epsilon - eps_i; + + for (int i = left_idx; i < right_idx; i += blockDim.x) { + float prob = expf(static_cast(inputs[i]) - s_max_input) / s_sum_exp; + float grad = 0; + grad += (vocab_size * prob - 1) * eps_i; + grad += prob * nll_weight; + if ((i - block_start) == target_tid) { + grad -= nll_weight; + } + grad_inputs[i] = grad_out * grad; + } +} + +template +void launch_cross_entropy_fw(const T *inputs_ptr, const int *targets_ptr, + float *outputs_ptr, float *nll_loss_ptr, + float *loss_buffer, const int padding_idx, + const float epsilon, const int batch_size, + const int seq_len, const int vocab_size, + cudaStream_t stream) { + int grid_dim = batch_size * seq_len; + float *nll_loss_buffer = loss_buffer + grid_dim; + ls_cross_entropy_fw_kernel<<>>( + inputs_ptr, targets_ptr, loss_buffer, nll_loss_buffer, padding_idx, + epsilon, vocab_size); + + int num_items = grid_dim; + void *d_temp_storage = NULL; + size_t temp_storage_bytes = 0; + CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes, + loss_buffer, outputs_ptr, + num_items, stream)); + CHECK_GPU_ERROR( + g_allocator.DeviceAllocate(&d_temp_storage, temp_storage_bytes)); + CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes, + loss_buffer, outputs_ptr, + num_items, stream)); + CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes, + nll_loss_buffer, nll_loss_ptr, + num_items, stream)); + CHECK_GPU_ERROR(g_allocator.DeviceFree(d_temp_storage)); +} + +template void launch_cross_entropy_fw( + const float *inputs_ptr, const int *targets_ptr, float *outputs_ptr, + float *nll_loss_ptr, float *loss_buffer, const int padding_idx, + const float epsilon, const int batch_size, const int seq_len, + const int vocab_size, cudaStream_t stream); + +template void launch_cross_entropy_fw<__half>( + const __half *inputs_ptr, const int *targets_ptr, float *outputs_ptr, + float *nll_loss_ptr, float *loss_buffer, const int padding_idx, + const float epsilon, const int batch_size, const int seq_len, + const int vocab_size, cudaStream_t stream); + +template +void launch_cross_entropy_bw(const float *grad_outputs_ptr, const T *inputs_ptr, + const int *targets_ptr, T *grad_inputs_ptr, + const int padding_idx, const float epsilon, + const int batch_size, const int seq_len, + const int vocab_size, cudaStream_t stream) { + int grid_dim = batch_size * seq_len; + ls_cross_entropy_bw_kernel<<>>( + grad_outputs_ptr, inputs_ptr, targets_ptr, grad_inputs_ptr, padding_idx, + epsilon, vocab_size); +} + +template void launch_cross_entropy_bw( + const float *grad_outputs_ptr, const float *inputs_ptr, + const int *targets_ptr, float *grad_inputs_ptr, const int padding_idx, + const float epsilon, const int batch_size, const int seq_len, + const int vocab_size, cudaStream_t stream); + +template void launch_cross_entropy_bw<__half>( + const float *grad_outputs_ptr, const __half *inputs_ptr, + const int *targets_ptr, __half *grad_inputs_ptr, const int padding_idx, + const float epsilon, const int batch_size, const int seq_len, + const int vocab_size, cudaStream_t stream); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/cublas_wrappers.cu b/colossalai/kernel/cuda_native/csrc/kernels/cublas_wrappers.cu new file mode 100644 index 0000000000000000000000000000000000000000..09f34763f9b292a05d1e7d7b2bddfd9b6367d781 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/cublas_wrappers.cu @@ -0,0 +1,88 @@ +/* Copyright 2021 The LightSeq Team + Copyright Microsoft DeepSpeed + This file is adapted from Microsoft DeepSpeed + Licensed under the MIT License. +*/ +#include "cublas_wrappers.h" + +int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa, + cublasOperation_t transb, int m, int n, int k, + const float *alpha, const float *beta, const float *A, + const float *B, float *C, cublasGemmAlgo_t algo) { + cublasStatus_t status = + cublasGemmEx(handle, transa, transb, m, n, k, (const void *)alpha, + (const void *)A, CUDA_R_32F, (transa == CUBLAS_OP_N) ? m : k, + (const void *)B, CUDA_R_32F, (transb == CUBLAS_OP_N) ? k : n, + (const void *)beta, C, CUDA_R_32F, m, CUDA_R_32F, algo); + + if (status != CUBLAS_STATUS_SUCCESS) { + fprintf(stderr, + "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", + m, n, k, (int)status); + return EXIT_FAILURE; + } + return 0; +} + +int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa, + cublasOperation_t transb, int m, int n, int k, + const float *alpha, const float *beta, const __half *A, + const __half *B, __half *C, cublasGemmAlgo_t algo) { + cublasStatus_t status = cublasGemmEx( + handle, transa, transb, m, n, k, (const void *)alpha, (const void *)A, + CUDA_R_16F, (transa == CUBLAS_OP_N) ? m : k, (const void *)B, CUDA_R_16F, + (transb == CUBLAS_OP_N) ? k : n, (const void *)beta, (void *)C, + CUDA_R_16F, m, CUDA_R_32F, algo); + + if (status != CUBLAS_STATUS_SUCCESS) { + fprintf(stderr, + "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", + m, n, k, (int)status); + return EXIT_FAILURE; + } + return 0; +} + +int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k, + const float *alpha, const float *beta, + const float *A, const float *B, float *C, + cublasOperation_t op_A, cublasOperation_t op_B, + int stride_A, int stride_B, int stride_C, + int batch, cublasGemmAlgo_t algo) { + cublasStatus_t status = cublasGemmStridedBatchedEx( + handle, op_A, op_B, m, n, k, alpha, A, CUDA_R_32F, + (op_A == CUBLAS_OP_N) ? m : k, stride_A, B, CUDA_R_32F, + (op_B == CUBLAS_OP_N) ? k : n, stride_B, beta, C, CUDA_R_32F, m, stride_C, + batch, CUDA_R_32F, algo); + + if (status != CUBLAS_STATUS_SUCCESS) { + fprintf(stderr, + "!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, " + "error: %d) \n", + batch, m, n, k, (int)status); + return EXIT_FAILURE; + } + return 0; +} + +int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k, + const float *alpha, const float *beta, + const __half *A, const __half *B, __half *C, + cublasOperation_t op_A, cublasOperation_t op_B, + int stride_A, int stride_B, int stride_C, + int batch, cublasGemmAlgo_t algo) { + cublasStatus_t status = cublasGemmStridedBatchedEx( + handle, op_A, op_B, m, n, k, alpha, A, CUDA_R_16F, + (op_A == CUBLAS_OP_N) ? m : k, stride_A, B, CUDA_R_16F, + (op_B == CUBLAS_OP_N) ? k : n, stride_B, beta, C, CUDA_R_16F, m, stride_C, + batch, CUDA_R_32F, algo); + + if (status != CUBLAS_STATUS_SUCCESS) { + fprintf(stderr, + "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", + m, n, k, (int)status); + return EXIT_FAILURE; + } + + return 0; +} diff --git a/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu b/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu new file mode 100644 index 0000000000000000000000000000000000000000..26efa2ad6f31632a4e7ceddd06745b067759bb43 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu @@ -0,0 +1,170 @@ +#include +#include + + +#include "cuda_util.h" + +/* GPU function guard */ +std::string _cudaGetErrorString(cudaError_t error) { + return cudaGetErrorString(error); +} + +std::string _cudaGetErrorString(cublasStatus_t error) { + switch (error) { + case CUBLAS_STATUS_SUCCESS: + return "CUBLAS_STATUS_SUCCESS"; + + case CUBLAS_STATUS_NOT_INITIALIZED: + return "CUBLAS_STATUS_NOT_INITIALIZED"; + + case CUBLAS_STATUS_ALLOC_FAILED: + return "CUBLAS_STATUS_ALLOC_FAILED"; + + case CUBLAS_STATUS_INVALID_VALUE: + return "CUBLAS_STATUS_INVALID_VALUE"; + + case CUBLAS_STATUS_ARCH_MISMATCH: + return "CUBLAS_STATUS_ARCH_MISMATCH"; + + case CUBLAS_STATUS_MAPPING_ERROR: + return "CUBLAS_STATUS_MAPPING_ERROR"; + + case CUBLAS_STATUS_EXECUTION_FAILED: + return "CUBLAS_STATUS_EXECUTION_FAILED"; + + case CUBLAS_STATUS_INTERNAL_ERROR: + return "CUBLAS_STATUS_INTERNAL_ERROR"; + + case CUBLAS_STATUS_NOT_SUPPORTED: + return "CUBLAS_STATUS_NOT_SUPPORTED"; + + case CUBLAS_STATUS_LICENSE_ERROR: + return "CUBLAS_STATUS_LICENSE_ERROR"; + } + return "CUBLAS_UNKNOW"; +} + +template +void check_gpu_error(T result, char const *const func, const char *const file, + int const line) { + if (result) { + throw std::runtime_error(std::string("[CUDA][ERROR] ") + +file + "(" + + std::to_string(line) + + "): " + (_cudaGetErrorString(result)) + "\n"); + } +} + +template void check_gpu_error(cudaError_t result, + char const *const func, + const char *const file, + int const line); +template void check_gpu_error(cublasStatus_t result, + char const *const func, + const char *const file, + int const line); + +template +void print_vec(const T *outv, std::string outn, int num_output_ele) { + std::cout << outn << ": "; + std::vector hout(num_output_ele, (T)0); + cudaMemcpy(hout.data(), outv, num_output_ele * sizeof(T), + cudaMemcpyDeviceToHost); + for (int i = 0; i < num_output_ele; i++) { + std::cout << hout[i] << ", "; + } + std::cout << std::endl; +} + +template <> +void print_vec<__half>(const __half *outv, std::string outn, + int num_output_ele) { + std::cout << outn << ": "; + std::vector<__half> hout(num_output_ele, (__half)0.f); + cudaMemcpy(hout.data(), outv, num_output_ele * sizeof(__half), + cudaMemcpyDeviceToHost); + for (int i = 0; i < num_output_ele; i++) { + std::cout << __half2float(hout[i]) << ", "; + } + std::cout << std::endl; +} + +template void print_vec(const float *outv, std::string outn, + int num_output_ele); + +template void print_vec(const int *outv, std::string outn, + int num_output_ele); + +template void print_vec<__half>(const __half *outv, std::string outn, + int num_output_ele); + +template +T *cuda_malloc(size_t ele_num) { + size_t byte_size = ele_num * sizeof(T); + T *pdata = nullptr; + CHECK_GPU_ERROR(cudaMalloc((void **)&pdata, byte_size)); + return pdata; +} + +template float *cuda_malloc(size_t ele_num); + +template __half *cuda_malloc<__half>(size_t ele_num); + +template uint8_t *cuda_malloc(size_t ele_num); + +void cuda_free(void *pdata) { + if (pdata != nullptr) { + cudaFree(pdata); + } +} + +template +struct _isnan { + __device__ bool operator()(T a) const { return isnan(a); } +}; + +template <> +struct _isnan<__half> { + __device__ bool operator()(const __half a) const { return __hisnan(a); } +}; + +template +struct _isinf { + __device__ bool operator()(T a) const { return isinf(a); } +}; + +template <> +struct _isinf<__half> { + __device__ bool operator()(const __half a) const { return __hisinf(a); } +}; + +template +void check_nan_inf(const T *data_ptr, int dsize, bool check_nan_inf, + std::string file, int line, cudaStream_t stream) { + // check_nan_inf = 0 for checking nan + // check_nan_inf = 1 for checking inf + bool res = false; + std::string msg = file + "(" + std::to_string(line) + "): "; + if (check_nan_inf) { + msg += "nan."; + res = thrust::transform_reduce(thrust::cuda::par.on(stream), data_ptr, + data_ptr + dsize, _isnan(), false, + thrust::logical_or()); + } else { + msg += "inf."; + res = thrust::transform_reduce(thrust::cuda::par.on(stream), data_ptr, + data_ptr + dsize, _isinf(), false, + thrust::logical_or()); + } + if (res) { + throw std::runtime_error(msg); + } + std::cout << msg << " [check pass]." << std::endl; +} + +template void check_nan_inf(const float *data_ptr, int dsize, + bool check_nan_inf, std::string file, + int line, cudaStream_t stream); + +template void check_nan_inf<__half>(const __half *data_ptr, int dsize, + bool check_nan_inf, std::string file, + int line, cudaStream_t stream); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu new file mode 100644 index 0000000000000000000000000000000000000000..a39a6dae0f7fb6968e6ee65fde8db4bbc5d61ab0 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu @@ -0,0 +1,1002 @@ +#include +#include + +#include "kernels.h" + +#include + + +namespace cg = cooperative_groups; + +curandStatePhilox4_32_10_t *curandstate; + +/** + * @brief element-wise activation function on device, like Relu, Gelu + * + * @tparam enum class ActivationType, kRelu, kGelu + * @tparam input type + * @param any shape of float and __half2 + * @return same shape and type with input + */ +template +__forceinline__ __device__ T activation_kernel(T x); + +template <> +__device__ float activation_kernel(float x) { + float cdf = + 0.5f * + (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); + return x * cdf; +} + +template <> +__device__ __half2 +activation_kernel(__half2 val) { + __half2 val_pow3 = __hmul2(val, __hmul2(val, val)); + float2 tmp_pow = __half22float2(val_pow3); + float2 tmp = __half22float2(val); + + tmp.x = + 0.5f * + (1.0f + tanhf((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x)))); + tmp.y = + 0.5f * + (1.0f + tanhf((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y)))); + return __hmul2(val, __float22half2_rn(tmp)); +} + +template <> +__device__ float activation_kernel(float x) { + return fmaxf(x, 0); +} + +template <> +__device__ __half2 +activation_kernel(__half2 x) { + return __floats2half2_rn(fmaxf(0.f, __half2float(x.x)), + fmaxf(0.f, __half2float(x.y))); +} + +/** + * @brief element-wise activation backward function on device + * + * @tparam enum class ActivationType + * @tparam input type + * @param any shape of float and __half2 + * @return same shape of input + */ +template +__forceinline__ __device__ T activation_bwd_kernel(T grad, T x); + +template <> +__device__ float activation_bwd_kernel(float grad, + float x) { + const float sqrt_param = 0.79788456080286535587989211986876f; + const float mul_param = 0.044715; + + float x2mul = x * x * mul_param; + float tan_h = tanhf(sqrt_param * (x + x * x2mul)); + float dg1 = 0.5f * (1.0f + tan_h); + float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h); + float dg3 = dg2 * 3 * x2mul; + return grad * (dg1 + dg2 + dg3); +} + +template <> +__device__ __half activation_bwd_kernel( + __half grad, __half x_half) { + float x = __half2float(x_half); + const float sqrt_param = 0.79788456080286535587989211986876f; + const float mul_param = 0.044715; + + float x2mul = x * x * mul_param; + float tan_h = tanhf(sqrt_param * (x + x * x2mul)); + float dg1 = 0.5f * (1.0f + tan_h); + float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h); + float dg3 = dg2 * 3 * x2mul; + return grad * __float2half(dg1 + dg2 + dg3); +} + +template <> +__device__ float activation_bwd_kernel(float grad, + float x) { + return x > 0.f ? grad : 0.f; +} + +template <> +__device__ __half +activation_bwd_kernel(__half grad, __half x) { + const __half half_zero = __float2half(0.f); + return x > half_zero ? grad : half_zero; +} + +template <> +__device__ __half2 activation_bwd_kernel( + __half2 grad2, __half2 x_half2) { + const __half half_zero = __float2half(0.f); + return __floats2half2_rn(x_half2.x > half_zero ? grad2.x : half_zero, + x_half2.y > half_zero ? grad2.y : half_zero); +} + +/** + * @brief init curand states in global memory + * + * @thread grid_dim * block*dim to suuport any size of states + * @param state persistant curand states + * @param seed seed to init states + * @return void + */ +__global__ void curand_init_kernel(curandStatePhilox4_32_10_t *state, + int seed) { + /* Each thread gets same seed, a different sequence + number, no offset */ + int id = threadIdx.x + blockIdx.x * blockDim.x; + curand_init(seed, id, 0, &state[id]); +} + +void launch_curand_init(int total_count, int dim, cudaStream_t stream) { + cudaMalloc(&curandstate, total_count * sizeof(curandStatePhilox4_32_10_t)); + int grid_dim = total_count >> 9; + curand_init_kernel<<>>( + curandstate, std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count()); +} + +/** + * @brief element-wise dropout, store dropped position in mask, it's not + * in-place + * + * @thread + * gridDim.x = total_count / 1024 + * blockDim.x = 1024 + * + * @param total_count total elements + * @param ratio drop ratio + * @param out any size of float and __half + * @param in same with out + * @param mask uint8 type, same size with out + * @param seed seed to curand + * @return void + */ +__global__ void ls_dropout_kernel(const int total_count, const float ratio, + float *__restrict__ out, + const float *__restrict__ in, + uint8_t *__restrict__ mask, const int seed) { + const float scale = 1.f / (1.f - ratio); + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 4 >= total_count) return; + + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + uint8_t m[4]; + + float4 *out4 = reinterpret_cast(out); + const float4 *data4 = reinterpret_cast(in); + uint32_t *mask4 = reinterpret_cast(mask); + float4 rand = curand_uniform4(&state); + + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + + uint32_t *m4 = reinterpret_cast(m); + mask4[i] = m4[0]; + + float4 input4 = data4[i]; + float4 res4; + res4.x = input4.x * scale * m[0]; + res4.y = input4.y * scale * m[1]; + res4.z = input4.z * scale * m[2]; + res4.w = input4.w * scale * m[3]; + out4[i] = res4; +} + +__global__ void ls_dropout_kernel(const int total_count, const float ratio, + __half *__restrict__ out, + const __half *__restrict__ in, + uint8_t *__restrict__ mask, const int seed) { + const float scale = 1.f / (1.f - ratio); + + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 8 >= total_count) return; + + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + + const float4 *vals_float4 = reinterpret_cast(in); + float4 *outs_float4 = reinterpret_cast(out); + uint64_t *mask8 = reinterpret_cast(mask); + + uint8_t m[8]; + float4 rand = curand_uniform4(&state); + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + rand = curand_uniform4(&state); + m[4] = (uint8_t)(rand.x > ratio); + m[5] = (uint8_t)(rand.y > ratio); + m[6] = (uint8_t)(rand.z > ratio); + m[7] = (uint8_t)(rand.w > ratio); + uint64_t *m8 = reinterpret_cast(m); + mask8[i] = *m8; + + float4 val_float4 = vals_float4[i]; + float4 out_float4; + __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); + __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); + __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]); + __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]); + __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]); + __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]); + out_half2[0] = __hmul2(val_half2[0], scale_mask_1); + out_half2[1] = __hmul2(val_half2[1], scale_mask_2); + out_half2[2] = __hmul2(val_half2[2], scale_mask_3); + out_half2[3] = __hmul2(val_half2[3], scale_mask_4); + outs_float4[i] = out_float4; +} + +/** + * @brief element-wise dropout backward with dropout mask, it's + * not in-place + * + * @thread + * gridDim.x = total_count / 1024 + * blockDim.x = 1024 + * + * @param total_count total elements + * @param ratio drop ratio + * @param in any size of float and __half + * @param mask uint8 type, same size with in + * @return void + */ +__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, + float *out, const float *in, + const uint8_t *__restrict__ mask) { + const float scale = 1.f / (1.f - ratio); + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 4 >= total_count) return; + + uint8_t m[4]; + + float4 *out4 = reinterpret_cast(out); + const float4 *in4 = reinterpret_cast(in); + const uint32_t *mask4 = reinterpret_cast(mask); + + uint32_t *m4 = reinterpret_cast(m); + m4[0] = mask4[i]; + + float4 input4 = in4[i]; + float4 res4; + res4.x = input4.x * scale * static_cast(m[0]); + res4.y = input4.y * scale * static_cast(m[1]); + res4.z = input4.z * scale * static_cast(m[2]); + res4.w = input4.w * scale * static_cast(m[3]); + out4[i] = res4; +} + +__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, + __half *out, const __half *in, + const uint8_t *__restrict__ mask) { + const __half scale = 1.f / (1.f - ratio); + + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 8 >= total_count) return; + + float4 *out4 = reinterpret_cast(out); + const float4 *vals_float4 = reinterpret_cast(in); + const uint64_t *mask8 = reinterpret_cast(mask); + + uint8_t m[8]; + uint64_t *m8 = reinterpret_cast(m); + m8[0] = mask8[i]; + + float4 val_float4 = vals_float4[i]; + float4 out_float4; + __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); + __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); + __half2 scale_mask_1 = + __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1])); + __half2 scale_mask_2 = + __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3])); + __half2 scale_mask_3 = + __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5])); + __half2 scale_mask_4 = + __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7])); + out_half2[0] = __hmul2(val_half2[0], scale_mask_1); + out_half2[1] = __hmul2(val_half2[1], scale_mask_2); + out_half2[2] = __hmul2(val_half2[2], scale_mask_3); + out_half2[3] = __hmul2(val_half2[3], scale_mask_4); + out4[i] = out_float4; +} + +template <> +void launch_ls_dropout(float *out, const float *vals, uint8_t *mask, + int total_count, float ratio, cudaStream_t stream, + bool backward) { + int grid_dim = total_count >> 12; + if (!backward) { + ls_dropout_kernel<<>>( + total_count, ratio, out, vals, mask, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count()); + } else { + ls_dropout_bwd_kernel<<>>(total_count, ratio, + out, vals, mask); + } +} + +template <> +void launch_ls_dropout<__half>(__half *out, const __half *vals, uint8_t *mask, + int total_count, float ratio, + cudaStream_t stream, bool backward) { + int grid_dim = total_count >> 13; + if (!backward) { + ls_dropout_kernel<<>>( + total_count, ratio, out, vals, mask, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count()); + } else { + ls_dropout_bwd_kernel<<>>(total_count, ratio, + out, vals, mask); + } +} + +/** + * @brief fused bias, dropout, and residual at the end of Attention and FFN, + * store dropped position in mask, it's not in-place + * + * @thread + * gridDim.x = total_count / 1024 + * blockDim.x = 1024 + * + * @param total_count total elements + * @param ratio drop ratio + * @param out [batch_size, seq_len, hidden_size], float and __half + * @param in [batch_size, seq_len, hidden_size], float and __half + * @param mask [batch_size, seq_len, hidden_size], uint8 type + * @param bias [hidden_size], ffn bias + * @param residual [batch_size, seq_len, hidden_size], float and __half + * @param seed seed to curand + * @param hidden_size hidden size + * @return void + */ +__global__ void ls_dropout_res_bias_kernel( + const int total_count, const float ratio, float *__restrict__ out, + const float *__restrict__ in, uint8_t *__restrict__ mask, + const float *__restrict__ bias, const float *__restrict__ residual, + const int seed, const int hidden_size) { + const float scale = 1.f / (1.f - ratio); + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 4 >= total_count) return; + + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + uint8_t m[4]; + + float4 *out4 = reinterpret_cast(out); + const float4 *data4 = reinterpret_cast(in); + const float4 *residual4 = reinterpret_cast(residual); + const float4 *bias4 = reinterpret_cast(bias); + uint32_t *mask4 = reinterpret_cast(mask); + float4 rand = curand_uniform4(&state); + + m[0] = static_cast(rand.x > ratio); + m[1] = static_cast(rand.y > ratio); + m[2] = static_cast(rand.z > ratio); + m[3] = static_cast(rand.w > ratio); + + int bias_i = i % (hidden_size >> 2); + uint32_t *m4 = reinterpret_cast(m); + mask4[i] = m4[0]; + const float4 input4 = data4[i]; + const float4 b4 = __ldg(&bias4[bias_i]); + const float4 res4 = residual4[i]; + float4 output4; + + output4.x = (input4.x + b4.x) * scale * m[0] + res4.x; + output4.y = (input4.y + b4.y) * scale * m[1] + res4.y; + output4.z = (input4.z + b4.z) * scale * m[2] + res4.z; + output4.w = (input4.w + b4.w) * scale * m[3] + res4.w; + + out4[i] = output4; +} + +__global__ void ls_dropout_res_bias_kernel( + const int total_count, const float ratio, __half *__restrict__ out, + const __half *__restrict__ in, uint8_t *__restrict__ mask, + const __half *__restrict__ bias, const __half *__restrict__ residual, + const int seed, const int hidden_size) { + const __half scale = 1. / (1. - ratio); + + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 8 >= total_count) return; + + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + + const float4 *vals_float4 = reinterpret_cast(in); + float4 *outs_float4 = reinterpret_cast(out); + const float4 *residual4 = reinterpret_cast(residual); + const float4 *bias4 = reinterpret_cast(bias); + uint64_t *mask8 = reinterpret_cast(mask); + + uint8_t m[8]; + float4 rand = curand_uniform4(&state); + m[0] = static_cast(rand.x > ratio); + m[1] = static_cast(rand.y > ratio); + m[2] = static_cast(rand.z > ratio); + m[3] = static_cast(rand.w > ratio); + rand = curand_uniform4(&state); + m[4] = static_cast(rand.x > ratio); + m[5] = static_cast(rand.y > ratio); + m[6] = static_cast(rand.z > ratio); + m[7] = static_cast(rand.w > ratio); + uint64_t *m8 = reinterpret_cast(m); + mask8[i] = m8[0]; + + int bias_i = i % (hidden_size >> 3); + float4 val_float4 = vals_float4[i]; + const float4 b4 = __ldg(&bias4[bias_i]); + const float4 res4 = residual4[i]; + float4 out_float4; + + __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); + __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); + const __half2 *b_half2 = reinterpret_cast(&b4); + const __half2 *res_half2 = reinterpret_cast(&res4); + __half2 scale_mask_1 = + __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1])); + __half2 scale_mask_2 = + __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3])); + __half2 scale_mask_3 = + __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5])); + __half2 scale_mask_4 = + __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7])); + out_half2[0] = + __hfma2(__hadd2(val_half2[0], b_half2[0]), scale_mask_1, res_half2[0]); + out_half2[1] = + __hfma2(__hadd2(val_half2[1], b_half2[1]), scale_mask_2, res_half2[1]); + out_half2[2] = + __hfma2(__hadd2(val_half2[2], b_half2[2]), scale_mask_3, res_half2[2]); + out_half2[3] = + __hfma2(__hadd2(val_half2[3], b_half2[3]), scale_mask_4, res_half2[3]); + outs_float4[i] = out_float4; +} + +template <> +void launch_ls_dropout_res_bias(float *out, const float *vals, + uint8_t *mask, const float *bias, + const float *residual, int total_count, + int dim, float ratio, + cudaStream_t stream) { + int grid_dim = total_count >> 12; + ls_dropout_res_bias_kernel<<>>( + total_count, ratio, out, vals, mask, bias, residual, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(), + dim); +} + +template <> +void launch_ls_dropout_res_bias<__half>(__half *out, const __half *vals, + uint8_t *mask, const __half *bias, + const __half *residual, int total_count, + int dim, float ratio, + cudaStream_t stream) { + int grid_dim = total_count >> 13; + ls_dropout_res_bias_kernel<<>>( + total_count, ratio, out, vals, mask, bias, residual, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(), + dim); +} + +/** + * @brief fused bias and dropout backward at the end of Attention and FFN + * + * @thread + * gridDim.x = hidden_size / 8 + * blockDim.x = 8 + * blockDim.y = 1024 / 8 = 128 + * + * @param row_size batch_size * seq_len + * @param ratio dropout ratio + * @param in_grad [batch_size, seq_len, hidden_size], input grad + * @param bias_grad [hidden_size], bias grad + * @param out_grad [batch_size, seq_len, hidden_size], output grad + * @param mask [batch_size, seq_len, hidden_size], dropout mask + * @param hidden_size + * @return void + */ +__global__ void ls_dropout_bias_bwd_kernel( + const int row_size, const float ratio, float *__restrict__ in_grad, + float *__restrict__ bias_grad, const float *__restrict__ out_grad, + const uint8_t *__restrict__ mask, const int hidden_size) { + const float scale = 1.f / (1.f - ratio); + // every block generate 8 bias result + __shared__ float tile[8][129]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8); + int stride = hidden_size * 128; + float local_sum = 0; + + int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); + for (int r = threadIdx.y; r < row_size; r += 128) { + float val = out_grad[idx]; + val *= scale * static_cast(mask[idx]); + local_sum += val; + in_grad[idx] = val; + idx += stride; + } + + tile[threadIdx.x][threadIdx.y] = local_sum; + __syncthreads(); + + float sum = 0; + int tid = threadIdx.y * blockDim.x + threadIdx.x; + int x = tid >> 7; + int y = tid & (127); + if (y < 32) { +#pragma unroll + for (int i = 0; i < 4; i++) { + sum += tile[x][y + i * 32]; + } + } + __syncthreads(); + + for (int i = 1; i < 32; i <<= 1) sum += g.shfl_down(sum, i); + + if (y == 0) tile[0][x] = sum; + __syncthreads(); + + if (threadIdx.x < 8) { + int pos = flat_2dim(blockIdx.x, threadIdx.x, 8); + bias_grad[pos] = tile[0][threadIdx.x]; + } +} + +__global__ void ls_dropout_bias_bwd_kernel( + const int row_size, const float ratio, __half *__restrict__ in_grad, + __half *__restrict__ bias_grad, const __half *__restrict__ out_grad, + const uint8_t *__restrict__ mask, const int hidden_size) { + const __half2 scale = __float2half2_rn(1.f / (1.f - ratio)); + __shared__ __half2 tile[8][129]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad); + const __half2 *out_grad2 = reinterpret_cast(out_grad); + __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad); + + int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8); + int stride = hidden_size * 128; + __half2 local_sum = __float2half2_rn(0.f); + + int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); + for (int r = threadIdx.y; r < row_size; r += 128) { + __half2 val = out_grad2[idx]; + __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]); + val *= scale * m2; + local_sum += val; + in_grad2[idx] = val; + idx += stride; + } + + tile[threadIdx.x][threadIdx.y] = local_sum; + __syncthreads(); + + __half2 sum = __float2half2_rn(0.f); + int tid = threadIdx.y * blockDim.x + threadIdx.x; + int x = tid >> 7; + int y = tid & (127); + if (y < 32) { +#pragma unroll + for (int i = 0; i < 4; i++) { + sum += tile[x][y + i * 32]; + } + } + __syncthreads(); + + for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); + + if (y == 0) tile[0][x] = sum; + __syncthreads(); + + if (threadIdx.x < 8) { + int pos = flat_2dim(blockIdx.x, threadIdx.x, 8); + bias_grad2[pos] = tile[0][threadIdx.x]; + } +} + +template +void launch_ls_dropout_bias_bwd(T *in_grad, T *bias_grad, const T *out_grad, + const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream) { + dim3 grid_dim((dim - 1) / 8 + 1); + dim3 block_dim(8, 128); + ls_dropout_bias_bwd_kernel<<>>( + row_size, ratio, in_grad, bias_grad, out_grad, mask, dim); +} + +template <> +void launch_ls_dropout_bias_bwd(__half *in_grad, __half *bias_grad, + const __half *out_grad, const uint8_t *mask, + int row_size, int dim, float ratio, + cudaStream_t stream) { + dim >>= 1; + dim3 grid_dim((dim - 1) / 8 + 1); + dim3 block_dim(8, 128); + ls_dropout_bias_bwd_kernel<<>>( + row_size, ratio, in_grad, bias_grad, out_grad, mask, dim); +} + +template void launch_ls_dropout_bias_bwd(float *in_grad, float *bias_grad, + const float *out_grad, + const uint8_t *mask, int row_size, + int dim, float ratio, + cudaStream_t stream); + +/** + * @brief fused bias, activation, and dropout at the end of first ffn + * + * @thread + * gridDim.x = hidden_size / 8 + * blockDim.x = 8 + * blockDim.y = 1024 / 8 = 128 + * + * @tparam act_type activation function, like kRelu, kGelu + * @param total_count total elements + * @param ratio drop ratio + * @param out [batch_size, seq_len, hidden_size], float and __half + * @param in [batch_size, seq_len, hidden_size], float and __half + * @param mask [batch_size, seq_len, hidden_size], uint8 type + * @param bias [hidden_size], ffn bias + * @param seed seed to curand + * @param hidden_size + * @return void + */ +template +__global__ void ls_dropout_act_bias_kernel( + const int total_count, const float ratio, float *__restrict__ out, + const float *__restrict__ in, uint8_t *__restrict__ mask, + const float *__restrict__ bias, const int seed, const int hidden_size) { + const float scale = 1.f / (1.f - ratio); + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 4 >= total_count) return; + + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + uint8_t m[4]; + + float4 *out4 = reinterpret_cast(out); + const float4 *data4 = reinterpret_cast(in); + const float4 *bias4 = reinterpret_cast(bias); + uint32_t *mask4 = reinterpret_cast(mask); + float4 rand = curand_uniform4(&state); + + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + + int bias_i = i % (hidden_size >> 2); + uint32_t *m4 = reinterpret_cast(m); + mask4[i] = m4[0]; + const float4 input4 = data4[i]; + const float4 b4 = __ldg(&bias4[bias_i]); + float4 output4; + + output4.x = + activation_kernel(input4.x + b4.x) * scale * m[0]; + output4.y = + activation_kernel(input4.y + b4.y) * scale * m[1]; + output4.z = + activation_kernel(input4.z + b4.z) * scale * m[2]; + output4.w = + activation_kernel(input4.w + b4.w) * scale * m[3]; + + out4[i] = output4; +} + +template +__global__ void ls_dropout_act_bias_kernel( + const int total_count, const float ratio, __half *__restrict__ out, + const __half *__restrict__ in, uint8_t *__restrict__ mask, + const __half *__restrict__ bias, const int seed, const int hidden_size) { + const float scale = 1.f / (1.f - ratio); + + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 8 >= total_count) return; + + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + + const float4 *vals_float4 = reinterpret_cast(in); + float4 *outs_float4 = reinterpret_cast(out); + const float4 *bias4 = reinterpret_cast(bias); + uint64_t *mask8 = reinterpret_cast(mask); + + uint8_t m[8]; + float4 rand = curand_uniform4(&state); + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + rand = curand_uniform4(&state); + m[4] = (uint8_t)(rand.x > ratio); + m[5] = (uint8_t)(rand.y > ratio); + m[6] = (uint8_t)(rand.z > ratio); + m[7] = (uint8_t)(rand.w > ratio); + uint64_t *m8 = reinterpret_cast(m); + mask8[i] = *m8; + + int bias_i = i % (hidden_size >> 3); + float4 val_float4 = vals_float4[i]; + const float4 b4 = __ldg(&bias4[bias_i]); + float4 out_float4; + + __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); + __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); + const __half2 *b_half2 = reinterpret_cast(&b4); + + __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]); + __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]); + __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]); + __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]); + out_half2[0] = __hmul2( + activation_kernel(__hadd2(val_half2[0], b_half2[0])), + scale_mask_1); + out_half2[1] = __hmul2( + activation_kernel(__hadd2(val_half2[1], b_half2[1])), + scale_mask_2); + out_half2[2] = __hmul2( + activation_kernel(__hadd2(val_half2[2], b_half2[2])), + scale_mask_3); + out_half2[3] = __hmul2( + activation_kernel(__hadd2(val_half2[3], b_half2[3])), + scale_mask_4); + outs_float4[i] = out_float4; +} + +template <> +void launch_ls_dropout_act_bias( + float *out, const float *vals, uint8_t *mask, const float *bias, + int total_count, int dim, float ratio, cudaStream_t stream) { + int grid_dim = total_count >> 10; + ls_dropout_act_bias_kernel + <<>>( + total_count, ratio, out, vals, mask, bias, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(), + dim); +} + +template <> +void launch_ls_dropout_act_bias( + __half *out, const __half *vals, uint8_t *mask, const __half *bias, + int total_count, int dim, float ratio, cudaStream_t stream) { + int grid_dim = total_count >> 11; + ls_dropout_act_bias_kernel + <<>>( + total_count, ratio, out, vals, mask, bias, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(), + dim); +} + +template <> +void launch_ls_dropout_act_bias( + float *out, const float *vals, uint8_t *mask, const float *bias, + int total_count, int dim, float ratio, cudaStream_t stream) { + int grid_dim = total_count >> 10; + ls_dropout_act_bias_kernel + <<>>( + total_count, ratio, out, vals, mask, bias, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(), + dim); +} + +template <> +void launch_ls_dropout_act_bias( + __half *out, const __half *vals, uint8_t *mask, const __half *bias, + int total_count, int dim, float ratio, cudaStream_t stream) { + int grid_dim = total_count >> 11; + ls_dropout_act_bias_kernel + <<>>( + total_count, ratio, out, vals, mask, bias, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(), + dim); +} + +/** + * @brief fused bias, activation, and dropout backward + * + * @thread + * gridDim.x = total_count / 1024 + * blockDim.x = 1024 + * + * @tparam act_type kRelu + * @param row_size batch_size * seq_len + * @param ratio dropout ratio + * @param in_grad [batch_size, seq_len, hidden_size], input grad + * @param bias_grad [hidden_size], bias grad + * @param out_grad [batch_size, seq_len, hidden_size], output grad + * @param mask [batch_size, seq_len, hidden_size], dropout mask + * @param hidden_size + * @return void + */ +template +__global__ void ls_dropout_act_bias_bwd_kernel( + const int row_size, const float ratio, T *in_grad, + T *__restrict__ bias_grad, const T *__restrict__ input, + const T *__restrict__ bias, const T *out_grad, + const uint8_t *__restrict__ mask, const int hidden_size) { + const float scale = 1.f / (1.f - ratio); + __shared__ float tile[WARP_SIZE][WARP_SIZE + 1]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); + + int stride = hidden_size * WARP_SIZE; + float local_sum = 0; + + int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); + if (col_idx < hidden_size) { + for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) { + float val = out_grad[idx]; + float in = input[idx]; + float b = bias[idx % hidden_size]; + val = activation_bwd_kernel( + val * scale * static_cast(mask[idx]), in + b); + local_sum += val; + in_grad[idx] = val; + idx += stride; + } + } + + tile[threadIdx.x][threadIdx.y] = local_sum; + __syncthreads(); + float sum = tile[threadIdx.y][threadIdx.x]; + __syncthreads(); + + for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); + + if (threadIdx.x == 0) tile[0][threadIdx.y] = sum; + __syncthreads(); + + if (threadIdx.y == 0) { + int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); + bias_grad[pos] = tile[0][threadIdx.x]; + } +} + +// @brief fused bias, activation, and dropout backward +// It is deprecated for precision reason. Keep it for future optimization. +// +// template +// __global__ void ls_dropout_act_bias_bwd_kernel( +// const int row_size, const float ratio, __half * in_grad, +// __half *__restrict__ bias_grad, const __half *__restrict__ input, const +// __half *__restrict__ bias, const __half * out_grad, const uint8_t +// *__restrict__ mask, const int hidden_size) { +// const __half2 scale = __float2half2_rn(1.f / (1.f - ratio)); +// __shared__ __half2 tile[WARP_SIZE][WARP_SIZE + 1]; + +// cg::thread_block b = cg::this_thread_block(); +// cg::thread_block_tile g = cg::tiled_partition(b); + +// __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad); +// __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad); +// const __half2 *out_grad2 = reinterpret_cast(out_grad); +// const __half2 *input2 = reinterpret_cast(input); +// const __half2 *bias2 = reinterpret_cast(bias); + +// int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); + +// int stride = hidden_size * WARP_SIZE; +// __half2 local_sum = __float2half2_rn(0.f); + +// int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); +// if (col_idx < hidden_size) { +// for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) { +// __half2 val = out_grad2[idx]; +// __half2 in2 = input2[idx]; +// __half2 b2 = bias2[idx % hidden_size ]; +// __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]); +// val = activation_bwd_kernel(val * scale +// * +// m2, +// in2+b2); +// local_sum += val; +// in_grad2[idx] = val; +// idx += stride; +// } +// } + +// tile[threadIdx.x][threadIdx.y] = local_sum; +// __syncthreads(); +// __half2 sum = tile[threadIdx.y][threadIdx.x]; +// __syncthreads(); + +// for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); + +// if (threadIdx.x == 0) tile[0][threadIdx.y] = sum; +// __syncthreads(); + +// if (threadIdx.y == 0) { +// int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); +// bias_grad2[pos] = tile[0][threadIdx.x]; +// } +// } + +template +void launch_ls_dropout_act_bias_bwd(T *in_grad, T *bias_grad, const T *input, + const T *bias, const T *out_grad, + const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream) { + dim3 grid_dim((dim - 1) / WARP_SIZE + 1); + dim3 block_dim(WARP_SIZE, WARP_SIZE); + ls_dropout_act_bias_bwd_kernel<<>>( + row_size, ratio, in_grad, bias_grad, input, bias, out_grad, mask, dim); +} + +// template <> +// void launch_ls_dropout_act_bias_bwd( +// __half *in_grad, __half *bias_grad,const __half *input, const __half +// *bias, const __half *out_grad, const uint8_t *mask, int row_size, int +// dim, float ratio, cudaStream_t stream) { +// dim >>= 1; +// dim3 grid_dim((dim - 1) / WARP_SIZE + 1); +// dim3 block_dim(WARP_SIZE, WARP_SIZE); +// ls_dropout_act_bias_bwd_kernel +// <<>>(row_size, ratio, in_grad, +// bias_grad, +// input, bias,out_grad, mask, dim); +// } + +template void launch_ls_dropout_act_bias_bwd( + float *in_grad, float *bias_grad, const float *input, const float *bias, + const float *out_grad, const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream); + +template void launch_ls_dropout_act_bias_bwd( + __half *in_grad, __half *bias_grad, const __half *input, const __half *bias, + const __half *out_grad, const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream); + +template void launch_ls_dropout_act_bias_bwd( + float *in_grad, float *bias_grad, const float *input, const float *bias, + const float *out_grad, const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream); + +template void launch_ls_dropout_act_bias_bwd( + __half *in_grad, __half *bias_grad, const __half *input, const __half *bias, + const __half *out_grad, const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu new file mode 100644 index 0000000000000000000000000000000000000000..bc90c54c0a004a3a847968a0d62c7ad8a999dcb3 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu @@ -0,0 +1,232 @@ +#include + +#include "kernels.h" + +namespace cg = cooperative_groups; + +/** +@brief: fuse_transpose_bias +Calculate the sum of elements in each column of the matrix. + +@thread +gridDim.x = ceil(cols / WARP_SIZE) +blockDim.x = WARP_SIZE +blockDim.y = WARP_SIZE + +@param +inp: [rows, cols] +out: [cols] +rows: the number of rows in the matrix +cols: the number of cols in the matrix +*/ +template +__global__ void column_sum_reduce(const T *__restrict__ inp, + T *__restrict__ out, int rows, int cols) { + __shared__ float tile[WARP_SIZE][WARP_SIZE]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); + int y_stride = cols * WARP_SIZE; + float localSum = 0; + + // Loop across matrix row + // TODO: optimize to log complexity + if (idx < cols) { + int offset = flat_2dim(threadIdx.y, idx, cols); + for (int r = threadIdx.y; r < rows; r += WARP_SIZE) { + localSum += (float)inp[offset]; + offset += y_stride; + } + } + + // The sum of a row in tile is equal to the sum of a col in original matrix + tile[threadIdx.x][threadIdx.y] = localSum; + + __syncthreads(); + + // Sum the shared buffer. + // The change of threadIdx.x is continuous + float sum = tile[threadIdx.y][threadIdx.x]; + + __syncthreads(); + + // Calculate the sum of a row in tile + for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); + + if (threadIdx.x == 0) { + int pos = flat_2dim(blockIdx.x, threadIdx.y, WARP_SIZE); + if (pos < cols) out[pos] = sum; + } +} + +// [r, c] -> [c] +template <> +void launch_fuse_transpose_bias_kernel(const float *inp, float *out, + int rows, int cols, + cudaStream_t stream) { + dim3 grid_dim((cols - 1) / WARP_SIZE + 1); + dim3 block_dim(WARP_SIZE, WARP_SIZE); + + column_sum_reduce + <<>>(inp, out, rows, cols); +} + +template <> +void launch_fuse_transpose_bias_kernel<__half>(const __half *inp, __half *out, + int rows, int cols, + cudaStream_t stream) { + dim3 grid_dim((cols - 1) / WARP_SIZE + 1); + dim3 block_dim(WARP_SIZE, WARP_SIZE); + + column_sum_reduce<__half> + <<>>(inp, out, rows, cols); +} + +/** +@brief: fused_add2 +Add two matrix inp1 and inp2 to out. + +@thread +gridDim.x = batch_size * seq_len +blockDim.x = min(hidden_dim, MAX_THREADS) + +@param +inp1: [batch_size, seq_len, hidden_dim] +inp2: [batch_size, seq_len, hidden_dim] +out: [batch_size, seq_len, hidden_dim] +batch_size: the size of the current batch +seq_len: the sequence length of the current batch +hidden_dim: dim of the hidden tensor +*/ +template +__global__ void fused_add2_kernel(T *out, const T *inp1, const T *inp2, + int hidden_dim); + +template <> +__global__ void fused_add2_kernel(float *out, const float *inp1, + const float *inp2, int hidden_dim) { + int row_id = blockIdx.x; + int offset = flat_2dim(row_id, 0, hidden_dim); + + const float4 *inp1_4 = reinterpret_cast(inp1); + const float4 *inp2_4 = reinterpret_cast(inp2); + float4 *out_4 = reinterpret_cast(out); + float4 vinp1; + float4 vinp2; + float4 val; + + for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { + vinp1 = inp1_4[offset + i]; + vinp2 = inp2_4[offset + i]; + val.x = vinp1.x + vinp2.x; + val.y = vinp1.y + vinp2.y; + val.z = vinp1.z + vinp2.z; + val.w = vinp1.w + vinp2.w; + out_4[offset + i] = val; + } +} + +template <> +__global__ void fused_add2_kernel<__half>(__half *out, const __half *inp1, + const __half *inp2, int hidden_dim) { + int row_id = blockIdx.x; + int offset = flat_2dim(row_id, 0, hidden_dim); + + const float4 *inp1_4 = reinterpret_cast(inp1); + const float4 *inp2_4 = reinterpret_cast(inp2); + float4 *out_4 = reinterpret_cast(out); + float4 vinp1; + float4 vinp2; + float4 val; + __half2 *h2_inp1 = reinterpret_cast<__half2 *>(&vinp1); + __half2 *h2_inp2 = reinterpret_cast<__half2 *>(&vinp2); + __half2 *h2_val = reinterpret_cast<__half2 *>(&val); + + for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { + vinp1 = inp1_4[offset + i]; + vinp2 = inp2_4[offset + i]; + h2_val[0] = __hadd2(h2_inp1[0], h2_inp2[0]); + h2_val[1] = __hadd2(h2_inp1[1], h2_inp2[1]); + h2_val[2] = __hadd2(h2_inp1[2], h2_inp2[2]); + h2_val[3] = __hadd2(h2_inp1[3], h2_inp2[3]); + out_4[offset + i] = val; + } +} + +//[b, s, h] -> [b, s, h] +template <> +void launch_fused_add2(float *out, const float *inp1, const float *inp2, + int batch_size, int seq_len, int hidden_dim, + cudaStream_t &stream) { + hidden_dim >>= 2; + + dim3 grid_dim(batch_size * seq_len); + dim3 block_dim(min(hidden_dim, MAX_THREADS)); + + fused_add2_kernel<<>>(out, inp1, inp2, + hidden_dim); +} + +template <> +void launch_fused_add2<__half>(__half *out, const __half *inp1, + const __half *inp2, int batch_size, int seq_len, + int hidden_dim, cudaStream_t &stream) { + hidden_dim >>= 3; + + dim3 grid_dim(batch_size * seq_len); + dim3 block_dim(min(hidden_dim, MAX_THREADS)); + + fused_add2_kernel<<>>(out, inp1, inp2, + hidden_dim); +} + +template +__global__ void kernel_concat3_dim1(const T *inp1, const T *inp2, T *output, + int sz0, int sz2, int sz1_1, int sz1_2) { + int nele = sz0 * sz2 * (sz1_1 + sz1_2); + int idx = flat_2dim(blockIdx.x, threadIdx.x, blockDim.x); + if (idx >= nele) { + return; + } + float4 *dst_ptr = (float4 *)output + idx; + int idx2 = idx % sz2; + idx = idx / sz2; + int idx1 = idx % (sz1_1 + sz1_2); + int idx0 = idx / (sz1_1 + sz1_2); + float4 *src_ptr = nullptr; + int sz1 = 0; + if (idx1 < sz1_1) { + sz1 = sz1_1; + src_ptr = (float4 *)inp1; + } else { + idx1 -= sz1_1; + sz1 = sz1_2; + src_ptr = (float4 *)inp2; + } + src_ptr += flat_3dim(idx0, idx1, idx2, sz1, sz2); + dst_ptr[0] = src_ptr[0]; +} + +template <> +void launch_concat3_dim1(const float *inp1, const float *inp2, + float *output, int sz0, int sz2, int sz1_1, + int sz1_2, cudaStream_t stream) { + sz2 >>= 2; + int nele = sz0 * sz2 * (sz1_1 + sz1_2); + int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS; + kernel_concat3_dim1<<>>( + inp1, inp2, output, sz0, sz2, sz1_1, sz1_2); +} + +template <> +void launch_concat3_dim1<__half>(const __half *inp1, const __half *inp2, + __half *output, int sz0, int sz2, int sz1_1, + int sz1_2, cudaStream_t stream) { + sz2 >>= 3; + int nele = sz0 * sz2 * (sz1_1 + sz1_2); + int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS; + kernel_concat3_dim1<<>>( + inp1, inp2, output, sz0, sz2, sz1_1, sz1_2); +} diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/block_reduce.h b/colossalai/kernel/cuda_native/csrc/kernels/include/block_reduce.h new file mode 100644 index 0000000000000000000000000000000000000000..38103c1734c86a67812e42a6382e7b5af0ca9b00 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/block_reduce.h @@ -0,0 +1,312 @@ +/* Copyright 2021 The LightSeq Team + Copyright Tencent/TurboTransformers + This block_reduce_n is adapted from Tencent/TurboTransformers +*/ +#pragma once +#include +#include +#include + +enum class ReduceType { kMax = 0, kSum }; +const unsigned int WARP_REDUCE_MASK = 0xffffffff; +const float REDUCE_FLOAT_INF_NEG = -100000000.f; +const float REDUCE_FLOAT_INF_POS = 100000000.f; +const unsigned int WARP_REDUCE_SIZE = 32; + +template +__forceinline__ __device__ T warpReduceSum(T val) { + for (int mask = (WARP_REDUCE_SIZE >> 1); mask > 0; mask >>= 1) + val += __shfl_xor_sync(WARP_REDUCE_MASK, val, mask, WARP_REDUCE_SIZE); + return val; +} + +/* Calculate the sum of all elements in a block */ +template +__forceinline__ __device__ T blockReduceSum(T val) { + static __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + val = warpReduceSum(val); + + if (lane == 0) shared[wid] = val; + __syncthreads(); + + val = (threadIdx.x < (blockDim.x >> 5)) ? shared[lane] : (T)0.0f; + val = warpReduceSum(val); + return val; +} + +template +__inline__ __device__ void blockReduce(float *pval); + +// use template to make code more concise +template +__inline__ __device__ void warpReduce(float *pval); + +// static +template <> +__inline__ __device__ void warpReduce(float *pval) { + *pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 16, 32)); + *pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 8, 32)); + *pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 4, 32)); + *pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 2, 32)); + *pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 1, 32)); +} + +template <> +__inline__ __device__ void warpReduce(float *pval) { + float val0_tmp, val1_tmp; +#define WarpReduceMaxOneStep(a, b) \ + val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval), a, b); \ + val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \ + *(pval) = max(val0_tmp, *(pval)); \ + *(pval + 1) = max(val1_tmp, *(pval + 1)); + + WarpReduceMaxOneStep(16, 32); + WarpReduceMaxOneStep(8, 32); + WarpReduceMaxOneStep(4, 32); + WarpReduceMaxOneStep(2, 32); + WarpReduceMaxOneStep(1, 32); +#undef WarpReduceMaxOneStep +} + +template <> +__inline__ __device__ void warpReduce(float *pval) { + *pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 16, 32); + *pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 8, 32); + *pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 4, 32); + *pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 2, 32); + *pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 1, 32); +} + +/* + * Unorll for loop for warpreduce to + * imporve instruction issue efficiency + * ElemX means there are X numbers to be summed + */ + +template <> +__inline__ __device__ void warpReduce(float *pval) { + float val0_tmp, val1_tmp; +#define WarpReduceSumOneStep(a, b) \ + val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \ + val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \ + *(pval + 0) += val0_tmp; \ + *(pval + 1) += val1_tmp + + WarpReduceSumOneStep(16, 32); + WarpReduceSumOneStep(8, 32); + WarpReduceSumOneStep(4, 32); + WarpReduceSumOneStep(2, 32); + WarpReduceSumOneStep(1, 32); + +#undef WarpReduceSumOneStep +} + +template <> +__inline__ __device__ void warpReduce(float *pval) { + float val0_tmp, val1_tmp, val2_tmp, val3_tmp; +#define WarpReduceSumOneStep(a, b) \ + val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \ + val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \ + val2_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 2), a, b); \ + val3_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 3), a, b); \ + *(pval + 0) += val0_tmp; \ + *(pval + 1) += val1_tmp; \ + *(pval + 2) += val2_tmp; \ + *(pval + 3) += val3_tmp + + WarpReduceSumOneStep(16, 32); + WarpReduceSumOneStep(8, 32); + WarpReduceSumOneStep(4, 32); + WarpReduceSumOneStep(2, 32); + WarpReduceSumOneStep(1, 32); +#undef WarpReduceSumOneStep +} + +template <> +__inline__ __device__ void blockReduce(float *pval) { + const int num = 1; + static __shared__ float shared[num][32]; + int lane_id = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + warpReduce(pval); + + if (lane_id == 0) { +#pragma unroll + for (int i = 0; i < num; ++i) { + shared[i][wid] = *(pval + i); + } + } + __syncthreads(); + + if (threadIdx.x < (blockDim.x >> 5)) { +#pragma unroll + for (int i = 0; i < num; ++i) { + *(pval + i) = shared[i][lane_id]; + } + } else { +#pragma unroll + for (int i = 0; i < num; ++i) { + *(pval + i) = 0.f; + } + } + warpReduce(pval); +} + +template <> +__inline__ __device__ void blockReduce(float *pval) { + const int num = 2; + static __shared__ float shared[num][32]; + int lane_id = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + warpReduce(pval); + + if (lane_id == 0) { +#pragma unroll + for (int i = 0; i < num; ++i) { + shared[i][wid] = *(pval + i); + } + } + __syncthreads(); + + if (threadIdx.x < (blockDim.x >> 5)) { +#pragma unroll + for (int i = 0; i < num; ++i) { + *(pval + i) = shared[i][lane_id]; + } + } else { +#pragma unroll + for (int i = 0; i < num; ++i) { + *(pval + i) = 0.f; + } + } + warpReduce(pval); +} + +template <> +__inline__ __device__ void blockReduce(float *pval) { + const int num = 4; + static __shared__ float shared[num][32]; + int lane_id = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + warpReduce(pval); + + if (lane_id == 0) { +#pragma unroll + for (int i = 0; i < num; ++i) { + shared[i][wid] = *(pval + i); + } + } + __syncthreads(); + + if (threadIdx.x < (blockDim.x >> 5)) { +#pragma unroll + for (int i = 0; i < num; ++i) { + *(pval + i) = shared[i][lane_id]; + } + } else { +#pragma unroll + for (int i = 0; i < num; ++i) { + *(pval + i) = 0.f; + } + } + warpReduce(pval); +} + +template <> +__inline__ __device__ void blockReduce(float *pval) { + const int num = 1; + static __shared__ float shared[num][32]; + int lane_id = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + warpReduce(pval); + + if (lane_id == 0) { +#pragma unroll + for (int i = 0; i < num; ++i) { + shared[i][wid] = *(pval + i); + } + } + __syncthreads(); + + if (threadIdx.x < (blockDim.x >> 5)) { +#pragma unroll + for (int i = 0; i < num; ++i) { + *(pval + i) = shared[i][lane_id]; + } + } else { +#pragma unroll + for (int i = 0; i < num; ++i) { + *(pval + i) = REDUCE_FLOAT_INF_NEG; + } + } + warpReduce(pval); +} + +template <> +__inline__ __device__ void blockReduce(float *pval) { + const int num = 1; + static __shared__ float shared[num][32]; + int lane_id = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + warpReduce(pval); + + if (lane_id == 0) { +#pragma unroll + for (int i = 0; i < num; ++i) { + shared[i][wid] = *(pval + i); + } + } + __syncthreads(); + + if (threadIdx.x < (blockDim.x >> 5)) { +#pragma unroll + for (int i = 0; i < num; ++i) { + *(pval + i) = shared[i][lane_id]; + } + } else { +#pragma unroll + for (int i = 0; i < num; ++i) { + *(pval + i) = REDUCE_FLOAT_INF_NEG; + } + } + warpReduce(pval); +} + +template <> +__inline__ __device__ void blockReduce(float *pval) { + const int num = 1; + static __shared__ float shared[num][32]; + int lane_id = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + warpReduce(pval); + + if (lane_id == 0) { +#pragma unroll + for (int i = 0; i < num; ++i) { + shared[i][wid] = *(pval + i); + } + } + __syncthreads(); + + if (threadIdx.x < (blockDim.x >> 5)) { +#pragma unroll + for (int i = 0; i < num; ++i) { + *(pval + i) = shared[i][lane_id]; + } + } else { +#pragma unroll + for (int i = 0; i < num; ++i) { + *(pval + i) = REDUCE_FLOAT_INF_NEG; + } + } + warpReduce(pval); +} diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/context.h b/colossalai/kernel/cuda_native/csrc/kernels/include/context.h new file mode 100644 index 0000000000000000000000000000000000000000..f7d75f38cc2b568db74c935ef26cc14afce312ef --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/context.h @@ -0,0 +1,36 @@ +#pragma once + +#include +#include + +#include +#include + +#include "cuda_util.h" + +class Context { + public: + Context() : _stream(nullptr) { + CHECK_GPU_ERROR(cublasCreate(&_cublasHandle)); + } + + virtual ~Context() {} + + static Context &Instance() { + static Context _ctx; + return _ctx; + } + + void set_stream(cudaStream_t stream) { + _stream = stream; + CHECK_GPU_ERROR(cublasSetStream(_cublasHandle, _stream)); + } + + cudaStream_t get_stream() { return _stream; } + + cublasHandle_t get_cublashandle() { return _cublasHandle; } + + private: + cudaStream_t _stream; + cublasHandle_t _cublasHandle; +}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/cross_entropy_layer.h b/colossalai/kernel/cuda_native/csrc/kernels/include/cross_entropy_layer.h new file mode 100644 index 0000000000000000000000000000000000000000..f4e9befc6588563e04c889da6460bd50ffa5aa56 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/cross_entropy_layer.h @@ -0,0 +1,46 @@ +#pragma once + +#include +#include +#include + +#include + +#include "cuda_util.h" + +template +class CrossEntropyLayer { + public: + CrossEntropyLayer(float epsilon, int padding_idx, int max_batch_tokens); + + virtual ~CrossEntropyLayer(); + + void Forward(const T *inputs_ptr, const int *targets_ptr, float *outputs_ptr, + float *nll_loss_ptr); + + void Backward(const float *grad_outputs_ptr, const T *inputs_ptr, + const int *targets_ptr, T *grad_inputs_ptr); + + void set_cur_batch_shape(int batch_size, int seq_len, int vocab_size); + + private: + void allocate_mem_buffer() { + // allocate local gpu memory + _loss_buffer = cuda_malloc(_max_batch_tokens * 2); + } + + void free_mem_buffer() { + // free local gpu memory + cuda_free(_loss_buffer); + } + + const int _padding_idx; + const float _epsilon; + const int _max_batch_tokens; + + size_t _batch_size; + size_t _seq_len; + size_t _vocab_size; + + float *_loss_buffer; +}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/cublas_wrappers.h b/colossalai/kernel/cuda_native/csrc/kernels/include/cublas_wrappers.h new file mode 100644 index 0000000000000000000000000000000000000000..90255152b2c80412b3cf6e3e73a4de804a91634f --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/cublas_wrappers.h @@ -0,0 +1,41 @@ +/* Copyright 2021 The LightSeq Team + Copyright Microsoft DeepSpeed + This file is adapted from Microsoft DeepSpeed + Licensed under the MIT License. +*/ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa, + cublasOperation_t transb, int m, int n, int k, + const float *alpha, const float *beta, const float *A, + const float *B, float *C, + cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT); + +int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa, + cublasOperation_t transb, int m, int n, int k, + const float *alpha, const float *beta, const __half *A, + const __half *B, __half *C, + cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP); + +int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k, + const float *alpha, const float *beta, + const float *A, const float *B, float *C, + cublasOperation_t op_A, cublasOperation_t op_B, + int stride_A, int stride_B, int stride_C, + int batch, + cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT); + +int cublas_strided_batched_gemm( + cublasHandle_t handle, int m, int n, int k, const float *alpha, + const float *beta, const __half *A, const __half *B, __half *C, + cublasOperation_t op_A, cublasOperation_t op_B, int stride_A, int stride_B, + int stride_C, int batch, + cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h b/colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h new file mode 100644 index 0000000000000000000000000000000000000000..1595257be0f5599f33c4fbcebbaa3047d0c4535f --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h @@ -0,0 +1,34 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +template +void check_gpu_error(T result, char const *const func, const char *const file, + int const line); + +#define CHECK_GPU_ERROR(val) check_gpu_error((val), #val, __FILE__, __LINE__) + +template +void print_vec(const T *outv, std::string outn, int num_output_ele); + +template +T *cuda_malloc(size_t ele_num); + +void cuda_free(void *pdata); + +template +void check_nan_inf(const T *data_ptr, int dsize, bool check_nan_inf, + std::string file, int line, cudaStream_t stream); + +#define CHECK_NAN_INF(ptr, size, stream) \ + check_nan_inf((ptr), (size), true, __FILE__, __LINE__, (stream)); \ + check_nan_inf((ptr), (size), false, __FILE__, __LINE__, (stream)) diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h b/colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h new file mode 100644 index 0000000000000000000000000000000000000000..563a7fe284a305adb65965c365cc600878e7085c --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h @@ -0,0 +1,96 @@ +#pragma once + +#include +#include +#include + +#include + +#include "kernels.h" + +template +class Dropout { + public: + struct Config { + float ratio; + bool training; + + Config(float r) : ratio(r), training(true) {} + float RATIO() const { return training ? ratio : 0.0; } + }; + + Dropout(const Config &config, size_t max_ele_num) + : _config(config), _mask(nullptr) { + _mask = cuda_malloc(max_ele_num); + } + + virtual ~Dropout() { cuda_free(_mask); } + + // after attention softmax + void dropout(T *output, const T *input, int count, cudaStream_t stream, + bool bwd = false) { + launch_ls_dropout(output, input, _mask, count, _config.RATIO(), stream, + bwd); + } + + void d_dropout(T *d_inp_out, int count, cudaStream_t stream) { + launch_ls_dropout(d_inp_out, d_inp_out, _mask, count, _config.RATIO(), + stream, true); + } + + // transformer layer's postprocessing dropout, after attn or ffn module, + // before residual add. + void bias_dropout_residual(T *output, const T *input, const T *residual, + const T *bias, int rows, int cols, + cudaStream_t stream) { + launch_ls_dropout_res_bias(output, input, _mask, bias, residual, + rows * cols, cols, _config.RATIO(), stream); + } + + void d_bias_dropout_residual(T *d_input, T *d_bias, const T *d_output, + int rows, int cols, cudaStream_t stream) { + launch_ls_dropout_bias_bwd(d_input, d_bias, d_output, _mask, rows, cols, + _config.RATIO(), stream); + } + + // dropout inside ffn. + void bias_act_dropout(T *output, const T *input, const T *bias, int rows, + int cols, std::string activation_fn, + cudaStream_t stream) { + if (activation_fn == "relu") { + launch_ls_dropout_act_bias( + output, input, _mask, bias, rows * cols, cols, _config.RATIO(), + stream); + } else if (activation_fn == "gelu") { + launch_ls_dropout_act_bias( + output, input, _mask, bias, rows * cols, cols, _config.RATIO(), + stream); + } else { + throw std::runtime_error("not supported activation: " + activation_fn); + } + } + + void d_bias_act_dropout(T *d_inp_out, T *d_bias_out, const T *input, + const T *bias, int rows, int cols, + std::string activation_fn, cudaStream_t stream) { + if (activation_fn == "relu") { + launch_ls_dropout_act_bias_bwd( + d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols, + _config.RATIO(), stream); + } else if (activation_fn == "gelu") { + launch_ls_dropout_act_bias_bwd( + d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols, + _config.RATIO(), stream); + } else { + throw std::runtime_error("not supported activation: " + activation_fn); + } + } + + bool HasDropout() const { return _config.RATIO() > 0.0; } + + void SetTrainingMode(bool training) { _config.training = training; } + + private: + uint8_t *_mask; + Config _config; +}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h b/colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h new file mode 100644 index 0000000000000000000000000000000000000000..8186da1eed5f996d3cc4b204298ca802f0ac61b4 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h @@ -0,0 +1,69 @@ +#pragma once + +/* Copyright 2021 The LightSeq Team + Copyright Microsoft DeepSpeed + This file is adapted from Microsoft DeepSpeed + Licensed under the MIT License. +*/ +#include +#include +#include + +#include + +#include "cublas_wrappers.h" +#include "kernels.h" + +template +class FeedForward { + public: + struct Config { + int outputSize; + int inputSize; + std::array gemm_algos; + Config(int outputs, int inputs) + : outputSize(outputs), + inputSize(inputs), + gemm_algos(std::array({99, 99, 99})) {} + }; + + FeedForward(Config config) : config_(config) {} + + ~FeedForward() {} + + void Forward(int bsz, const T *input_ptr, const T *weights, T *out, + cublasHandle_t &_cublasHandle) { + float alpha = T(1.); + float beta = T(0.); + + cublas_gemm_ex(_cublasHandle, CUBLAS_OP_T, CUBLAS_OP_N, config_.outputSize, + bsz, config_.inputSize, &alpha, &beta, weights, input_ptr, + out, cublasGemmAlgo_t(config_.gemm_algos[0])); + } + void Backward(int bsz, const T *out_grad, const T *input_ptr, + const T *weights, T *weights_grad, T *bias_grad, + cublasHandle_t &_cublasHandle, cudaStream_t &stream, + T *inp_grad_out = nullptr, T *out_grad_trans_out = nullptr, + bool compute_bias = true) { + float alpha = (T)1.0, beta = (T)0.0; + cublas_gemm_ex(_cublasHandle, CUBLAS_OP_N, CUBLAS_OP_T, config_.inputSize, + config_.outputSize, bsz, &alpha, &beta, input_ptr, out_grad, + weights_grad, cublasGemmAlgo_t(config_.gemm_algos[1])); + + cublas_gemm_ex(_cublasHandle, CUBLAS_OP_N, CUBLAS_OP_N, config_.inputSize, + bsz, config_.outputSize, &alpha, &beta, weights, out_grad, + inp_grad_out, cublasGemmAlgo_t(config_.gemm_algos[2])); + if (compute_bias) { + launch_fuse_transpose_bias_kernel(out_grad, bias_grad, bsz, + config_.outputSize, stream); + } + } + + void reset_size(int outputSize, int inputSize) { + config_.outputSize = outputSize; + config_.inputSize = inputSize; + } + + private: + Config config_; +}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h b/colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h new file mode 100644 index 0000000000000000000000000000000000000000..fbb9c5465c24e98165f72b212854f51211e4e8a1 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h @@ -0,0 +1,270 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#define MAX_THREADS 1024 +#define WARP_SIZE 32 + +enum class ActivationType { kRelu, kGelu }; + +void launch_curand_init(int total_count, int dim, cudaStream_t stream); + +template +void launch_layer_norm(T *ln_res, T *vars, T *means, const T *inp, + const T *scale, const T *bias, int batch_size, + int hidden_dim, cudaStream_t stream); + +template +void launch_ln_bw(T *gamma_grad, T *betta_grad, T *inp_grad, const T *out_grad, + const T *residual_grad, const T *inp_or_out, const T *gamma, + const T *betta, const T *vars, const T *means, int batch, + int hidden_dim, cudaStream_t stream[2]); + +template +void launch_attn_softmax(T *vals, const T *attn_mask, int batch_size, int heads, + int from_len, int to_len, bool mask_future, + cudaStream_t stream); + +template +void launch_attn_softmax_bw(T *out_grad, const T *soft_inp, int rows, + int softmax_len, cudaStream_t stream); + +// [b, s, h] -> [b, nh, s, ad] +template +void launch_transform_0213(T *output, const T *vals, int batch_size, + int seq_length, int hidden_dim, int nhead, + cudaStream_t stream); + +// [b, s, 3, h] -> [3, b, nh, s, ad] +template +void launch_bias_add_transform_20314(T *output, const T *input, const T *bias, + int dim_0, int dim_1, int dim_2, int dim_3, + int dim_4, cudaStream_t stream); + +// [tc, b, nh, s, ad] -> [b, s, tc, nh, ad] +template +void launch_transform4d_0213(T *output, const T *vals, int batch_size, + int seq_len, int hidden_dim, int nhead, + int trans_count, cudaStream_t stream); + +template +void launch_ls_dropout(T *out, const T *vals, uint8_t *mask, int total_count, + float ratio, cudaStream_t stream, bool backward = false); + +template +void launch_ls_dropout_res_bias(T *out, const T *vals, uint8_t *mask, + const T *bias, const T *residual, + int total_count, int dim, float ratio, + cudaStream_t stream); + +template +void launch_ls_dropout_act_bias(T *out, const T *vals, uint8_t *mask, + const T *bias, int total_count, int dim, + float ratio, cudaStream_t stream); + +template +void launch_ls_dropout_bias_bwd(T *in_grad, T *bias_grad, const T *out_grad, + const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream); + +template +void launch_ls_dropout_act_bias_bwd(T *in_grad, T *bias_grad, const T *input, + const T *bias, const T *out_grad, + const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream); + +template +void launch_fuse_transpose_bias_kernel(const T *inp, T *out, int rows, int cols, + cudaStream_t stream); + +void launch_param_update(const float *input, __half *output, int size, + cudaStream_t stream); + +template +void launch_concat3_dim1(const T *inp1, const T *inp2, T *output, int sz0, + int sz2, int sz1_1, int sz1_2, cudaStream_t stream); + +template +void launch_fused_add2(T *out, const T *inp1, const T *inp2, int batch_size, + int seq_len, int hidden_size, cudaStream_t &stream); + +template +void launch_cross_entropy_fw(const T *inputs_ptr, const int *targets_ptr, + float *outputs_ptr, float *nll_loss_ptr, + float *loss_buffer, const int padding_idx, + const float epsilon, const int batch_size, + const int seq_len, const int vocab_size, + cudaStream_t stream); + +template +void launch_cross_entropy_bw(const float *grad_outputs_ptr, const T *inputs_ptr, + const int *targets_ptr, T *grad_inputs_ptr, + const int padding_idx, const float epsilon, + const int batch_size, const int seq_len, + const int vocab_size, cudaStream_t stream); + +template +void launch_lookup_scale_pos_dropout( + T *output, const int *input, const T *embeddings, const T *pos_embeddings, + uint8_t *dropout_mask, int batch_size, int seq_len, int embedding_dim, + int padding_idx, float dropout_ratio, int step, cudaStream_t &stream); + +template +void launch_d_lookup_scale_pos_dropout( + T *grad_embeddings, const T *grad_output, const int *input, + const uint8_t *dropout_mask, int batch_size, int seq_len, int embedding_dim, + int vocab_size, int padding_idx, float dropout_ratio, cudaStream_t &stream); + +/* Convert 2-dim tensor index into vector index */ +__forceinline__ __host__ __device__ int flat_2dim(int id1, int id2, int dim2) { + return id1 * dim2 + id2; +} + +/* Convert 3-dim tensor index into vector index */ +__forceinline__ __host__ __device__ int flat_3dim(int id1, int id2, int id3, + int dim2, int dim3) { + return id1 * dim2 * dim3 + id2 * dim3 + id3; +} + +/* Convert 4-dim tensor index into vector index */ +__forceinline__ __host__ __device__ int +flat_4dim(int id1, int id2, int id3, int id4, int dim2, int dim3, int dim4) { + // return id1*(dim2*dim3*dim4) + id2*(dim3*dim4) + id3*dim4 + id4; + int res = id4; + + int ld = dim4; + res += id3 * ld; + + ld *= dim3; + res += id2 * ld; + + ld *= dim2; + res += id1 * ld; + + return res; +} + +/* Convert 5-dim tensor index into vector index */ +__forceinline__ __host__ __device__ int flat_5dim(int id1, int id2, int id3, + int id4, int id5, int dim2, + int dim3, int dim4, + int dim5) { + // return id1*(dim2*dim3*dim4*dim5) + id2*(dim3*dim4*dim5) + id3*(dim4*dim5) + + // id4*dim5 + dim5; + int res = id5; + + int ld = dim5; + res += id4 * ld; + + ld *= dim4; + res += id3 * ld; + + ld *= dim3; + res += id2 * ld; + + ld *= dim2; + res += id1 * ld; + + return res; +} + +/* Convert 6-dim tensor index into vector index */ +__forceinline__ __host__ __device__ int flat_6dim(int id1, int id2, int id3, + int id4, int id5, int id6, + int dim2, int dim3, int dim4, + int dim5, int dim6) { + // return id1*(dim2*dim3*dim4*dim5*dim6) + id2*(dim3*dim4*dim5*dim6) + + // id3*(dim4*dim5*dim6) + id4*(dim5*dim6) + id5*dim6 + id6; + int res = id6; + + int ld = dim6; + res += id5 * ld; + + ld *= dim5; + res += id4 * ld; + + ld *= dim4; + res += id3 * ld; + + ld *= dim3; + res += id2 * ld; + + ld *= dim2; + res += id1 * ld; + + return res; +} + +/* Convert vector index to 6-dim tensor index */ +__forceinline__ __host__ __device__ void +decompose_6dim(int src, int dim1, int dim2, int dim3, int dim4, int dim5, + int *id0, int *id1, int *id2, int *id3, int *id4, int *id5) { + *id5 = src % dim5; + src /= dim5; + + *id4 = src % dim4; + src /= dim4; + + *id3 = src % dim3; + src /= dim3; + + *id2 = src % dim2; + src /= dim2; + + *id1 = src % dim1; + *id0 = src / dim1; +} + +/* Convert vector index to 5-dim tensor index */ +__forceinline__ __host__ __device__ void +decompose_5dim(int src, int dim1, int dim2, int dim3, int dim4, int *id0, + int *id1, int *id2, int *id3, int *id4) { + *id4 = src % dim4; + src /= dim4; + + *id3 = src % dim3; + src /= dim3; + + *id2 = src % dim2; + src /= dim2; + + *id1 = src % dim1; + *id0 = src / dim1; +} + +/* Convert vector index to 4-dim tensor index */ +__forceinline__ __host__ __device__ void decompose_4dim(int src, int dim1, + int dim2, int dim3, + int *id0, int *id1, + int *id2, int *id3) { + *id3 = src % dim3; + src /= dim3; + + *id2 = src % dim2; + src /= dim2; + + *id1 = src % dim1; + *id0 = src / dim1; +} + +/* Convert vector index to 3-dim tensor index */ +__forceinline__ __host__ __device__ void +decompose_3dim(int src, int dim1, int dim2, int *id0, int *id1, int *id2) { + *id2 = src % dim2; + src /= dim2; + + *id1 = src % dim1; + *id0 = src / dim1; +} + +/* Convert vector index to 2-dim tensor index */ +__forceinline__ __host__ __device__ void decompose_2dim(int src, int dim1, + int *id0, int *id1) { + *id1 = src % dim1; + *id0 = src / dim1; +} diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/ls_cub.cuh b/colossalai/kernel/cuda_native/csrc/kernels/include/ls_cub.cuh new file mode 100644 index 0000000000000000000000000000000000000000..4f65e7b54ba19e9520e19d969bebbe4e5d43c266 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/ls_cub.cuh @@ -0,0 +1,12 @@ +// copied from https://github.com/dmlc/dgl/pull/2758 +#ifndef DGL_ARRAY_CUDA_DGL_CUB_CUH_ +#define DGL_ARRAY_CUDA_DGL_CUB_CUH_ + +#define CUB_NS_PREFIX namespace ls { +#define CUB_NS_POSTFIX } +#include "cub/cub.cuh" +#include "cub/util_allocator.cuh" +#undef CUB_NS_POSTFIX +#undef CUB_NS_PREFIX + +#endif diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/normalize_layer.h b/colossalai/kernel/cuda_native/csrc/kernels/include/normalize_layer.h new file mode 100644 index 0000000000000000000000000000000000000000..ded5c0fdcbeee6f02d29544a577021f2407976d7 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/normalize_layer.h @@ -0,0 +1,64 @@ +#pragma once + +#include +#include +#include + +#include + +#include "kernels.h" + +using namespace std; + +template class Normalize_Layer { +public: + struct Config { + uint32_t hidden_dim; + bool use_mean; + Config(uint32_t hidden_dim, bool use_mean = false) + : hidden_dim(hidden_dim), use_mean(use_mean) {} + }; + + Normalize_Layer(Config config, size_t max_rows) + : config_(config), vars_(nullptr), means_(nullptr) { + vars_ = cuda_malloc(max_rows); + if (config_.use_mean) { + means_ = cuda_malloc(max_rows); + } + } + + ~Normalize_Layer() { + cuda_free(vars_); + cuda_free(means_); + } + + void Forward(T *ln_res, const T *inp, const T *gamma, const T *betta, + int batch_size, cudaStream_t stream) { + launch_layer_norm(ln_res, vars_, means_, inp, gamma, betta, batch_size, + config_.hidden_dim, stream); + } + + /* + residual_grad, inp_or_out, betta should be treated carefully. + inp_or_out = input if use_mean else output + residual_grad, betta can be nullptr. + residual_grad will be added to dinp if it is not nullptr + which is useful in transformer layer when pre-ln + betta are only used to compute xhat, + (use_mean == false) ^ (betta == nullptr) should be true + */ + void Backward(T *gamma_grad, T *betta_grad, T *inp_grad, const T *out_grad, + const T *residual_grad, const T *inp_or_out, const T *gamma, + const T *betta, int batch_size, cudaStream_t stream[2]) { + launch_ln_bw(gamma_grad, betta_grad, inp_grad, out_grad, residual_grad, + inp_or_out, gamma, betta, vars_, means_, batch_size, + config_.hidden_dim, stream); + } + + inline bool use_mean() const { return config_.use_mean; } + +private: + Config config_; + T *vars_; + T *means_; +}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h b/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h new file mode 100644 index 0000000000000000000000000000000000000000..ec447ad84c54614839ef14838e23898191e97129 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h @@ -0,0 +1,42 @@ +#pragma once + +#include +#include +#include + +#include + +#include "kernels.h" + +using namespace std; + +template +class Softmax { + public: + struct Config { + size_t nhead; + Config(size_t nhead) : nhead(nhead) {} + }; + + Softmax(Config config) : config_(config) {} + + ~Softmax() {} + + void Forward(T *vals, const T *attn_mask, int batch_size, int from_len, + int to_len, cudaStream_t &stream, bool mask_future = true) { + launch_attn_softmax(vals, attn_mask, batch_size, config_.nhead, from_len, + to_len, mask_future, stream); + } + + void Backward(T *out_grad, const T *soft_out, int batch_size, int from_len, + int to_len, cudaStream_t stream) { + launch_attn_softmax_bw(out_grad, soft_out, + batch_size * config_.nhead * from_len, to_len, + stream); + } + + void reset_size(size_t nhead) { config_.nhead = nhead; } + + private: + Config config_; +}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/strided_batch_gemm.h b/colossalai/kernel/cuda_native/csrc/kernels/include/strided_batch_gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..d386650e82356e261ceca179a4681813230762a2 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/strided_batch_gemm.h @@ -0,0 +1,100 @@ +/* Copyright 2021 The LightSeq Team + Copyright Microsoft DeepSpeed + This file is adapted from Microsoft DeepSpeed + Licensed under the MIT License. +*/ +#pragma once + +#include +#include +#include + +#include + +#include "cublas_wrappers.h" + +template +class StridedBatchGemm { + public: + struct Config { + int m; + int n; + int k; + float alpha; + float beta; + cublasOperation_t op_A; + cublasOperation_t op_B; + std::array gemm_algos; + + Config(float param_alpha, float param_beta, cublasOperation_t opA, + cublasOperation_t opB) + : alpha(param_alpha), + beta(param_beta), + op_A(opA), + op_B(opB), + gemm_algos(std::array({99, 99, 99})) {} + void SetConfig(int mm, int nn, int kk) { + m = mm; + n = nn; + k = kk; + } + }; + + StridedBatchGemm(const Config &config) : _config(config) {} + + virtual ~StridedBatchGemm() {} + + void Forward(int bsz, T *output, const T *_buffer_a, const T *_buffer_b, + cublasHandle_t handle) { + int stride_a = _config.m * _config.k; + int stride_b = _config.n * _config.k; + int stride_c = _config.m * _config.n; + + cublas_strided_batched_gemm( + handle, _config.m, _config.n, _config.k, &_config.alpha, &_config.beta, + _buffer_a, _buffer_b, output, _config.op_A, _config.op_B, stride_a, + stride_b, stride_c, bsz, cublasGemmAlgo_t(_config.gemm_algos[0])); + } + + void Backward(int bsz, const T *d_output, const T *_buffer_a, + const T *_buffer_b, cublasHandle_t handle, + T *inpGradA = nullptr, T *inpGradB = nullptr) { + int mb = (_config.op_A == CUBLAS_OP_T ? _config.k : _config.m); + int kb = (_config.op_A == CUBLAS_OP_T ? _config.m : _config.k); + + int stride_a = mb * _config.n; + int stride_b = _config.n * kb; + int stride_c = _config.m * _config.k; + + // B need to transpose. + cublasOperation_t op_b = + (_config.op_B == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T); + + // Calculate d_A. + cublas_strided_batched_gemm( + handle, mb, kb, _config.n, &_config.alpha, &_config.beta, + (_config.op_A == CUBLAS_OP_T ? _buffer_b : d_output), + (_config.op_A == CUBLAS_OP_T ? d_output : _buffer_b), inpGradA, + CUBLAS_OP_N, op_b, stride_a, stride_b, stride_c, bsz, + cublasGemmAlgo_t(_config.gemm_algos[1])); + + // A need to transpose. + cublasOperation_t op_a = + (_config.op_A == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T); + + stride_a = _config.m * _config.k; + stride_b = _config.m * _config.n; + stride_c = _config.n * _config.k; + + // Calculate d_B. + cublas_strided_batched_gemm( + handle, _config.k, _config.n, _config.m, &_config.alpha, &_config.beta, + _buffer_a, d_output, inpGradB, op_a, CUBLAS_OP_N, stride_a, stride_b, + stride_c, bsz, cublasGemmAlgo_t(_config.gemm_algos[2])); + } + + inline void SetConfig(int m, int n, int k) { _config.SetConfig(m, n, k); } + + private: + Config _config; +}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu new file mode 100644 index 0000000000000000000000000000000000000000..3e61d4e35832cb9b2f90ebc9b64d8398f6d6b9e6 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu @@ -0,0 +1,1169 @@ +#include "block_reduce.h" +#include "kernels.h" +#include + +namespace cg = cooperative_groups; +const float LN_EPSILON = 1e-8f; +#define TILE_DIM 32 + +template __forceinline__ __device__ T add_eps(T x) { + return fabsf(x) > LN_EPSILON ? x : (x < 0 ? -LN_EPSILON : LN_EPSILON); +} + +/** +@brief: ker_layer_norm +Standard layer normalization. +It will not only output the layer norm result, + but also outputs variance. + may also output means, depends on whether + the means argument is nullptr + +@thread +gridDim.x = batch_size * seq_len +blockDim.x = hidden_size + +@param +ln_res: [batch_size* seq_len, hidden_size], ln result. +vars: [batch_size* seq_len], variance per token +means: [batch_size* seq_len], means per token, can be nullput +inp: [batch_size * seq_len, hidden_size], ln input. +scale: [hidden_size], ln scale +bias: [hidden_size], ln bias +*/ +template +__global__ void ker_layer_norm(T *ln_res, T *vars, T *means, const T *inp, + const T *scale, const T *bias, int hidden_size) { + // step 0. compute local sum + float l_sum = 0; + float l_square_sum = 0; + const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size; + for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float4 val = inp_f4[idx]; + l_sum += val.x + val.y + val.z + val.w; + l_square_sum += + val.x * val.x + val.y * val.y + val.z * val.z + val.w * val.w; + } + + // step 1. compute reduce sum + float mean_dim = float(hidden_size) * 4.f; + float reduce_val[2] = {l_sum, l_square_sum}; + blockReduce(reduce_val); + __shared__ float s_mean, s_var; + if (threadIdx.x == 0) { + s_mean = reduce_val[0] / mean_dim; + if (means != nullptr) { + means[blockIdx.x] = s_mean; + } + s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; + vars[blockIdx.x] = s_var; + s_var = rsqrtf(s_var); + } + __syncthreads(); + + // step 2. layer norm result + float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size; + for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float4 vscale = __ldg((const float4 *)scale + idx); + float4 vbias = __ldg((const float4 *)bias + idx); + float4 val = inp_f4[idx]; + val.x = (val.x - s_mean) * s_var * vscale.x + vbias.x; + val.y = (val.y - s_mean) * s_var * vscale.y + vbias.y; + val.z = (val.z - s_mean) * s_var * vscale.z + vbias.z; + val.w = (val.w - s_mean) * s_var * vscale.w + vbias.w; + output_f4[idx] = val; + } +} + +template <> +__global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars, + __half *means, const __half *inp, + const __half *scale, const __half *bias, + int hidden_size) { + // step 0. compute local sum + float l_sum = 0; + float l_square_sum = 0; + const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size; + for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float4 val_f4 = inp_f4[idx]; + __half2 *val_h2 = (__half2 *)(&val_f4); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 val_f2 = __half22float2(val_h2[i]); + l_sum += val_f2.x + val_f2.y; + l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y; + } + } + + // step 1. compute reduce sum + float mean_dim = float(hidden_size) * 8.f; + float reduce_val[2] = {l_sum, l_square_sum}; + blockReduce(reduce_val); + __shared__ float s_mean, s_var; + if (threadIdx.x == 0) { + s_mean = reduce_val[0] / mean_dim; + if (means != nullptr) { + means[blockIdx.x] = s_mean; + } + s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; + vars[blockIdx.x] = s_var; + s_var = rsqrtf(s_var); + } + __syncthreads(); + + // step 2. layer norm result + float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size; + for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + // load scale, bias, input + float4 scale_f4 = __ldg((const float4 *)scale + idx); + __half2 *scale_h2 = (__half2 *)(&scale_f4); + float4 bias_f4 = __ldg((const float4 *)bias + idx); + __half2 *bias_h2 = (__half2 *)(&bias_f4); + float4 val_f4 = inp_f4[idx]; + __half2 *val_h2 = (__half2 *)(&val_f4); + +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 scale_f2 = __half22float2(scale_h2[i]); + float2 bias_f2 = __half22float2(bias_h2[i]); + float2 val_f2 = __half22float2(val_h2[i]); + val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; + val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; + val_h2[i] = __float22half2_rn(val_f2); + } + output_f4[idx] = val_f4; + } +} + +// __global__ void ker_layer_norm_x2(__half *ln_res, __half *vars, +// __half *means, const __half *inp, +// const __half *scale, const __half +// *bias, int hidden_size) { +// // step 0. compute local sum +// float l_sum = 0; +// float l_square_sum = 0; +// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * 2 * hidden_size; +// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * +// 2) { +// float4 val_f4 = inp_f4[idx]; +// float4 val_f4_1 = inp_f4[idx+1]; +// __half2 *val_h2 = (__half2 *)(&val_f4); +// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); +// #pragma unroll +// for (int i = 0; i < 4; i++) { +// float2 val_f2 = __half22float2(val_h2[i]); +// float2 val_f2_1 = __half22float2(val_h2_1[i]); +// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y; +// l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y + val_f2_1.x +// * val_f2_1.x + val_f2_1.y * val_f2_1.y; +// } +// } + +// // step 1. compute reduce sum +// float mean_dim = float(hidden_size) * 8.f * 2; +// float reduce_val[2] = {l_sum, l_square_sum}; +// blockReduce(reduce_val); +// __shared__ float s_mean, s_var; +// if (threadIdx.x == 0) { +// s_mean = reduce_val[0] / mean_dim; +// if (means != nullptr) { +// means[blockIdx.x] = s_mean; +// } +// s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; +// vars[blockIdx.x] = s_var; +// s_var = rsqrtf(s_var); +// } +// __syncthreads(); + +// // step 2. layer norm result +// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 2; +// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * +// 2) { +// // load scale, bias, input +// float4 scale_f4 = __ldg((const float4 *)scale + idx); +// __half2 *scale_h2 = (__half2 *)(&scale_f4); +// float4 scale_f4_1 = __ldg((const float4 *)scale + idx + 1); +// __half2 *scale_h2_1 = (__half2 *)(&scale_f4_1); +// float4 bias_f4 = __ldg((const float4 *)bias + idx); +// __half2 *bias_h2 = (__half2 *)(&bias_f4); +// float4 bias_f4_1 = __ldg((const float4 *)bias + idx + 1); +// __half2 *bias_h2_1 = (__half2 *)(&bias_f4_1); +// float4 val_f4 = inp_f4[idx]; +// __half2 *val_h2 = (__half2 *)(&val_f4); +// float4 val_f4_1 = inp_f4[idx+1]; +// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); + +// #pragma unroll +// for (int i = 0; i < 4; i++) { +// float2 scale_f2 = __half22float2(scale_h2[i]); +// float2 scale_f2_1 = __half22float2(scale_h2_1[i]); +// float2 bias_f2 = __half22float2(bias_h2[i]); +// float2 bias_f2_1 = __half22float2(bias_h2_1[i]); +// float2 val_f2 = __half22float2(val_h2[i]); +// float2 val_f2_1 = __half22float2(val_h2_1[i]); +// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; +// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; +// val_h2[i] = __float22half2_rn(val_f2); +// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + +// bias_f2_1.x; val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y +// + bias_f2_1.y; val_h2_1[i] = __float22half2_rn(val_f2_1); +// } +// output_f4[idx] = val_f4; +// output_f4[idx+1] = val_f4_1; +// } +// } + +// __global__ void ker_layer_norm_x4(__half *ln_res, __half *vars, +// __half *means, const __half *inp, +// const __half *scale, const __half +// *bias, int hidden_size) { +// // step 0. compute local sum +// float l_sum = 0; +// float l_square_sum = 0; +// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size * 4; +// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * +// 4) { +// float4 val_f4 = inp_f4[idx]; +// float4 val_f4_1 = inp_f4[idx+1]; +// float4 val_f4_2 = inp_f4[idx+2]; +// float4 val_f4_3 = inp_f4[idx+3]; +// __half2 *val_h2 = (__half2 *)(&val_f4); +// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); +// __half2 *val_h2_2 = (__half2 *)(&val_f4_2); +// __half2 *val_h2_3 = (__half2 *)(&val_f4_3); +// #pragma unroll +// for (int i = 0; i < 4; i++) { +// float2 val_f2 = __half22float2(val_h2[i]); +// float2 val_f2_1 = __half22float2(val_h2_1[i]); +// float2 val_f2_2 = __half22float2(val_h2_2[i]); +// float2 val_f2_3 = __half22float2(val_h2_3[i]); +// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y + val_f2_2.x + +// val_f2_2.y + val_f2_3.x + val_f2_3.y; l_square_sum += val_f2.x * +// val_f2.x + val_f2.y * val_f2.y; l_square_sum += val_f2_1.x * val_f2_1.x +// + val_f2_1.y * val_f2_1.y; l_square_sum += val_f2_2.x * val_f2_2.x + +// val_f2_2.y * val_f2_2.y; l_square_sum += val_f2_3.x * val_f2_3.x + +// val_f2_3.y * val_f2_3.y; +// } +// } + +// // step 1. compute reduce sum +// float mean_dim = float(hidden_size) * 8.f * 4; +// float reduce_val[2] = {l_sum, l_square_sum}; +// blockReduce(reduce_val); +// __shared__ float s_mean, s_var; +// if (threadIdx.x == 0) { +// s_mean = reduce_val[0] / mean_dim; +// if (means != nullptr) { +// means[blockIdx.x] = s_mean; +// } +// s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; +// vars[blockIdx.x] = s_var; +// s_var = rsqrtf(s_var); +// } +// __syncthreads(); + +// // step 2. layer norm result +// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 4; +// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * +// 4) { +// // load scale, bias, input +// float4 scale_f4 = __ldg((const float4 *)scale + idx); +// __half2 *scale_h2 = (__half2 *)(&scale_f4); +// float4 scale_f4_1 = __ldg((const float4 *)scale + idx + 1); +// __half2 *scale_h2_1 = (__half2 *)(&scale_f4_1); +// float4 scale_f4_2 = __ldg((const float4 *)scale + idx + 2); +// __half2 *scale_h2_2 = (__half2 *)(&scale_f4_2); +// float4 scale_f4_3 = __ldg((const float4 *)scale + idx + 3); +// __half2 *scale_h2_3 = (__half2 *)(&scale_f4_3); +// float4 bias_f4 = __ldg((const float4 *)bias + idx); +// __half2 *bias_h2 = (__half2 *)(&bias_f4); +// float4 bias_f4_1 = __ldg((const float4 *)bias + idx + 1); +// __half2 *bias_h2_1 = (__half2 *)(&bias_f4_1); +// float4 bias_f4_2 = __ldg((const float4 *)bias + idx + 2); +// __half2 *bias_h2_2 = (__half2 *)(&bias_f4_2); +// float4 bias_f4_3 = __ldg((const float4 *)bias + idx + 3); +// __half2 *bias_h2_3 = (__half2 *)(&bias_f4_3); +// float4 val_f4 = inp_f4[idx]; +// __half2 *val_h2 = (__half2 *)(&val_f4); +// float4 val_f4_1 = inp_f4[idx+1]; +// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); +// float4 val_f4_2 = inp_f4[idx+2]; +// __half2 *val_h2_2 = (__half2 *)(&val_f4_2); +// float4 val_f4_3 = inp_f4[idx+3]; +// __half2 *val_h2_3 = (__half2 *)(&val_f4_3); + +// #pragma unroll +// for (int i = 0; i < 4; i++) { +// float2 scale_f2 = __half22float2(scale_h2[i]); +// float2 scale_f2_1 = __half22float2(scale_h2_1[i]); +// float2 scale_f2_2 = __half22float2(scale_h2_2[i]); +// float2 scale_f2_3 = __half22float2(scale_h2_3[i]); +// float2 bias_f2 = __half22float2(bias_h2[i]); +// float2 bias_f2_1 = __half22float2(bias_h2_1[i]); +// float2 bias_f2_2 = __half22float2(bias_h2_2[i]); +// float2 bias_f2_3 = __half22float2(bias_h2_3[i]); +// float2 val_f2 = __half22float2(val_h2[i]); +// float2 val_f2_1 = __half22float2(val_h2_1[i]); +// float2 val_f2_2 = __half22float2(val_h2_2[i]); +// float2 val_f2_3 = __half22float2(val_h2_3[i]); +// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; +// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; +// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + +// bias_f2_1.x; val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y +// + bias_f2_1.y; val_f2_2.x = (val_f2_2.x - s_mean) * s_var * +// scale_f2_2.x + bias_f2_2.x; val_f2_2.y = (val_f2_2.y - s_mean) * s_var +// * scale_f2_2.y + bias_f2_2.y; val_f2_3.x = (val_f2_3.x - s_mean) * +// s_var * scale_f2_3.x + bias_f2_3.x; val_f2_3.y = (val_f2_3.y - s_mean) +// * s_var * scale_f2_3.y + bias_f2_3.y; val_h2[i] = +// __float22half2_rn(val_f2); val_h2_1[i] = __float22half2_rn(val_f2_1); +// val_h2_2[i] = __float22half2_rn(val_f2_2); +// val_h2_3[i] = __float22half2_rn(val_f2_3); +// } +// output_f4[idx] = val_f4; +// output_f4[idx+1] = val_f4_1; +// output_f4[idx+2] = val_f4_2; +// output_f4[idx+3] = val_f4_3; +// } +// } + +template <> +void launch_layer_norm(float *ln_res, float *vars, float *means, + const float *inp, const float *scale, + const float *bias, int batch_size, int hidden_dim, + cudaStream_t stream) { + if (hidden_dim % 4 != 0) { + throw std::runtime_error("violate hidden_dim % 4 = 0"); + } + hidden_dim >>= 2; + int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + dim3 grid_dim(batch_size); + dim3 block_dim(nthread); + + ker_layer_norm<<>>( + ln_res, vars, means, inp, scale, bias, hidden_dim); +} + +template <> +void launch_layer_norm<__half>(__half *ln_res, __half *vars, __half *means, + const __half *inp, const __half *scale, + const __half *bias, int batch_size, + int hidden_dim, cudaStream_t stream) { + if (hidden_dim % 8 != 0) { + throw std::runtime_error("violate hidden_dim % 8 = 0"); + } + hidden_dim >>= 3; + int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + dim3 grid_dim(batch_size); + dim3 block_dim(nthread); + + ker_layer_norm<__half><<>>( + ln_res, vars, means, inp, scale, bias, hidden_dim); + // if (hidden_dim % 8 != 0) { + // throw std::runtime_error("violate hidden_dim % 8 = 0"); + // } + // hidden_dim >>= 3; + + // if (hidden_dim * 8 < 8192) { + // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + // dim3 grid_dim(batch_size); + // dim3 block_dim(nthread); + // ker_layer_norm<__half><<>>( + // ln_res, vars, means, inp, scale, bias, hidden_dim); + // } else if (hidden_dim * 8 >= 8192 && hidden_dim * 8 <= 8192 * 2) { + // hidden_dim >>= 1; + // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + // dim3 grid_dim(batch_size); + // dim3 block_dim(nthread); + // ker_layer_norm_x2<<>>( + // ln_res, vars, means, inp, scale, bias, hidden_dim); + // } else if (hidden_dim * 8 > 8192 * 2 && hidden_dim * 8 <= 8192 * 4) { + // hidden_dim >>= 2; + // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + // dim3 grid_dim(batch_size); + // dim3 block_dim(nthread); + // ker_layer_norm_x4<<>>( + // ln_res, vars, means, inp, scale, bias, hidden_dim); + // } else { + // throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 32768"); + // } +} + +/** +@brief: ker_ln_bw_dgamma_dbetta +Layer norm backword kernel, compute the gradient of gamma and betta. +dbetta = sum(dout, dim=0) +dgamma = sum(xhat * dout, dim=0) +xhat = (input - mean) * rsqrt(var) or + (output - betta) / gamma + + +@thread +gridDim.x = hidden_size / 32 +blockDim.x = 32 +blockDim.y = 32 + +@param +gamma_grad: [hidden_size], gradient of gamma +betta_grad: [hidden_size], gradient of betta +out_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output +inp_or_out: [batch_size * seq_len, hidden_size], ln output if means is nullptr + ln input if means is not nullptr +gamma: [hidden_size], gamma of ln, + used to compute xhat, maybe nullptr +betta: [hidden_size], betta of ln, + used to compute xhat, maybe nullptr +vars: [batch_size * seq_len], variance of ln forward, + used to compute xhat, maybe nullptr +means: [batch_size * seq_len], mean of ln forward, + used to compute xhat, maybe nullptr +(gamma && betta) ^ (vars && means) should be true +*/ +template +__global__ void +ker_ln_bw_dgamma_dbetta(T *gamma_grad, T *betta_grad, const T *out_grad, + const T *inp_or_out, const T *gamma, const T *betta, + const T *vars, const T *means, int rows, int width) { + __shared__ float betta_buffer[TILE_DIM][TILE_DIM]; + __shared__ float gamma_buffer[TILE_DIM][TILE_DIM]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int idx = blockDim.x * blockIdx.x + threadIdx.x; + int offset = threadIdx.y * width + idx; + int y_stride = width * TILE_DIM; + + // Loop across inp height + float dbetta = 0; + float dgamma = 0; + float dout, val; + if (idx < width) { + if (means == nullptr) { + float vbetta = (float)betta[idx]; + float vgamma = (float)gamma[idx]; + for (int r = threadIdx.y; r < rows; r += TILE_DIM) { + dout = (float)out_grad[offset]; + // inp_or_out is output + val = (float)inp_or_out[offset]; + dbetta += dout; + dgamma += ((val - vbetta) / add_eps(vgamma) * dout); + offset += y_stride; + } + } else { + for (int r = threadIdx.y; r < rows; r += TILE_DIM) { + dout = (float)out_grad[offset]; + // inp_or_out is input + val = (float)inp_or_out[offset]; + dbetta += dout; + dgamma += ((val - (float)means[r]) * + rsqrtf((float)vars[r] + LN_EPSILON) * dout); + offset += y_stride; + } + } + } + + // Sum the shared buffer. + betta_buffer[threadIdx.x][threadIdx.y] = dbetta; + gamma_buffer[threadIdx.x][threadIdx.y] = dgamma; + __syncthreads(); + float s1 = betta_buffer[threadIdx.y][threadIdx.x]; + float s2 = gamma_buffer[threadIdx.y][threadIdx.x]; + __syncthreads(); + + for (int i = 1; i < TILE_DIM; i <<= 1) { + s1 += g.shfl_down(s1, i); + s2 += g.shfl_down(s2, i); + } + + int pos = blockIdx.x * TILE_DIM + threadIdx.y; + if (threadIdx.x == 0 && idx < width) { + betta_grad[pos] = s1; + gamma_grad[pos] = s2; + } +} + +/** +@brief: ker_ln_bw_dinp +Layer norm backword kernel, compute the gradient of input. +dinp = (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / hidden_dim) + * rsqrt(var) +xhat = (input - mean) * rsqrt(var) if mean is not nullptr + (output - betta) / gamma if mean is nullptr +dxhat = dout * gamma + + +@thread +gridDim.x = batch_size * seq_len +blockDim.x = hidden_size + +@param +inp_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output +out_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output +residual_grad: [batch_size * seq_len, hidden_size], gradient of residual input, + usually appear in pre-layer-norm for transformer layer, maybe nullptr +inp_or_out: [batch_size * seq_len, hidden_size], ln output if means is nullptr + ln input if means is not nullptr +gamma: [hidden_size], gamma of ln, + used to compute xhat and dxhat +betta: [hidden_size], betta of ln, + used to compute xhat, maybe nullptr +vars: [batch_size * seq_len], variance of ln forward, + used to compute xhat and dinp +means: [batch_size * seq_len], mean of ln forward, + used to compute xhat, maybe nullptr +*/ +template +__global__ void ker_ln_bw_dinp(T *inp_grad, const T *out_grad, + const T *residual_grad, const T *inp_or_out, + const T *gamma, const T *betta, const T *vars, + const T *means, int hidden_dim) { + int offset = blockIdx.x * hidden_dim + threadIdx.x; + float4 dxhat, xhat; + float var_rsqrt; + + if (threadIdx.x < hidden_dim) { + // step 0. dxhat = dout * gamma + dxhat = ((const float4 *)out_grad)[offset]; + float4 vgamma = ((const float4 *)gamma)[threadIdx.x]; + dxhat.x *= vgamma.x; + dxhat.y *= vgamma.y; + dxhat.z *= vgamma.z; + dxhat.w *= vgamma.w; + + /* + step 1. xhat = (output - betta) / gamma or + (input - mean) * rsqrtf(var) + */ + xhat = ((const float4 *)inp_or_out)[offset]; + var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); + if (means == nullptr) { + // inp_or_out is output, xhat = (output - betta) / gamma + float4 vbetta = ((const float4 *)betta)[threadIdx.x]; + xhat.x = (xhat.x - vbetta.x) / add_eps(vgamma.x); + xhat.y = (xhat.y - vbetta.y) / add_eps(vgamma.y); + xhat.z = (xhat.z - vbetta.z) / add_eps(vgamma.z); + xhat.w = (xhat.w - vbetta.w) / add_eps(vgamma.w); + } else { + // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) + float fmean = (float)means[blockIdx.x]; + xhat.x = (xhat.x - fmean) * var_rsqrt; + xhat.y = (xhat.y - fmean) * var_rsqrt; + xhat.z = (xhat.z - fmean) * var_rsqrt; + xhat.w = (xhat.w - fmean) * var_rsqrt; + } + } + + /* step2. block reduce sum for dxhat and dxhat*xhat */ + float reduce_val[2] = {0.f, 0.f}; + if (threadIdx.x < hidden_dim) { + reduce_val[0] = dxhat.x + dxhat.y + dxhat.z + dxhat.w; + reduce_val[1] = dxhat.x * xhat.x + dxhat.y * xhat.y + dxhat.z * xhat.z + + dxhat.w * xhat.w; + } + blockReduce(reduce_val); + __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; + if (threadIdx.x == 0) { + float mean_dim = hidden_dim * 4; + s_sum_dxhat = reduce_val[0] / mean_dim; + s_sum_dxhat_xhat = reduce_val[1] / mean_dim; + } + __syncthreads(); + + /* + step3. compute input gradient + (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) + */ + if (threadIdx.x >= hidden_dim) { + return; + } + dxhat.x = (dxhat.x - s_sum_dxhat - xhat.x * s_sum_dxhat_xhat) * var_rsqrt; + dxhat.y = (dxhat.y - s_sum_dxhat - xhat.y * s_sum_dxhat_xhat) * var_rsqrt; + dxhat.z = (dxhat.z - s_sum_dxhat - xhat.z * s_sum_dxhat_xhat) * var_rsqrt; + dxhat.w = (dxhat.w - s_sum_dxhat - xhat.w * s_sum_dxhat_xhat) * var_rsqrt; + if (residual_grad) { + // Add the residual grad, + // usually in pre-layer-norm for transformer layer + float4 dresidual = ((const float4 *)residual_grad)[offset]; + dxhat.x += dresidual.x; + dxhat.y += dresidual.y; + dxhat.z += dresidual.z; + dxhat.w += dresidual.w; + } + ((float4 *)inp_grad)[offset] = dxhat; +} + +template <> +__global__ void ker_ln_bw_dinp<__half>(__half *inp_grad, const __half *out_grad, + const __half *residual_grad, + const __half *inp_or_out, + const __half *gamma, const __half *betta, + const __half *vars, const __half *means, + int hidden_dim) { + int offset = blockIdx.x * hidden_dim + threadIdx.x; + + float2 dxhat[4], xhat[4]; + float var_rsqrt; + float4 vtmp; + __half2 *tmp_h2; + float reduce_val[2] = {0.f, 0.f}; + + if (threadIdx.x < hidden_dim) { + // step 0. dxhat = dout * gamma + vtmp = ((const float4 *)out_grad)[offset]; + tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); + float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x]; + __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vdout = __half22float2(tmp_h2[i]); + float2 vgamma = __half22float2(gamma_h2[i]); + dxhat[i].x = vdout.x * vgamma.x; + dxhat[i].y = vdout.y * vgamma.y; + reduce_val[0] += dxhat[i].x + dxhat[i].y; + } + + /* + step 1. xhat = (output - betta) / gamma or + (input - mean) * rsqrtf(var) + */ + vtmp = ((const float4 *)inp_or_out)[offset]; + var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); + if (means == nullptr) { + // inp_or_out is output, xhat = (output - betta) / gamma + float4 vbetta = ((const float4 *)betta)[threadIdx.x]; + __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vout = __half22float2(tmp_h2[i]); + float2 vgamma = __half22float2(gamma_h2[i]); + float2 vbetta = __half22float2(betta_h2[i]); + xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); + xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); + reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; + } + } else { + // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) + float fmean = (float)means[blockIdx.x]; +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vinp = __half22float2(tmp_h2[i]); + xhat[i].x = (vinp.x - fmean) * var_rsqrt; + xhat[i].y = (vinp.y - fmean) * var_rsqrt; + reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; + } + } + } + + /* step2. block reduce sum for dxhat and dxhat*xhat */ + blockReduce(reduce_val); + __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; + if (threadIdx.x == 0) { + float mean_dim = hidden_dim * 8; + s_sum_dxhat = reduce_val[0] / mean_dim; + s_sum_dxhat_xhat = reduce_val[1] / mean_dim; + } + __syncthreads(); + + /* + step3. compute input gradient + (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) + */ + if (threadIdx.x >= hidden_dim) { + return; + } + if (residual_grad) { + // Add the residual grad, + // usually in pre-layer-norm for transformer layer + float4 dresidual = ((const float4 *)residual_grad)[offset]; + __half *hdres = reinterpret_cast<__half *>(&dresidual); +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp_h2[i].x = __float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i])); + tmp_h2[i].y = __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i + 1])); + } + } else { +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp_h2[i].x = __float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2[i].y = __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + } + } + ((float4 *)inp_grad)[offset] = vtmp; +} + +__global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad, + const __half *residual_grad, + const __half *inp_or_out, const __half *gamma, + const __half *betta, const __half *vars, + const __half *means, int hidden_dim) { + int offset = blockIdx.x * hidden_dim * 2 + threadIdx.x * 2; + + float2 dxhat[4], xhat[4]; + float2 dxhat_1[4], xhat_1[4]; + float var_rsqrt; + float4 vtmp, vtmp_1; + __half2 *tmp_h2; + __half2 *tmp_h2_1; + float reduce_val[2] = {0.f, 0.f}; + + if (threadIdx.x < hidden_dim) { + // step 0. dxhat = dout * gamma + vtmp = ((const float4 *)out_grad)[offset]; + vtmp_1 = ((const float4 *)out_grad)[offset + 1]; + tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); + tmp_h2_1 = reinterpret_cast<__half2 *>(&vtmp_1); + float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x * 2]; + float4 gamma_f4_1 = ((const float4 *)gamma)[threadIdx.x * 2 + 1]; + __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); + __half2 *gamma_h2_1 = reinterpret_cast<__half2 *>(&gamma_f4_1); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vdout = __half22float2(tmp_h2[i]); + float2 vdout_1 = __half22float2(tmp_h2_1[i]); + float2 vgamma = __half22float2(gamma_h2[i]); + float2 vgamma_1 = __half22float2(gamma_h2_1[i]); + dxhat[i].x = vdout.x * vgamma.x; + dxhat[i].y = vdout.y * vgamma.y; + dxhat_1[i].x = vdout_1.x * vgamma_1.x; + dxhat_1[i].y = vdout_1.y * vgamma_1.y; + reduce_val[0] += dxhat[i].x + dxhat[i].y + dxhat_1[i].x + dxhat_1[i].y; + } + + /* + step 1. xhat = (output - betta) / gamma or + (input - mean) * rsqrtf(var) + */ + vtmp = ((const float4 *)inp_or_out)[offset]; + vtmp_1 = ((const float4 *)inp_or_out)[offset + 1]; + var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); + if (means == nullptr) { + // inp_or_out is output, xhat = (output - betta) / gamma + float4 vbetta = ((const float4 *)betta)[2 * threadIdx.x]; + float4 vbetta_1 = ((const float4 *)betta)[2 * threadIdx.x + 1]; + __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); + __half2 *betta_h2_1 = reinterpret_cast<__half2 *>(&vbetta_1); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vout = __half22float2(tmp_h2[i]); + float2 vout_1 = __half22float2(tmp_h2_1[i]); + float2 vgamma = __half22float2(gamma_h2[i]); + float2 vgamma_1 = __half22float2(gamma_h2_1[i]); + float2 vbetta = __half22float2(betta_h2[i]); + float2 vbetta_1 = __half22float2(betta_h2_1[i]); + xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); + xhat_1[i].x = (vout_1.x - vbetta_1.x) / add_eps(vgamma_1.x); + xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); + xhat_1[i].y = (vout_1.y - vbetta_1.y) / add_eps(vgamma_1.y); + reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; + reduce_val[1] += + xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; + } + } else { + // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) + float fmean = (float)means[blockIdx.x]; +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vinp = __half22float2(tmp_h2[i]); + float2 vinp_1 = __half22float2(tmp_h2_1[i]); + xhat[i].x = (vinp.x - fmean) * var_rsqrt; + xhat_1[i].x = (vinp_1.x - fmean) * var_rsqrt; + xhat[i].y = (vinp.y - fmean) * var_rsqrt; + xhat_1[i].y = (vinp_1.y - fmean) * var_rsqrt; + reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; + reduce_val[1] += + xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; + } + } + } + + /* step2. block reduce sum for dxhat and dxhat*xhat */ + blockReduce(reduce_val); + __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; + if (threadIdx.x == 0) { + float mean_dim = hidden_dim * 8 * 2; + s_sum_dxhat = reduce_val[0] / mean_dim; + s_sum_dxhat_xhat = reduce_val[1] / mean_dim; + } + __syncthreads(); + + /* + step3. compute input gradient + (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) + */ + if (threadIdx.x >= hidden_dim) { + return; + } + if (residual_grad) { + // Add the residual grad, + // usually in pre-layer-norm for transformer layer + float4 dresidual = ((const float4 *)residual_grad)[offset]; + float4 dresidual_1 = ((const float4 *)residual_grad)[offset + 1]; + __half *hdres = reinterpret_cast<__half *>(&dresidual); + __half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1); +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp_h2[i].x = __float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i])); + tmp_h2_1[i].x = __float2half( + (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i])); + tmp_h2[i].y = __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i + 1])); + tmp_h2_1[i].y = __float2half( + (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i + 1])); + } + } else { +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp_h2[i].x = __float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_1[i].x = __float2half( + (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2[i].y = __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_1[i].y = __float2half( + (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + } + } + ((float4 *)inp_grad)[offset] = vtmp; + ((float4 *)inp_grad)[offset + 1] = vtmp_1; +} + +__global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad, + const __half *residual_grad, + const __half *inp_or_out, const __half *gamma, + const __half *betta, const __half *vars, + const __half *means, int hidden_dim) { + int offset = blockIdx.x * hidden_dim * 4 + threadIdx.x * 4; + + float2 dxhat[4], xhat[4]; + float2 dxhat_1[4], xhat_1[4]; + float2 dxhat_2[4], xhat_2[4]; + float2 dxhat_3[4], xhat_3[4]; + float var_rsqrt; + float4 vtmp, vtmp_1, vtmp_2, vtmp_3; + __half2 *tmp_h2; + __half2 *tmp_h2_1; + __half2 *tmp_h2_2; + __half2 *tmp_h2_3; + float reduce_val[2] = {0.f, 0.f}; + + if (threadIdx.x < hidden_dim) { + // step 0. dxhat = dout * gamma + vtmp = ((const float4 *)out_grad)[offset]; + vtmp_1 = ((const float4 *)out_grad)[offset + 1]; + vtmp_2 = ((const float4 *)out_grad)[offset + 2]; + vtmp_3 = ((const float4 *)out_grad)[offset + 3]; + tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); + tmp_h2_1 = reinterpret_cast<__half2 *>(&vtmp_1); + tmp_h2_2 = reinterpret_cast<__half2 *>(&vtmp_2); + tmp_h2_3 = reinterpret_cast<__half2 *>(&vtmp_3); + float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x * 4]; + float4 gamma_f4_1 = ((const float4 *)gamma)[threadIdx.x * 4 + 1]; + float4 gamma_f4_2 = ((const float4 *)gamma)[threadIdx.x * 4 + 2]; + float4 gamma_f4_3 = ((const float4 *)gamma)[threadIdx.x * 4 + 3]; + __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); + __half2 *gamma_h2_1 = reinterpret_cast<__half2 *>(&gamma_f4_1); + __half2 *gamma_h2_2 = reinterpret_cast<__half2 *>(&gamma_f4_2); + __half2 *gamma_h2_3 = reinterpret_cast<__half2 *>(&gamma_f4_3); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vdout = __half22float2(tmp_h2[i]); + float2 vdout_1 = __half22float2(tmp_h2_1[i]); + float2 vdout_2 = __half22float2(tmp_h2_2[i]); + float2 vdout_3 = __half22float2(tmp_h2_3[i]); + float2 vgamma = __half22float2(gamma_h2[i]); + float2 vgamma_1 = __half22float2(gamma_h2_1[i]); + float2 vgamma_2 = __half22float2(gamma_h2_2[i]); + float2 vgamma_3 = __half22float2(gamma_h2_3[i]); + dxhat[i].x = vdout.x * vgamma.x; + dxhat[i].y = vdout.y * vgamma.y; + dxhat_1[i].x = vdout_1.x * vgamma_1.x; + dxhat_1[i].y = vdout_1.y * vgamma_1.y; + dxhat_2[i].x = vdout_2.x * vgamma_2.x; + dxhat_2[i].y = vdout_2.y * vgamma_2.y; + dxhat_3[i].x = vdout_3.x * vgamma_3.x; + dxhat_3[i].y = vdout_3.y * vgamma_3.y; + reduce_val[0] += dxhat[i].x + dxhat[i].y + dxhat_1[i].x + dxhat_1[i].y + + dxhat_2[i].x + dxhat_2[i].y + dxhat_3[i].x + + dxhat_3[i].y; + } + + /* + step 1. xhat = (output - betta) / gamma or + (input - mean) * rsqrtf(var) + */ + vtmp = ((const float4 *)inp_or_out)[offset]; + vtmp_1 = ((const float4 *)inp_or_out)[offset + 1]; + vtmp_2 = ((const float4 *)inp_or_out)[offset + 2]; + vtmp_3 = ((const float4 *)inp_or_out)[offset + 3]; + var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); + if (means == nullptr) { + // inp_or_out is output, xhat = (output - betta) / gamma + float4 vbetta = ((const float4 *)betta)[4 * threadIdx.x]; + float4 vbetta_1 = ((const float4 *)betta)[4 * threadIdx.x + 1]; + float4 vbetta_2 = ((const float4 *)betta)[4 * threadIdx.x + 2]; + float4 vbetta_3 = ((const float4 *)betta)[4 * threadIdx.x + 3]; + __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); + __half2 *betta_h2_1 = reinterpret_cast<__half2 *>(&vbetta_1); + __half2 *betta_h2_2 = reinterpret_cast<__half2 *>(&vbetta_2); + __half2 *betta_h2_3 = reinterpret_cast<__half2 *>(&vbetta_3); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vout = __half22float2(tmp_h2[i]); + float2 vout_1 = __half22float2(tmp_h2_1[i]); + float2 vout_2 = __half22float2(tmp_h2_2[i]); + float2 vout_3 = __half22float2(tmp_h2_3[i]); + float2 vgamma = __half22float2(gamma_h2[i]); + float2 vgamma_1 = __half22float2(gamma_h2_1[i]); + float2 vgamma_2 = __half22float2(gamma_h2_2[i]); + float2 vgamma_3 = __half22float2(gamma_h2_3[i]); + float2 vbetta = __half22float2(betta_h2[i]); + float2 vbetta_1 = __half22float2(betta_h2_1[i]); + float2 vbetta_2 = __half22float2(betta_h2_2[i]); + float2 vbetta_3 = __half22float2(betta_h2_3[i]); + xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); + xhat_1[i].x = (vout_1.x - vbetta_1.x) / add_eps(vgamma_1.x); + xhat_2[i].x = (vout_2.x - vbetta_2.x) / add_eps(vgamma_2.x); + xhat_3[i].x = (vout_3.x - vbetta_3.x) / add_eps(vgamma_3.x); + xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); + xhat_1[i].y = (vout_1.y - vbetta_1.y) / add_eps(vgamma_1.y); + xhat_2[i].y = (vout_2.y - vbetta_2.y) / add_eps(vgamma_2.y); + xhat_3[i].y = (vout_3.y - vbetta_3.y) / add_eps(vgamma_3.y); + reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; + reduce_val[1] += + xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; + reduce_val[1] += + xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y; + reduce_val[1] += + xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y; + } + } else { + // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) + float fmean = (float)means[blockIdx.x]; +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vinp = __half22float2(tmp_h2[i]); + float2 vinp_1 = __half22float2(tmp_h2_1[i]); + float2 vinp_2 = __half22float2(tmp_h2_2[i]); + float2 vinp_3 = __half22float2(tmp_h2_3[i]); + xhat[i].x = (vinp.x - fmean) * var_rsqrt; + xhat_1[i].x = (vinp_1.x - fmean) * var_rsqrt; + xhat_2[i].x = (vinp_2.x - fmean) * var_rsqrt; + xhat_3[i].x = (vinp_3.x - fmean) * var_rsqrt; + xhat[i].y = (vinp.y - fmean) * var_rsqrt; + xhat_1[i].y = (vinp_1.y - fmean) * var_rsqrt; + xhat_2[i].y = (vinp_2.y - fmean) * var_rsqrt; + xhat_3[i].y = (vinp_3.y - fmean) * var_rsqrt; + reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; + reduce_val[1] += + xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; + reduce_val[1] += + xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y; + reduce_val[1] += + xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y; + } + } + } + + /* step2. block reduce sum for dxhat and dxhat*xhat */ + blockReduce(reduce_val); + __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; + if (threadIdx.x == 0) { + float mean_dim = hidden_dim * 8 * 4; + s_sum_dxhat = reduce_val[0] / mean_dim; + s_sum_dxhat_xhat = reduce_val[1] / mean_dim; + } + __syncthreads(); + + /* + step3. compute input gradient + (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) + */ + if (threadIdx.x >= hidden_dim) { + return; + } + if (residual_grad) { + // Add the residual grad, + // usually in pre-layer-norm for transformer layer + float4 dresidual = ((const float4 *)residual_grad)[offset]; + float4 dresidual_1 = ((const float4 *)residual_grad)[offset + 1]; + float4 dresidual_2 = ((const float4 *)residual_grad)[offset + 2]; + float4 dresidual_3 = ((const float4 *)residual_grad)[offset + 3]; + __half *hdres = reinterpret_cast<__half *>(&dresidual); + __half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1); + __half *hdres_2 = reinterpret_cast<__half *>(&dresidual_2); + __half *hdres_3 = reinterpret_cast<__half *>(&dresidual_3); +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp_h2[i].x = __float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i])); + tmp_h2_1[i].x = __float2half( + (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i])); + tmp_h2_2[i].x = __float2half( + (dxhat_2[i].x - s_sum_dxhat - xhat_2[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_2[2 * i])); + tmp_h2_3[i].x = __float2half( + (dxhat_3[i].x - s_sum_dxhat - xhat_3[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_3[2 * i])); + tmp_h2[i].y = __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i + 1])); + tmp_h2_1[i].y = __float2half( + (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i + 1])); + tmp_h2_2[i].y = __float2half( + (dxhat_2[i].y - s_sum_dxhat - xhat_2[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i + 1])); + tmp_h2_3[i].y = __float2half( + (dxhat_3[i].y - s_sum_dxhat - xhat_3[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i + 1])); + } + } else { +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp_h2[i].x = __float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_1[i].x = __float2half( + (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_2[i].x = __float2half( + (dxhat_2[i].x - s_sum_dxhat - xhat_2[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_3[i].x = __float2half( + (dxhat_3[i].x - s_sum_dxhat - xhat_3[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2[i].y = __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_1[i].y = __float2half( + (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_2[i].y = __float2half( + (dxhat_2[i].y - s_sum_dxhat - xhat_2[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_3[i].y = __float2half( + (dxhat_3[i].y - s_sum_dxhat - xhat_3[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + } + } + ((float4 *)inp_grad)[offset] = vtmp; + ((float4 *)inp_grad)[offset + 1] = vtmp_1; + ((float4 *)inp_grad)[offset + 2] = vtmp_2; + ((float4 *)inp_grad)[offset + 3] = vtmp_3; +} + +/** +Layer norm backword, + compute the gradient of gamma, betta and input. +dbetta = sum(dout, dim=0) +xhat = (input - mean) * rsqrt(var) if mean is not nullptr + (output - betta) / gamma if mean is nullptr +dgamma = sum(xhat * dout, dim=0) +dxhat = dout * gamma +dinp = (dxhat - (sum(dxhat, 1) + xhat * sum(dxhat * xhat, 1)) / hidden_dim) + * rsqrt(var) + +residual_grad, means, betta can be nullptr. +residual_grad will be added to dinp if it is not nullptr + which is useful in transformer layer when pre-ln +means and betta are only used to compute xhat, + (means == nullptr) ^ (betta == nullptr) should be true +*/ +template <> +void launch_ln_bw(float *gamma_grad, float *betta_grad, float *inp_grad, + const float *out_grad, const float *residual_grad, + const float *inp_or_out, const float *gamma, + const float *betta, const float *vars, + const float *means, int batch, int hidden_dim, + cudaStream_t stream[2]) { + // compute grad of gamma and betta + dim3 grid_dim(((hidden_dim + TILE_DIM - 1) / TILE_DIM) * TILE_DIM); + dim3 block_dim(TILE_DIM, TILE_DIM); + ker_ln_bw_dgamma_dbetta<<>>( + gamma_grad, betta_grad, out_grad, inp_or_out, gamma, betta, vars, means, + batch, hidden_dim); + + // compute grad of input + if (hidden_dim % 4 != 0 || hidden_dim > 4096) { + throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 4096"); + } + hidden_dim >>= 2; + int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + ker_ln_bw_dinp<<>>( + inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, means, + hidden_dim); +} + +template <> +void launch_ln_bw<__half>(__half *gamma_grad, __half *betta_grad, + __half *inp_grad, const __half *out_grad, + const __half *residual_grad, const __half *inp_or_out, + const __half *gamma, const __half *betta, + const __half *vars, const __half *means, int batch, + int hidden_dim, cudaStream_t stream[2]) { + // compute grad of gamma and betta + dim3 grid_dim(((hidden_dim + TILE_DIM - 1) / TILE_DIM) * TILE_DIM); + dim3 block_dim(TILE_DIM, TILE_DIM); + ker_ln_bw_dgamma_dbetta<__half><<>>( + gamma_grad, betta_grad, out_grad, inp_or_out, gamma, betta, vars, means, + batch, hidden_dim); + + // compute grad of input + if (hidden_dim % 8 != 0) { + throw std::runtime_error("hidden_dim % 8 != 0"); + } + hidden_dim >>= 3; + + if (hidden_dim * 8 <= 8192) { + int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + ker_ln_bw_dinp<<>>( + inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, + means, hidden_dim); + } else if (hidden_dim * 8 > 8192 && hidden_dim * 8 <= 8192 * 2) { + hidden_dim >>= 1; + int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + ker_ln_bw_dinp_x2<<>>( + inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, + means, hidden_dim); + } else if (hidden_dim * 8 > 2 * 8192 && hidden_dim * 8 <= 8192 * 4) { + hidden_dim >>= 2; + int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + ker_ln_bw_dinp_x4<<>>( + inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, + means, hidden_dim); + } else { + throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 32768"); + } +} diff --git a/colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu new file mode 100644 index 0000000000000000000000000000000000000000..98af433fe3972db64f8599bf4d23597365b80f4b --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu @@ -0,0 +1,365 @@ +#include +#include + +#include +#include + +#include "block_reduce.h" +#include "kernels.h" + +namespace cg = cooperative_groups; +const float EPSILON = 1e-8f; + +/** +@brief: softmax_kernel +Softmax forward kernel for + enc-self-attn, dec-self-attn, encdec-attn + +@thread +gridDim.x = dynamic +gridDim.y = batch_size +gridDim.z = nhead +blockDim.x = from_len + +@param +inp: [batch_size, nhead, from_len, to_len], softmax input. +attn_mask: [batch_size, to_len], padding tokens are -inf, + non padding tokens are 0. + attn_mask!=nullptr for enc-self-attn and enc-dec-attn + attn_mask=nullptr and mask_future=ture for dec-self-attn training + attn_mask=nullptr and mask_future=false for dec-self-attn infer +*/ +template +__global__ void ker_attn_softmax(T *inp, const T *attn_mask, int from_len, + int to_len, bool mask_future) { + int batch_id = blockIdx.y; + int head_id = blockIdx.z; + const int nhead = gridDim.z; + const int token_per_reduce = 1; + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + T mval[ele_per_thread]; + if (attn_mask) { + attn_mask += batch_id * to_len; + BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG); + } + + inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len); + for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len; + token_id += gridDim.x * token_per_reduce) { + T inp_val[token_per_reduce][ele_per_thread]; + for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { + BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len, + REDUCE_FLOAT_INF_NEG); + } + + /* step 1. compute max */ + // thread local max + float val[token_per_reduce][ele_per_thread]; + float l_max[token_per_reduce]; + for (int i = 0; i < token_per_reduce; i++) { + l_max[i] = REDUCE_FLOAT_INF_NEG; + for (int j = 0; j < ele_per_thread; j++) { + if (attn_mask) { + val[i][j] = (float)inp_val[i][j] + (float)mval[j]; + } else { + if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) { + val[i][j] = REDUCE_FLOAT_INF_NEG; + } else { + val[i][j] = (float)inp_val[i][j]; + } + } + l_max[i] = fmaxf(l_max[i], val[i][j]); + } + } + // block reduce max + blockReduce(l_max); + // write shared + __shared__ float s_max[token_per_reduce]; + if (threadIdx.x == 0) { + for (int i = 0; i < token_per_reduce; i++) { + s_max[i] = l_max[i]; + } + } + __syncthreads(); + + /* step 2. compute sum */ + // thread local sum + float l_sum[token_per_reduce]; + for (int i = 0; i < token_per_reduce; i++) { + l_sum[i] = 0.f; + for (int j = 0; j < ele_per_thread; j++) { + val[i][j] = __expf(val[i][j] - s_max[i]); + l_sum[i] += val[i][j]; + } + } + // block reduce sum + blockReduce(l_sum); + // write shared + __shared__ float s_sum[token_per_reduce]; + if (threadIdx.x == 0) { + for (int i = 0; i < token_per_reduce; i++) { + s_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON); + } + } + __syncthreads(); + + /* step 3. compute final result */ + for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { + for (int j = 0; j < ele_per_thread; j++) { + inp_val[i][j] = (T)(val[i][j] * s_sum[i]); + } + BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i], + to_len); + } + } // blockIdx.x +} + +template +__global__ void ker_attn_softmax_lt32(T *inp, const T *attn_mask, int from_len, + int to_len, bool mask_future) { + int batch_id = blockIdx.y; + int head_id = blockIdx.z; + const int nhead = gridDim.z; + const int token_per_reduce = 1; + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + T mval[ele_per_thread]; + if (attn_mask) { + attn_mask += batch_id * to_len; + BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG); + } + + inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len); + for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len; + token_id += gridDim.x * token_per_reduce) { + T inp_val[token_per_reduce][ele_per_thread]; + for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { + BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len, + REDUCE_FLOAT_INF_NEG); + } + + /* step 1. compute max */ + // thread local max + float val[token_per_reduce][ele_per_thread]; + float l_max[token_per_reduce]; + for (int i = 0; i < token_per_reduce; i++) { + l_max[i] = REDUCE_FLOAT_INF_NEG; + for (int j = 0; j < ele_per_thread; j++) { + if (attn_mask) { + val[i][j] = (float)inp_val[i][j] + (float)mval[j]; + } else { + if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) { + val[i][j] = REDUCE_FLOAT_INF_NEG; + } else { + val[i][j] = (float)inp_val[i][j]; + } + } + l_max[i] = fmaxf(l_max[i], val[i][j]); + } + } + // warp reduce max + warpReduce(l_max); + + /* step 2. compute sum */ + // thread local sum + float l_sum[token_per_reduce]; + for (int i = 0; i < token_per_reduce; i++) { + l_sum[i] = 0.f; + for (int j = 0; j < ele_per_thread; j++) { + val[i][j] = __expf(val[i][j] - l_max[i]); + l_sum[i] += val[i][j]; + } + } + // warp reduce sum + warpReduce(l_sum); + + /* step 3. compute final result */ + for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { + l_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON); + for (int j = 0; j < ele_per_thread; j++) { + inp_val[i][j] = (T)(val[i][j] * l_sum[i]); + } + BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i], + to_len); + } + } // blockIdx.x +} + +/* + attn_mask!=nullptr for enc-self-attn and enc-dec-attn + attn_mask=nullptr and mask_future=ture for dec-self-attn training + attn_mask=nullptr and mask_future=false for dec-self-attn infer +*/ +template <> +void launch_attn_softmax(float *inp, const float *attn_mask, + int batch_size, int nhead, int from_len, + int to_len, bool mask_future, + cudaStream_t stream) { + dim3 grid_dim(1, batch_size, nhead); + if (to_len <= 32) { + ker_attn_softmax_lt32<<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 64) { + ker_attn_softmax_lt32<<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 128) { + grid_dim.x = 16; + ker_attn_softmax<<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 256) { + grid_dim.x = 32; + ker_attn_softmax<<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 512) { + grid_dim.x = 64; + ker_attn_softmax<<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else { + throw std::runtime_error( + "Sequence length greater than 512 is currently not supported"); + } +} + +template <> +void launch_attn_softmax<__half>(__half *inp, const __half *attn_mask, + int batch_size, int nhead, int from_len, + int to_len, bool mask_future, + cudaStream_t stream) { + dim3 grid_dim(1, batch_size, nhead); + if (to_len <= 32) { + ker_attn_softmax_lt32<__half, 32, 1><<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 64) { + ker_attn_softmax_lt32<__half, 32, 2><<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 128) { + grid_dim.x = 8; + ker_attn_softmax<__half, 64, 2><<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 256) { + grid_dim.x = 16; + ker_attn_softmax<__half, 128, 2><<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 512) { + grid_dim.x = 32; + ker_attn_softmax<__half, 256, 2><<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else { + throw std::runtime_error( + "Sequence length greater than 512 is currently not supported"); + } +} + +/** +@brief: ker_attn_softmax_bw +Softmax backward in self attention. + +@thread +gridDim.x = batch_size * nhead * seq_len / warps_per_block +blockDim.x = WARP_SIZE +blockDim.y = warps_per_block + +@param +grad: [batch_size, nhead, seq_len, seq_len], output grad. +output: [batch_size, nhead, seq_len, seq_len], output of softmax forward. +*/ +template +__global__ void ker_attn_softmax_bw(T *grad, const T *inp, int softmax_length) { + int batch_idx = blockIdx.x * blockDim.y + threadIdx.y; + int offset = batch_idx * softmax_length + threadIdx.x; + + grad += offset; + inp += offset; + + T grad_reg[ITERATIONS]; + T inp_reg[ITERATIONS]; + float sum = 0.0; + +#pragma unroll + for (int i = 0; i < ITERATIONS; ++i) { + int curr_idx = threadIdx.x + i * WARP_SIZE; + if (curr_idx < softmax_length) { + grad_reg[i] = grad[i * WARP_SIZE]; + inp_reg[i] = inp[i * WARP_SIZE]; + sum += (float)grad_reg[i] * (float)inp_reg[i]; + } + } + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i); + +#pragma unroll + for (int i = 0; i < ITERATIONS; ++i) { + int curr_idx = threadIdx.x + i * WARP_SIZE; + if (curr_idx < softmax_length) + grad[i * WARP_SIZE] = (T)((float)inp_reg[i] * ((float)grad_reg[i] - sum)); + } +} + +template +void launch_attn_softmax_bw(T *out_grad, const T *soft_inp, int rows, + int softmax_len, cudaStream_t stream) { + const int warps_per_block = 4; + // rows = batch_size * nhead * from_len + dim3 grid_dim(rows / warps_per_block); + dim3 block_dim(WARP_SIZE, warps_per_block); + + if (softmax_len <= 32) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 64) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 128) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 256) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 384) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 512) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 768) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 1024) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 2048) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else + throw std::runtime_error( + std::string( + "Special sequence length found in softmax backward, seq_len: ") + + std::to_string(softmax_len)); +} + +template void launch_attn_softmax_bw<__half>(__half *out_grad, + const __half *soft_inp, int rows, + int softmax_len, + cudaStream_t stream); +template void launch_attn_softmax_bw(float *out_grad, + const float *soft_inp, int rows, + int softmax_len, + cudaStream_t stream); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu new file mode 100644 index 0000000000000000000000000000000000000000..d03084b22e126fe5facd8ff709ac94a81e511a7e --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu @@ -0,0 +1,312 @@ +#include +#include +#include + +#include "kernels.h" + +using namespace cub; + +/** +@brief: transform_0213 +Split the attention heads and reshape input +during backward progress of encoder self-attention + +@thread +gridDim.x = batch_size +gridDim.y = seq_len +blockDim.x = min(hidden_dim, MAX_THREADS) + +@param +input: [batch_size, seq_len, hidden_dim] +output: [batch_size, nhead, seq_len, head_dim] +batch_size: the size of the current batch +seq_len: the sequence length of the current batch +hidden_dim: dim of the hidden tensor +nhead: number of attention heads +*/ + +template +__global__ void transform_0213(T *output, const T *input, int hidden_dim, + int head_dim); + +template <> +__global__ void transform_0213(float *output, const float *input, + int hidden_dim, int head_dim) { + int batch_id = blockIdx.x; + int token_id = blockIdx.y; + int seq_len = gridDim.y; + int nhead = hidden_dim / head_dim; + + // [b, s, h] + int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim); + // [b, nh, s, ad] + int trg_offset = + flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim); + + const float4 *input4 = reinterpret_cast(input); + float4 *res4 = reinterpret_cast(output); + float4 vinput4; + + for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { + vinput4 = input4[src_offset + i]; + + int head_id = i / head_dim; + int dim_id = i % head_dim; + int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim); + res4[trg_offset + cur_trg_offset] = vinput4; + } +} + +template <> +__global__ void transform_0213<__half>(__half *output, const __half *input, + int hidden_dim, int head_dim) { + int batch_id = blockIdx.x; + int token_id = blockIdx.y; + int seq_len = gridDim.y; + int nhead = hidden_dim / head_dim; + + // [b, s, h] + int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim); + // [b, nh, s, ad] + int trg_offset = + flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim); + + const float4 *input4 = reinterpret_cast(input); + float4 *res4 = reinterpret_cast(output); + float4 vinput4; + + for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { + vinput4 = input4[src_offset + i]; + + int head_id = i / head_dim; + int dim_id = i % head_dim; + int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim); + res4[trg_offset + cur_trg_offset] = vinput4; + } +} + +// [b, s, h] -> [b, nh, s, ad] +template <> +void launch_transform_0213(float *output, const float *input, + int batch_size, int seq_len, int hidden_dim, + int nhead, cudaStream_t stream) { + hidden_dim >>= 2; + int head_dim = hidden_dim / nhead; + + dim3 grid_dim(batch_size, seq_len); + dim3 block_dim(min(hidden_dim, MAX_THREADS)); + + transform_0213 + <<>>(output, input, hidden_dim, head_dim); +} + +template <> +void launch_transform_0213<__half>(__half *output, const __half *input, + int batch_size, int seq_len, int hidden_dim, + int nhead, cudaStream_t stream) { + hidden_dim >>= 3; + int head_dim = hidden_dim / nhead; + + dim3 grid_dim(batch_size, seq_len); + dim3 block_dim(min(hidden_dim, MAX_THREADS)); + + transform_0213<__half> + <<>>(output, input, hidden_dim, head_dim); +} + +/** +@brief: bias_add_transform_20314 +Add bias to input, transform from +[0, 1, 2, 3, 4] to [2, 0, 3, 1, 4] + +@thread +gridDim.x = dim_0 +gridDim.y = dim_1 +gridDim.z = dim_2 +blockDim.x = min(dim_3 * dim_4, MAX_THREADS) + +@param +input: [dim_0, dim_1, dim_2, dim_3, dim_4] +bias: [dim_2, dim_3, dim_4] +output: [dim_2, dim_0, dim_3, dim_1, dim_4] +*/ +template +__global__ void bias_add_transform_20314(T *output, const T *input, + const T *bias, int dim_3, int dim_4); + +template <> +__global__ void +bias_add_transform_20314(float *output, const float *input, + const float *bias, int dim_3, int dim_4) { + int id0 = blockIdx.x; + int id1 = blockIdx.y; + int id2 = blockIdx.z; + int dim_0 = gridDim.x; + int dim_1 = gridDim.y; + int dim_2 = gridDim.z; + int dim_34 = dim_3 * dim_4; + + int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34); + int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4); + int bias_offset = flat_2dim(id2, 0, dim_34); + + const float4 *qkv4 = reinterpret_cast(input); + const float4 *bias4 = reinterpret_cast(bias); + float4 *res4 = reinterpret_cast(output); + float4 vqkv4; + float4 vbias4; + float4 vres4; + + for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) { + vqkv4 = qkv4[src_offset + i]; + vbias4 = bias4[bias_offset + i]; + vres4.x = vqkv4.x + vbias4.x; + vres4.y = vqkv4.y + vbias4.y; + vres4.z = vqkv4.z + vbias4.z; + vres4.w = vqkv4.w + vbias4.w; + + int id3 = i / dim_4; + int id4 = i % dim_4; + int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4); + res4[trg_offset + cur_trg_offset] = vres4; + } +} + +template <> +__global__ void +bias_add_transform_20314<__half>(__half *output, const __half *input, + const __half *bias, int dim_3, int dim_4) { + int id0 = blockIdx.x; + int id1 = blockIdx.y; + int id2 = blockIdx.z; + int dim_0 = gridDim.x; + int dim_1 = gridDim.y; + int dim_2 = gridDim.z; + int dim_34 = dim_3 * dim_4; + + int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34); + int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4); + int bias_offset = flat_2dim(id2, 0, dim_34); + + const float4 *qkv4 = reinterpret_cast(input); + const float4 *bias4 = reinterpret_cast(bias); + float4 *res4 = reinterpret_cast(output); + float4 vqkv4; + float4 vbias4; + float4 vres4; + __half2 *h2_qkv = reinterpret_cast<__half2 *>(&vqkv4); + __half2 *h2_bias = reinterpret_cast<__half2 *>(&vbias4); + __half2 *h2_res = reinterpret_cast<__half2 *>(&vres4); + + for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) { + vqkv4 = qkv4[src_offset + i]; + vbias4 = bias4[bias_offset + i]; + h2_res[0] = __hadd2(h2_qkv[0], h2_bias[0]); + h2_res[1] = __hadd2(h2_qkv[1], h2_bias[1]); + h2_res[2] = __hadd2(h2_qkv[2], h2_bias[2]); + h2_res[3] = __hadd2(h2_qkv[3], h2_bias[3]); + + int id3 = i / dim_4; + int id4 = i % dim_4; + int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4); + res4[trg_offset + cur_trg_offset] = vres4; + } +} + +// [b, s, 3, h] -> [3, b, nh, s, ad] +template <> +void launch_bias_add_transform_20314(float *output, const float *input, + const float *bias, int dim_0, + int dim_1, int dim_2, int dim_3, + int dim_4, cudaStream_t stream) { + dim_4 >>= 2; + + dim3 grid_dim(dim_0, dim_1, dim_2); + dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS)); + + bias_add_transform_20314 + <<>>(output, input, bias, dim_3, dim_4); +} + +template <> +void launch_bias_add_transform_20314<__half>(__half *output, + const __half *input, + const __half *bias, int dim_0, + int dim_1, int dim_2, int dim_3, + int dim_4, cudaStream_t stream) { + dim_4 >>= 3; + + dim3 grid_dim(dim_0, dim_1, dim_2); + dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS)); + + bias_add_transform_20314<__half> + <<>>(output, input, bias, dim_3, dim_4); +} + +/** +@brief: transform4d_0213 +Reshape the input matrix to merge the heads + +@thread +gridDim.x = (num_all + max_block_thread - 1) / max_block_thread +blockDim.x = max_block_thread + +@param +input: [trans_count, batch_size, nhead, seq_len, head_dim] +output: [batch_size, seq_len, trans_count, nhead, head_dim] +batch_size: the size of the current batch +seq_len: the sequence length of the current batch +hidden_dim: dim of the hidden tensor +nhead: number of attention heads +trans_count: 1 or 3, the count of matrice need to be transformed +*/ +template +__global__ void transform4d_0213(T *output, const T *input, int batch_size, + int seq_len, int trans_count, int nhead, + int head_dim, int num_all) { + int offset = blockIdx.x * blockDim.x + threadIdx.x; + if (offset >= num_all) { + return; + } + int trans_id, batch_id, head_id, token_id, dim_id; + decompose_5dim(offset, batch_size, nhead, seq_len, head_dim, &trans_id, + &batch_id, &head_id, &token_id, &dim_id); + // [b, s, tc, nh, ad] + int trg_offset = flat_5dim(batch_id, token_id, trans_id, head_id, dim_id, + seq_len, trans_count, nhead, head_dim); + + const float4 *input4 = reinterpret_cast(input); + float4 *res4 = reinterpret_cast(output); + res4[trg_offset] = input4[offset]; +} + +// [tc, b, nh, s, ad] -> [b, s, tc, nh, ad] +template <> +void launch_transform4d_0213(float *output, const float *input, + int batch_size, int seq_len, int hidden_dim, + int nhead, int trans_count, + cudaStream_t stream) { + hidden_dim >>= 2; + int head_dim = hidden_dim / nhead; + int num_all = batch_size * seq_len * trans_count * hidden_dim; + int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS; + + transform4d_0213<<>>( + output, input, batch_size, seq_len, trans_count, nhead, head_dim, + num_all); +} + +template <> +void launch_transform4d_0213<__half>(__half *output, const __half *input, + int batch_size, int seq_len, + int hidden_dim, int nhead, int trans_count, + cudaStream_t stream) { + hidden_dim >>= 3; + int head_dim = hidden_dim / nhead; + int num_all = batch_size * seq_len * trans_count * hidden_dim; + int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS; + + transform4d_0213<__half><<>>( + output, input, batch_size, seq_len, trans_count, nhead, head_dim, + num_all); +} diff --git a/colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp b/colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4690277e63db0a49c23c9274e6553da1e6b04103 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp @@ -0,0 +1,141 @@ +/*This code from NVIDIA apex: + * https://github.com/NVIDIA/apex + * with minor changes. */ + +#include + +#include +#include + +#include "compat.h" + +namespace { + +void compute_n1_n2(at::Tensor input, at::IntArrayRef normalized_shape, int &n1, + int &n2) { + int idiff = input.ndimension() - normalized_shape.size(); + n2 = 1; + for (int i = 0; i < (int)normalized_shape.size(); ++i) { + assert(input.sizes()[i + idiff] == normalized_shape[i]); + n2 *= normalized_shape[i]; + } + n1 = 1; + for (int i = 0; i < idiff; ++i) { + n1 *= input.sizes()[i]; + } +} + +void check_args(at::IntArrayRef normalized_shape, at::Tensor gamma, + at::Tensor beta) { + TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape)); + TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape)); +} + +void check_args(at::Tensor input, at::IntArrayRef normalized_shape, int &n1, + int &n2) { + int64_t normalized_ndim = normalized_shape.size(); + + if (normalized_ndim < 1) { + std::stringstream ss; + ss << "Expected normalized_shape to be at least 1-dimensional, i.e., " + << "containing at least one element, but got normalized_shape=" + << normalized_shape; + throw std::runtime_error(ss.str()); + } + + auto input_shape = input.sizes(); + auto input_ndim = input.dim(); + + if (input_ndim < normalized_ndim || + !input_shape.slice(input_ndim - normalized_ndim) + .equals(normalized_shape)) { + std::stringstream ss; + ss << "Given normalized_shape=" << normalized_shape + << ", expected input with shape [*"; + for (auto size : normalized_shape) { + ss << ", " << size; + } + ss << "], but got input of size" << input_shape; + throw std::runtime_error(ss.str()); + } + + compute_n1_n2(input, normalized_shape, n1, n2); +} + +void check_args(at::Tensor input, at::IntArrayRef normalized_shape, + at::Tensor gamma, at::Tensor beta, int &n1, int &n2) { + check_args(input, normalized_shape, n1, n2); + check_args(normalized_shape, gamma, beta); +} +} // namespace + +void cuda_layer_norm(at::Tensor *output, at::Tensor *mean, at::Tensor *invvar, + at::Tensor *input, int n1, int n2, + at::IntArrayRef normalized_shape, at::Tensor *gamma, + at::Tensor *beta, double epsilon); + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +std::vector layer_norm_affine(at::Tensor input, + at::IntArrayRef normalized_shape, + at::Tensor gamma, at::Tensor beta, + double epsilon) { + CHECK_INPUT(input); + CHECK_INPUT(gamma); + CHECK_INPUT(beta); + int n1, n2; + check_args(input, normalized_shape, gamma, beta, n1, n2); + + at::Tensor output = + at::empty_like(input, gamma.options().dtype(gamma.scalar_type())); + at::Tensor mean = + at::empty({n1}, input.options().dtype(at::ScalarType::Float)); + at::Tensor invvar = at::empty_like(mean); + + cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2, normalized_shape, + &gamma, &beta, epsilon); + + return {output, mean, invvar}; +} + +void cuda_layer_norm_gradient(at::Tensor *dout, at::Tensor *mean, + at::Tensor *invvar, at::Tensor *input, int n1, + int n2, at::IntArrayRef normalized_shape, + at::Tensor *gamma, at::Tensor *beta, + double epsilon, at::Tensor *grad_input, + at::Tensor *grad_gamma, at::Tensor *grad_beta); + +std::vector layer_norm_gradient_affine( + at::Tensor dout, at::Tensor mean, at::Tensor invvar, at::Tensor input, + at::IntArrayRef normalized_shape, at::Tensor gamma, at::Tensor beta, + double epsilon) { + CHECK_INPUT(dout); + CHECK_INPUT(mean); + CHECK_INPUT(invvar); + CHECK_INPUT(input); + CHECK_INPUT(gamma); + CHECK_INPUT(beta); + int n1, n2; + check_args(input, normalized_shape, gamma, beta, n1, n2); + + at::Tensor grad_input = at::empty_like(input); + at::Tensor grad_gamma = at::empty_like(gamma); + at::Tensor grad_beta = at::empty_like(beta); + + cuda_layer_norm_gradient(&dout, &mean, &invvar, &input, n1, n2, + normalized_shape, &gamma, &beta, epsilon, + &grad_input, &grad_gamma, &grad_beta); + + return {grad_input, grad_gamma, grad_beta}; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)"); + m.def("backward_affine", &layer_norm_gradient_affine, + "LayerNorm backward (CUDA)"); +} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu b/colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..ad7066bbd9df1c1582946092460734986cdc2d03 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu @@ -0,0 +1,683 @@ +/*This code from NVIDIA apex: + * https://github.com/NVIDIA/apex + * with minor changes. */ + +#include +#include + +#include "ATen/ATen.h" +#include "ATen/AccumulateType.h" +#include "ATen/cuda/CUDAContext.h" +#include "ATen/cuda/DeviceUtils.cuh" +#include "type_shim.h" + +template +__device__ void cuWelfordOnlineSum(const U curr, U& mu, U& sigma2, U& count) { + count = count + U(1); + U delta = curr - mu; + U lmean = mu + delta / count; + mu = lmean; + U delta2 = curr - lmean; + sigma2 = sigma2 + delta * delta2; +} + +template +__device__ void cuChanOnlineSum(const U muB, const U sigma2B, const U countB, + U& mu, U& sigma2, U& count) { + U delta = muB - mu; + U nA = count; + U nB = countB; + count = count + countB; + U nX = count; + if (nX > U(0)) { + nA = nA / nX; + nB = nB / nX; + mu = nA * mu + nB * muB; + sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX; + } else { + mu = U(0); + sigma2 = U(0); + } +} + +template +__device__ void cuWelfordMuSigma2(const T* __restrict__ vals, const int n1, + const int n2, const int i1, U& mu, U& sigma2, + U* buf) { + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensor is contiguous + // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. + // + // compute variance and mean over n2 + U count = U(0); + mu = U(0); + sigma2 = U(0); + if (i1 < n1) { + // one warp normalizes one n1 index, + // synchronization is implicit + // initialize with standard Welford algorithm + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + const T* lvals = vals + i1 * n2; + int l = 4 * thrx; + for (; l + 3 < n2; l += 4 * numx) { + for (int k = 0; k < 4; ++k) { + U curr = static_cast(lvals[l + k]); + cuWelfordOnlineSum(curr, mu, sigma2, count); + } + } + for (; l < n2; ++l) { + U curr = static_cast(lvals[l]); + cuWelfordOnlineSum(curr, mu, sigma2, count); + } + // intra-warp reductions + for (int l = 0; l <= 4; ++l) { + int srcLaneB = (threadIdx.x + (1 << l)) & 31; + U muB = WARP_SHFL(mu, srcLaneB); + U countB = WARP_SHFL(count, srcLaneB); + U sigma2B = WARP_SHFL(sigma2, srcLaneB); + cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); + } + // threadIdx.x == 0 has correct values for each warp + // inter-warp reductions + if (blockDim.y > 1) { + U* ubuf = (U*)buf; + U* ibuf = (U*)(ubuf + blockDim.y); + for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.x == 0 && threadIdx.y >= offset && + threadIdx.y < 2 * offset) { + const int wrt_y = threadIdx.y - offset; + ubuf[2 * wrt_y] = mu; + ubuf[2 * wrt_y + 1] = sigma2; + ibuf[wrt_y] = count; + } + __syncthreads(); + // lower half merges + if (threadIdx.x == 0 && threadIdx.y < offset) { + U muB = ubuf[2 * threadIdx.y]; + U sigma2B = ubuf[2 * threadIdx.y + 1]; + U countB = ibuf[threadIdx.y]; + cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); + } + __syncthreads(); + } + // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values + if (threadIdx.x == 0 && threadIdx.y == 0) { + ubuf[0] = mu; + ubuf[1] = sigma2; + } + __syncthreads(); + mu = ubuf[0]; + sigma2 = ubuf[1] / U(n2); + // don't care about final value of count, we know count == n2 + } else { + mu = WARP_SHFL(mu, 0); + sigma2 = WARP_SHFL(sigma2 / U(n2), 0); + } + } +} + +template <> +__device__ void cuWelfordMuSigma2(const at::Half* __restrict__ vals, + const int n1, const int n2, const int i1, + float& mu, float& sigma2, float* buf) { + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensor is contiguous + // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. + // + // compute variance and mean over n2 + float count = 0.0f; + mu = float(0); + sigma2 = float(0); + if (i1 < n1) { + // one warp normalizes one n1 index, + // synchronization is implicit + // initialize with standard Welford algorithm + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + const at::Half* lvals = vals + i1 * n2; + int l = 8 * thrx; + if ((((size_t)lvals) & 3) != 0) { + // 16 bit alignment + // first thread consumes first point + if (thrx == 0) { + float curr = static_cast(lvals[0]); + cuWelfordOnlineSum(curr, mu, sigma2, count); + } + ++l; + } + // at this point, lvals[l] are 32 bit aligned for all threads. + for (; l + 7 < n2; l += 8 * numx) { + for (int k = 0; k < 8; k += 2) { + float2 curr = __half22float2(*((__half2*)(lvals + l + k))); + cuWelfordOnlineSum(curr.x, mu, sigma2, count); + cuWelfordOnlineSum(curr.y, mu, sigma2, count); + } + } + for (; l < n2; ++l) { + float curr = static_cast(lvals[l]); + cuWelfordOnlineSum(curr, mu, sigma2, count); + } + // intra-warp reductions + for (int l = 0; l <= 4; ++l) { + int srcLaneB = (threadIdx.x + (1 << l)) & 31; + float muB = WARP_SHFL(mu, srcLaneB); + float countB = WARP_SHFL(count, srcLaneB); + float sigma2B = WARP_SHFL(sigma2, srcLaneB); + cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); + } + // threadIdx.x == 0 has correct values for each warp + // inter-warp reductions + if (blockDim.y > 1) { + float* ubuf = (float*)buf; + float* ibuf = (float*)(ubuf + blockDim.y); + for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.x == 0 && threadIdx.y >= offset && + threadIdx.y < 2 * offset) { + const int wrt_y = threadIdx.y - offset; + ubuf[2 * wrt_y] = mu; + ubuf[2 * wrt_y + 1] = sigma2; + ibuf[wrt_y] = count; + } + __syncthreads(); + // lower half merges + if (threadIdx.x == 0 && threadIdx.y < offset) { + float muB = ubuf[2 * threadIdx.y]; + float sigma2B = ubuf[2 * threadIdx.y + 1]; + float countB = ibuf[threadIdx.y]; + cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); + } + __syncthreads(); + } + // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values + if (threadIdx.x == 0 && threadIdx.y == 0) { + ubuf[0] = mu; + ubuf[1] = sigma2; + } + __syncthreads(); + mu = ubuf[0]; + sigma2 = ubuf[1] / float(n2); + // don't care about final value of count, we know count == n2 + } else { + mu = WARP_SHFL(mu, 0); + sigma2 = WARP_SHFL(sigma2 / float(n2), 0); + } + } +} + +template +U rsqrt(U v) { + return U(1) / sqrt(v); +} +template <> +float rsqrt(float v) { + return rsqrtf(v); +} +template <> +double rsqrt(double v) { + return rsqrt(v); +} + +namespace { +// This is the un-specialized struct. Note that we prevent instantiation of +// this struct by putting an undefined symbol in the function body so it won't +// compile. +// template +// struct SharedMemory +// { +// // Ensure that we won't compile any un-specialized types +// __device__ T *getPointer() +// { +// extern __device__ void error(void); +// error(); +// return NULL; +// } +// }; +// https://github.com/NVIDIA/apex/issues/246 +template +struct SharedMemory; + +template <> +struct SharedMemory { + __device__ float* getPointer() { + extern __shared__ float s_float[]; + return s_float; + } +}; + +} // namespace + +template +__global__ void cuApplyLayerNorm(V* __restrict__ output_vals, + U* __restrict__ mean, U* __restrict__ invvar, + const T* __restrict__ vals, const int n1, + const int n2, const U epsilon, + const V* __restrict__ gamma, + const V* __restrict__ beta) { + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensors are contiguous + // + for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { + SharedMemory shared; + U* buf = shared.getPointer(); + U mu, sigma2; + cuWelfordMuSigma2(vals, n1, n2, i1, mu, sigma2, buf); + const T* lvals = vals + i1 * n2; + V* ovals = output_vals + i1 * n2; + U c_invvar = rsqrt(sigma2 + epsilon); + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + if (gamma != NULL && beta != NULL) { + for (int i = thrx; i < n2; i += numx) { + U curr = static_cast(lvals[i]); + ovals[i] = gamma[i] * static_cast(c_invvar * (curr - mu)) + beta[i]; + } + } else { + for (int i = thrx; i < n2; i += numx) { + U curr = static_cast(lvals[i]); + ovals[i] = static_cast(c_invvar * (curr - mu)); + } + } + if (threadIdx.x == 0 && threadIdx.y == 0) { + mean[i1] = mu; + invvar[i1] = c_invvar; + } + } +} + +template +__device__ void cuLoadWriteStridedInputs( + const int i1_block, const int thr_load_row_off, const int thr_load_col_off, + const int i2_off, const int row_stride, U* warp_buf1, U* warp_buf2, + const T* input, const V* dout, const int i1_end, const int n2, + const U* __restrict__ mean, const U* __restrict__ invvar) { + int i1 = i1_block + thr_load_row_off; + if (i1 < i1_end) { + U curr_mean = mean[i1]; + U curr_invvar = invvar[i1]; + for (int k = 0; k < blockDim.y; ++k) { + int i2 = i2_off + k; + int load_idx = i1 * n2 + i2; + int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; + if (i2 < n2) { + U curr_input = static_cast(input[load_idx]); + U curr_dout = static_cast(dout[load_idx]); + warp_buf1[write_idx] = curr_dout; + warp_buf2[write_idx] = + curr_dout * (curr_input - curr_mean) * curr_invvar; + } else { + warp_buf1[write_idx] = U(0); + warp_buf2[write_idx] = U(0); + } + } + } else { + for (int k = 0; k < blockDim.y; ++k) { + int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; + warp_buf1[write_idx] = U(0); + warp_buf2[write_idx] = U(0); + } + } +} + +template +__device__ void cuLoadAddStridedInputs( + const int i1_block, const int thr_load_row_off, const int thr_load_col_off, + const int i2_off, const int row_stride, U* warp_buf1, U* warp_buf2, + const T* input, const V* dout, const int i1_end, const int n2, + const U* __restrict__ mean, const U* __restrict__ invvar) { + int i1 = i1_block + thr_load_row_off; + if (i1 < i1_end) { + U curr_mean = mean[i1]; + U curr_invvar = invvar[i1]; + for (int k = 0; k < blockDim.y; ++k) { + int i2 = i2_off + k; + int load_idx = i1 * n2 + i2; + int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; + if (i2 < n2) { + U curr_input = static_cast(input[load_idx]); + U curr_dout = static_cast(dout[load_idx]); + warp_buf1[write_idx] += curr_dout; + warp_buf2[write_idx] += + curr_dout * (curr_input - curr_mean) * curr_invvar; + } + } + } +} + +template +__global__ void cuComputePartGradGammaBeta( + const V* __restrict__ dout, const T* __restrict__ input, const int n1, + const int n2, const U* __restrict__ mean, const U* __restrict__ invvar, + U epsilon, U* part_grad_gamma, U* part_grad_beta) { + const int numsegs_n1 = + (n1 + blockDim.y * blockDim.y - 1) / (blockDim.y * blockDim.y); + const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y; + const int i1_beg = blockIdx.y * segs_per_block * blockDim.y * blockDim.y; + const int i1_beg_plus_one = + (blockIdx.y + 1) * segs_per_block * blockDim.y * blockDim.y; + const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1; + const int row_stride = blockDim.x + 1; + const int thr_load_col_off = (threadIdx.x * blockDim.y) & (blockDim.x - 1); + const int thr_load_row_off = + (threadIdx.x * blockDim.y) / blockDim.x + threadIdx.y * blockDim.y; + const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off; + SharedMemory shared; + U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * + // blockDim.y + (blockDim.y - + // 1)*(blockDim.x/blockDim.y) elements + U* warp_buf1 = (U*)buf; + U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride; + // compute partial sums from strided inputs + // do this to increase number of loads in flight + cuLoadWriteStridedInputs(i1_beg, thr_load_row_off, thr_load_col_off, i2_off, + row_stride, warp_buf1, warp_buf2, input, dout, + i1_end, n2, mean, invvar); + for (int i1_block = i1_beg + blockDim.y * blockDim.y; i1_block < i1_end; + i1_block += blockDim.y * blockDim.y) { + cuLoadAddStridedInputs(i1_block, thr_load_row_off, thr_load_col_off, i2_off, + row_stride, warp_buf1, warp_buf2, input, dout, + i1_end, n2, mean, invvar); + } + __syncthreads(); + // inter-warp reductions + // sum within each warp + U acc1 = U(0); + U acc2 = U(0); + for (int k = 0; k < blockDim.y; ++k) { + int row1 = threadIdx.y + k * blockDim.y; + int idx1 = row1 * row_stride + threadIdx.x; + acc1 += warp_buf1[idx1]; + acc2 += warp_buf2[idx1]; + } + warp_buf1[threadIdx.y * row_stride + threadIdx.x] = acc1; + warp_buf2[threadIdx.y * row_stride + threadIdx.x] = acc2; + __syncthreads(); + // sum all warps + for (int offset = blockDim.y / 2; offset > 1; offset /= 2) { + if (threadIdx.y < offset) { + int row1 = threadIdx.y; + int row2 = threadIdx.y + offset; + int idx1 = row1 * row_stride + threadIdx.x; + int idx2 = row2 * row_stride + threadIdx.x; + warp_buf1[idx1] += warp_buf1[idx2]; + warp_buf2[idx1] += warp_buf2[idx2]; + } + __syncthreads(); + } + int i2 = blockIdx.x * blockDim.x + threadIdx.x; + if (threadIdx.y == 0 && i2 < n2) { + int row1 = threadIdx.y; + int row2 = threadIdx.y + 1; + int idx1 = row1 * row_stride + threadIdx.x; + int idx2 = row2 * row_stride + threadIdx.x; + part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2]; + part_grad_gamma[blockIdx.y * n2 + i2] = warp_buf2[idx1] + warp_buf2[idx2]; + } +} + +template +__global__ void cuComputeGradGammaBeta(const U* part_grad_gamma, + const U* part_grad_beta, + const int part_size, const int n1, + const int n2, V* grad_gamma, + V* grad_beta) { + // sum partial gradients for gamma and beta + SharedMemory shared; + U* buf = shared.getPointer(); + int i2 = blockIdx.x * blockDim.x + threadIdx.x; + if (i2 < n2) { + // each warp does sequential reductions until reduced part_size is num_warps + int num_warp_reductions = part_size / blockDim.y; + U sum_gamma = U(0); + U sum_beta = U(0); + const U* part_grad_gamma_ptr = + part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2; + const U* part_grad_beta_ptr = + part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2; + for (int warp_offset = 0; warp_offset < num_warp_reductions; + ++warp_offset) { + sum_gamma += part_grad_gamma_ptr[warp_offset * n2]; + sum_beta += part_grad_beta_ptr[warp_offset * n2]; + } + // inter-warp reductions + const int nbsize3 = blockDim.x * blockDim.y / 2; + for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) { + // top half write to shared memory + if (threadIdx.y >= offset && threadIdx.y < 2 * offset) { + const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x; + buf[write_idx] = sum_gamma; + buf[write_idx + nbsize3] = sum_beta; + } + __syncthreads(); + // bottom half sums + if (threadIdx.y < offset) { + const int read_idx = threadIdx.y * blockDim.x + threadIdx.x; + sum_gamma += buf[read_idx]; + sum_beta += buf[read_idx + nbsize3]; + } + __syncthreads(); + } + // write out fully summed gradients + if (threadIdx.y == 0) { + grad_gamma[i2] = sum_gamma; + grad_beta[i2] = sum_beta; + } + } +} + +template +__global__ void cuComputeGradInput(const V* __restrict__ dout, + const T* __restrict__ input, const int n1, + const int n2, const U* __restrict__ mean, + const U* __restrict__ invvar, U epsilon, + const V* gamma, T* grad_input) { + for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { + U sum_loss1 = U(0); + U sum_loss2 = U(0); + const U c_mean = mean[i1]; + const U c_invvar = invvar[i1]; + const T* k_input = input + i1 * n2; + const V* k_dout = dout + i1 * n2; + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + if (gamma != NULL) { + int l = 4 * thrx; + for (; l + 3 < n2; l += 4 * numx) { + for (int k = 0; k < 4; ++k) { + const U c_h = static_cast(k_input[l + k]); + const U c_loss = static_cast(k_dout[l + k]); + sum_loss1 += c_loss * gamma[l + k]; + sum_loss2 += c_loss * gamma[l + k] * (c_h - c_mean) * c_invvar; + } + } + for (; l < n2; ++l) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + sum_loss1 += c_loss * gamma[l]; + sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar; + } + } else { + int l = 4 * thrx; + for (; l + 3 < n2; l += 4 * numx) { + for (int k = 0; k < 4; ++k) { + const U c_h = static_cast(k_input[l + k]); + const U c_loss = static_cast(k_dout[l + k]); + sum_loss1 += c_loss; + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } + } + for (; l < n2; ++l) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + sum_loss1 += c_loss; + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } + } + // intra-warp reductions + for (int mask = blockDim.x / 2; mask > 0; mask /= 2) { + sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask); + sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask); + } + // inter-warp reductions + if (blockDim.y > 1) { + SharedMemory shared; + U* buf = shared.getPointer(); + for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.y >= offset && threadIdx.y < 2 * offset) { + const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x; + buf[2 * wrt_i] = sum_loss1; + buf[2 * wrt_i + 1] = sum_loss2; + } + __syncthreads(); + // lower half merges + if (threadIdx.y < offset) { + const int read_i = threadIdx.y * blockDim.x + threadIdx.x; + sum_loss1 += buf[2 * read_i]; + sum_loss2 += buf[2 * read_i + 1]; + } + __syncthreads(); + } + if (threadIdx.y == 0) { + buf[2 * threadIdx.x] = sum_loss1; + buf[2 * threadIdx.x + 1] = sum_loss2; + } + __syncthreads(); + if (threadIdx.y != 0) { + sum_loss1 = buf[2 * threadIdx.x]; + sum_loss2 = buf[2 * threadIdx.x + 1]; + } + } + // all threads now have the two sums over l + U fH = (U)n2; + U term1 = (U(1) / fH) * c_invvar; + T* k_grad_input = grad_input + i1 * n2; + if (gamma != NULL) { + for (int l = thrx; l < n2; l += numx) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + U f_grad_input = fH * c_loss * gamma[l]; + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + f_grad_input *= term1; + k_grad_input[l] = static_cast(f_grad_input); + } + } else { + for (int l = thrx; l < n2; l += numx) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + U f_grad_input = fH * c_loss; + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + f_grad_input *= term1; + k_grad_input[l] = static_cast(f_grad_input); + } + } + } +} + +template +void HostApplyLayerNorm(V* output, U* mean, U* invvar, const T* input, int n1, + int n2, double epsilon, const V* gamma, const V* beta) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + const dim3 threads(32, 4, 1); + const uint64_t maxGridY = + at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); + int nshared = + threads.y > 1 ? threads.y * sizeof(U) + (threads.y / 2) * sizeof(U) : 0; + cuApplyLayerNorm<<>>( + output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta); +} + +void cuda_layer_norm(at::Tensor* output, at::Tensor* mean, at::Tensor* invvar, + at::Tensor* input, int n1, int n2, +#ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, +#else + at::IntList normalized_shape, +#endif + at::Tensor* gamma, at::Tensor* beta, double epsilon) { + using namespace at; + DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( + input->scalar_type(), output->scalar_type(), "cuda_layer_norm_kernel", + HostApplyLayerNorm(output->DATA_PTR(), + mean->DATA_PTR(), invvar->DATA_PTR(), + input->DATA_PTR(), n1, n2, epsilon, + gamma != NULL ? gamma->DATA_PTR() : NULL, + beta != NULL ? beta->DATA_PTR() : NULL);) +} + +template +void HostLayerNormGradient(const V* dout, const U* mean, const U* invvar, + at::Tensor* input, int n1, int n2, const V* gamma, + const V* beta, double epsilon, T* grad_input, + V* grad_gamma, V* grad_beta) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + if (gamma != NULL && beta != NULL) { + // compute grad_gamma(j) and grad_beta(j) + const int part_size = 16; + const dim3 threads2(32, 4, 1); + const dim3 blocks2((n2 + threads2.x - 1) / threads2.x, part_size, 1); + const int nshared2_a = + 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1); + const int nshared2_b = threads2.x * threads2.y * sizeof(U); + const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; + at::Tensor part_grad_gamma = at::empty( + {part_size, n2}, input->options().dtype(at::ScalarType::Float)); + at::Tensor part_grad_beta = at::empty_like(part_grad_gamma); + cuComputePartGradGammaBeta<<>>( + dout, input->DATA_PTR(), n1, n2, mean, invvar, U(epsilon), + part_grad_gamma.DATA_PTR(), part_grad_beta.DATA_PTR()); + + const dim3 threads3(32, 8, 1); + const dim3 blocks3((n2 + threads2.x - 1) / threads2.x, 1, 1); + const int nshared3 = threads3.x * threads3.y * sizeof(U); + cuComputeGradGammaBeta<<>>( + part_grad_gamma.DATA_PTR(), part_grad_beta.DATA_PTR(), part_size, + n1, n2, grad_gamma, grad_beta); + } + + // compute grad_input + const uint64_t maxGridY = + at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1); + const dim3 threads1(32, 4, 1); + int nshared = threads1.y > 1 ? threads1.y * threads1.x * sizeof(U) : 0; + cuComputeGradInput<<>>( + dout, input->DATA_PTR(), n1, n2, mean, invvar, U(epsilon), gamma, + grad_input); +} + +void cuda_layer_norm_gradient(at::Tensor* dout, at::Tensor* mean, + at::Tensor* invvar, at::Tensor* input, int n1, + int n2, +#ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, +#else + at::IntList normalized_shape, +#endif + at::Tensor* gamma, at::Tensor* beta, + double epsilon, at::Tensor* grad_input, + at::Tensor* grad_gamma, at::Tensor* grad_beta) { + using namespace at; + DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( + input->scalar_type(), gamma->scalar_type(), + "cuda_layer_norm_gradient_kernel", + HostLayerNormGradient( + dout->DATA_PTR(), mean->DATA_PTR(), + invvar->DATA_PTR(), input, n1, n2, + // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta + // if gamma Tensor is NULL on input. + gamma != NULL ? gamma->DATA_PTR() : NULL, + gamma != NULL ? beta->DATA_PTR() : NULL, epsilon, + grad_input->DATA_PTR(), + gamma != NULL ? grad_gamma->DATA_PTR() : NULL, + gamma != NULL ? grad_beta->DATA_PTR() : NULL);) +} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/moe_cuda.cpp b/colossalai/kernel/cuda_native/csrc/moe_cuda.cpp new file mode 100644 index 0000000000000000000000000000000000000000..61c8a725052fdf9462d5adf1c9e43680e836caf3 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/moe_cuda.cpp @@ -0,0 +1,97 @@ +#include + +torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h, + torch::Tensor batch_tokens, + torch::Tensor mask, + torch::Tensor dest_idx); + +torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h, + torch::Tensor expert_grad, + torch::Tensor mask, + torch::Tensor dest_idx); + +torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h, + torch::Tensor expert_tokens, + torch::Tensor logits, torch::Tensor mask, + torch::Tensor dest_idx); + +std::vector moe_combine_cuda_backward( + int s, int e, int c, int h, torch::Tensor tokens_grad, + torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask, + torch::Tensor dest_idx); + +torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask); + +#define CHECK_CUDA(x) \ + TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +torch::Tensor moe_dispatch_forward(int s, int ec, int h, + torch::Tensor batch_tokens, + torch::Tensor mask, torch::Tensor dest_idx) { + CHECK_INPUT(batch_tokens); + CHECK_CUDA(mask); + CHECK_CUDA(dest_idx); + + return moe_dispatch_cuda_forward(s, ec, h, batch_tokens, mask, dest_idx); +} + +torch::Tensor moe_dispatch_backward(int s, int ec, int h, + torch::Tensor expert_grad, + torch::Tensor mask, + torch::Tensor dest_idx) { + CHECK_INPUT(expert_grad); + CHECK_CUDA(mask); + CHECK_CUDA(dest_idx); + + return moe_dispatch_cuda_backward(s, ec, h, expert_grad, mask, dest_idx); +} + +torch::Tensor moe_combine_forward(int s, int e, int c, int h, + torch::Tensor expert_tokens, + torch::Tensor logits, torch::Tensor mask, + torch::Tensor dest_idx) { + CHECK_INPUT(expert_tokens); + CHECK_INPUT(logits); + CHECK_CUDA(mask); + CHECK_CUDA(dest_idx); + + return moe_combine_cuda_forward(s, e, c, h, expert_tokens, logits, mask, + dest_idx); +} + +std::vector moe_combine_backward(int s, int e, int c, int h, + torch::Tensor tokens_grad, + torch::Tensor expert_tokens, + torch::Tensor logits, + torch::Tensor mask, + torch::Tensor dest_idx) { + CHECK_INPUT(tokens_grad); + CHECK_INPUT(logits); + CHECK_CUDA(mask); + CHECK_CUDA(dest_idx); + + return moe_combine_cuda_backward(s, e, c, h, tokens_grad, expert_tokens, + logits, mask, dest_idx); +} + +torch::Tensor moe_cumsum(torch::Tensor mask) { + CHECK_INPUT(mask); + return cumsum_sub_one_in_dim0(mask); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("cumsum_sub_one", &moe_cumsum, "Fast cumsum operation in dim0"); + m.def("dispatch_forward", &moe_dispatch_forward, + "Forward operation in MoE dispatch function"); + m.def("dispatch_backward", &moe_dispatch_backward, + "Backward operation in MoE dispatch function"); + m.def("combine_forward", &moe_combine_forward, + "Combine operation in MoE combine function"); + m.def("combine_backward", &moe_combine_backward, + "Combine operation in MoE combine function"); +} diff --git a/colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu b/colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..0454377a2fadb4ac4f1ded1359c71f17110b1ea3 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu @@ -0,0 +1,659 @@ +#include +#include +#include + +#include + +#include "block_reduce.h" + +template +__device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(src_row + idx, pack); + BlockStore(ts_store).Store(dst_row + idx, pack); + } +} + +template +__device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(dst_row + idx, pack); + BlockStore(ts_store).Store(src_row + idx, pack); + } +} + +template +__device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2, + const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(src_row + idx, pack); + BlockStore(ts_store).Store(dst_row1 + idx, pack); + BlockStore(ts_store).Store(dst_row2 + idx, pack); + } +} + +template +__device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2, + const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack1[pack_size], pack2[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(dst_row1 + idx, pack1); + BlockLoad(ts_load).Load(dst_row2 + idx, pack2); + +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + pack1[i] += pack2[i]; + } + + BlockStore(ts_store).Store(src_row + idx, pack1); + } +} + +template +__device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight, + const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(src_row + idx, pack); + +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + pack[i] *= weight; + } + + BlockStore(ts_store).Store(dst_row + idx, pack); + } +} + +template +__device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row, + T *weight_grad, const T weight, const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T grad[pack_size], tokens[pack_size]; + float thread_sum = 0; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(dst_row + idx, grad); + BlockLoad(ts_load).Load(tks_row + idx, tokens); + +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + thread_sum += grad[i] * tokens[i]; + grad[i] *= weight; + } + + BlockStore(ts_store).Store(src_row + idx, grad); + } + + blockReduce(&thread_sum); + + if (threadIdx.x == 0) *weight_grad = static_cast(thread_sum); +} + +template +__device__ void moe_cb_two_fwd(T *src_row1, T *src_row2, T *dst_row, + const T weight1, const T weight2, + const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack1[pack_size], pack2[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(src_row1 + idx, pack1); + BlockLoad(ts_load).Load(src_row2 + idx, pack2); + +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + pack1[i] = pack1[i] * weight1 + pack2[i] * weight2; + } + + BlockStore(ts_store).Store(dst_row + idx, pack1); + } +} + +template +__device__ void moe_cb_two_bwd(T *src_row1, T *src_row2, T *dst_row, + T *tks_row1, T *tks_row2, T *weight_grad1, + T *weight_grad2, const T weight1, + const T weight2, const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T grad[pack_size], tokens1[pack_size], tokens2[pack_size], sgrad1[pack_size], + sgrad2[pack_size]; + float thread_sum[2] = {0, 0}; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(dst_row + idx, grad); + BlockLoad(ts_load).Load(tks_row1 + idx, tokens1); + BlockLoad(ts_load).Load(tks_row2 + idx, tokens2); + +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + thread_sum[0] += grad[i] * tokens1[i]; + thread_sum[1] += grad[i] * tokens2[i]; + sgrad1[i] = weight1 * grad[i]; + sgrad2[i] = weight2 * grad[i]; + } + + BlockStore(ts_store).Store(src_row1 + idx, sgrad1); + BlockStore(ts_store).Store(src_row2 + idx, sgrad2); + } + + blockReduce(thread_sum); + + if (threadIdx.x == 0) + *weight_grad1 = static_cast(thread_sum[0]); + else if (threadIdx.x == 1) + *weight_grad2 = static_cast(thread_sum[1]); +} + +// DISPATCH KERNELS -------------------------------- + +template +__device__ void moe_dpch_fwd_selector(T *src_row, T *dst_row1, T *dst_row2, + const int cols, const int indicator1, + const int indicator2) { + if (indicator1 != 0 && indicator2 != 0) + moe_dpch_two_fwd(src_row, dst_row1, dst_row2, + cols); + else if (indicator1 != 0) + moe_dpch_one_fwd(src_row, dst_row1, cols); + else if (indicator2 != 0) + moe_dpch_one_fwd(src_row, dst_row2, cols); + else + return; +} + +template +__device__ void moe_dpch_bwd_selector(T *src_row, T *dst_row1, T *dst_row2, + const int cols, const int indicator1, + const int indicator2) { + if (indicator1 != 0 && indicator2 != 0) + moe_dpch_two_bwd(src_row, dst_row1, dst_row2, + cols); + else if (indicator1 != 0) + moe_dpch_one_bwd(src_row, dst_row1, cols); + else if (indicator2 != 0) + moe_dpch_one_bwd(src_row, dst_row2, cols); + else + return; +} + +template +__global__ void moe_dpch_fwd_kernel(T *batch_tokens, T *expert_input, + int *mask1, int *mask2, int *dest1, + int *dest2, const int h) { + int row = blockIdx.x; + int indicator2 = mask2 == nullptr ? 0 : mask2[row]; + moe_dpch_fwd_selector( + batch_tokens + (row * h), expert_input + (dest1[row] * h), + expert_input + (dest2[row] * h), h, mask1[row], indicator2); +} + +template +__global__ void moe_dpch_bwd_kernel(T *tokens_grad, T *expert_grad, int *mask1, + int *mask2, int *dest1, int *dest2, + const int h) { + int row = blockIdx.x; + int indicator2 = mask2 == nullptr ? 0 : mask2[row]; + moe_dpch_bwd_selector( + tokens_grad + (row * h), expert_grad + (dest1[row] * h), + expert_grad + (dest2[row] * h), h, mask1[row], indicator2); +} + +// COMBINE KERNELS -------------------------------- + +template +__device__ void moe_cb_fwd_selector(T *src_row1, T *src_row2, T *dst_row, + const int cols, const T weight1, + const T weight2, const int indicator1, + const int indicator2) { + if (indicator1 != 0 && indicator2 != 0) + moe_cb_two_fwd(src_row1, src_row2, dst_row, + weight1, weight2, cols); + else if (indicator1 != 0) + moe_cb_one_fwd(src_row1, dst_row, weight1, cols); + else if (indicator2 != 0) + moe_cb_one_fwd(src_row2, dst_row, weight2, cols); + else + return; +} + +template +__device__ void moe_cb_bwd_selector(T *src_row1, T *src_row2, T *dst_row, + const int cols, T *tks_row1, T *tks_row2, + T *wt_grad1, T *wt_grad2, const T weight1, + const T weight2, const int indicator1, + const int indicator2) { + if (indicator1 != 0 && indicator2 != 0) + moe_cb_two_bwd(src_row1, src_row2, dst_row, + tks_row1, tks_row2, wt_grad1, + wt_grad2, weight1, weight2, cols); + else if (indicator1 != 0) + moe_cb_one_bwd(src_row1, dst_row, tks_row1, + wt_grad1, weight1, cols); + else if (indicator2 != 0) + moe_cb_one_bwd(src_row2, dst_row, tks_row2, + wt_grad2, weight2, cols); + else + return; +} + +template +__global__ void moe_cb_fwd_kernel(T *expert_tokens, T *combine_tokens, + T *logits, int *mask1, int *mask2, int *dest1, + int *dest2, const int e, const int c, + const int h) { + int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c; + int indicator2 = mask2 == nullptr ? 0 : mask2[row]; + T *row_log = logits + (row * e); + moe_cb_fwd_selector( + expert_tokens + (dest1[row] * h), expert_tokens + (dest2[row] * h), + combine_tokens + (row * h), h, row_log[eid1], row_log[eid2], mask1[row], + indicator2); +} + +template +__global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks, + T *logits, T *logits_grad, int *mask1, + int *mask2, int *dest1, int *dest2, + const int e, const int c, const int h) { + int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c; + int indicator2 = mask2 == nullptr ? 0 : mask2[row]; + T *row_log = logits + (row * e), *row_grad = logits_grad + (row * e); + moe_cb_bwd_selector( + expert_grad + (dest1[row] * h), expert_grad + (dest2[row] * h), + tokens_grad + (row * h), h, tks + (dest1[row] * h), + tks + (dest2[row] * h), row_grad + eid1, row_grad + eid2, row_log[eid1], + row_log[eid2], mask1[row], indicator2); +} + +// CUMSUM KERNEL -------------------------------- + +template +__global__ void cumsum_kernel(int *inputs, int *outputs, const int s, + const int e) { + assert(s % pack_size == 0); + constexpr int bpack_size = block_size * pack_size; + int tid = threadIdx.x, bid = blockIdx.x, tps = tid * pack_size, last_sum = -1; + __shared__ int temp[block_size + 1]; + int pack[pack_size]; + + for (int idx = 0; idx < s; idx += bpack_size) { + int offset = 1; + + if (idx + tps < s) { + temp[tid] = inputs[tps * e + bid]; +#pragma unroll + for (int i = 1; i < pack_size; ++i) { + pack[i] = inputs[(tps + i) * e + bid]; + } +#pragma unroll + for (int i = 1; i < pack_size; ++i) { + temp[tid] += pack[i]; + } + } + + for (int i = block_size >> 1; i > 0; i >>= 1) { + __syncthreads(); + if (tid < i) { + int j = offset * (2 * tid + 1) - 1; + temp[j + offset] += temp[j]; + } + offset <<= 1; + } + + if (tid == 0) { + temp[block_size] = temp[block_size - 1]; + temp[block_size - 1] = 0; + } + + for (int i = 1; i < block_size; i <<= 1) { + offset >>= 1; + __syncthreads(); + if (tid < i) { + int j = offset * (2 * tid + 1) - 1, k = j + offset, ts = temp[j]; + temp[j] = temp[k]; + temp[k] += ts; + } + } + __syncthreads(); + + if (tid == 0) temp[0] = temp[block_size]; + __syncthreads(); + + if (idx + tps < s) { + temp[tid + 1] += last_sum; +#pragma unroll + for (int i = pack_size - 1; i > 0; --i) { + outputs[(tps + i) * e + bid] = temp[tid + 1]; + temp[tid + 1] -= pack[i]; + } + outputs[tps * e + bid] = temp[tid + 1]; + } + __syncthreads(); + + last_sum += temp[0]; + inputs += bpack_size * e; + outputs += bpack_size * e; + } +} + +// LAUNCH FUNCTIONS -------------------------------- + +template +void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1, + int *mask2, int *dest1, int *dest2, const int s, + const int h) { + if (h < 256) + moe_dpch_fwd_kernel + <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); + else if (h < 512) + moe_dpch_fwd_kernel + <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); + else if (h < 1024) + moe_dpch_fwd_kernel + <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); + else if (h < 2048) + moe_dpch_fwd_kernel + <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); + else + moe_dpch_fwd_kernel + <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); +} + +template +void moe_dpch_bwd_launch(T *tokens_grad, T *expert_grad, int *mask1, int *mask2, + int *dest1, int *dest2, const int s, const int h) { + if (h < 256) + moe_dpch_bwd_kernel + <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); + else if (h < 512) + moe_dpch_bwd_kernel + <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); + else if (h < 1024) + moe_dpch_bwd_kernel + <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); + else if (h < 2048) + moe_dpch_bwd_kernel + <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); + else + moe_dpch_bwd_kernel + <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); +} + +template +void moe_cb_fwd_launch(T *expert_tokens, T *combine_tokens, T *logits, + int *mask1, int *mask2, int *dest1, int *dest2, + const int s, const int e, const int c, const int h) { + if (h < 256) + moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, + logits, mask1, mask2, dest1, dest2, + e, c, h); + else if (h < 512) + moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, + logits, mask1, mask2, dest1, dest2, + e, c, h); + else if (h < 1024) + moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, + logits, mask1, mask2, dest1, dest2, + e, c, h); + else if (h < 2048) + moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, + logits, mask1, mask2, dest1, dest2, + e, c, h); + else + moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, + logits, mask1, mask2, dest1, + dest2, e, c, h); +} + +template +void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits, + T *logits_grad, int *mask1, int *mask2, int *dest1, + int *dest2, const int s, const int e, const int c, + const int h) { + if (h < 256) + moe_cb_bwd_kernel<<>>(tokens_grad, expert_grad, tks, + logits, logits_grad, mask1, mask2, + dest1, dest2, e, c, h); + else // if (h < 512) + moe_cb_bwd_kernel<<>>(tokens_grad, expert_grad, tks, + logits, logits_grad, mask1, mask2, + dest1, dest2, e, c, h); + // else if (h < 1024) + // moe_cb_bwd_kernel<<>> + // (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, + // dest1, dest2, e, c, h); + // else + // moe_cb_bwd_kernel<<>> + // (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, + // dest1, dest2, e, c, h); +} + +void cumsum_launch(int *inputs, int *outputs, const int s, const int e) { + if (s <= 256) + cumsum_kernel<256, 1><<>>(inputs, outputs, s, e); + else if (s <= 512) + cumsum_kernel<512, 1><<>>(inputs, outputs, s, e); + else if (s <= 1024) + cumsum_kernel<1024, 1><<>>(inputs, outputs, s, e); + else if (s <= 2048) + cumsum_kernel<1024, 2><<>>(inputs, outputs, s, e); + else + cumsum_kernel<1024, 4><<>>(inputs, outputs, s, e); +} + +// API FUNCTIONS -------------------------------- + +#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \ + switch (TYPE) { \ + case at::ScalarType::Float: { \ + using scalar_t = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: { \ + using scalar_t = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented yet for specific data type."); \ + } + +torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h, + torch::Tensor batch_tokens, + torch::Tensor mask, + torch::Tensor dest_idx) { + assert(h % 16 == 0); + auto res = torch::zeros( + {ec, h}, + torch::dtype(batch_tokens.dtype()).device(batch_tokens.device())); + auto k = mask.size(0); + + DISPATCH_FLOAT_AND_HALF( + batch_tokens.scalar_type(), "moe dispatch forward", + moe_dpch_fwd_launch( + batch_tokens.data(), res.data(), + mask[0].data(), k == 1 ? nullptr : mask[1].data(), + dest_idx[0].data(), + k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, h)); + + return res; +} + +torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h, + torch::Tensor expert_grad, + torch::Tensor mask, + torch::Tensor dest_idx) { + assert(h % 16 == 0); + auto res = torch::zeros( + {s, h}, torch::dtype(expert_grad.dtype()).device(expert_grad.device())); + auto k = mask.size(0); + + DISPATCH_FLOAT_AND_HALF( + expert_grad.scalar_type(), "moe dispatch backward", + moe_dpch_bwd_launch( + res.data(), expert_grad.data(), + mask[0].data(), k == 1 ? nullptr : mask[1].data(), + dest_idx[0].data(), + k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, h)); + + return res; +} + +torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h, + torch::Tensor expert_tokens, + torch::Tensor logits, torch::Tensor mask, + torch::Tensor dest_idx) { + assert(h % 16 == 0); + assert(expert_tokens.dtype() == logits.dtype()); + + auto res = torch::zeros( + {s, h}, + torch::dtype(expert_tokens.dtype()).device(expert_tokens.device())); + auto k = mask.size(0); + + DISPATCH_FLOAT_AND_HALF( + expert_tokens.scalar_type(), "moe combine forward", + moe_cb_fwd_launch( + expert_tokens.data(), res.data(), + logits.data(), mask[0].data(), + k == 1 ? nullptr : mask[1].data(), dest_idx[0].data(), + k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, e, c, + h)); + + return res; +} + +std::vector moe_combine_cuda_backward( + int s, int e, int c, int h, torch::Tensor tokens_grad, + torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask, + torch::Tensor dest_idx) { + assert(h % 16 == 0); + assert(tokens_grad.dtype() == expert_tokens.dtype()); + assert(expert_tokens.dtype() == logits.dtype()); + + auto egrad = torch::zeros( + {e * c, h}, + torch::dtype(tokens_grad.dtype()).device(tokens_grad.device())), + wgrad = torch::zeros( + {s, e}, torch::dtype(logits.dtype()).device(logits.device())); + auto k = mask.size(0); + + DISPATCH_FLOAT_AND_HALF( + tokens_grad.scalar_type(), "moe combine backward", + moe_cb_bwd_launch( + tokens_grad.data(), egrad.data(), + expert_tokens.data(), logits.data(), + wgrad.data(), mask[0].data(), + k == 1 ? nullptr : mask[1].data(), dest_idx[0].data(), + k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, e, c, + h)); + + return {egrad, wgrad}; +} + +torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask) { + assert(mask.dim() == 2); + assert(mask.dtype() == torch::kInt32); + + const int s = mask.size(0), e = mask.size(1); + auto res = + torch::empty({s, e}, torch::dtype(torch::kInt32).device(mask.device())); + cumsum_launch(mask.data(), res.data(), s, e); + + return res; +} diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_adam.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_adam.cu new file mode 100644 index 0000000000000000000000000000000000000000..9cc3ae1eac1085d8ca59a97058459bf9a4f4be1a --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_adam.cu @@ -0,0 +1,146 @@ +// modified from +// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_adam.cu +/* Copyright 2020 The Microsoft DeepSpeed Team + Copyright NVIDIA/apex + This file is adapted from fused adam in NVIDIA/apex, commit a109f85 + Licensed under the MIT License. +*/ +#include +#include +#include +#include +// Another possibility: +// #include + +#include + +#include "multi_tensor_apply.cuh" +#include "type_shim.h" + +#define BLOCK_SIZE 512 +#define ILP 4 + +typedef enum { + ADAM_MODE_0 = 0, // L2 regularization mode + ADAM_MODE_1 = 1 // Decoupled weight decay mode(AdamW) +} adamMode_t; + +using MATH_T = float; + +template +struct AdamFunctor { + __device__ __forceinline__ void operator()( + int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl, + const float beta1, const float beta2, const float beta1_correction, + const float beta2_correction, const float epsilon, const float lr, + adamMode_t mode, const float decay, const float div_scale) { + // I'd like this kernel to propagate infs/nans. + // if(*noop_gmem == 1) + // return; + + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + + // potentially use to pass in list of scalar + // int tensor_num = tl.start_tensor_this_launch + tensor_loc; + + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + T_g *g = (T_g *)tl.addresses[0][tensor_loc]; + g += chunk_idx * chunk_size; + + T_p *p = (T_p *)tl.addresses[1][tensor_loc]; + p += chunk_idx * chunk_size; + + T_p *m = (T_p *)tl.addresses[2][tensor_loc]; + m += chunk_idx * chunk_size; + + T_p *v = (T_p *)tl.addresses[3][tensor_loc]; + v += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + // see note in multi_tensor_scale_kernel.cu + for (int i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * ILP) { + MATH_T r_g[ILP]; + MATH_T r_p[ILP]; + MATH_T r_m[ILP]; + MATH_T r_v[ILP]; +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + r_g[ii] = g[i]; + r_p[ii] = p[i]; + r_m[ii] = m[i]; + r_v[ii] = v[i]; + } else { + r_g[ii] = MATH_T(0); + r_p[ii] = MATH_T(0); + r_m[ii] = MATH_T(0); + r_v[ii] = MATH_T(0); + } + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + if (div_scale > 0) r_g[ii] /= div_scale; + + if (mode == ADAM_MODE_0) { // L2 + r_g[ii] = r_g[ii] + (decay * r_p[ii]); + r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; + r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + MATH_T update = next_m_unbiased / denom; + r_p[ii] = r_p[ii] - (lr * update); + } else { // weight decay + r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; + r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]); + r_p[ii] = r_p[ii] - (lr * update); + } + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + p[i] = r_p[ii]; + m[i] = r_m[ii]; + v[i] = r_v[ii]; + } + } + } + } +}; + +void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + const float lr, const float beta1, + const float beta2, const float epsilon, + const int step, const int mode, + const int bias_correction, const float weight_decay, + const float div_scale) { + using namespace at; + + // Handle bias correction mode + float bias_correction1 = 1.0f, bias_correction2 = 1.0f; + if (bias_correction == 1) { + bias_correction1 = 1 - std::pow(beta1, step); + bias_correction2 = 1 - std::pow(beta2, step); + } + + DISPATCH_FLOAT_AND_HALF_FOR_G_P( + tensor_lists[0][0].scalar_type(), tensor_lists[1][0].scalar_type(), 0, + "adam", + multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + AdamFunctor(), beta1, + beta2, bias_correction1, bias_correction2, epsilon, + lr, (adamMode_t)mode, weight_decay, div_scale);) + + AT_CUDA_CHECK(cudaGetLastError()); +} diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_apply.cuh b/colossalai/kernel/cuda_native/csrc/multi_tensor_apply.cuh new file mode 100644 index 0000000000000000000000000000000000000000..ec55dd320b405c35f9c51078c46a633253f66df2 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_apply.cuh @@ -0,0 +1,130 @@ +// modified from +// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_apply.cuh +/* Copyright 2020 The Microsoft DeepSpeed Team + Copyright NVIDIA/apex + This file is adapted from fused adam in NVIDIA/apex, commit a109f85 + Licensed under the MIT License. +*/ +#include +#include +#include +#include +#include +#include + +#include "compat.h" + +// #include + +// This header is the one-stop shop for all your multi-tensor apply needs. + +// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson) +constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30}; +constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320}; + +template +struct TensorListMetadata { + void *addresses[n][depth_to_max_tensors[n - 1]]; + int sizes[depth_to_max_tensors[n - 1]]; + unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; + int block_to_chunk[depth_to_max_blocks[n - 1]]; // I fear this needs to be a + // full int. + int start_tensor_this_launch; +}; + +template +__global__ void multi_tensor_apply_kernel(int chunk_size, + volatile int *noop_flag, T tl, + U callable, ArgTypes... args) { + // Hand the chunk information to the user-supplied functor to process however + // it likes. + callable(chunk_size, noop_flag, tl, args...); +} + +template +void multi_tensor_apply( + int block_size, int chunk_size, const at::Tensor &noop_flag, + const std::vector> &tensor_lists, T callable, + ArgTypes... args) { + TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth"); + int len0 = tensor_lists[0].size(); + TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0"); + auto ref_device = tensor_lists[0][0].device(); + TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda"); + for (int l = 0; l < tensor_lists.size(); + l++) // No range-based for because I need indices + { + TORCH_CHECK(tensor_lists[l].size() == len0, + "Size mismatch among tensor lists"); + for (int t = 0; t < tensor_lists[l].size(); t++) { + // TODO: Print which tensor fails. + bool contiguous_memory = tensor_lists[l][t].is_contiguous(); +#ifdef VERSION_GE_1_5 + contiguous_memory = + (contiguous_memory || + tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast)); +#endif + TORCH_CHECK(contiguous_memory, "A tensor was not contiguous."); + TORCH_CHECK(tensor_lists[l][t].device() == ref_device, + "A tensor was not on the same device as the first tensor"); + TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), + "Size mismatch"); + } + } + + int ntensors = tensor_lists[0].size(); + + TensorListMetadata tl; + + const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0])); + auto stream = at::cuda::getCurrentCUDAStream(); + + tl.start_tensor_this_launch = 0; + int loc_block_info = 0; + int loc_tensor_info = 0; + for (int t = 0; t < ntensors; t++) { + tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel(); + for (int d = 0; d < depth; d++) + tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr(); + loc_tensor_info++; + + int chunks_this_tensor = + (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; + + for (int chunk = 0; chunk < chunks_this_tensor; chunk++) { + // std::cout << chunks_this_tensor << std::endl; + tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1; + tl.block_to_chunk[loc_block_info] = chunk; + loc_block_info++; + + bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth - 1] && + chunk == chunks_this_tensor - 1); + bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]); + bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1); + if (tensors_full || blocks_full || last_chunk) { + // using accscalar_t = acc_type; + multi_tensor_apply_kernel<<>>( + chunk_size, noop_flag.DATA_PTR(), tl, callable, args...); + + AT_CUDA_CHECK(cudaGetLastError()); + + // Reset. The control flow possibilities here make my brain hurt. + loc_block_info = 0; + if (chunk == chunks_this_tensor - 1) { + // std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 + // << std::endl; + loc_tensor_info = 0; + tl.start_tensor_this_launch = t + 1; + } else { + // std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 + // << std::endl; + tl.sizes[0] = tl.sizes[loc_tensor_info - 1]; + for (int d = 0; d < depth; d++) + tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1]; + loc_tensor_info = 1; + tl.start_tensor_this_launch = t; + } + } + } + } +} diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..49ab83e8fc81df4c9887d55ddc5503f20498bb7d --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu @@ -0,0 +1,382 @@ +// modified from +// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_l2norm_kernel.cu +#include +#include +#include +#include +#include +// Another possibility: +// #include + +#include + +#include "multi_tensor_apply.cuh" +#include "type_shim.h" + +#define BLOCK_SIZE 512 +#define ILP 4 + +template +__device__ __forceinline__ bool is_aligned(T *p) { + return ((uint64_t)p) % (ILP * sizeof(T)) == 0; +} + +template +__device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset, + int src_offset) { + typedef + typename std::aligned_storage::type LT; + ((LT *)dst)[dst_offset] = ((LT *)src)[src_offset]; +} + +template +struct L2NormFunctor { + __device__ __forceinline__ void operator()( + int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl, + float *output, float *output_per_tensor, bool per_tensor, + int max_chunks_per_tensor) { + // I'd like this kernel to propagate infs/nans. + // if(*noop_gmem == 1) + // return; + + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + x_t *x = (x_t *)tl.addresses[0][tensor_loc]; + x += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + __shared__ float s_vals[512]; + + float vals[ILP]; // = {0}; // this probably works too but I want to be + // sure... + x_t r_x[ILP]; + for (int i = 0; i < ILP; i++) { + vals[i] = 0.f; + r_x[i] = 0; + } + + // to make things simple, we put aligned case in a different code path + if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) { + for (int i_start = threadIdx.x; + i_start * ILP < n && i_start * ILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_x, x, 0, i_start); +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + float next = static_cast(r_x[ii]); + vals[ii] += next * next; + } + } + } else { + for (int i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * ILP) { +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + float next = static_cast(x[i]); + vals[ii] += next * next; + } + } + } + } + + float val = 0.f; + for (int i = 0; i < ILP; i++) val += vals[i]; + + float final = reduce_block_into_lanes(s_vals, val); + + if (threadIdx.x == 0) { + if (!isfinite(final)) + *noop_gmem = + 1; // Blindly fire off a write. These will race but that's ok. + output[blockIdx.x] += final; + if (per_tensor) + output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) * + max_chunks_per_tensor + + chunk_idx] = final; + } + } +}; + +// Probably better to template, but since we are not likely to support other +// norm +template +struct MaxNormFunctor { + __device__ __forceinline__ void operator()( + int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl, + float *output, float *output_per_tensor, bool per_tensor, + int max_chunks_per_tensor) { + // I'd like this kernel to propagate infs/nans. + // if(*noop_gmem == 1) + // return; + + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + x_t *x = (x_t *)tl.addresses[0][tensor_loc]; + x += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + __shared__ float s_vals[512]; + + float vals[ILP]; // = {0}; // this probably works too but I want to be + // sure... + x_t r_x[ILP]; + for (int i = 0; i < ILP; i++) { + vals[i] = 0.f; + r_x[i] = 0; + } + + // to make things simple, we put aligned case in a different code path + if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) { + for (int i_start = threadIdx.x; + i_start * ILP < n && i_start * ILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_x, x, 0, i_start); +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + float next = static_cast(r_x[ii]); + vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next)); + } + } + } else { + for (int i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * ILP) { +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + float next = static_cast(x[i]); + vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next)); + } + } + } + } + + float val = 0.f; + for (int i = 0; i < ILP; i++) val = fmaxf(fabsf(val), fabsf(vals[i])); + + float final = reduce_block_into_lanes_max_op(s_vals, val); + + if (threadIdx.x == 0) { + if (!isfinite(final)) + *noop_gmem = + 1; // Blindly fire off a write. These will race but that's ok. + output[blockIdx.x] = fmaxf(fabsf(output[blockIdx.x]), fabsf(final)); + if (per_tensor) + output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) * + max_chunks_per_tensor + + chunk_idx] = final; + } + } +}; + +__global__ void cleanup(float *output, float *output_per_tensor, float *ret, + float *ret_per_tensor, bool per_tensor, + int max_chunks_per_tensor) { + __shared__ float vals[512]; + + if (blockIdx.x == 0) { + float val = 0; + if (threadIdx.x < 320) val = output[threadIdx.x]; + + float final = reduce_block_into_lanes(vals, val); + + if (threadIdx.x == 0) *ret = sqrt(final); + } + + if (per_tensor) { + float *output_this_tensor = + output_per_tensor + blockIdx.x * max_chunks_per_tensor; + + float val = 0; + for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) + val += output_this_tensor[i]; + + float final = reduce_block_into_lanes(vals, val); + + if (threadIdx.x == 0) ret_per_tensor[blockIdx.x] = sqrt(final); + } +} + +__global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret, + float *ret_per_tensor, bool per_tensor, + int max_chunks_per_tensor, int norm_type, + float alpha, float beta) { + __shared__ float vals[512]; + + if (blockIdx.x == 0) { + float val = 0; + if (threadIdx.x < 320) val = output[threadIdx.x]; + + if (norm_type == 0) { + float final = reduce_block_into_lanes_max_op(vals, val); + if (threadIdx.x == 0) *ret = alpha * (*ret) + beta * final; + } else { + float final = reduce_block_into_lanes(vals, val); + if (threadIdx.x == 0) *ret = sqrt(alpha * (*ret) * (*ret) + beta * final); + } + } + + if (per_tensor) { + float *output_this_tensor = + output_per_tensor + blockIdx.x * max_chunks_per_tensor; + + if (norm_type == 0) { + float val = 0; + for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) + val = fmaxf(fabsf(val), fabsf(output_this_tensor[i])); + + float final = reduce_block_into_lanes_max_op(vals, val); + + if (threadIdx.x == 0) + ret_per_tensor[blockIdx.x] = + alpha * ret_per_tensor[blockIdx.x] + beta * final; + } else { + float val = 0; + for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) + val += output_this_tensor[i]; + + float final = reduce_block_into_lanes(vals, val); + + if (threadIdx.x == 0) + ret_per_tensor[blockIdx.x] = sqrt(alpha * ret_per_tensor[blockIdx.x] * + ret_per_tensor[blockIdx.x] + + beta * final); + } + } +} + +std::tuple multi_tensor_l2norm_cuda( + int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::optional per_tensor_python) { + bool per_tensor = + per_tensor_python.has_value() ? per_tensor_python.value() : false; + + auto float_options = tensor_lists[0][0].options().dtype(at::kFloat); + auto output = at::zeros({320}, float_options); + + at::Tensor output_per_tensor; + at::Tensor ret_per_tensor; + + int ntensors = tensor_lists[0].size(); + int max_chunks_per_tensor = -1; + + if (per_tensor) { + for (int t = 0; t < ntensors; t++) { + int max_chunks_this_tensor = + (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; + if (max_chunks_this_tensor > max_chunks_per_tensor) + max_chunks_per_tensor = max_chunks_this_tensor; + } + output_per_tensor = + at::zeros({ntensors * max_chunks_per_tensor}, float_options); + ret_per_tensor = at::empty({ntensors}, float_options); + } else { + ret_per_tensor = at::empty({0}, float_options); + } + + DISPATCH_FLOAT_AND_HALF( + tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda", + multi_tensor_apply<1>( + BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + L2NormFunctor(), output.DATA_PTR(), + per_tensor ? output_per_tensor.DATA_PTR() : nullptr, + per_tensor, max_chunks_per_tensor);) + + AT_CUDA_CHECK(cudaGetLastError()); + // AT_CUDA_CHECK(cudaDeviceSynchronize()); + + // This involves one more small kernel launches, but will be negligible end to + // end. I could get rid of these by hacking the functor + multi tensor harness + // with persistence logic, but keeping it simple for now + auto ret = at::empty({1}, output.options()); + const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); + auto stream = at::cuda::getCurrentCUDAStream(); + cleanup<<>>( + output.DATA_PTR(), + per_tensor ? output_per_tensor.DATA_PTR() : nullptr, + ret.DATA_PTR(), + per_tensor ? ret_per_tensor.DATA_PTR() : nullptr, per_tensor, + max_chunks_per_tensor); + + return std::tuple(ret, ret_per_tensor); +} + +// Compute and update grad norm +// Here use a per tensor norm, and blend new norm(n) and old norm(gn) by +// L-2: gn = sqrt(a * gn^2 + b * n^2) +// L-inf: gn = a * gn + b * n +void multi_tensor_norm_out_cuda( + int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, at::Tensor out, + const float alpha, const float beta, const int norm_type) { + auto float_options = tensor_lists[0][0].options().dtype(at::kFloat); + TORCH_CHECK(tensor_lists[0][0].device() == noop_flag.device(), + "noop flag should be on the same device as tensors"); + // we don't need global thus uses empty here + auto output = at::empty({320}, float_options); + + at::Tensor output_per_tensor; + at::Tensor ret_per_tensor; + + int ntensors = tensor_lists[0].size(); + int max_chunks_per_tensor = -1; + + for (int t = 0; t < ntensors; t++) { + int max_chunks_this_tensor = + (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; + if (max_chunks_this_tensor > max_chunks_per_tensor) + max_chunks_per_tensor = max_chunks_this_tensor; + } + + // Although it is single write then read, still need to be zero + // Since tailing element also participate cleanup + output_per_tensor = + at::zeros({ntensors * max_chunks_per_tensor}, float_options); + + if (norm_type == 0) { + DISPATCH_FLOAT_AND_HALF( + tensor_lists[0][0].scalar_type(), 0, "multi_tensor_maxnorm_cuda", + multi_tensor_apply<1>( + BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + MaxNormFunctor(), output.DATA_PTR(), + output_per_tensor.DATA_PTR(), true, max_chunks_per_tensor);) + } else { + DISPATCH_FLOAT_AND_HALF( + tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda", + multi_tensor_apply<1>( + BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + L2NormFunctor(), output.DATA_PTR(), + output_per_tensor.DATA_PTR(), true, max_chunks_per_tensor);) + } + AT_CUDA_CHECK(cudaGetLastError()); + + // AT_CUDA_CHECK(cudaDeviceSynchronize()); + + // This involves one more small kernel launches, but will be negligible end to + // end. I could get rid of these by hacking the functor + multi tensor harness + // with persistence logic, but keeping it simple for now + auto ret = at::empty({1}, output.options()); + + // Adding the following device guard since it happens sometimes that the + // tensors are on one device and the cuda stream is on another device which + // results in ILLEGAL MEM ACCESS error. + const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); + auto stream = at::cuda::getCurrentCUDAStream(); + cleanup_v2<<>>( + output.DATA_PTR(), output_per_tensor.DATA_PTR(), + ret.DATA_PTR(), out.DATA_PTR(), true, max_chunks_per_tensor, + norm_type, alpha, beta); + + return; +} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu new file mode 100644 index 0000000000000000000000000000000000000000..54c4220190d80d6309e74c90da412d6ccda32c8f --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu @@ -0,0 +1,354 @@ +// modified from +// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_lamb.cu +#include +#include +#include +#include +// Another possibility: +// #include + +#include + +#include "multi_tensor_apply.cuh" +#include "type_shim.h" + +#define BLOCK_SIZE 512 +#define ILP 4 + +template +__device__ __forceinline__ bool is_aligned(T *p) { + return ((uint64_t)p) % (ILP * sizeof(T)) == 0; +} + +template +__device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset, + int src_offset) { + typedef + typename std::aligned_storage::type LT; + ((LT *)dst)[dst_offset] = ((LT *)src)[src_offset]; +} + +typedef enum { + MOMENT_MODE_0 = 0, // L2 regularization mode + MOMENT_MODE_1 = 1 // Decoupled weight decay mode +} adamMode_t; + +std::tuple multi_tensor_l2norm_cuda( + int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::optional per_tensor_python); + +using MATH_T = float; + +template +struct LAMBStage1Functor { + __device__ __forceinline__ void operator()( + int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl, + const float beta1, const float beta2, const float beta3, + const float beta1_correction, const float beta2_correction, + const float epsilon, adamMode_t mode, const float decay, + const float *global_grad_norm, const float max_global_grad_norm) { + // I'd like this kernel to propagate infs/nans. + // if(*noop_gmem == 1) + // return; + + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + float clipped_global_grad_norm = + (*global_grad_norm) > max_global_grad_norm + ? (*global_grad_norm) / max_global_grad_norm + : 1.0f; + + T *g = (T *)tl.addresses[0][tensor_loc]; + g += chunk_idx * chunk_size; + + T *p = (T *)tl.addresses[1][tensor_loc]; + p += chunk_idx * chunk_size; + + T *m = (T *)tl.addresses[2][tensor_loc]; + m += chunk_idx * chunk_size; + + T *v = (T *)tl.addresses[3][tensor_loc]; + v += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + MATH_T r_g[ILP]; + MATH_T r_p[ILP]; + MATH_T r_m[ILP]; + MATH_T r_v[ILP]; + // to make things simple, we put aligned case in a different code path + if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(g) && + is_aligned(p) && is_aligned(m) && is_aligned(v)) { + T l_g[ILP]; + T l_p[ILP]; + T l_m[ILP]; + T l_v[ILP]; + for (int i_start = threadIdx.x; + i_start * ILP < n && i_start * ILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(l_g, g, 0, i_start); + if (decay != 0) load_store(l_p, p, 0, i_start); + load_store(l_m, m, 0, i_start); + load_store(l_v, v, 0, i_start); + // unpack +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + r_g[ii] = l_g[ii]; + if (decay == 0) { + r_p[ii] = MATH_T(0); + } else { + r_p[ii] = l_p[ii]; + } + r_m[ii] = l_m[ii]; + r_v[ii] = l_v[ii]; + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + if (mode == MOMENT_MODE_0) { + MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm; + // L2 on scaled grad + scaled_grad = scaled_grad + decay * r_p[ii]; + r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; + r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + r_p[ii] = next_m_unbiased / denom; + } else { + MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm; + r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; + r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + r_p[ii] = (next_m_unbiased / denom) + (decay * r_p[ii]); + } + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + l_p[ii] = r_p[ii]; + l_m[ii] = r_m[ii]; + l_v[ii] = r_v[ii]; + } + // store + load_store(g, l_p, i_start, 0); + load_store(m, l_m, i_start, 0); + load_store(v, l_v, i_start, 0); + } + } else { + // see note in multi_tensor_scale_kernel.cu + for (int i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * ILP) { + MATH_T r_g[ILP]; + MATH_T r_p[ILP]; + MATH_T r_m[ILP]; + MATH_T r_v[ILP]; +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + r_g[ii] = g[i]; + // special ?optimization? for lamb stage 1 + if (decay == 0) { + r_p[ii] = MATH_T(0); + } else { + r_p[ii] = p[i]; + } + r_m[ii] = m[i]; + r_v[ii] = v[i]; + } else { + r_g[ii] = MATH_T(0); + r_p[ii] = MATH_T(0); + r_m[ii] = MATH_T(0); + r_v[ii] = MATH_T(0); + } + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + if (mode == MOMENT_MODE_0) { + MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm; + // L2 on scaled grad + scaled_grad = scaled_grad + decay * r_p[ii]; + r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; + r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + r_p[ii] = next_m_unbiased / denom; + } else { + MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm; + r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; + r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + r_p[ii] = (next_m_unbiased / denom) + (decay * r_p[ii]); + } + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + g[i] = r_p[ii]; + m[i] = r_m[ii]; + v[i] = r_v[ii]; + } + } + } + } + } +}; + +// Step 2 reads in 'update' value and per-tensor param_norm and update_norm. +// It computes new parameter value. +template +struct LAMBStage2Functor { + __device__ __forceinline__ void operator()( + int chunk_size, volatile int *noop_gmem, TensorListMetadata<2> &tl, + const float *per_tensor_param_norm, const float *per_tensor_update_norm, + const float learning_rate, const float decay, bool use_nvlamb) { + // I'd like this kernel to propagate infs/nans. + // if(*noop_gmem == 1) + // return; + + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int tensor_num = tl.start_tensor_this_launch + tensor_loc; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + MATH_T ratio = learning_rate; + // nvlamb: apply adaptive learning rate to all parameters + // otherwise, only apply to those with non-zero weight decay + if (use_nvlamb || (decay != 0.0)) { + float param_norm = per_tensor_param_norm[tensor_num]; + float update_norm = per_tensor_update_norm[tensor_num]; + ratio = (update_norm != 0.0f && param_norm != 0.0f) + ? learning_rate * (param_norm / update_norm) + : learning_rate; + } + + T *update = (T *)tl.addresses[0][tensor_loc]; + update += chunk_idx * chunk_size; + + T *p = (T *)tl.addresses[1][tensor_loc]; + p += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + // to make things simple, we put aligned case in a different code path + if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(p) && + is_aligned(update)) { + T r_p[ILP]; + T r_update[ILP]; + for (int i_start = threadIdx.x; + i_start * ILP < n && i_start * ILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_p, p, 0, i_start); + load_store(r_update, update, 0, i_start); +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + r_p[ii] = static_cast(r_p[ii]) - + (ratio * static_cast(r_update[ii])); + } + load_store(p, r_p, i_start, 0); + } + } else { + for (int i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * ILP) { + MATH_T r_p[ILP]; + MATH_T r_update[ILP]; +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + r_p[ii] = p[i]; + r_update[ii] = update[i]; + } + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + r_p[ii] = r_p[ii] - (ratio * r_update[ii]); + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + p[i] = r_p[ii]; + } + } + } + } + } +}; + +void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + const float lr, const float beta1, + const float beta2, const float epsilon, + const int step, const int bias_correction, + const float weight_decay, const int grad_averaging, + const int mode, at::Tensor global_grad_norm, + const float max_grad_norm, + at::optional use_nvlamb_python) { + using namespace at; + // Master weight and 32bit momentum(potentially changing) is not handled by + // this So we assume every tensor are all in the same type + + bool use_nvlamb = + use_nvlamb_python.has_value() ? use_nvlamb_python.value() : false; + + // Handle bias correction mode + float bias_correction1 = 1.0f, bias_correction2 = 1.0f; + if (bias_correction == 1) { + bias_correction1 = 1 - std::pow(beta1, step); + bias_correction2 = 1 - std::pow(beta2, step); + } + + // Handle grad averaging mode + float beta3 = 1.0f; + if (grad_averaging == 1) beta3 = 1 - beta1; + + std::vector> grad_list(tensor_lists.begin(), + tensor_lists.begin() + 1); + std::vector> param_list(tensor_lists.begin() + 1, + tensor_lists.begin() + 2); + + // Compute per tensor param norm + auto param_norm_tuple = + multi_tensor_l2norm_cuda(chunk_size, noop_flag, param_list, true); + + // We now in-place modify grad to store update before compute its norm + // Generally this is not a issue since people modify grad in step() method all + // the time We can also grab list of empty tensor to avoid this, but I'd like + // to save space/cpu code + DISPATCH_FLOAT_AND_HALF( + tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1", + multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + LAMBStage1Functor(), beta1, beta2, + beta3, // 1-beta1 or 1 depends on averaging mode + bias_correction1, bias_correction2, epsilon, + (adamMode_t)mode, weight_decay, + global_grad_norm.DATA_PTR(), max_grad_norm);) + + // Compute update norms + auto update_norm_tuple = + multi_tensor_l2norm_cuda(chunk_size, noop_flag, grad_list, true); + + std::vector> grad_param_list( + tensor_lists.begin(), tensor_lists.begin() + 2); + + DISPATCH_FLOAT_AND_HALF( + tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2", + multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, grad_param_list, + LAMBStage2Functor(), + std::get<1>(param_norm_tuple).DATA_PTR(), + std::get<1>(update_norm_tuple).DATA_PTR(), + lr, weight_decay, use_nvlamb);) + + AT_CUDA_CHECK(cudaGetLastError()); +} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..360485dcd02fbfc21a76a2bfa6dd6568b8909499 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu @@ -0,0 +1,125 @@ +#include +#include +#include +#include +// Another possibility: +// #include + +#include +// Stringstream is a big hammer, but I want to rely on operator<< for dtype. +#include + +#include "multi_tensor_apply.cuh" +#include "type_shim.h" + +#define BLOCK_SIZE 512 +#define ILP 4 + +template +__device__ __forceinline__ bool is_aligned(T *p) { + return ((uint64_t)p) % (ILP * sizeof(T)) == 0; +} + +template +__device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset, + int src_offset) { + typedef + typename std::aligned_storage::type LT; + ((LT *)dst)[dst_offset] = ((LT *)src)[src_offset]; +} + +template +struct ScaleFunctor { + __device__ __forceinline__ void operator()(int chunk_size, + volatile int *noop_gmem, + TensorListMetadata<2> &tl, + float scale) { + // I'd like this kernel to propagate infs/nans. + // if(*noop_gmem == 1) + // return; + + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + in_t *in = (in_t *)tl.addresses[0][tensor_loc]; + in += chunk_idx * chunk_size; + + out_t *out = (out_t *)tl.addresses[1][tensor_loc]; + out += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + bool finite = true; + in_t r_in[ILP]; + out_t r_out[ILP]; + + // to make things simple, we put aligned case in a different code path + if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(in) && + is_aligned(out)) { + for (int i_start = threadIdx.x; + i_start * ILP < n && i_start * ILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_in, in, 0, i_start); +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + r_out[ii] = static_cast(r_in[ii]) * scale; + finite = finite && isfinite(r_in[ii]); + } + // store + load_store(out, r_out, i_start, 0); + } + } else { + // Non-divergent exit condition for __syncthreads, not necessary here + for (int i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * ILP) { +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + r_in[ii] = 0; + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) r_in[ii] = in[i]; + } + // note for clarification to future michael: + // From a pure memory dependency perspective, there's likely no point + // unrolling the write loop, since writes just fire off once their LDGs + // arrive. Put another way, the STGs are dependent on the LDGs, but not + // on each other. There is still compute ILP benefit from unrolling the + // loop though. +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + r_out[ii] = static_cast(r_in[ii]) * scale; + finite = finite && isfinite(r_in[ii]); + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) out[i] = r_out[ii]; + } + } + } + if (!finite) + *noop_gmem = + 1; // Blindly fire off a write. These will race but that's ok. + } +}; + +void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + float scale) { + using namespace at; + // The output (downscaled) type is always float. + // If build times suffer, think about where to put this dispatch, + // and what logic should be moved out of multi_tensor_apply. + + DISPATCH_FLOAT_AND_HALF( + tensor_lists[0][0].scalar_type(), 0, "multi_tensor_scale_cuda", + DISPATCH_FLOAT_AND_HALF( + tensor_lists[1][0].scalar_type(), 1, "multi_tensor_scale_cuda", + multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + ScaleFunctor(), + scale);)) + AT_CUDA_CHECK(cudaGetLastError()); + + // AT_CUDA_CHECK(cudaDeviceSynchronize()); +} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..35f2c9b4ed15eab94b1456ce436694180d706a45 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu @@ -0,0 +1,167 @@ +// modified from +// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_sgd_kernel.cu +#include +#include +#include +#include +#include +#include + +#include "compat.h" +#include "multi_tensor_apply.cuh" + +#define BLOCK_SIZE 512 +#define ILP 4 + +/** + * Perform fused SGD on multiple buffers + * N: number of tensors + * tl[0] : gradients + * tl[1] : weights + * tl[2] : momentum buffers + * tl[3] : fp16 weights (if appropriate) + * wd : weight_decay (scalar) + * momentum : momentum (scalar) + * dampening : momentum dampening (scalar) + * lr : learning rate (scalar) + * nesterov : enable nesterov (bool) + * first run : necessary for proper momentum handling & init + * wd_after_momentum : apply weight decay _after_ momentum instead of before + **/ +template +struct SGDFunctor { + __device__ __forceinline__ void operator()( + int chunk_size, volatile int *noop_gmem, TensorListMetadata<3> &tl, + float wd, float momentum, float dampening, float lr, bool nesterov, + bool first_run, bool wd_after_momentum, float scale) { + // Early exit if we don't need to do anything + if (*noop_gmem) return; + + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + T_grad *grad_in = (T_grad *)tl.addresses[0][tensor_loc]; + grad_in += chunk_idx * chunk_size; + + T_weight *weight_in = (T_weight *)tl.addresses[1][tensor_loc]; + weight_in += chunk_idx * chunk_size; + + T_weight *mom_in = (T_weight *)tl.addresses[2][tensor_loc]; + mom_in += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + // Non-divergent exit condition for the __syncthreads + float incoming_grads[ILP]; + float incoming_weights[ILP]; + float incoming_moms[ILP]; + for (int i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * ILP) { +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + incoming_grads[ii] = 0; + incoming_weights[ii] = 0; + incoming_moms[ii] = 0; + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + incoming_grads[ii] = static_cast(grad_in[i]) * scale; + incoming_weights[ii] = static_cast(weight_in[i]); + incoming_moms[ii] = static_cast(mom_in[i]); + } + } + +// note for clarification to future michael: +// From a pure memory dependency perspective, there's likely no point unrolling +// the write loop, since writes just fire off once their LDGs arrive. +// Put another way, the STGs are dependent on the LDGs, but not on each other. +// There is still compute ILP benefit from unrolling the loop though. +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + // apply weight decay before momentum if necessary + if (wd != 0.f && !wd_after_momentum) + incoming_grads[ii] += wd * incoming_weights[ii]; + + if (momentum != 0.f) { + if (!first_run) + incoming_moms[ii] = incoming_moms[ii] * momentum + + (1.f - dampening) * incoming_grads[ii]; + else // initialize momentums to current incoming grads + incoming_moms[ii] = incoming_grads[ii]; + + if (nesterov) + incoming_grads[ii] += momentum * incoming_moms[ii]; + else + incoming_grads[ii] = incoming_moms[ii]; + } + + // Apply WD after momentum if desired + if (wd != 0.f && wd_after_momentum) + incoming_grads[ii] += wd * incoming_weights[ii]; + + // adjust the weight and write out + weight_in[i] += (-lr * incoming_grads[ii]); + + // also write out the new momentum + if (momentum != 0.f) mom_in[i] = incoming_moms[ii]; + } + } + } + } +}; + +void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + float wd, float momentum, float dampening, float lr, + bool nesterov, bool first_run, + bool wd_after_momentum, float scale) { + auto num_tensors = tensor_lists.size(); + auto grad_type = tensor_lists[0][0].scalar_type(); + auto weight_type = tensor_lists[1][0].scalar_type(); + + TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(), + "expected noop flag to be on the same device as tensors"); + + // We have 3 possibilities to handle here, in terms of + // grad_type, param_type, momentum_type + // 1. fp16, fp16, fp16 + // 2. fp32, fp32, fp32 + // 3. fp16, fp32, fp32 + // It's easier to hardcode these possibilities than to use + // switches etc. to handle the cross-product of cases where + // we don't want the majority of them. + + // Case 1. fp16, fp16, fp16, No + if (grad_type == at::ScalarType::Half && + weight_type == at::ScalarType::Half && num_tensors == 3) { + multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + SGDFunctor(), wd, momentum, + dampening, lr, nesterov, first_run, wd_after_momentum, + scale); + } + // Case 2. fp32, fp32, fp32 + else if (grad_type == at::ScalarType::Float && + weight_type == at::ScalarType::Float && num_tensors == 3) { + multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + SGDFunctor(), wd, momentum, dampening, + lr, nesterov, first_run, wd_after_momentum, scale); + } + // Case 3. fp16, fp32, fp32 + else if (grad_type == at::ScalarType::Half && + weight_type == at::ScalarType::Float && num_tensors == 3) { + multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + SGDFunctor(), wd, momentum, + dampening, lr, nesterov, first_run, wd_after_momentum, + scale); + } else { + AT_ERROR( + "multi_tensor_sgd only supports some combinations of gradient & weight " + "types. Given: ", + "gradient: ", grad_type, ", weight: ", weight_type, + ", num_lists: ", num_tensors); + } + + AT_CUDA_CHECK(cudaGetLastError()); +} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp b/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d08f3dbc74d8292163c1a81c421cd89ec252de63 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp @@ -0,0 +1,406 @@ +#include "multihead_attention_1d.h" + +#include +#include +#include + +#if TORCH_VERSION_MAJOR > 1 || \ + (TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13) +#include +#else +#include +#endif +#include + +#include "context.h" +#include "kernels.h" + +template +MultiHeadAttention::MultiHeadAttention(int layer_id, int max_batch_tokens, + int max_seq_len, int hidden_size, + int num_heads, + float attn_prob_dropout_ratio, + float hidden_output_dropout_ratio, + bool pre_or_postLayerNorm) + : _layer_id(layer_id), + _max_batch_tokens(max_batch_tokens), + _max_seq_len(max_seq_len), + _hidden_size(hidden_size), + _heads(num_heads), + _training(true), + _pre_or_postLayerNorm(pre_or_postLayerNorm), + _qkv_linear( + typename FeedForward::Config(3 * hidden_size, hidden_size)), + _attn_out_linear( + typename FeedForward::Config(hidden_size, hidden_size)), + _attn_ln(typename Normalize_Layer::Config(hidden_size, false), + _max_batch_tokens), + _softmax(typename Softmax::Config(num_heads)), + _attn_prob_dropout(typename Dropout::Config(attn_prob_dropout_ratio), + _max_batch_tokens * _heads * _max_seq_len), + _attn_dropout(typename Dropout::Config(hidden_output_dropout_ratio), + _max_batch_tokens * _hidden_size), + _attn_scores(typename StridedBatchGemm::Config( + (T(1.0) / T(sqrt(_hidden_size / _heads))), T(0.0), CUBLAS_OP_T, + CUBLAS_OP_N)), + _attn_context(typename StridedBatchGemm::Config( + T(1.0), T(0.0), CUBLAS_OP_N, CUBLAS_OP_N)) { + assert(_hidden_size % _heads == 0); +} + +template +MultiHeadAttention::~MultiHeadAttention() { + free_mem_buffer(); +} + +template +void MultiHeadAttention::attn_layer_fw(const T *input_ptr, + const T *input_mask_ptr, + T *output_ptr, T *buffer) { + T *q_tf_ptr = _qkv_ptr; + T *k_tf_ptr = q_tf_ptr + _batch_dim / pg_size; + T *v_tf_ptr = k_tf_ptr + _batch_dim / pg_size; + + if (_pre_or_postLayerNorm) { + _attn_ln.Forward(_gemmQKV_inp_ptr, input_ptr, _attn_nw_ptr, _attn_nb_ptr, + _batch_tokens, _stream); + } + const T *gemmQKV_inp_ptr = + _pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr; + _qkv_linear.reset_size(3 * _hidden_size / pg_size, _hidden_size); + _qkv_linear.Forward(_batch_tokens, gemmQKV_inp_ptr, _attn_qkvw_ptr, buffer, + _cublasHandle); + + launch_bias_add_transform_20314(q_tf_ptr, buffer, _attn_qkvb_ptr, + _batch_size, _seq_len, 3, _heads / pg_size, + _hidden_size / _heads, _stream); + + // attention scores, q*k + _attn_scores.Forward(_batch_heads, _soft_out_ptr, k_tf_ptr, q_tf_ptr, + _cublasHandle); + + // Softmax + Mask + _softmax.reset_size(_heads / pg_size); + _softmax.Forward(_soft_out_ptr, input_mask_ptr, _batch_size, _seq_len, + _seq_len, _stream, true); + + // attn prob dropout. + _attn_prob_dropout.dropout(_ctx_bufB_ptr, _soft_out_ptr, + _batch_heads * _seq_len * _seq_len, _stream); + + // attention context, score * v + _attn_context.Forward(_batch_heads, buffer, v_tf_ptr, _ctx_bufB_ptr, + _cublasHandle); + + // [b, nh, s, ad] -> [b, s, nh, ad] + launch_transform4d_0213(_attn_o_inp_ptr, buffer, _batch_size, _seq_len, + _hidden_size / pg_size, _heads / pg_size, 1, + _stream); + + _attn_out_linear.reset_size(_hidden_size, _hidden_size / pg_size); + _attn_out_linear.Forward(_batch_tokens, _attn_o_inp_ptr, _attn_ow_ptr, + output_ptr, _cublasHandle); + + // allreduce + if (pg == c10::detail::UniqueVoidPtr() || pg->getSize() == 1) { + } else { + auto data_type = torch::kFloat; + if (typeid(T) != typeid(float)) { + data_type = torch::kHalf; + } + auto output_tensor = torch::from_blob( + output_ptr, {int(_batch_size), int(_seq_len), int(_hidden_size)}, + torch::TensorOptions(torch::kCUDA).dtype(data_type)); + std::vector allreduce_tensors = {output_tensor}; + auto work = pg->allreduce(allreduce_tensors, c10d::AllreduceOptions()); + work->wait(); + } + + _attn_dropout.bias_dropout_residual(output_ptr, output_ptr, input_ptr, + _attn_ob_ptr, _batch_tokens, _hidden_size, + _stream); + if (!_pre_or_postLayerNorm) { + // in-place ln since ln-input will not be used in post-ln mode + _attn_ln.Forward(output_ptr, output_ptr, _attn_nw_ptr, _attn_nb_ptr, + _batch_tokens, _stream); + } +} + +template +void MultiHeadAttention::Forward(const T *input_ptr, const T *input_mask_ptr, + T *out_ptr) { + _stream = Context::Instance().get_stream(); + _cublasHandle = Context::Instance().get_cublashandle(); + T *attn_buffer = _shared_mem_ptr; // 3 * _batch_dim + + attn_layer_fw(input_ptr, input_mask_ptr, out_ptr, attn_buffer); +} + +template +void MultiHeadAttention::attn_layer_bw(const T *input_ptr, + const T *input_mask_ptr, + const T *output_ptr, + const T *grad_output_ptr, + T *grad_input_ptr, T *buffer) { + cudaStream_t streams[2] = {_stream, _stream}; + + const T *q_tf_ptr = _qkv_ptr; + const T *k_tf_ptr = q_tf_ptr + _batch_dim / pg_size; + const T *v_tf_ptr = k_tf_ptr + _batch_dim / pg_size; + // batch_dim = batch_size * seq_len * hidden_size + // buffer size: batch_dim * 3 + max(batch_dim * 3, + // batch_size * head_num * seq_len * seq_len) + T *grad_residual_ptr = buffer; + buffer += _batch_dim; + + T *grad_input_buf_ptr = buffer; // batch_dim + T *grad_qkv_5d_ptr = buffer; // batch_dim * 3 + buffer += 3 * _batch_dim / pg_size; + + T *grad_qkv_4d_ptr = buffer; // batch_dim * 3 + T *grad_softmax_ptr = buffer; // batch_size * head_num * seq_len * seq_len + // buffer += max(3 * _batch_dim, + // batch_size * head_num * seq_len * seq_len); + + if (_pre_or_postLayerNorm) { + _attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr, + grad_output_ptr, _batch_tokens, + _hidden_size, _stream); + } else { + _attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_residual_ptr, + grad_output_ptr, nullptr, output_ptr, _attn_nw_ptr, + _attn_nb_ptr, _batch_tokens, streams); + _attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr, + grad_residual_ptr, _batch_tokens, + _hidden_size, _stream); + } + + // bw of output project + _attn_out_linear.reset_size(_hidden_size, _hidden_size / pg_size); + _attn_out_linear.Backward(_batch_tokens, grad_input_ptr, _attn_o_inp_ptr, + _attn_ow_ptr, _grad_attn_ow_ptr, _grad_attn_ob_ptr, + _cublasHandle, _stream, grad_input_buf_ptr, nullptr, + false); + launch_transform_0213(grad_input_ptr, grad_input_buf_ptr, _batch_size, + _seq_len, _hidden_size / pg_size, _heads / pg_size, + _stream); + + // bw of score * v + _attn_context.Backward( + _batch_heads, grad_input_ptr, v_tf_ptr, _ctx_bufB_ptr, _cublasHandle, + grad_qkv_5d_ptr + 2 * _batch_dim / pg_size, grad_softmax_ptr); + + _attn_prob_dropout.d_dropout(grad_softmax_ptr, + _batch_heads * _seq_len * _seq_len, _stream); + + _softmax.reset_size(_heads / pg_size); + _softmax.Backward(grad_softmax_ptr, _soft_out_ptr, _batch_size, _seq_len, + _seq_len, _stream); + + // bw of q * k + _attn_scores.Backward(_batch_heads, grad_softmax_ptr, k_tf_ptr, q_tf_ptr, + _cublasHandle, grad_qkv_5d_ptr + _batch_dim / pg_size, + grad_qkv_5d_ptr); + + // [3, b, nh, s, ad] -> [b, s, 3, h] + launch_transform4d_0213(grad_qkv_4d_ptr, grad_qkv_5d_ptr, _batch_size, + _seq_len, _hidden_size / pg_size, _heads / pg_size, + 3, _stream); + + const T *gemmQKV_inp_ptr = + _pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr; + _qkv_linear.reset_size(3 * _hidden_size / pg_size, _hidden_size); + _qkv_linear.Backward(_batch_tokens, grad_qkv_4d_ptr, gemmQKV_inp_ptr, + _attn_qkvw_ptr, _grad_attn_qkvw_ptr, _grad_attn_qkvb_ptr, + _cublasHandle, _stream, grad_input_buf_ptr, nullptr, + true); + + // allreduce + if (pg == c10::detail::UniqueVoidPtr() || pg->getSize() == 1) { + } else { + auto data_type = torch::kFloat; + if (typeid(T) != typeid(float)) { + data_type = torch::kHalf; + } + auto grad_input_tensor = + torch::from_blob(grad_input_buf_ptr, + {int(_batch_size), int(_seq_len), int(_hidden_size)}, + torch::TensorOptions(torch::kCUDA).dtype(data_type)); + std::vector allreduce_tensors = {grad_input_tensor}; + auto work = pg->allreduce(allreduce_tensors, c10d::AllreduceOptions()); + work->wait(); + } + + if (_pre_or_postLayerNorm) { + _attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_input_ptr, + grad_input_buf_ptr, grad_output_ptr, gemmQKV_inp_ptr, + _attn_nw_ptr, _attn_nb_ptr, _batch_tokens, streams); + } else { + // FIXME later + launch_fused_add2(grad_input_ptr, grad_input_buf_ptr, grad_residual_ptr, + _batch_size, _seq_len, _hidden_size, _stream); + } +} + +template +void MultiHeadAttention::Backward(const T *grad_output_ptr, + const T *input_ptr, const T *output_ptr, + const T *input_mask_ptr, + T *grad_input_ptr) { + _stream = Context::Instance().get_stream(); + _cublasHandle = Context::Instance().get_cublashandle(); + T *buffer = _shared_mem_ptr; + + /* + buffer size needed by attn bw: + 4 * _batch_dim + max(3 * _batch_dim, + _batch_size * _head_num * _seq_len * _seq_len); + */ + attn_layer_bw(input_ptr, input_mask_ptr, output_ptr, grad_output_ptr, + grad_input_ptr, buffer); +} + +template +void MultiHeadAttention::SetTrainingMode(bool training) { + // Dropout will be skipped when not in training model. + _attn_prob_dropout.SetTrainingMode(training); + _attn_dropout.SetTrainingMode(training); +} + +template +T *MultiHeadAttention::_shared_mem_ptr = nullptr; + +template class MultiHeadAttention; +template class MultiHeadAttention<__half>; + +// x is torch::Tensor +#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) \ + AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +static std::unordered_map> s_multihead_attention; + +template +int create_multihead_attention(int layer_id, int max_batch_tokens, + int max_seq_len, int hidden_dim, int num_heads, + float attn_prob_dropout_ratio, + float hidden_dropout_ratio, + bool pre_or_postLayerNorm, + c10::intrusive_ptr pg_) { + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + Context::Instance().set_stream(stream); + auto layer = std::make_shared>( + layer_id, max_batch_tokens, max_seq_len, hidden_dim, num_heads, + attn_prob_dropout_ratio, hidden_dropout_ratio, pre_or_postLayerNorm); + + layer->SetPG(pg_); + + s_multihead_attention[layer_id] = layer; + + std::string dtype = (std::is_same::value) ? "half" : "float"; + + return 0; +} + +template +std::vector multihead_attention_fw( + int layer_id, const torch::Tensor &input, const torch::Tensor &input_mask, + const torch::Tensor &in_proj_weight, const torch::Tensor &in_proj_bias, + const torch::Tensor &out_proj_weight, const torch::Tensor &out_proj_bias, + const torch::Tensor &norm_weight, const torch::Tensor &norm_bias, + bool training_mode, bool prelayernorm) { + CHECK_INPUT(input); + CHECK_INPUT(input_mask); + + const T *input_ptr = (const T *)input.data_ptr(); + const T *input_mask_ptr = (const T *)input_mask.data_ptr(); + + auto output = torch::empty_like(input); + T *out_ptr = (T *)output.data_ptr(); + + std::shared_ptr> layer = + std::static_pointer_cast>( + s_multihead_attention[layer_id]); + layer->set_cur_batch_shape(input.size(0), input.size(1)); + layer->SetTrainingMode(training_mode); + + layer->_attn_qkvw_ptr = (const T *)in_proj_weight.data_ptr(); + layer->_attn_qkvb_ptr = (const T *)in_proj_bias.data_ptr(); + layer->_attn_ow_ptr = (const T *)out_proj_weight.data_ptr(); + layer->_attn_ob_ptr = (const T *)out_proj_bias.data_ptr(); + layer->_attn_nw_ptr = (const T *)norm_weight.data_ptr(); + layer->_attn_nb_ptr = (const T *)norm_bias.data_ptr(); + + layer->Forward(input_ptr, input_mask_ptr, out_ptr); + + return {output}; +} + +template +std::vector multihead_attention_bw( + int layer_id, const torch::Tensor &grad_dec_output, + const torch::Tensor &output, const torch::Tensor &input, + const torch::Tensor &input_mask, const torch::Tensor &in_proj_weight, + const torch::Tensor &in_proj_bias, const torch::Tensor &out_proj_weight, + const torch::Tensor &out_proj_bias, const torch::Tensor &norm_weight, + const torch::Tensor &norm_bias) { + auto g_output = grad_dec_output.contiguous(); + CHECK_INPUT(g_output); + CHECK_INPUT(output); + CHECK_INPUT(input); + CHECK_INPUT(input_mask); + + auto grad_input = torch::empty_like(input); + auto grad_in_proj_weight = torch::empty_like(in_proj_weight); + auto grad_in_proj_bias = torch::empty_like(in_proj_bias); + auto grad_out_proj_weight = torch::empty_like(out_proj_weight); + auto grad_out_proj_bias = torch::empty_like(out_proj_bias); + auto grad_norm_weight = torch::empty_like(norm_weight); + auto grad_norm_bias = torch::empty_like(norm_bias); + + // inputs. + const T *grad_dec_output_ptr = (const T *)g_output.data_ptr(); + const T *input_ptr = (const T *)input.data_ptr(); + const T *output_ptr = (const T *)output.data_ptr(); + const T *input_mask_ptr = (const T *)input_mask.data_ptr(); + + // outputs. + T *grad_input_ptr = (T *)grad_input.data_ptr(); + + std::shared_ptr> layer = + std::static_pointer_cast>( + s_multihead_attention[layer_id]); + layer->set_cur_batch_shape(g_output.size(0), g_output.size(1)); + + layer->_grad_attn_qkvw_ptr = (T *)grad_in_proj_weight.data_ptr(); + layer->_grad_attn_qkvb_ptr = (T *)grad_in_proj_bias.data_ptr(); + layer->_grad_attn_ow_ptr = (T *)grad_out_proj_weight.data_ptr(); + layer->_grad_attn_ob_ptr = (T *)grad_out_proj_bias.data_ptr(); + layer->_grad_attn_nw_ptr = (T *)grad_norm_weight.data_ptr(); + layer->_grad_attn_nb_ptr = (T *)grad_norm_bias.data_ptr(); + + layer->Backward(grad_dec_output_ptr, input_ptr, output_ptr, input_mask_ptr, + grad_input_ptr); + + return {grad_input, grad_in_proj_weight, grad_in_proj_bias, + grad_out_proj_weight, grad_out_proj_bias, grad_norm_weight, + grad_norm_bias}; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("multihead_attention_fw_fp32", &multihead_attention_fw, + "Multi-head Attention forward with fp32 (CUDA)"); + m.def("multihead_attention_fw_fp16", &multihead_attention_fw<__half>, + "Multi-head Attention forward with fp16 (CUDA)"); + m.def("multihead_attention_bw_fp32", &multihead_attention_bw, + "Multi-head Attention backward with fp32 (CUDA)"); + m.def("multihead_attention_bw_fp16", &multihead_attention_bw<__half>, + "Multi-head Attention backward with fp16 (CUDA)"); + m.def("create_multihead_attention_fp32", &create_multihead_attention, + "Create Multi-head Attention with fp32 (CUDA)"); + m.def("create_multihead_attention_fp16", &create_multihead_attention<__half>, + "Create Multi-head Attention with fp16 (CUDA)"); +} diff --git a/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h b/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h new file mode 100644 index 0000000000000000000000000000000000000000..6505eb31fb9fa31b5e04c7c1473c9bd9baae49b6 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h @@ -0,0 +1,167 @@ +#pragma once + +#include +#include +#include +#include +#include + +#if TORCH_VERSION_MAJOR > 1 || \ + (TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13) +#include +#else +#include +#endif + +#include +#include + +#include "cuda_util.h" +#include "dropout.h" +#include "feed_forward.h" +#include "normalize_layer.h" +#include "softmax.h" +#include "strided_batch_gemm.h" + +template +class MultiHeadAttention { + public: + MultiHeadAttention(int layer_id, int max_batch_tokens, int _max_seq_len, + int hidden_size, int num_heads, float attn_dropout_ratio, + float hidden_output_dropout_ratio, + bool pre_or_postLayerNorm); + + virtual ~MultiHeadAttention(); + + void Forward(const T *input_ptr, const T *input_mask_ptr, T *out_ptr); + + void Backward(const T *grad_output_ptr, const T *input_ptr, + const T *output_ptr, const T *input_mask_ptr, + T *grad_input_ptr); + + void attn_layer_fw(const T *input_ptr, const T *input_mask_ptr, T *output_ptr, + T *buffer); + + void attn_layer_bw(const T *input_ptr, const T *input_mask_ptr, + const T *output_ptr, const T *grad_output_ptr, + T *grad_input_attn_layer_bwptr, T *buffer); + + void set_cur_batch_shape(int batch_size, int seq_len) { + _batch_size = batch_size; + _seq_len = seq_len; + _batch_tokens = batch_size * seq_len; + _batch_heads = batch_size * _heads / pg_size; + _batch_dim = _batch_tokens * _hidden_size; + _attn_scores.SetConfig(_seq_len, _seq_len, _hidden_size / _heads); + _attn_context.SetConfig(_hidden_size / _heads, _seq_len, _seq_len); + } + + void SetTrainingMode(bool training); + inline bool IsTrainingMode() const { return _training; } + + void SetPG(c10::intrusive_ptr pg_) { + pg = pg_; + pg_size = 1; + if (pg != c10::detail::UniqueVoidPtr()) { + pg_size = pg->getSize(); + } + allocate_mem_buffer(); + } + + // weights ptr + const T *_attn_qkvw_ptr; + const T *_attn_qkvb_ptr; + const T *_attn_ow_ptr; + const T *_attn_ob_ptr; + const T *_attn_nw_ptr; + const T *_attn_nb_ptr; + + // grads ptr + T *_grad_attn_qkvw_ptr; + T *_grad_attn_qkvb_ptr; + T *_grad_attn_ow_ptr; + T *_grad_attn_ob_ptr; + T *_grad_attn_nw_ptr; + T *_grad_attn_nb_ptr; + + private: + void allocate_mem_buffer() { + // allocate local gpu memory + if (_pre_or_postLayerNorm) { + _gemmQKV_inp_ptr = cuda_malloc(_max_batch_tokens * _hidden_size); + } else { + _gemmQKV_inp_ptr = nullptr; + } + + _qkv_ptr = cuda_malloc(_max_batch_tokens * _hidden_size * 3); + _soft_out_ptr = + cuda_malloc(_max_batch_tokens * _heads / pg_size * _max_seq_len); + _ctx_bufB_ptr = + cuda_malloc(_max_batch_tokens * _heads / pg_size * _max_seq_len); + _attn_o_inp_ptr = cuda_malloc(_max_batch_tokens * _hidden_size); + + // buffer size needed by attn bw + size_t smem_size = + 4 * _max_batch_tokens * _hidden_size / pg_size + + std::max(3 * _max_batch_tokens * _hidden_size / pg_size, + _max_batch_tokens * _heads / pg_size * _max_seq_len); + + if (!_shared_mem_ptr) { + cuda_free(_shared_mem_ptr); + _shared_mem_ptr = cuda_malloc(smem_size); + } + } + + void free_mem_buffer() { + // free local gpu memory + cuda_free(_gemmQKV_inp_ptr); + cuda_free(_qkv_ptr); + cuda_free(_soft_out_ptr); + cuda_free(_ctx_bufB_ptr); + cuda_free(_attn_o_inp_ptr); + + // free shared gpu memory between layers + cuda_free(_shared_mem_ptr); + _shared_mem_ptr = nullptr; + } + + // const parameter between batch + const size_t _layer_id; + const size_t _hidden_size; + const size_t _heads; + const size_t _max_batch_tokens; + const size_t _max_seq_len; + const bool _pre_or_postLayerNorm; + // dynamic parameter between batch + size_t _batch_size; + size_t _seq_len; + size_t _batch_tokens; + size_t _batch_heads; + size_t _batch_dim; + bool _training; + + cublasHandle_t _cublasHandle; + cudaStream_t _stream; + + // layers + FeedForward _qkv_linear; + FeedForward _attn_out_linear; + Normalize_Layer _attn_ln; + Softmax _softmax; + Dropout _attn_prob_dropout; + Dropout _attn_dropout; + StridedBatchGemm _attn_scores; + StridedBatchGemm _attn_context; + + // local GPU memory + T *_gemmQKV_inp_ptr; + T *_qkv_ptr; + T *_soft_out_ptr; + T *_ctx_bufB_ptr; + T *_attn_o_inp_ptr; + // shared GPU memory between layer + static T *_shared_mem_ptr; + + c10::intrusive_ptr pg; + int pg_size; +}; diff --git a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4ae3c853ca5e844272ca4fdb907c8c95a7f2b787 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp @@ -0,0 +1,84 @@ +/*This code from NVIDIA Megatron: + * with minor changes. */ + +#include +#include +#include + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_masked_softmax { + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + torch::Tensor const& mask, + float scale_factor); + +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor); + +int get_batch_per_block_cuda( + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads); + +torch::Tensor fwd( + torch::Tensor const& input, + torch::Tensor const& mask, + float scale_factor) { + AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); + AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); + + return fwd_cuda(input, mask, scale_factor); +} + +torch::Tensor bwd( + torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor) { + + AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); + AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); + + AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + return bwd_cuda(output_grads, softmax_results, scale_factor); +} + +int get_batch_per_block( + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads) { + return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads); +} + +} // end namespace scaled_masked_softmax +} // end namespace fused_softmax +} // end namespace multihead_attn + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", + &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, + "Self Multihead Attention scaled, time masked softmax -- Forward."); + + m.def("backward", + &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, + "Self Multihead Attention scaled, time masked softmax -- Backward."); + + m.def("get_batch_per_block", + &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block, + "Return Batch per block size." + ); +} diff --git a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.h b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.h new file mode 100644 index 0000000000000000000000000000000000000000..1583030b8235acfb3a3af1a86fa938901ae52bbb --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.h @@ -0,0 +1,492 @@ +/*This code from NVIDIA Megatron: + * with minor changes. */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace { + +template +__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); + +template <> +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } + +int log2_ceil(int value) { + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; +} + +template +struct Add { + __device__ __forceinline__ T operator()(T a, T b) const { + return a + b; + } +}; + +template +struct Max { + __device__ __forceinline__ T operator()(T a, T b) const { + return a < b ? b : a; + } +}; + +template +__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) +{ +#if CUDA_VERSION >= 9000 + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t* sum) { + ReduceOp r; + #pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); + sum[i] = r(sum[i], b); + } + } +} + +/* + * Extended softmax (from native aten pytorch) with following additional features + * 1) input scaling + * 2) Explicit masking + */ +template +__global__ void scaled_masked_softmax_warp_forward( + output_t *dst, + const input_t *src, + const uint8_t *mask, + const acc_t scale, + int micro_batch_size, + int element_count, + int pad_batches) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) + // gridDim/blockIdx = (seq_len, attn_heads, batches) + int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH; + int pad_first_batch = 0; + if (pad_batches != 1) { // bert style + pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH; + } else { // gpt2 style + pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + } + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; + uint8_t temp_mask[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < batch_element_count) { + int itr_idx = i*element_count+it*WARP_SIZE; + copy_vector(temp_data, src + itr_idx); + copy_vector(temp_mask, mask + itr_idx); + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (temp_mask[element] != 1) { + elements[i][it + element] = (acc_t)temp_data[element] * scale; + } else { + elements[i][it + element] = -10000.0; + } + } + } else { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } + } + + // compute max_value + acc_t max_value[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + } + } + warp_reduce(max_value); + + acc_t sum[WARP_BATCH] { 0.0f }; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; + } + } + warp_reduce(sum); + + // store result + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = elements[i][it + element] / sum[i]; + } + copy_vector(dst + i * element_count + it * WARP_SIZE, out); + } else { + break; + } + } + } +} + +template +__global__ void scaled_masked_softmax_warp_backward( + output_t *gradInput, + input_t *grad, + const input_t *output, + acc_t scale, + int micro_batch_size, + int element_count) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_backward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) + // gridDim/blockIdx = (seq_len, attn_heads, batches) + int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + // the first element to process by the current thread + int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + grad += thread_offset; + output += thread_offset; + gradInput += thread_offset; + + // load data from global memory + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; + input_t temp_grad[ELEMENTS_PER_LDG_STG]; + input_t temp_output[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + copy_vector(temp_grad, grad + i * element_count + it * WARP_SIZE); + copy_vector(temp_output, output + i * element_count + it * WARP_SIZE); + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + output_reg[i][it + element] = (acc_t)temp_output[element]; + } + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; + } + } + } + } + + acc_t sum[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_reg[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_reg[i][it]; + } + } + warp_reduce(sum); + + // store result + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + // compute gradients + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); + } + copy_vector(gradInput + i * element_count + it * WARP_SIZE, out); + } + } + } +} +} // end of anonymous namespace + +int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){ + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + constexpr int threads_per_block = 128; + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + + return batches_per_block; +} + +template +void dispatch_scaled_masked_softmax_forward( + output_t *dst, + const input_t *src, + const uint8_t *mask, + const input_t scale, + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads, + int pad_batches) +{ + TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048 ); + if (key_seq_len == 0) { + return; + } else { + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + int batch_count = batches * attn_heads * query_seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0); + dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 1: // 2 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 2: // 4 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 3: // 8 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 4: // 16 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 5: // 32 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 6: // 64 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 7: // 128 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 8: // 256 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 9: // 512 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 10: // 1024 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 11: // 2048 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + default: + break; + } + } +} + +template +void dispatch_scaled_masked_softmax_backward( + output_t *grad_input, + input_t *grad, + const input_t *output, + const acc_t scale, + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads) +{ + TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 2048 ); + if (key_seq_len == 0) { + return; + } else { + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + int batch_count = batches * attn_heads * query_seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + int blocks = batch_count/batches_per_block; + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 1: // 2 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 2: // 4 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 3: // 8 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 4: // 16 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 5: // 32 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 6: // 64 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 7: // 128 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 8: // 256 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 9: // 512 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 10: // 1024 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 11: // 2048 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + default: + break; + } + } +} diff --git a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..41781ebc7fe0dfffd4a502fb61ff60b362781222 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu @@ -0,0 +1,89 @@ +/*This code from NVIDIA Megatron: + * with minor changes. */ + +#include +#include +#include +#include +#include +#include +#include + +#include "scaled_masked_softmax.h" +#include "type_shim.h" + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_masked_softmax { + +int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, + int attn_heads) { + return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads); +} + +torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, + float scale_factor) { + // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, + // seq_len] + const int batches = input.size(0); + const int pad_batches = mask.size(0); + const int attn_heads = input.size(1); + const int query_seq_len = input.size(2); + const int key_seq_len = input.size(3); + TORCH_INTERNAL_ASSERT(key_seq_len <= 2048); + TORCH_INTERNAL_ASSERT(query_seq_len > 1); + TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches); + TORCH_INTERNAL_ASSERT(mask.size(1) == 1); + TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len); + TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len); + + // Output + auto act_options = input.options().requires_grad(false); + torch::Tensor softmax_results = torch::empty( + {batches, attn_heads, query_seq_len, key_seq_len}, act_options); + + // Softmax Intermediate Result Ptr + void* input_ptr = static_cast(input.data_ptr()); + void* mask_ptr = static_cast(mask.data_ptr()); + void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); + + DISPATCH_HALF_AND_BFLOAT( + input.scalar_type(), "dispatch_scaled_masked_softmax_forward", + dispatch_scaled_masked_softmax_forward( + reinterpret_cast(softmax_results_ptr), + reinterpret_cast(input_ptr), + reinterpret_cast(mask_ptr), scale_factor, + query_seq_len, key_seq_len, batches, attn_heads, pad_batches);); + return softmax_results; +} + +torch::Tensor bwd_cuda(torch::Tensor const& output_grads_, + torch::Tensor const& softmax_results_, + float scale_factor) { + auto output_grads = output_grads_.contiguous(); + auto softmax_results = softmax_results_.contiguous(); + + // output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, + // seq_len] + const int batches = output_grads.size(0); + const int attn_heads = output_grads.size(1); + const int query_seq_len = output_grads.size(2); + const int key_seq_len = output_grads.size(3); + + void* output_grads_ptr = static_cast(output_grads.data_ptr()); + + // Softmax Grad + DISPATCH_HALF_AND_BFLOAT( + output_grads_.scalar_type(), "dispatch_scaled_masked_softmax_backward", + dispatch_scaled_masked_softmax_backward( + reinterpret_cast(output_grads_ptr), + reinterpret_cast(output_grads_ptr), + reinterpret_cast(softmax_results.data_ptr()), + scale_factor, query_seq_len, key_seq_len, batches, attn_heads);); + + // backward pass is completely in-place + return output_grads; +} +} // namespace scaled_masked_softmax +} // namespace fused_softmax +} // namespace multihead_attn diff --git a/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.cpp b/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cbbc3706497a69ea142e2898acc321c383fd1939 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.cpp @@ -0,0 +1,54 @@ +/*This code from NVIDIA Megatron: + * with minor changes. */ + +#include +#include + +#include + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_upper_triang_masked_softmax { + +torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor); + +torch::Tensor bwd_cuda(torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor); + +torch::Tensor fwd(torch::Tensor const& input, float scale_factor) { + AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); + AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + return fwd_cuda(input, scale_factor); +} + +torch::Tensor bwd(torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, float scale_factor) { + AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); + AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); + + AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + return bwd_cuda(output_grads, softmax_results, scale_factor); +} + +} // end namespace scaled_upper_triang_masked_softmax +} // end namespace fused_softmax +} // end namespace multihead_attn + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", + &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd, + "Self Multihead Attention scaled, time masked softmax -- Forward."); + m.def("backward", + &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd, + "Self Multihead Attention scaled, time masked softmax -- Backward."); +} diff --git a/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.h b/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.h new file mode 100644 index 0000000000000000000000000000000000000000..3af487f9de0ffdc22faaca142cbc2ff86b68d03e --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.h @@ -0,0 +1,500 @@ +/*This code from NVIDIA Megatron: + * with minor changes. */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace { + +template +__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); + +template <> +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } + +template +__device__ __inline__ void copy_zero_vector(Datatype *dst); + +template <> +__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *dst = 0.0; } + +template <> +__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } + +template <> +__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *dst = 0.0; } + +template <> +__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } + + +int log2_ceil(int value) { + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; +} + +template +struct Add { + __device__ __forceinline__ T operator()(T a, T b) const { + return a + b; + } +}; + +template +struct Max { + __device__ __forceinline__ T operator()(T a, T b) const { + return a < b ? b : a; + } +}; + +template +__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) +{ +#if CUDA_VERSION >= 9000 + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t* sum) { + ReduceOp r; + #pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); + sum[i] = r(sum[i], b); + } + } +} + +/* + * Extended softmax (from native aten pytorch) with following additional features + * 1) input scaling + * 2) Implicit time (diagonal masking) + */ +template +__global__ void scaled_upper_triang_masked_softmax_warp_forward( + output_t *dst, + const input_t *src, + const acc_t scale, + int micro_batch_size, + int stride, + int element_count) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; + int local_seq = blockIdx.x + 1; + int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : local_seq; + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < batch_element_count) { + copy_vector(temp_data, src + i*element_count*stride + it*WARP_SIZE); + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if ((element_index + element) < batch_element_count) { + elements[i][it+element] = (acc_t)temp_data[element] * scale; + } else { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } else { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } + } + + // compute max_value + acc_t max_value[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + } + } + warp_reduce(max_value); + + acc_t sum[WARP_BATCH] { 0.0f }; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + if (it < warp_iteration_limit) { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; + } + } + } + warp_reduce(sum); + + // store result + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < local_seq) { + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < local_seq) { + out[element] = elements[i][it + element] / sum[i]; + } else { + out[element] = 0; + } + } + copy_vector(dst + i * element_count * stride + it * WARP_SIZE, out); + } else if (element_index < element_count) { + copy_zero_vector(dst + i * element_count * stride + it * WARP_SIZE); + } else { + break; + } + } + } +} + +template +__global__ void scaled_upper_triang_masked_softmax_warp_backward( + output_t *gradInput, + input_t *grad, + const input_t *output, + acc_t scale, + int micro_batch_size, + int stride, + int element_count) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_backward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; + int local_seq = blockIdx.x + 1; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + // the first element to process by the current thread + int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + grad += thread_offset; + output += thread_offset; + gradInput += thread_offset; + + // load data from global memory + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; + input_t temp_grad[ELEMENTS_PER_LDG_STG]; + input_t temp_output[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : local_seq; + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + copy_vector(temp_grad, grad + i * element_count * stride + it * WARP_SIZE); + copy_vector(temp_output, output + i * element_count * stride + it * WARP_SIZE); + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < batch_element_count) { + output_reg[i][it + element] = (acc_t)temp_output[element]; + } + } + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < batch_element_count) { + grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; + } + } + } + } + } + + acc_t sum[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_reg[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_reg[i][it]; + } + } + warp_reduce(sum); + + // store result + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + // compute gradients + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); + } + copy_vector(gradInput + i * element_count * stride + it * WARP_SIZE, out); + } + } + } +} + +} // end of anonymous namespace + +template +void dispatch_scaled_upper_triang_masked_softmax_forward( + output_t *dst, + const input_t *src, + const input_t scale, + int softmax_elements, + int softmax_elements_stride, + int attn_batches) +{ + TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 ); + if (softmax_elements == 0) { + return; + } else { + int log2_elements = log2_ceil(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + int seq_len = softmax_elements; + int batch_count = attn_batches * seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); + + int blocks_per_seq = attn_batches / batches_per_block; + dim3 blocks(seq_len, blocks_per_seq, 1); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 1: // 2 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 2: // 4 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 3: // 8 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 4: // 16 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 5: // 32 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 6: // 64 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 7: // 128 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 8: // 256 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 9: // 512 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 10: // 1024 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 11: // 2048 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + default: + break; + } + } +} + +template +void dispatch_scaled_upper_triang_masked_softmax_backward( + output_t *grad_input, + input_t *grad, + const input_t *output, + const acc_t scale, + int softmax_elements, + int softmax_elements_stride, + int attn_batches) +{ + TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 ); + if (softmax_elements == 0) { + return; + } else { + int log2_elements = log2_ceil(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + int seq_len = softmax_elements; + int batch_count = attn_batches * seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); + + int blocks_per_seq = attn_batches / batches_per_block; + dim3 blocks(seq_len, blocks_per_seq, 1); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 1: // 2 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 2: // 4 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 3: // 8 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 4: // 16 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 5: // 32 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 6: // 64 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 7: // 128 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 8: // 256 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 9: // 512 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 10: // 1024 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 11: // 2048 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + default: + break; + } + } +} diff --git a/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax_cuda.cu b/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..62c56e6f78707b7c4e3be734f45f1b0bafb2d223 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax_cuda.cu @@ -0,0 +1,75 @@ +/*This code from NVIDIA Megatron: + * with minor changes. */ + +#include +#include +#include +#include +#include +#include +#include + +#include "scaled_upper_triang_masked_softmax.h" +#include "type_shim.h" + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_upper_triang_masked_softmax { + +torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor) { + // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] + const int attn_batches = input.size(0); + const int seq_len = input.size(1); + TORCH_INTERNAL_ASSERT(seq_len <= 2048); + + // Output + auto act_options = input.options().requires_grad(false); + torch::Tensor softmax_results = + torch::empty({attn_batches, seq_len, seq_len}, act_options); + + // Softmax Intermediate Result Ptr + void* input_ptr = static_cast(input.data_ptr()); + void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); + + DISPATCH_HALF_AND_BFLOAT( + input.scalar_type(), + "dispatch_scaled_upper_triang_masked_softmax_forward", + dispatch_scaled_upper_triang_masked_softmax_forward( + reinterpret_cast(softmax_results_ptr), + reinterpret_cast(input_ptr), scale_factor, seq_len, + seq_len, attn_batches);); + return softmax_results; +} + +torch::Tensor bwd_cuda(torch::Tensor const& output_grads_, + torch::Tensor const& softmax_results_, + float scale_factor) { + auto output_grads = output_grads_.contiguous(); + auto softmax_results = softmax_results_.contiguous(); + + // output grads is a 3d tensor with dimensions [attn_batches, seq_len, + // seq_len] + const int attn_batches = output_grads.size(0); + const int seq_len = output_grads.size(1); + TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2)); + + void* output_grads_ptr = static_cast(output_grads.data_ptr()); + + // Softmax Grad + DISPATCH_HALF_AND_BFLOAT( + output_grads_.scalar_type(), + "dispatch_scaled_upper_triang_masked_softmax_backward", + dispatch_scaled_upper_triang_masked_softmax_backward( + reinterpret_cast(output_grads_ptr), + reinterpret_cast(output_grads_ptr), + reinterpret_cast(softmax_results.data_ptr()), + scale_factor, seq_len, seq_len, attn_batches);); + + // backward pass is completely in-place + return output_grads; +} +} // namespace scaled_upper_triang_masked_softmax +} // namespace fused_softmax +} // namespace multihead_attn diff --git a/colossalai/kernel/cuda_native/csrc/type_shim.h b/colossalai/kernel/cuda_native/csrc/type_shim.h new file mode 100644 index 0000000000000000000000000000000000000000..2f180a7783ec98b5d9b8286ccac4887b2a5b4bc5 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/type_shim.h @@ -0,0 +1,264 @@ +/* Taken from NVIDIA/apex commit 855808f3fc268e9715d613f3c2e56469d8c986d8 */ +/* Copyright 2020 The Microsoft DeepSpeed Team + Copyright NVIDIA/apex + This file is adapted from fused adam in NVIDIA/apex, commit a109f85 + Licensed under the MIT License. +*/ +#include + +#include "compat.h" + +#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ + switch (TYPE) { \ + case at::ScalarType::Half: { \ + using scalar_t = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: { \ + using scalar_t = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + +#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ + switch (TYPEIN) { \ + case at::ScalarType::Float: { \ + using scalar_t_in = float; \ + switch (TYPEOUT) { \ + case at::ScalarType::Float: { \ + using scalar_t_out = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: { \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: { \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ + } \ + break; \ + } \ + case at::ScalarType::Half: { \ + using scalar_t_in = at::Half; \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: { \ + using scalar_t_in = at::BFloat16; \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \ + } + +// Forward/backward compatiblity hack around +// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288 +// pending more future-proof guidance from upstream. +// struct TypeShim +// { +// const at::Type& payload; +// TypeShim(const at::Type& type) : payload(type) {} +// // Enable trivial conversion to a const at::Type& for pre-3aeb78 +// operator const at::Type&(){ return payload; }; +// // Enable dispatch switch statements to take *this directly for post-3aeb78 +// //operator at::ScalarType(){ return payload.; }; +// }; + +#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ + switch (TYPE) { \ + case at::ScalarType::Float: { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: { \ + using scalar_t_##LEVEL = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + +#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \ + switch (TYPE) { \ + case at::ScalarType::Float: { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: { \ + using scalar_t_##LEVEL = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Byte: { \ + using scalar_t_##LEVEL = uint8_t; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + +#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ + switch (TYPE) { \ + case at::ScalarType::Double: { \ + using scalar_t_##LEVEL = double; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Float: { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: { \ + using scalar_t_##LEVEL = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + +#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \ + switch (TYPE) { \ + case at::ScalarType::Double: { \ + using scalar_t_##LEVEL = double; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Float: { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + +#define DISPATCH_FLOAT_AND_HALF_FOR_G_P(GTYPE, PTYPE, LEVEL, NAME, ...) \ + if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Float) { \ + using g_scalar_t_##LEVEL = float; \ + using p_scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + } else if (GTYPE == at::ScalarType::Float && \ + PTYPE == at::ScalarType::Half) { \ + using g_scalar_t_##LEVEL = float; \ + using p_scalar_t_##LEVEL = at::Half; \ + __VA_ARGS__; \ + } else if (GTYPE == at::ScalarType::Half && \ + PTYPE == at::ScalarType::Float) { \ + using g_scalar_t_##LEVEL = at::Half; \ + using p_scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + } else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Half) { \ + using g_scalar_t_##LEVEL = at::Half; \ + using p_scalar_t_##LEVEL = at::Half; \ + __VA_ARGS__; \ + } else { \ + AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), \ + "'"); \ + } + +template +__device__ __forceinline__ T reduce_block_into_lanes( + T *x, T val, int lanes = 1, + bool share_result = false) // lanes is intended to be <= 32. +{ + int tid = threadIdx.x + threadIdx.y * blockDim.x; + int blockSize = + blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32. + + if (blockSize >= 64) { + x[tid] = val; + __syncthreads(); + } + +#pragma unroll + for (int i = (blockSize >> 1); i >= 64; i >>= 1) { + if (tid < i) x[tid] = x[tid] + x[tid + i]; + __syncthreads(); + } + + T final; + + if (tid < 32) { + if (blockSize >= 64) + final = x[tid] + x[tid + 32]; + else + final = val; + // __SYNCWARP(); + +#pragma unroll + for (int i = 16; i >= lanes; i >>= 1) + final = final + __shfl_down_sync(0xffffffff, final, i); + } + + if (share_result) { + if (tid < lanes) x[tid] = final; // EpilogueOp + // Make sure the smem result is visible to all warps. + __syncthreads(); + } + + return final; +} + +template +__device__ __forceinline__ T reduce_block_into_lanes_max_op( + T *x, T val, int lanes = 1, + bool share_result = false) // lanes is intended to be <= 32. +{ + int tid = threadIdx.x + threadIdx.y * blockDim.x; + int blockSize = + blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32. + + if (blockSize >= 64) { + x[tid] = val; + __syncthreads(); + } + +#pragma unroll + for (int i = (blockSize >> 1); i >= 64; i >>= 1) { + if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i])); + __syncthreads(); + } + + T final; + + if (tid < 32) { + if (blockSize >= 64) + final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32])); + else + final = val; + // __SYNCWARP(); + +#pragma unroll + for (int i = 16; i >= lanes; i >>= 1) + final = + fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i))); + } + + if (share_result) { + if (tid < lanes) x[tid] = final; // EpilogueOp + // Make sure the smem result is visible to all warps. + __syncthreads(); + } + + return final; +} diff --git a/colossalai/kernel/cuda_native/flash_attention.py b/colossalai/kernel/cuda_native/flash_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..d793815ed681c52a7d17b49528d86aee2d441a9f --- /dev/null +++ b/colossalai/kernel/cuda_native/flash_attention.py @@ -0,0 +1,635 @@ +""" +A general attention module using the flash attention kernels from xformers: +https://github.com/facebookresearch/xformers/tree/main/xformers/ops/fmha +""" + +import math +import os +import subprocess + +import torch + +try: + from xformers.ops.fmha import memory_efficient_attention + HAS_MEM_EFF_ATTN = True +except ImportError: + HAS_MEM_EFF_ATTN = False + print('please install xformers from https://github.com/facebookresearch/xformers') + +if HAS_MEM_EFF_ATTN: + + from typing import Optional + + from einops import rearrange + from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp + from xformers.ops.fmha.attn_bias import BlockDiagonalMask, LowerTriangularMask, LowerTriangularMaskWithTensorBias + + from .scaled_softmax import AttnMaskType + + allow_alibi = True + for op in MemoryEfficientAttentionCutlassOp: + allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES) + + class Unpad(torch.autograd.Function): + """ + Adapted from + https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py + """ + + @staticmethod + def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor): + ctx.save_for_backward(indices) + # [b, s, ...] + assert tensor.ndim >= 3 + ctx.bsz = tensor.shape[0] + out = rearrange(tensor, 'b s ... -> (b s) ...') + ctx.shape = out.shape + # [1, ntokens, ...] + return out[indices].unsqueeze(0) + + @staticmethod + def backward(ctx, grad_output): + indices, = ctx.saved_tensors + # [b*s, ...] + grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device) + grad[indices] = grad_output.squeeze(0) + grad = rearrange(grad, '(b s) ... -> b s ...', b=ctx.bsz) + # [b, s, ...] + return grad, None + + class Repad(torch.autograd.Function): + """ + Adapted from + https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py + """ + + @staticmethod + def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int): + ctx.save_for_backward(indices) + # [ntokens, ...] + tensor = tensor.squeeze(0) + out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device) + # [b*s, ...] + out[indices] = tensor + # [b, s, ...] + out = rearrange(out, '(b s) ... -> b s ...', b=batch_size) + return out + + @staticmethod + def backward(ctx, grad_output): + indices, = ctx.saved_tensors + # [b*s, ...] + grad_output = rearrange(grad_output, 'b s ... -> (b s) ...') + grad = grad_output[indices] + # [1, ntokens, ...] + return grad.unsqueeze(0), None, None, None + + class ColoAttention(torch.nn.Module): + + def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0): + super().__init__() + assert embed_dim % num_heads == 0, \ + f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})." + self.scale = 1 / math.sqrt(embed_dim // num_heads) + self.dropout = dropout + + @staticmethod + def get_seq_info_from_mask(attn_mask: torch.Tensor): + indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten() + seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten().tolist() + return indices, seqlens + + @staticmethod + def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: + return Unpad.apply(tensor, indices) + + @staticmethod + def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor: + return Repad.apply(tensor, indices, batch_size, seq_len) + + def forward(self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + attn_mask_type: Optional[AttnMaskType] = None, + bias: Optional[torch.Tensor] = None): + batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1] + attn_bias = None + if attn_mask_type == AttnMaskType.padding: # bert style + assert attn_mask is not None, \ + f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}." + assert attn_mask.dim() == 2, \ + "attention mask is supposed to have shape (batch_size, seq_len), " + \ + f"but got {attn_mask.dim()} dimensions." + if tgt_len == src_len: + q_indices, q_seqlen = self.get_seq_info_from_mask(attn_mask) + kv_seqlen = None + if batch_size > 1: + query, key, value = self.unpad(torch.stack([query, key, value], dim=2), q_indices).unbind(dim=2) + else: + q_indices = torch.arange(batch_size * tgt_len, dtype=torch.int32, device=query.device) + q_seqlen = torch.LongTensor([tgt_len] * batch_size, device=query.device) + kv_indices, kv_seqlen = self.get_seq_info_from_mask(attn_mask) + if batch_size > 1: + query = rearrange(query, "b s ... -> c (b s) ...", c=1) + key, value = self.unpad(torch.stack([query, key, value], dim=2), kv_indices).unbind(dim=2) + attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen) + elif attn_mask_type == AttnMaskType.causal: # gpt style + attn_bias = LowerTriangularMask() + + if bias is not None: # alibi / relative position emebedding + assert allow_alibi, "flash attention with bias is not supported in this system." + assert attn_mask_type == AttnMaskType.causal, \ + "attention with bias is only supported for causal attention so far." + attn_bias = attn_bias.add_bias(bias) + + out = memory_efficient_attention(query, key, value, attn_bias=attn_bias, p=self.dropout, scale=self.scale) + + if attn_mask_type == AttnMaskType.padding and batch_size > 1: + out = self.repad(out, q_indices, batch_size, tgt_len) + + out = rearrange(out, 'b s h d -> b s (h d)') + return out + + +########################################################################## +# the flash attention functions below that are copied +# from the OpenAI/triton repository will be deprecated +# You can find the repository in Triton https://github.com/openai/triton +# You can find the source file in https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py +# Reference: +# 1. Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf +# 2. Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf + + +def triton_cuda_check(): + cuda_home = os.getenv("CUDA_HOME", default="/usr/local/cuda") + cuda_version = subprocess.check_output([os.path.join(cuda_home, "bin/nvcc"), "--version"]).decode().strip() + cuda_version = cuda_version.split('release ')[1] + cuda_version = cuda_version.split(',')[0] + cuda_version = cuda_version.split('.') + if len(cuda_version) == 2 and \ + (int(cuda_version[0]) == 11 and int(cuda_version[1]) >= 4) or \ + int(cuda_version[0]) > 11: + return True + return False + + +try: + import triton + import triton.language as tl + if triton_cuda_check(): + HAS_TRITON = True + else: + print("triton requires cuda >= 11.4") + HAS_TRITON = False +except ImportError: + print('please install triton from https://github.com/openai/triton') + HAS_TRITON = False +try: + from flash_attn.flash_attention import FlashAttention + from flash_attn.flash_attn_interface import ( + flash_attn_unpadded_func, + flash_attn_unpadded_kvpacked_func, + flash_attn_unpadded_qkvpacked_func, + ) + HAS_FLASH_ATTN = True +except ImportError: + HAS_FLASH_ATTN = False + print('please install flash_attn from https://github.com/HazyResearch/flash-attention') + +if HAS_TRITON: + # the following functions are adapted from the OpenAI Triton tutorial + # https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py + @triton.jit + def _fwd_kernel( + Q, + K, + V, + sm_scale, + TMP, + L, + M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug + Out, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vk, + stride_vn, + stride_oz, + stride_oh, + stride_om, + stride_on, + Z, + H, + N_CTX, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk + off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk + # Initialize pointers to Q, K, V + q_ptrs = Q + off_q + k_ptrs = K + off_k + v_ptrs = V + off_v + # initialize pointer to m and l + t_ptrs = TMP + off_hz * N_CTX + offs_m + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # load q: it will stay in SRAM throughout + q = tl.load(q_ptrs) + # loop over k, v and update accumulator + for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + start_n * stride_kn) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k, trans_b=True) + qk *= sm_scale + qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf")) + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + tl.store(t_ptrs, acc_scale) + acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + start_n * stride_vk) + p = p.to(tl.float16) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + # rematerialize offsets to save registers + start_m = tl.program_id(0) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # write back l and m + l_ptrs = L + off_hz * N_CTX + offs_m + m_ptrs = M + off_hz * N_CTX + offs_m + tl.store(l_ptrs, l_i) + tl.store(m_ptrs, m_i) + # initialize pointers to output + offs_n = tl.arange(0, BLOCK_DMODEL) + off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on + out_ptrs = Out + off_o + tl.store(out_ptrs, acc) + + @triton.jit + def _bwd_preprocess( + Out, + DO, + L, + NewDO, + Delta, + BLOCK_M: tl.constexpr, + D_HEAD: tl.constexpr, + ): + off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + off_n = tl.arange(0, D_HEAD) + # load + o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + denom = tl.load(L + off_m).to(tl.float32) + # compute + do = do / denom[:, None] + delta = tl.sum(o * do, axis=1) + # write-back + tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do) + tl.store(Delta + off_m, delta) + + @triton.jit + def _bwd_kernel( + Q, + K, + V, + sm_scale, + Out, + DO, + DQ, + DK, + DV, + L, + M, + D, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vk, + stride_vn, + Z, + H, + N_CTX, + num_block, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + off_hz = tl.program_id(0) + off_z = off_hz // H + off_h = off_hz % H + # offset pointers for batch/head + Q += off_z * stride_qz + off_h * stride_qh + K += off_z * stride_qz + off_h * stride_qh + V += off_z * stride_qz + off_h * stride_qh + DO += off_z * stride_qz + off_h * stride_qh + DQ += off_z * stride_qz + off_h * stride_qh + DK += off_z * stride_qz + off_h * stride_qh + DV += off_z * stride_qz + off_h * stride_qh + for start_n in range(0, num_block): + lo = start_n * BLOCK_M + # initialize row/col offsets + offs_qm = lo + tl.arange(0, BLOCK_M) + offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M) + offs_m = tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_DMODEL) + # initialize pointers to value-like data + q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) + v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) + do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + # pointer to row-wise quantities in value-like data + D_ptrs = D + off_hz * N_CTX + m_ptrs = M + off_hz * N_CTX + # initialize dv amd dk + dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # k and v stay in SRAM throughout + k = tl.load(k_ptrs) + v = tl.load(v_ptrs) + # loop over rows + for start_m in range(lo, num_block * BLOCK_M, BLOCK_M): + offs_m_curr = start_m + offs_m + # load q, k, v, do on-chip + q = tl.load(q_ptrs) + # recompute p = softmax(qk, dim=-1).T + # NOTE: `do` is pre-divided by `l`; no normalization here + qk = tl.dot(q, k, trans_b=True) + qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) + m = tl.load(m_ptrs + offs_m_curr) + p = tl.exp(qk * sm_scale - m[:, None]) + # compute dv + do = tl.load(do_ptrs) + dv += tl.dot(p.to(tl.float16), do, trans_a=True) + # compute dp = dot(v, do) + Di = tl.load(D_ptrs + offs_m_curr) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] + dp += tl.dot(do, v, trans_b=True) + # compute ds = p * (dp - delta[:, None]) + ds = p * dp * sm_scale + # compute dk = dot(ds.T, q) + dk += tl.dot(ds.to(tl.float16), q, trans_a=True) + # # compute dq + dq = tl.load(dq_ptrs, eviction_policy="evict_last") + dq += tl.dot(ds.to(tl.float16), k) + tl.store(dq_ptrs, dq, eviction_policy="evict_last") + # # increment pointers + dq_ptrs += BLOCK_M * stride_qm + q_ptrs += BLOCK_M * stride_qm + do_ptrs += BLOCK_M * stride_qm + # write-back + dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) + dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) + tl.store(dv_ptrs, dv) + tl.store(dk_ptrs, dk) + + class _TritonFlashAttention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, sm_scale): + BLOCK = 128 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + o = torch.empty_like(q) + grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1]) + tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + num_warps = 4 if Lk <= 64 else 8 + + _fwd_kernel[grid]( + q, + k, + v, + sm_scale, + tmp, + L, + m, + o, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + o.stride(0), + o.stride(1), + o.stride(2), + o.stride(3), + q.shape[0], + q.shape[1], + q.shape[2], + BLOCK_M=BLOCK, + BLOCK_N=BLOCK, + BLOCK_DMODEL=Lk, + num_warps=num_warps, + num_stages=1, + ) + ctx.save_for_backward(q, k, v, o, L, m) + ctx.BLOCK = BLOCK + ctx.grid = grid + ctx.sm_scale = sm_scale + ctx.BLOCK_DMODEL = Lk + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o, l, m = ctx.saved_tensors + do = do.contiguous() + dq = torch.zeros_like(q, dtype=torch.float32) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + do_scaled = torch.empty_like(do) + delta = torch.empty_like(l) + _bwd_preprocess[(ctx.grid[0] * ctx.grid[1],)]( + o, + do, + l, + do_scaled, + delta, + BLOCK_M=ctx.BLOCK, + D_HEAD=ctx.BLOCK_DMODEL, + ) + + # NOTE: kernel currently buggy for other values of `num_warps` + num_warps = 8 + _bwd_kernel[(ctx.grid[1],)]( + q, + k, + v, + ctx.sm_scale, + o, + do_scaled, + dq, + dk, + dv, + l, + m, + delta, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + q.shape[0], + q.shape[1], + q.shape[2], + ctx.grid[0], + BLOCK_M=ctx.BLOCK, + BLOCK_N=ctx.BLOCK, + BLOCK_DMODEL=ctx.BLOCK_DMODEL, + num_warps=num_warps, + num_stages=1, + ) + return dq, dk, dv, None + + def triton_flash_attention(q, k, v, sm_scale): + """ + Arguments: + q: (batch, nheads, seq, headdim) + k: (batch, nheads, seq, headdim) + v: (batch, nheads, seq, headdim) + sm_scale: float. The scaling of QK^T before applying softmax. + Return: + out: (batch, nheads, seq, headdim) + """ + if HAS_TRITON: + return _TritonFlashAttention.apply(q, k, v, sm_scale) + else: + raise RuntimeError("Triton kernel requires CUDA 11.4+!") + + +if HAS_FLASH_ATTN: + + def flash_attention_qkv(qkv, sm_scale, batch_size, seq_len, dropout_p=0., causal=False): + """ + Arguments: + qkv: (batch * seqlen, 3, nheads, headdim) + batch_size: int. + seq_len: int. + sm_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + dropout_p: float. + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + Return: + out: (total, nheads, headdim). + """ + max_s = seq_len + cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len, step=seq_len, dtype=torch.int32, device=qkv.device) + out = flash_attn_unpadded_qkvpacked_func(qkv, + cu_seqlens, + max_s, + dropout_p, + softmax_scale=sm_scale, + causal=causal) + return out + + def flash_attention_q_kv(q, kv, sm_scale, batch_size, q_seqlen, kv_seqlen, dropout_p=0., causal=False): + """ + Arguments: + q: (batch * q_seqlen, nheads, headdim) + kv: (batch * kv_seqlen, 2, nheads, headdim) + batch_size: int. + seq_len: int. + sm_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + dropout_p: float. + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + Return: + out: (total, nheads, headdim). + """ + cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device) + cu_seqlens_k = torch.arange(0, (batch_size + 1) * kv_seqlen, + step=kv_seqlen, + dtype=torch.int32, + device=kv.device) + out = flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, q_seqlen, kv_seqlen, dropout_p, + sm_scale, causal) + return out + + def flash_attention_q_k_v(q, k, v, sm_scale, batch_size, q_seqlen, kv_seqlen, dropout_p=0., causal=False): + """ + Arguments: + q: (batch * q_seqlen, nheads, headdim) + k: (batch * kv_seqlen, nheads, headdim) + v: (batch * kv_seqlen, nheads, headdim) + batch_size: int. + seq_len: int. + dropout_p: float. Dropout probability. + sm_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + Return: + out: (total, nheads, headdim). + """ + cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device) + cu_seqlens_kv = torch.arange(0, (batch_size + 1) * kv_seqlen, + step=kv_seqlen, + dtype=torch.int32, + device=k.device) + return flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, q_seqlen, kv_seqlen, dropout_p, sm_scale, + causal) + + +########################################################################## diff --git a/colossalai/kernel/cuda_native/layer_norm.py b/colossalai/kernel/cuda_native/layer_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..40355a41ed0d2c1b1d1b6266b38bf160e28f8a28 --- /dev/null +++ b/colossalai/kernel/cuda_native/layer_norm.py @@ -0,0 +1,78 @@ +"""This code is from NVIDIA apex: + https://github.com/NVIDIA/apex + with some changes. """ + +import numbers + +import torch +from torch.cuda.amp import custom_bwd, custom_fwd +from torch.nn import init +from torch.nn.parameter import Parameter + +from colossalai.kernel.op_builder.layernorm import LayerNormBuilder + +try: + from colossalai._C import layer_norm +except ImportError: + layer_norm = None + + +class FusedLayerNormAffineFunction(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, input, weight, bias, normalized_shape, eps): + ctx.normalized_shape = normalized_shape + ctx.eps = eps + input_ = input.contiguous() + weight_ = weight.contiguous() + bias_ = bias.contiguous() + + global layer_norm + if layer_norm is None: + + layer_norm = LayerNormBuilder().load() + output, mean, invvar = layer_norm.forward_affine(input_, ctx.normalized_shape, weight_, bias_, ctx.eps) + ctx.layernorm_op = layer_norm + ctx.save_for_backward(input_, weight_, bias_, mean, invvar) + + return output + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + input_, weight_, bias_, mean, invvar = ctx.saved_tensors + grad_input = grad_weight = grad_bias = None + grad_input, grad_weight, grad_bias \ + = layer_norm.backward_affine( + grad_output.contiguous(), mean, invvar, + input_, ctx.normalized_shape, + weight_, bias_, ctx.eps) + + return grad_input, grad_weight, grad_bias, None, None + + +class MixedFusedLayerNorm(torch.nn.Module): + + def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None): + super(MixedFusedLayerNorm, self).__init__() + + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + self.normalized_shape = torch.Size(normalized_shape) + self.eps = eps + self.weight = Parameter(torch.empty(*normalized_shape, device=device, dtype=dtype)) + self.bias = Parameter(torch.empty(*normalized_shape, device=device, dtype=dtype)) + self.reset_parameters() + + def reset_parameters(self): + + init.ones_(self.weight) + init.zeros_(self.bias) + + def forward(self, input): + + return FusedLayerNormAffineFunction.apply(input, self.weight, self.bias, self.normalized_shape, self.eps) + + def __repr__(self): + return f'MixedFusedLayerNorm(normalized_shape={self.normalized_shape}, eps={self.eps})' diff --git a/colossalai/kernel/cuda_native/multihead_attention.py b/colossalai/kernel/cuda_native/multihead_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..3b6470cdcbb98e622fc68062b609dc590f5301ae --- /dev/null +++ b/colossalai/kernel/cuda_native/multihead_attention.py @@ -0,0 +1,257 @@ +import math +from dataclasses import dataclass + +import torch +from torch import nn +from torch.autograd import Function + + +def check_config(config): + if config.hidden_size % config.nhead != 0: + raise Exception("hidden_size % nhead != 0") + + factor = 8 if config.fp16 else 4 + upbound = factor * 1024 * 4 + if config.hidden_size > upbound: + # as required by ln backward kernel currently + raise Exception(f"hidden_size > {upbound}") + + head_dim = config.hidden_size // config.nhead + if head_dim % factor != 0: + # as required by reshape kernel + raise Exception(f"head_dim({head_dim}) % {factor} != 0") + + +def calc_offset(sizes): + offsets = [0] + tmp = 0 + for x in sizes: + tmp += x + offsets.append(tmp) + return offsets + + +colossal_multihead_attention = None + + +@dataclass +class Config: + max_batch_tokens: int # max batch token numbers + max_seq_len: int # max sequence length + hidden_size: int # size of transformer hidden layers + nhead: int # number of heads in attention + attn_prob_dropout_ratio: float # attention score dropout ratio + hidden_dropout_ratio: float # dropout ration before residual + norm_first: bool # norm_first + fp16: bool # fp16 presion + + +class MultiHeadAttention1DFunc(Function): + + @staticmethod + def forward(ctx, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, out_proj_bias, norm_weight, + norm_bias, config): + cuda_module = colossal_multihead_attention + forward_func = (cuda_module.multihead_attention_fw_fp16 + if config.fp16 else cuda_module.multihead_attention_fw_fp32) + if config.fp16: + input = input.to(torch.half) + input_mask = input_mask.to(torch.half) + + (output,) = forward_func(config.layer_id, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, + out_proj_bias, norm_weight, norm_bias, config.training, config.norm_first) + + if config.is_grad_enabled and config.training: + ctx.save_for_backward(output, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, + out_proj_bias, norm_weight, norm_bias) + ctx.config = config + return output + + @staticmethod + def backward(ctx, grad_output): + assert ctx.config.training + + cuda_module = colossal_multihead_attention + backward_func = (cuda_module.multihead_attention_bw_fp16 + if ctx.config.fp16 else cuda_module.multihead_attention_bw_fp32) + + output, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, \ + out_proj_bias, norm_weight, norm_bias = ctx.saved_tensors + + grad_input = None + grad_in_proj_weight = None + grad_in_proj_bias = None + grad_out_proj_weight = None + grad_out_proj_bias = None + grad_norm_weight = None + grad_norm_bias = None + + if ctx.config.fp16: + grad_output = grad_output.to(torch.half) + output = output.to(torch.half) + input = input.to(torch.half) + input_mask = input_mask.to(torch.half) + grad_input, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, \ + grad_out_proj_bias, grad_norm_weight, grad_norm_bias = backward_func( + ctx.config.layer_id, grad_output, output, input, input_mask, in_proj_weight, + in_proj_bias, out_proj_weight, out_proj_bias, norm_weight, norm_bias) + + return (grad_input, None, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, grad_out_proj_bias, + grad_norm_weight, grad_norm_bias, None) + + +class MultiHeadAttention(nn.Module): + """Initialize the MultiHeadAttention. + + Static variable: + + layer_id: The layer-index counter starting from 0 and incrementing by 1 every time a layer object is instantiated, + e.g. if a model has 24 transformer layers, layer_id goes from 0 to 23. + + Arguments: + hidden_size: Total dimension of hidden_size. + nhead: Number of parallel attention heads. + batch_size: Batch Size for one forward + max_seq_len: Max length of input sequence + dropout: Dropout probability + norm_first: perform LayerNorms before attention + """ + + layer_id = 0 + + def __init__(self, hidden_size, nhead, batch_size, max_seq_len, dropout=0.0, norm_first=False, fp16=True, pg=None): + super(MultiHeadAttention, self).__init__() + + self.config = Config(batch_size * max_seq_len, max_seq_len, hidden_size, nhead, dropout, dropout, norm_first, + fp16) + check_config(self.config) + self.pg = pg + self.pg_size = 1 + if self.pg: + self.pg_size = pg.size() + self.config.layer_id = MultiHeadAttention.layer_id + MultiHeadAttention.layer_id = MultiHeadAttention.layer_id + 1 + + # Load cuda modules if needed + global colossal_multihead_attention + if colossal_multihead_attention is None: + from colossalai.kernel.op_builder import MultiHeadAttnBuilder + multihead_attention = MultiHeadAttnBuilder().load() + colossal_multihead_attention = multihead_attention + + # create the layer in cuda kernels. + cuda_module = colossal_multihead_attention + create_layer_func = (cuda_module.create_multihead_attention_fp16 + if self.config.fp16 else cuda_module.create_multihead_attention_fp32) + + create_layer_func( + self.config.layer_id, + self.config.max_batch_tokens, + self.config.max_seq_len, + self.config.hidden_size, + self.config.nhead, + self.config.attn_prob_dropout_ratio, + self.config.hidden_dropout_ratio, + self.config.norm_first, + self.pg, + ) + + hs = self.config.hidden_size + + self.precision = torch.float32 + if self.config.fp16: + self.precision = torch.half + + self.hs_per_rank = int(hs / self.pg_size) + + self.in_proj_weight = nn.Parameter(torch.Tensor(3, self.hs_per_rank, hs)) + self.in_proj_bias = nn.Parameter(torch.Tensor(3, self.hs_per_rank)) + self.out_proj_weight = nn.Parameter(torch.Tensor(hs, self.hs_per_rank)) + self.out_proj_bias = nn.Parameter(torch.Tensor(hs)) + self.norm_weight = nn.Parameter(torch.Tensor(hs)) + self.norm_bias = nn.Parameter(torch.Tensor(hs)) + + self.reset_parameters() + torch.cuda.empty_cache() + + def calc_bound(self, w): + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(w) + bound = 1.0 / math.sqrt(fan_in) + return bound + + def reset_parameters(self): + hs = self.config.hidden_size + + nn.init.zeros_(self.out_proj_bias) + + nn.init.ones_(self.norm_weight) + nn.init.zeros_(self.norm_bias) + + if self.pg_size > 1: + rank_in_pg = torch.distributed.get_rank(self.pg) + attn_qkvw_global = torch.empty(hs * 3, hs) + attn_qkvb_global = torch.empty(hs * 3) + nn.init.xavier_uniform_(attn_qkvw_global, 1.0 / math.sqrt(2.0)) + bound = self.calc_bound(attn_qkvw_global) + nn.init.uniform_(attn_qkvb_global, -bound, bound) + + attn_qkvw_global = attn_qkvw_global.cuda() + attn_qkvb_global = attn_qkvb_global.cuda() + torch.distributed.broadcast(attn_qkvw_global, src=0, group=self.pg) + torch.distributed.broadcast(attn_qkvb_global, src=0, group=self.pg) + attn_qkvw_global = attn_qkvw_global.cpu() + attn_qkvb_global = attn_qkvb_global.cpu() + + with torch.no_grad(): + self.in_proj_weight.copy_( + attn_qkvw_global.view(3, hs, hs)[:, + int(hs * rank_in_pg / self.pg_size):int(hs * (rank_in_pg + 1) / + self.pg_size), :]) + self.in_proj_bias.copy_( + attn_qkvb_global.view(3, hs)[:, + int(hs * rank_in_pg / self.pg_size):int(hs * (rank_in_pg + 1) / + self.pg_size)]) + + attn_ow_global = torch.empty(hs, hs) + nn.init.xavier_uniform_(attn_ow_global, 1.0) + attn_ow_global = attn_ow_global.cuda() + torch.distributed.broadcast(attn_ow_global, src=0, group=self.pg) + attn_ow_global = attn_ow_global.cpu() + with torch.no_grad(): + self.out_proj_weight.copy_(attn_ow_global[:, + int(hs * rank_in_pg / + self.pg_size):int(hs * (rank_in_pg + 1) / self.pg_size)]) + + else: + attn_qkvw = self.in_proj_weight.view(-1, hs) + nn.init.xavier_uniform_(attn_qkvw, 1.0 / math.sqrt(2.0)) + bound = self.calc_bound(attn_qkvw) + nn.init.uniform_(self.in_proj_bias, -bound, bound) + + nn.init.xavier_uniform_(self.out_proj_weight, 1.0) + + def state_dict(self, destination=None, prefix="", keep_vars=False): + destination = torch.nn.Module.state_dict(self, destination=destination, prefix=prefix, keep_vars=keep_vars) + return destination + + def forward(self, hidden_states, encoder_padding_mask): + self.config.training = self.training + self.config.is_grad_enabled = torch.is_grad_enabled() + hidden_states = hidden_states.contiguous() + encoder_padding_mask = ((encoder_padding_mask * -1e8).type_as(hidden_states).contiguous()) + + bs, sl, dim = hidden_states.size() + if bs * sl > self.config.max_batch_tokens: + raise ValueError(f"Batch token numbers {bs * sl} exceeds the limit {self.config.max_batch_tokens}.") + if sl > self.config.max_seq_len: + raise ValueError(f"Sequence length {sl} exceeds the limit {self.config.max_seq_len}.") + if len(encoder_padding_mask.size()) == 1: + assert bs == 1 and sl == encoder_padding_mask.size(0) + else: + assert bs == encoder_padding_mask.size(0) and sl == encoder_padding_mask.size(1) + + output = MultiHeadAttention1DFunc.apply(hidden_states, encoder_padding_mask, self.in_proj_weight, + self.in_proj_bias, self.out_proj_weight, self.out_proj_bias, + self.norm_weight, self.norm_bias, self.config) + + return output.to(self.precision) diff --git a/colossalai/kernel/cuda_native/scaled_softmax.py b/colossalai/kernel/cuda_native/scaled_softmax.py new file mode 100644 index 0000000000000000000000000000000000000000..24e458bb3ea53d87fb5cca21339155dbaa35be07 --- /dev/null +++ b/colossalai/kernel/cuda_native/scaled_softmax.py @@ -0,0 +1,188 @@ +# This code from NVIDIA Megatron: +# with minor changes. + +import enum + +import torch +import torch.nn as nn + +from colossalai.kernel.op_builder.scaled_masked_softmax import ScaledMaskedSoftmaxBuilder +from colossalai.kernel.op_builder.scaled_upper_triangle_masked_softmax import ScaledUpperTrainglemaskedSoftmaxBuilder + +try: + from colossalai._C import scaled_masked_softmax, scaled_upper_triang_masked_softmax +except ImportError: + scaled_masked_softmax = None + scaled_upper_triang_masked_softmax = None + + +class AttnMaskType(enum.Enum): + padding = 1 + causal = 2 + + +class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): + """ + Fused operation which performs following three operations in sequence + + 1. Scale the tensor. + 2. Apply upper triangular mask (typically used in gpt models). + 3. Perform softmax. + """ + + @staticmethod + def forward(ctx, inputs, scale): + global scaled_upper_triang_masked_softmax + if scaled_upper_triang_masked_softmax: + scaled_upper_triang_masked_softmax = ScaledUpperTrainglemaskedSoftmaxBuilder().load() + + scale_t = torch.tensor([scale]) + softmax_results = scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0]) + + ctx.save_for_backward(softmax_results, scale_t) + return softmax_results + + @staticmethod + def backward(ctx, output_grads): + softmax_results, scale_t = ctx.saved_tensors + input_grads = scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, scale_t[0]) + + return input_grads, None + + +class ScaledMaskedSoftmax(torch.autograd.Function): + """ + Fused operation which performs following three operations in sequence + + 1. Scale the tensor. + 2. Apply the mask. + 3. Perform softmax. + """ + + @staticmethod + def forward(ctx, inputs, mask, scale): + scale_t = torch.tensor([scale]) + + # build and load kernel if not pre-built + global scaled_masked_softmax + if scaled_masked_softmax is None: + scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load() + + softmax_results = scaled_masked_softmax.forward(inputs, mask, scale_t[0]) + ctx.save_for_backward(softmax_results, scale_t) + return softmax_results + + @staticmethod + def backward(ctx, output_grads): + softmax_results, scale_t = ctx.saved_tensors + + input_grads = scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0]) + return input_grads, None, None, None + + +class FusedScaleMaskSoftmax(nn.Module): + """ + Fused operation: scaling + mask + softmax + + Arguments: + input_in_fp16: Flag to indicate if input in fp16 data format. + input_in_bf16: Flag to indicate if input in bf16 data format. + attn_mask_type: Attention mask type (pad or causal) + scaled_masked_softmax_fusion: Flag to indicate user want to use softmax fusion + mask_func: Mask function to be applied. + softmax_in_fp32: If True, softmax in performed at fp32 precision. + scale: Scaling factor used in input tensor scaling. + """ + + def __init__( + self, + input_in_fp16, + input_in_bf16, + attn_mask_type, + scaled_masked_softmax_fusion, + mask_func, + softmax_in_fp32, + scale, + ): + super(FusedScaleMaskSoftmax, self).__init__() + self.input_in_fp16 = input_in_fp16 + self.input_in_bf16 = input_in_bf16 + assert not (self.input_in_fp16 + and self.input_in_bf16), "both fp16 and bf16 flags cannot be active at the same time." + self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 + self.attn_mask_type = attn_mask_type + self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion + self.mask_func = mask_func + self.softmax_in_fp32 = softmax_in_fp32 + self.scale = scale + assert (self.scale is None or softmax_in_fp32), "softmax should be in fp32 when scaled" + + def forward(self, input, mask): + # [b, np, sq, sk] + assert input.dim() == 4 + + if self.is_kernel_available(mask, *input.size()): + return self.forward_fused_softmax(input, mask) + else: + return self.forward_torch_softmax(input, mask) + + def is_kernel_available(self, mask, b, np, sq, sk): + attn_batches = b * np + + if (self.scaled_masked_softmax_fusion # user want to fuse + and self.input_in_float16 # input must be fp16 + and mask is not None # mask tensor must not be None + and 16 < sk <= 2048 # sk must be 16 ~ 2048 + and sq % 4 == 0 # sq must be divisor of 4 + and attn_batches % 4 == 0 # np * b must be divisor of 4 + ): + if 0 <= sk <= 2048: + batch_per_block = self.get_batch_per_block(sq, sk, b, np) + + if self.attn_mask_type == AttnMaskType.causal: + if attn_batches % batch_per_block == 0: + return True + else: + if sq % batch_per_block == 0: + return True + return False + + def forward_fused_softmax(self, input, mask): + b, np, sq, sk = input.size() + scale = self.scale if self.scale is not None else 1.0 + + if self.attn_mask_type == AttnMaskType.causal: + assert sq == sk, "causal mask is only for self attention" + + # input is 3D tensor (attn_batches, sq, sk) + input = input.view(-1, sq, sk) + probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale) + return probs.view(b, np, sq, sk) + else: + # input is 4D tensor (b, np, sq, sk) + return ScaledMaskedSoftmax.apply(input, mask, scale) + + def forward_torch_softmax(self, input, mask): + if self.input_in_float16 and self.softmax_in_fp32: + input = input.float() + + if self.scale is not None: + input = input * self.scale + mask_output = self.mask_func(input, mask) if mask is not None else input + probs = torch.nn.Softmax(dim=-1)(mask_output) + + if self.input_in_float16 and self.softmax_in_fp32: + if self.input_in_fp16: + probs = probs.half() + else: + probs = probs.bfloat16() + + return probs + + def get_batch_per_block(self, sq, sk, b, np): + # build and load kernel if not pre-built + global scaled_masked_softmax + if scaled_masked_softmax is None: + scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load() + + return scaled_masked_softmax.get_batch_per_block(sq, sk, b, np) diff --git a/colossalai/kernel/jit/__init__.py b/colossalai/kernel/jit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..57b8fb7b2e996ea0f0336dad1e42ea379d608b15 --- /dev/null +++ b/colossalai/kernel/jit/__init__.py @@ -0,0 +1,8 @@ +from .option import set_jit_fusion_options +from .bias_dropout_add import bias_dropout_add_fused_train, bias_dropout_add_fused_inference +from .bias_gelu import bias_gelu_impl + +__all__ = [ + "bias_dropout_add_fused_train", "bias_dropout_add_fused_inference", "bias_gelu_impl", + "set_jit_fusion_options" +] diff --git a/colossalai/kernel/jit/bias_dropout_add.py b/colossalai/kernel/jit/bias_dropout_add.py new file mode 100644 index 0000000000000000000000000000000000000000..3687dde79a08b7f8f192d6516694938828aae659 --- /dev/null +++ b/colossalai/kernel/jit/bias_dropout_add.py @@ -0,0 +1,24 @@ +import torch + + +def bias_dropout_add(x, bias, residual, prob, training): + # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor + out = torch.nn.functional.dropout(x + bias, p=prob, training=training) + out = residual + out + return out + + +@torch.jit.script +def bias_dropout_add_fused_train(x: torch.Tensor, + bias: torch.Tensor, + residual: torch.Tensor, + prob: float) -> torch.Tensor: + return bias_dropout_add(x, bias, residual, prob, True) + + +@torch.jit.script +def bias_dropout_add_fused_inference(x: torch.Tensor, + bias: torch.Tensor, + residual: torch.Tensor, + prob: float) -> torch.Tensor: + return bias_dropout_add(x, bias, residual, prob, False) diff --git a/colossalai/kernel/jit/bias_gelu.py b/colossalai/kernel/jit/bias_gelu.py new file mode 100644 index 0000000000000000000000000000000000000000..33b4ac32b044f662e475386a3b8e7504b54b108f --- /dev/null +++ b/colossalai/kernel/jit/bias_gelu.py @@ -0,0 +1,46 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +import torch + +###### BIAS GELU FUSION/ NO AUTOGRAD ################ +# 1/sqrt(2*pi)-> 0.3989423 +# 1/sqrt(2) -> 0.70710678 +# sqrt(2/pi) -> 0.79788456 +# this function is tanh approximation of gelu +# actual gelu is: +# x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) + + +@torch.jit.script +def bias_gelu(bias, y): + x = bias + y + return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) + + +# gradient of tanh approximation of gelu +# gradient of actual gelu is: +# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) +@torch.jit.script +def bias_gelu_back(g, bias, y): + x = bias + y + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 + ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) + return ff * g + + +class GeLUFunction(torch.autograd.Function): + + @staticmethod + # bias is an optional argument + def forward(ctx, input, bias): + ctx.save_for_backward(input, bias) + return bias_gelu(bias, input) + + @staticmethod + def backward(ctx, grad_output): + input, bias = ctx.saved_tensors + tmp = bias_gelu_back(grad_output, bias, input) + return tmp, tmp + + +bias_gelu_impl = GeLUFunction.apply diff --git a/colossalai/kernel/jit/option.py b/colossalai/kernel/jit/option.py new file mode 100644 index 0000000000000000000000000000000000000000..aa41f57678fc116ac4acef72f52763f5dadabfed --- /dev/null +++ b/colossalai/kernel/jit/option.py @@ -0,0 +1,79 @@ +import torch + +from colossalai.nn.layer.colossalai_layer import Embedding, Linear +from colossalai.utils import get_current_device + +from .bias_dropout_add import bias_dropout_add_fused_train +from .bias_gelu import bias_gelu_impl + +JIT_OPTIONS_SET = False + + +def set_jit_fusion_options(): + """Set PyTorch JIT layer fusion options. + """ + # LSG: the latest pytorch and CUDA versions may not support + # the following jit settings + global JIT_OPTIONS_SET + if JIT_OPTIONS_SET == False: + # flags required to enable jit fusion kernels + TORCH_MAJOR = int(torch.__version__.split('.')[0]) + TORCH_MINOR = int(torch.__version__.split('.')[1]) + if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10): + # nvfuser + torch._C._jit_set_profiling_executor(True) + torch._C._jit_set_profiling_mode(True) + torch._C._jit_override_can_fuse_on_cpu(False) + torch._C._jit_override_can_fuse_on_gpu(False) + torch._C._jit_set_texpr_fuser_enabled(False) + torch._C._jit_set_nvfuser_enabled(True) + torch._C._debug_set_autodiff_subgraph_inlining(False) + else: + # legacy pytorch fuser + torch._C._jit_set_profiling_mode(False) + torch._C._jit_set_profiling_executor(False) + torch._C._jit_override_can_fuse_on_cpu(True) + torch._C._jit_override_can_fuse_on_gpu(True) + + JIT_OPTIONS_SET = True + + +def warmup_jit_fusion(batch_size: int, + hidden_size: int, + seq_length: int = 512, + vocab_size: int = 32768, + dtype: torch.dtype = torch.float32): + """ Compilie JIT functions before the main training steps """ + + embed = Embedding(vocab_size, hidden_size).to(get_current_device()) + linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_current_device()) + linear_2 = Linear(hidden_size * 4, hidden_size, skip_bias_add=True).to(get_current_device()) + + x = torch.randint(vocab_size, (batch_size, seq_length), dtype=torch.long, device=get_current_device()) + x = embed(x) + y, y_bias = linear_1(x) + z, z_bias = linear_2(y) + # Warmup JIT fusions with the input grad_enable state of both forward + # prop and recomputation + for bias_grad, input_grad in zip([True, True], [False, True]): + for _ in range(10): + bias = torch.rand_like(y_bias, dtype=dtype, device=get_current_device()) + input_ = torch.rand_like(y, dtype=dtype, device=get_current_device()) + bias.requires_grad, input_.requires_grad = bias_grad, input_grad + bias_gelu_impl(input_, bias) + + # Warmup fused bias+dropout+add + dropout_rate = 0.1 + # Warmup JIT fusions with the input grad_enable state of both forward + # prop and recomputation + for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]): + for _ in range(10): + input_ = torch.rand_like(z, dtype=dtype, device=get_current_device()) + residual = torch.rand_like(x, dtype=dtype, device=get_current_device()) + bias = torch.rand_like(z_bias, dtype=dtype, device=get_current_device()) + input_.requires_grad = input_grad + bias.requires_grad = bias_grad + residual.requires_grad = residual_grad + bias_dropout_add_fused_train(input_, bias, residual, dropout_rate) + + torch.cuda.empty_cache() diff --git a/colossalai/kernel/op_builder b/colossalai/kernel/op_builder new file mode 120000 index 0000000000000000000000000000000000000000..db4f9c3350652dff87a26f6bb95c024160842a99 --- /dev/null +++ b/colossalai/kernel/op_builder @@ -0,0 +1 @@ +../../op_builder \ No newline at end of file diff --git a/colossalai/logging/__init__.py b/colossalai/logging/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..97fe4f89ded370c6a71ce277bb33e252912b6f17 --- /dev/null +++ b/colossalai/logging/__init__.py @@ -0,0 +1,38 @@ +import logging +from typing import List, Optional + +from .logger import DistributedLogger + +__all__ = ['get_dist_logger', 'DistributedLogger', 'disable_existing_loggers'] + + +def get_dist_logger(name: str = 'colossalai') -> DistributedLogger: + """Get logger instance based on name. The DistributedLogger will create singleton instances, + which means that only one logger instance is created per name. + + Args: + name (str): name of the logger, name must be unique + + Returns: + :class:`colossalai.logging.DistributedLogger`: A distributed logger singleton instance. + """ + return DistributedLogger.get_instance(name=name) + + +def disable_existing_loggers(include: Optional[List[str]] = None, exclude: List[str] = ['colossalai']) -> None: + """Set the level of existing loggers to `WARNING`. By default, it will "disable" all existing loggers except the logger named "colossalai". + + Args: + include (Optional[List[str]], optional): Loggers whose name in this list will be disabled. + If set to `None`, `exclude` argument will be used. Defaults to None. + exclude (List[str], optional): Loggers whose name not in this list will be disabled. + This argument will be used only when `include` is None. Defaults to ['colossalai']. + """ + if include is None: + filter_func = lambda name: name not in exclude + else: + filter_func = lambda name: name in include + + for log_name in logging.Logger.manager.loggerDict.keys(): + if filter_func(log_name): + logging.getLogger(log_name).setLevel(logging.WARNING) diff --git a/colossalai/logging/logger.py b/colossalai/logging/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..af7b7de54a8d481312c4934ac1a24a2f1bcaf7e0 --- /dev/null +++ b/colossalai/logging/logger.py @@ -0,0 +1,193 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import inspect +import logging +from pathlib import Path +from typing import List, Union + +import colossalai +from colossalai.context.parallel_mode import ParallelMode + + +class DistributedLogger: + """This is a distributed event logger class essentially based on :class:`logging`. + + Args: + name (str): The name of the logger. + + Note: + The parallel_mode used in ``info``, ``warning``, ``debug`` and ``error`` + should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + + __instances = dict() + + @staticmethod + def get_instance(name: str): + """Get the unique single logger instance based on name. + + Args: + name (str): The name of the logger. + + Returns: + DistributedLogger: A DistributedLogger object + """ + if name in DistributedLogger.__instances: + return DistributedLogger.__instances[name] + else: + logger = DistributedLogger(name=name) + return logger + + def __init__(self, name): + if name in DistributedLogger.__instances: + raise Exception( + 'Logger with the same name has been created, you should use colossalai.logging.get_dist_logger') + else: + handler = None + formatter = logging.Formatter('colossalai - %(name)s - %(levelname)s: %(message)s') + try: + from rich.logging import RichHandler + handler = RichHandler(show_path=False, markup=True, rich_tracebacks=True) + handler.setFormatter(formatter) + except ImportError: + handler = logging.StreamHandler() + handler.setFormatter(formatter) + + self._name = name + self._logger = logging.getLogger(name) + self._logger.setLevel(logging.INFO) + if handler is not None: + self._logger.addHandler(handler) + self._logger.propagate = False + + DistributedLogger.__instances[name] = self + + @staticmethod + def __get_call_info(): + stack = inspect.stack() + + # stack[1] gives previous function ('info' in our case) + # stack[2] gives before previous function and so on + + fn = stack[2][1] + ln = stack[2][2] + func = stack[2][3] + + return fn, ln, func + + @staticmethod + def _check_valid_logging_level(level: str): + assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR'], 'found invalid logging level' + + def set_level(self, level: str) -> None: + """Set the logging level + + Args: + level (str): Can only be INFO, DEBUG, WARNING and ERROR. + """ + self._check_valid_logging_level(level) + self._logger.setLevel(getattr(logging, level)) + + def log_to_file(self, path: Union[str, Path], mode: str = 'a', level: str = 'INFO', suffix: str = None) -> None: + """Save the logs to file + + Args: + path (A string or pathlib.Path object): The file to save the log. + mode (str): The mode to write log into the file. + level (str): Can only be INFO, DEBUG, WARNING and ERROR. + suffix (str): The suffix string of log's name. + """ + assert isinstance(path, (str, Path)), \ + f'expected argument path to be type str or Path, but got {type(path)}' + self._check_valid_logging_level(level) + + if isinstance(path, str): + path = Path(path) + + # create log directory + path.mkdir(parents=True, exist_ok=True) + + # set the default file name if path is a directory + if not colossalai.core.global_context.is_initialized(ParallelMode.GLOBAL): + rank = 0 + else: + rank = colossalai.core.global_context.get_global_rank() + + if suffix is not None: + log_file_name = f'rank_{rank}_{suffix}.log' + else: + log_file_name = f'rank_{rank}.log' + path = path.joinpath(log_file_name) + + # add file handler + file_handler = logging.FileHandler(path, mode) + file_handler.setLevel(getattr(logging, level)) + formatter = logging.Formatter('colossalai - %(name)s - %(levelname)s: %(message)s') + file_handler.setFormatter(formatter) + self._logger.addHandler(file_handler) + + def _log(self, + level, + message: str, + parallel_mode: ParallelMode = ParallelMode.GLOBAL, + ranks: List[int] = None) -> None: + if ranks is None: + getattr(self._logger, level)(message) + else: + local_rank = colossalai.core.global_context.get_local_rank(parallel_mode) + if local_rank in ranks: + getattr(self._logger, level)(message) + + def info(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None: + """Log an info message. + + Args: + message (str): The message to be logged. + parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): + The parallel mode used for logging. Defaults to ParallelMode.GLOBAL. + ranks (List[int]): List of parallel ranks. + """ + message_prefix = "{}:{} {}".format(*self.__get_call_info()) + self._log('info', message_prefix, parallel_mode, ranks) + self._log('info', message, parallel_mode, ranks) + + def warning(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None: + """Log a warning message. + + Args: + message (str): The message to be logged. + parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): + The parallel mode used for logging. Defaults to ParallelMode.GLOBAL. + ranks (List[int]): List of parallel ranks. + """ + message_prefix = "{}:{} {}".format(*self.__get_call_info()) + self._log('warning', message_prefix, parallel_mode, ranks) + self._log('warning', message, parallel_mode, ranks) + + def debug(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None: + """Log a debug message. + + Args: + message (str): The message to be logged. + parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): + The parallel mode used for logging. Defaults to ParallelMode.GLOBAL. + ranks (List[int]): List of parallel ranks. + """ + message_prefix = "{}:{} {}".format(*self.__get_call_info()) + self._log('debug', message_prefix, parallel_mode, ranks) + self._log('debug', message, parallel_mode, ranks) + + def error(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None: + """Log an error message. + + Args: + message (str): The message to be logged. + parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): + The parallel mode used for logging. Defaults to ParallelMode.GLOBAL. + ranks (List[int]): List of parallel ranks. + """ + message_prefix = "{}:{} {}".format(*self.__get_call_info()) + self._log('error', message_prefix, parallel_mode, ranks) + self._log('error', message, parallel_mode, ranks) diff --git a/colossalai/nn/__init__.py b/colossalai/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..910ad203180c8dd533ebc7732d26d94a20b72929 --- /dev/null +++ b/colossalai/nn/__init__.py @@ -0,0 +1,6 @@ +from ._ops import * +from .layer import * +from .loss import * +from .lr_scheduler import * +from .metric import * +from .optimizer import * diff --git a/colossalai/nn/_ops/__init__.py b/colossalai/nn/_ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4991ad9a2217f904287d438ca37c8c4717a40a67 --- /dev/null +++ b/colossalai/nn/_ops/__init__.py @@ -0,0 +1,9 @@ +from .addmm import colo_addmm +from .batch_norm import colo_batch_norm +from .element_wise import * +from .embedding import colo_embedding +from .embedding_bag import colo_embedding_bag +from .layernorm import colo_layernorm +from .linear import colo_linear +from .loss import colo_cross_entropy +from .view import colo_view diff --git a/colossalai/nn/_ops/_utils.py b/colossalai/nn/_ops/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..24877bbb552f9dcd0e5b00baf783eef7f2ccecbb --- /dev/null +++ b/colossalai/nn/_ops/_utils.py @@ -0,0 +1,283 @@ +from typing import List, Optional, Union + +import torch +import torch.distributed as dist + +from colossalai.global_variables import tensor_parallel_env as env +from colossalai.nn.layer.utils import divide +from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup + +GeneralTensor = Union[ColoTensor, torch.Tensor] +Number = Union[int, float] + + +def convert_to_colo_tensor(tensor: Optional[GeneralTensor], pg: ProcessGroup) -> Optional[ColoTensor]: + if tensor is not None and not isinstance(tensor, ColoTensor): + tensor = ColoTensor.from_torch_tensor(tensor, ColoTensorSpec(pg)) + return tensor + + +def set_parallel_input(input_parallel: bool): + env.parallel_input_1d = input_parallel + + +def get_parallel_input(): + return env.parallel_input_1d + + +def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank): + index_f = rank * per_partition_vocab_size + index_l = index_f + per_partition_vocab_size + return index_f, index_l + + +def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size): + per_partition_vocab_size = divide(global_vocab_size, world_size) + return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank) + + +def _reduce(input_, pg: ProcessGroup): + # skip if only one rank involved + if pg.tp_world_size() == 1: + return input_ + assert input_.device.type == 'cuda' + group = pg.tp_process_group() + dist.all_reduce(input_, group=group) + + return input_ + + +def _split(input_, pg: ProcessGroup, dim=-1): + # skip if only one rank involved + world_size = pg.tp_world_size() + if world_size == 1: + return input_ + + # Split along last dimension. + dim_size = input_.size(dim) + assert dim_size % world_size == 0, \ + f'The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), ' \ + f'cannot split tensor evenly' + + tensor_list = torch.split(input_, dim_size // world_size, dim=dim) + rank = pg.tp_local_rank() + output = tensor_list[rank].contiguous() + + return output + + +def _gather(input_, pg: ProcessGroup, dim=-1): + # skip if only one rank involved + world_size = pg.tp_world_size() + if world_size == 1: + return input_ + + # all gather + rank = pg.tp_local_rank() + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank] = input_ + assert input_.device.type == 'cuda' + group = pg.tp_process_group() + torch.distributed.all_gather(tensor_list, input_, group=group) + + # concat + output = torch.cat(tensor_list, dim=dim).contiguous() + + return output + + +class _ReduceGrad(torch.autograd.Function): + """ + Pass the input to the model parallel region. + + Args: + input_: input matrix. + process_group: parallel mode. + """ + + @staticmethod + def symbolic(graph, input_): + return input_ + + @staticmethod + def forward(ctx, input_, process_group): + ctx.mode = process_group + return input_ + + @staticmethod + def backward(ctx, grad_output): + return _reduce(grad_output, ctx.mode), None + + +class _ReduceInput(torch.autograd.Function): + """ + All-reduce the input from the model parallel region. + + Args: + input_: input matrix. + process_group: parallel mode. + """ + + @staticmethod + def symbolic(graph, input_): + return _reduce(input_) + + @staticmethod + def forward(ctx, input_, process_group): + return _reduce(input_, process_group) + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None + + +class _SplitForwardGatherBackward(torch.autograd.Function): + """ + Split the input and keep only the corresponding chuck to the rank. + + Args: + input_: input matrix. + process_group: parallel mode. + dim: dimension + """ + + @staticmethod + def symbolic(graph, input_): + return _split(input_) + + @staticmethod + def forward(ctx, input_, process_group, dim): + ctx.mode = process_group + ctx.dim = dim + return _split(input_, process_group, dim) + + @staticmethod + def backward(ctx, grad_output): + return _gather(grad_output, ctx.mode, ctx.dim), None, None + + +class _GatherForwardSplitBackward(torch.autograd.Function): + """Gather the input from model parallel region and concatenate. + + Args: + input_: input matrix. + process_group: parallel mode. + dim: dimension + """ + + @staticmethod + def symbolic(graph, input_): + return _gather(input_) + + @staticmethod + def forward(ctx, input_, process_group, dim): + ctx.mode = process_group + ctx.dim = dim + return _gather(input_, process_group, dim) + + @staticmethod + def backward(ctx, grad_output): + return _split(grad_output, ctx.mode, ctx.dim), None, None + + +def reduce_grad(input_, process_group): + return _ReduceGrad.apply(input_, process_group) + + +def reduce_input(input_, process_group): + return _ReduceInput.apply(input_, process_group) + + +def split_forward_gather_backward(input_, process_group, dim): + return _SplitForwardGatherBackward.apply(input_, process_group, dim) + + +def gather_forward_split_backward(input_, process_group, dim): + return _GatherForwardSplitBackward.apply(input_, process_group, dim) + + +def _all_to_all(x: torch.Tensor, pg: ProcessGroup, scatter_dim: int, gather_dim: int) -> torch.Tensor: + world_size = pg.tp_world_size() + if world_size == 1: + return x + + # TODO: enabling mpi backend to support CPU all_to_all + assert x.device.type == 'cuda', f"Currently, the collective function dual_all_to_all only supports nccl backend" + + shapes = list(x.size()) + shapes[scatter_dim] = shapes[scatter_dim] // world_size + + scatter_list = [each.contiguous() for each in torch.tensor_split(x, world_size, scatter_dim)] + gather_list = [torch.empty(*shapes, dtype=x.dtype, device=x.device) for _ in range(world_size)] + torch.distributed.all_to_all(gather_list, scatter_list, group=pg.tp_process_group()) + + return torch.cat(gather_list, dim=gather_dim).contiguous() + + +class _DualAllToAll(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, pg, scatter_dim, gather_dim): + ctx.scatter_dim = scatter_dim + ctx.gather_dim = gather_dim + ctx.pg = pg + return _all_to_all(x, pg, scatter_dim, gather_dim) + + @staticmethod + def backward(ctx, grad): + return _all_to_all(grad, ctx.pg, ctx.gather_dim, ctx.scatter_dim), None, None, None + + +def dual_all_to_all(x, pg, scatter_dim: int, gather_dim: int): + return _DualAllToAll.apply(x, pg, scatter_dim, gather_dim) + + +### table wise embedding shard + + +def _all_to_all_for_tablewise(x: torch.Tensor, + pg: ProcessGroup, + scatter_strides: List[int], + gather_strides: List[int], + forward=True) -> torch.Tensor: + world_size = pg.tp_world_size() + rank = pg.tp_local_rank() + if world_size == 1: + return x + assert x.device.type == 'cuda', f"Currently, the collective function dual_all_to_all only supports nccl backend" + if forward: + scatter_list = list(x.split(scatter_strides, 0)) + gather_list = [ + torch.empty(scatter_strides[rank], gather_strides[i], dtype=x.dtype, device=x.device) + for i in range(world_size) + ] + torch.distributed.all_to_all(gather_list, scatter_list, group=pg.tp_process_group()) + return torch.cat(gather_list, 1).contiguous() + else: + # split on dim 1, lose contiguity + scatter_list = [each.contiguous() for each in x.split(scatter_strides, 1)] + gather_list = [ + torch.empty(gather_strides[i], scatter_strides[rank], dtype=x.dtype, device=x.device) + for i in range(world_size) + ] + torch.distributed.all_to_all(gather_list, scatter_list, group=pg.tp_process_group()) + return torch.cat(gather_list, 0).contiguous() + + +class _DualAllToAllForTablewise(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, pg, scatter_strides, gather_strides): + ctx.pg = pg + ctx.scatter_strides = scatter_strides + ctx.gather_strides = gather_strides + return _all_to_all_for_tablewise(x, pg, scatter_strides, gather_strides, forward=True) + + @staticmethod + def backward(ctx, grad): + return _all_to_all_for_tablewise(grad, ctx.pg, ctx.gather_strides, ctx.scatter_strides, + forward=False), None, None, None + + +def dual_all_to_all_tablewise(x, pg, scatter_strides, gather_strides): + return _DualAllToAllForTablewise.apply(x, pg, scatter_strides, gather_strides) diff --git a/colossalai/nn/_ops/addmm.py b/colossalai/nn/_ops/addmm.py new file mode 100644 index 0000000000000000000000000000000000000000..660b48a71d57ee4cc9b1e0cc904f9b7f8d69e244 --- /dev/null +++ b/colossalai/nn/_ops/addmm.py @@ -0,0 +1,90 @@ +import torch + +from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec, distspec +from colossalai.tensor.op_wrapper import colo_op_impl + +from ._utils import GeneralTensor, Number, convert_to_colo_tensor, reduce_grad, reduce_input + + +def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number, + alpha: Number) -> ColoTensor: + # mat1:S[1] x mat2:S[0] = Output:P + # beta * input + alpha * All-Reduce(Output) = res + + mat1 = mat1.redistribute(ShardSpec([-1], [mat2.get_tp_world_size()]), mat2.get_process_group()) + + # Output:P + partial_output = torch.mm(mat1, mat2) + # Reduce(Output) + output = reduce_input(partial_output, mat2.get_process_group()) + # input + assert not input_tensor.has_compute_spec(), 'Invalid input spec for 1Drow addmm op' + output = beta * input_tensor + alpha * output + output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(input_tensor.get_process_group())) + return output + + +def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number, + alpha: Number) -> ColoTensor: + # mat1:B x mat2:S[1] + input:S[1] = Output:S[1] + compute_spec = mat2.compute_spec + mat1 = mat1.redistribute(ReplicaSpec()) + mat1 = reduce_grad(mat1, mat1.get_process_group()) + + output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha) + output_spec = ColoTensorSpec(input_tensor.get_process_group(), ShardSpec([-1], [mat2.get_tp_world_size()]), + ComputeSpec(ComputePattern.TP1D)) + output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) + + if compute_spec.output_replicate: + return output.to_replicate() + else: + return output + + +def colo_addmm_1d(mode: str, input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number, + alpha: Number) -> ColoTensor: + assert mode in ('row', 'col') + funcs = {'row': colo_addmm_1Drow, 'col': colo_addmm_1Dcol} + return funcs[mode](input_tensor, mat1, mat2, beta, alpha) + + +@colo_op_impl(torch.addmm) +def colo_addmm(input_tensor: GeneralTensor, + mat1: ColoTensor, + mat2: ColoTensor, + beta: Number = 1, + alpha: Number = 1, + **kargs) -> ColoTensor: + """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``. + This method computes a linear. + """ + # At least one of the tensor should be ColoTensor + assert isinstance(mat2, ColoTensor) + input_tensor = convert_to_colo_tensor(input_tensor, mat2.get_process_group()) + mat1 = convert_to_colo_tensor(mat1, mat2.get_process_group()) + + # Add communication logic before and after linear call. + ret_tensor = None + if not mat2.has_compute_spec(): # No Model Parallel Applied + assert mat2.is_replicate(), 'Invalid mat2 spec for native addmm op' + assert input_tensor.is_replicate(), 'Invalid input spec for native addmm op' + ret_tensor = ColoTensor.from_torch_tensor(tensor=torch.addmm(input_tensor, + mat1, + mat2, + beta=beta, + alpha=alpha, + **kargs), + spec=ColoTensorSpec(mat2.get_process_group())) + elif mat2.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied + if mat2.is_shard_1drow() and input_tensor.is_replicate(): + mode = 'row' + elif mat2.is_shard_1dcol() and (input_tensor.is_shard_1dcol() or input_tensor.is_shard_1drow()): + mode = 'col' + else: + raise NotImplementedError + ret_tensor = colo_addmm_1d(mode, input_tensor, mat1, mat2, beta, alpha) + else: + raise NotImplementedError + + return ret_tensor diff --git a/colossalai/nn/_ops/batch_norm.py b/colossalai/nn/_ops/batch_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..54ecc88f420a8d8a2c81aca6ac765e57f5a56cac --- /dev/null +++ b/colossalai/nn/_ops/batch_norm.py @@ -0,0 +1,33 @@ +from typing import Optional + +import torch.nn.functional as F + +from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec +from colossalai.tensor.op_wrapper import colo_op_impl + +from ._utils import GeneralTensor, convert_to_colo_tensor + + +@colo_op_impl(F.batch_norm) +def colo_batch_norm( + input: GeneralTensor, + running_mean: Optional[GeneralTensor], + running_var: Optional[GeneralTensor], + weight: Optional[GeneralTensor] = None, + bias: Optional[GeneralTensor] = None, + training: bool = False, + momentum: float = 0.1, + eps: float = 1e-5, +): + assert isinstance(weight, ColoTensor) + running_mean = running_mean.detach() + running_var = running_var.detach() + + input = convert_to_colo_tensor(input, weight.get_process_group()) + bias = convert_to_colo_tensor(bias, weight.get_process_group()) + input = input.redistribute(ReplicaSpec()) + bias = bias.redistribute(ReplicaSpec()) + + output = F.batch_norm(input, running_mean, running_var, weight, bias, training, momentum, eps) + output = ColoTensor.from_torch_tensor(tensor=output, spec=ColoTensorSpec(pg=weight.get_process_group())) + return output diff --git a/colossalai/nn/_ops/element_wise.py b/colossalai/nn/_ops/element_wise.py new file mode 100644 index 0000000000000000000000000000000000000000..2de51e24a6dd1bd45271a0b8d51372ee5209415d --- /dev/null +++ b/colossalai/nn/_ops/element_wise.py @@ -0,0 +1,250 @@ +import torch +import torch.nn.functional as F +from torch import Tensor + +from colossalai.tensor import ColoTensor, ColoTensorSpec +from colossalai.tensor.op_wrapper import colo_op_impl + +from ._utils import GeneralTensor, convert_to_colo_tensor + + +def register_elementwise_op(op): + + @colo_op_impl(op) + def elementwise_op(input_tensor: GeneralTensor, *args, **kwargs): + """ + Handles ``__torch_function__`` dispatch for the elementwise op such + as ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``. + This method computes on either a normal tensor or a sharded tensor. + """ + if 'inplace' in kwargs: + # TODO(jiaruifang) inplace will cause bugs + input_tensor = input_tensor.clone() + return op(input_tensor, *args, **kwargs) + else: + output = op(input_tensor, *args, **kwargs) + # return output + if isinstance(input_tensor, ColoTensor): + if isinstance(output, str): + return output + if not isinstance(output, torch.Tensor): + raise NotImplementedError + return ColoTensor.from_torch_tensor(output, + spec=ColoTensorSpec(input_tensor.get_process_group(), + dist_attr=input_tensor.dist_spec)) + + +# @colo_op_impl(torch.relu_) +# def elementwise_op(input_tensor): +# torch.relu_(input_tensor.data) +# return input_tensor + +# @colo_op_impl(Tensor.add_) +# def elementwise_op(input_tensor: ColoTensor, *args, **kwargs): +# input_tensor = input_tensor.data.add_(*args, **kwargs) +# return input_tensor + +# Tensor op +register_elementwise_op(Tensor.abs) +register_elementwise_op(Tensor.absolute) +register_elementwise_op(Tensor.acos) +register_elementwise_op(Tensor.arccos) +register_elementwise_op(Tensor.angle) +register_elementwise_op(Tensor.asin) +register_elementwise_op(Tensor.arcsin) +register_elementwise_op(Tensor.atan) +register_elementwise_op(Tensor.arctan) +register_elementwise_op(Tensor.all) +register_elementwise_op(Tensor.any) +register_elementwise_op(Tensor.bernoulli) +register_elementwise_op(Tensor.bfloat16) +register_elementwise_op(Tensor.bitwise_not) +register_elementwise_op(Tensor.bool) +register_elementwise_op(Tensor.byte) +register_elementwise_op(Tensor.ceil) +register_elementwise_op(Tensor.char) +register_elementwise_op(Tensor.clamp) +register_elementwise_op(Tensor.clamp_max) +register_elementwise_op(Tensor.clamp_min) +register_elementwise_op(Tensor.clip) +register_elementwise_op(Tensor.clone) +register_elementwise_op(Tensor.contiguous) +register_elementwise_op(Tensor.copysign) +register_elementwise_op(Tensor.cos) +register_elementwise_op(Tensor.cosh) +register_elementwise_op(Tensor.acosh) +register_elementwise_op(Tensor.arccosh) +register_elementwise_op(Tensor.cpu) +register_elementwise_op(Tensor.cuda) +register_elementwise_op(Tensor.deg2rad) +register_elementwise_op(Tensor.detach) +register_elementwise_op(Tensor.digamma) +register_elementwise_op(Tensor.double) +register_elementwise_op(Tensor.erf) +register_elementwise_op(Tensor.erfc) +register_elementwise_op(Tensor.erfinv) +register_elementwise_op(Tensor.exp) +register_elementwise_op(Tensor.expm1) +register_elementwise_op(Tensor.fix) +register_elementwise_op(Tensor.trunc) +register_elementwise_op(Tensor.float) +register_elementwise_op(Tensor.float_power) +register_elementwise_op(Tensor.floor) +register_elementwise_op(Tensor.frac) +register_elementwise_op(Tensor.half) +register_elementwise_op(Tensor.hardshrink) +register_elementwise_op(Tensor.heaviside) +register_elementwise_op(Tensor.i0) +register_elementwise_op(Tensor.int) +register_elementwise_op(Tensor.isfinite) +register_elementwise_op(Tensor.isinf) +register_elementwise_op(Tensor.isposinf) +register_elementwise_op(Tensor.isneginf) +register_elementwise_op(Tensor.isnan) +register_elementwise_op(Tensor.lgamma) +register_elementwise_op(Tensor.log) +register_elementwise_op(Tensor.log10) +register_elementwise_op(Tensor.log1p) +register_elementwise_op(Tensor.log2) +register_elementwise_op(Tensor.logical_not) +register_elementwise_op(Tensor.logit) +register_elementwise_op(Tensor.long) +register_elementwise_op(Tensor.nan_to_num) +register_elementwise_op(Tensor.neg) +register_elementwise_op(Tensor.negative) +register_elementwise_op(Tensor.positive) +register_elementwise_op(Tensor.pow) +register_elementwise_op(Tensor.rad2deg) +register_elementwise_op(Tensor.reciprocal) +register_elementwise_op(Tensor.round) +register_elementwise_op(Tensor.rsqrt) +register_elementwise_op(Tensor.short) +register_elementwise_op(Tensor.sigmoid) +register_elementwise_op(Tensor.sign) +register_elementwise_op(Tensor.signbit) +register_elementwise_op(Tensor.sgn) +register_elementwise_op(Tensor.sin) +register_elementwise_op(Tensor.sinc) +register_elementwise_op(Tensor.sinh) +register_elementwise_op(Tensor.asinh) +register_elementwise_op(Tensor.arcsinh) +register_elementwise_op(Tensor.sqrt) +register_elementwise_op(Tensor.square) +register_elementwise_op(Tensor.to) +register_elementwise_op(Tensor.tan) +register_elementwise_op(Tensor.tanh) +register_elementwise_op(Tensor.atanh) +register_elementwise_op(Tensor.arctanh) +register_elementwise_op(Tensor.type) +register_elementwise_op(Tensor.type_as) + +# torch OP +register_elementwise_op(torch.abs) +register_elementwise_op(torch.absolute) +register_elementwise_op(torch.acos) +register_elementwise_op(torch.arccos) +register_elementwise_op(torch.angle) +register_elementwise_op(torch.asin) +register_elementwise_op(torch.arcsin) +register_elementwise_op(torch.atan) +register_elementwise_op(torch.arctan) +register_elementwise_op(torch.all) +register_elementwise_op(torch.any) +register_elementwise_op(torch.bernoulli) +register_elementwise_op(torch.bitwise_not) +register_elementwise_op(torch.ceil) +register_elementwise_op(torch.clamp) +register_elementwise_op(torch.clamp_max) +register_elementwise_op(torch.clamp_min) +register_elementwise_op(torch.clip) +register_elementwise_op(torch.clone) +register_elementwise_op(torch.copysign) +register_elementwise_op(torch.cos) +register_elementwise_op(torch.cosh) +register_elementwise_op(torch.acosh) +register_elementwise_op(torch.arccosh) +register_elementwise_op(torch.deg2rad) +register_elementwise_op(torch.digamma) +register_elementwise_op(torch.erf) +register_elementwise_op(torch.erfc) +register_elementwise_op(torch.erfinv) +register_elementwise_op(torch.exp) +register_elementwise_op(torch.expm1) +register_elementwise_op(torch.fix) +register_elementwise_op(torch.trunc) +register_elementwise_op(torch.float_power) +register_elementwise_op(torch.floor) +register_elementwise_op(torch.frac) +register_elementwise_op(torch.hardshrink) +register_elementwise_op(torch.heaviside) +register_elementwise_op(torch.i0) +register_elementwise_op(torch.isfinite) +register_elementwise_op(torch.isinf) +register_elementwise_op(torch.isposinf) +register_elementwise_op(torch.isneginf) +register_elementwise_op(torch.isnan) +register_elementwise_op(torch.lgamma) +register_elementwise_op(torch.log) +register_elementwise_op(torch.log10) +register_elementwise_op(torch.log1p) +register_elementwise_op(torch.log2) +register_elementwise_op(torch.logical_not) +register_elementwise_op(torch.logit) +register_elementwise_op(torch.nan_to_num) +register_elementwise_op(torch.neg) +register_elementwise_op(torch.negative) +register_elementwise_op(torch.positive) +register_elementwise_op(torch.pow) +register_elementwise_op(torch.rad2deg) +register_elementwise_op(torch.reciprocal) +register_elementwise_op(torch.round) +register_elementwise_op(torch.rsqrt) +register_elementwise_op(torch.sigmoid) +register_elementwise_op(torch.sign) +register_elementwise_op(torch.signbit) +register_elementwise_op(torch.sgn) +register_elementwise_op(torch.sin) +register_elementwise_op(torch.sinc) +register_elementwise_op(torch.sinh) +register_elementwise_op(torch.asinh) +register_elementwise_op(torch.arcsinh) +register_elementwise_op(torch.sqrt) +register_elementwise_op(torch.square) +register_elementwise_op(torch.tan) +register_elementwise_op(torch.tanh) +register_elementwise_op(torch.atanh) +register_elementwise_op(torch.arctanh) +register_elementwise_op(torch.zeros_like) + +# nn.functional OP +register_elementwise_op(F.threshold) +register_elementwise_op(F.relu) +register_elementwise_op(F.hardtanh) +register_elementwise_op(F.hardswish) +register_elementwise_op(F.relu6) +register_elementwise_op(F.elu) +register_elementwise_op(F.selu) +register_elementwise_op(F.celu) +register_elementwise_op(F.leaky_relu) +register_elementwise_op(F.prelu) +register_elementwise_op(F.rrelu) +register_elementwise_op(F.gelu) +register_elementwise_op(F.logsigmoid) +register_elementwise_op(F.hardshrink) +register_elementwise_op(F.tanhshrink) +register_elementwise_op(F.softsign) +register_elementwise_op(F.softplus) +register_elementwise_op(F.softmin) +register_elementwise_op(F.softmax) +register_elementwise_op(F.softshrink) +register_elementwise_op(F.gumbel_softmax) +register_elementwise_op(F.log_softmax) +register_elementwise_op(F.tanh) +register_elementwise_op(F.sigmoid) +register_elementwise_op(F.hardsigmoid) +register_elementwise_op(F.silu) +register_elementwise_op(F.mish) +# TODO(ver217): dropout handles seed +register_elementwise_op(F.dropout) +register_elementwise_op(F.alpha_dropout) +register_elementwise_op(F.feature_alpha_dropout) diff --git a/colossalai/nn/_ops/embedding.py b/colossalai/nn/_ops/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..a045f305b5dc72454043298b2a69f114ad50f1e9 --- /dev/null +++ b/colossalai/nn/_ops/embedding.py @@ -0,0 +1,140 @@ +import torch.nn.functional as F +from typing import Optional +from colossalai.tensor.op_wrapper import colo_op_impl +from colossalai.tensor import ComputePattern, ColoTensorSpec, ComputePattern, ComputeSpec, ColoTensor, ShardSpec, \ + ReplicaSpec +from ._utils import GeneralTensor, convert_to_colo_tensor, reduce_input + + +def colo_embedding_1Dcol(input_tensor: ColoTensor, + weight: ColoTensor, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False) -> ColoTensor: + # embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P) + # Gather splitted lookup table + input_tensor = input_tensor.redistribute(ReplicaSpec()) + + output_parallel = F.embedding(input_tensor, + weight, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse) + output_spec = ColoTensorSpec(weight.get_process_group(), ShardSpec([-1], [weight.get_tp_world_size()]), + ComputeSpec(ComputePattern.TP1D)) + output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) + + compute_spec = weight.compute_spec + + if compute_spec.output_replicate: + return output.to_replicate() + else: + return output + + +def colo_embedding_1Drow(input_tensor: ColoTensor, + weight: ColoTensor, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False) -> ColoTensor: + # embedding_1Drow splits the weight(lookup table) to the shape, [num_embeddings/P, embedding_dim] + # get the index of current segment and mask other segments with 0 + + # get complete input tensor through all-gather + input_tensor = input_tensor.redistribute(ReplicaSpec()) + + # tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + tensor_parallel_rank = weight.get_process_group().tp_local_rank() + num_embeddings_per_partition = weight.size_local(0) + vocab_start_index = tensor_parallel_rank * num_embeddings_per_partition + vocab_end_index = vocab_start_index + num_embeddings_per_partition + + # build the mask. + input_mask = (input_tensor < vocab_start_index) | (input_tensor >= vocab_end_index) + # mask the input. + # TODO(jzy) masked_input may be an activation managed by ColoTensor. + masked_input = input_tensor - vocab_start_index + masked_input[input_mask] = 0 + + partial_output = F.embedding(masked_input, + weight, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse) + + # Mask the output embedding. + partial_output[input_mask, :] = 0. + # Reduce across all the model parallel GPUs. + output = reduce_input(partial_output, weight.get_process_group()) + output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(weight.get_process_group(), ReplicaSpec())) + return output + + +def colo_embedding_1d(mode: str, + input_tensor: ColoTensor, + weight: ColoTensor, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False) -> ColoTensor: + assert mode in ('row', 'col') + funcs = {'row': colo_embedding_1Drow, 'col': colo_embedding_1Dcol} + return funcs[mode](input_tensor, + weight, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse) + + +@colo_op_impl(F.embedding) +def colo_embedding(input_tensor: GeneralTensor, + weight: GeneralTensor, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False): + """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding``. + This method looks up an embedding table. + """ + assert isinstance(weight, ColoTensor) + input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group()) + + if not weight.has_compute_spec(): # No Model Parallel Applied + assert weight.is_replicate(), 'Invalid weight spec for native embedding op' + return ColoTensor.from_torch_tensor(tensor=F.embedding(input_tensor, + weight, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse), + spec=ColoTensorSpec(weight.get_process_group())) + elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied + if weight.is_shard_1drow(): + mode = 'row' + elif weight.is_shard_1dcol(): + mode = 'col' + else: + raise NotImplementedError + return colo_embedding_1d(mode, + input_tensor, + weight, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse) + else: + raise NotImplementedError diff --git a/colossalai/nn/_ops/embedding_bag.py b/colossalai/nn/_ops/embedding_bag.py new file mode 100644 index 0000000000000000000000000000000000000000..0026f579b6dccc1344ba0f32cf47d4131f30f7e6 --- /dev/null +++ b/colossalai/nn/_ops/embedding_bag.py @@ -0,0 +1,125 @@ +import torch.nn.functional as F +from typing import Optional +from torch import Tensor +from colossalai.tensor.op_wrapper import colo_op_impl +from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor, distspec, ColoTensorSpec, \ + ShardSpec, ReplicaSpec +from ._utils import GeneralTensor, convert_to_colo_tensor + + +def colo_embedding_bag_1Dcol(input_tensor: ColoTensor, + weight: ColoTensor, + offsets: Optional[Tensor] = None, + max_norm: Optional[float] = None, + norm_type: float = 2, + scale_grad_by_freq: bool = False, + mode: str = "mean", + sparse: bool = False, + per_sample_weights: Optional[Tensor] = None, + include_last_offset: bool = False, + padding_idx: Optional[int] = None) -> ColoTensor: + # embedding_bag_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P) + # Gather splitted lookup table + pg = weight.get_process_group() + input_tensor = input_tensor.redistribute(ReplicaSpec()) + + output_parallel = F.embedding_bag(input_tensor, + weight, + offsets=offsets, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + mode=mode, + sparse=sparse, + per_sample_weights=per_sample_weights, + include_last_offset=include_last_offset, + padding_idx=padding_idx) + output_spec = ColoTensorSpec(pg, ShardSpec([-1], [weight.get_tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) + + if weight.compute_spec.output_replicate: + return output.to_replicate() + else: + return output + + +def colo_embedding_bag_1d(tp_mode: str, + input_tensor: ColoTensor, + weight: ColoTensor, + offsets: Optional[Tensor] = None, + max_norm: Optional[float] = None, + norm_type: float = 2, + scale_grad_by_freq: bool = False, + mode: str = "mean", + sparse: bool = False, + per_sample_weights: Optional[Tensor] = None, + include_last_offset: bool = False, + padding_idx: Optional[int] = None) -> ColoTensor: + assert tp_mode in ('col',) + funcs = {'col': colo_embedding_bag_1Dcol} + return funcs[tp_mode](input_tensor, + weight, + offsets=offsets, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + mode=mode, + sparse=sparse, + per_sample_weights=per_sample_weights, + include_last_offset=include_last_offset, + padding_idx=padding_idx) + + +@colo_op_impl(F.embedding_bag) +def colo_embedding_bag(input_tensor: GeneralTensor, + weight: GeneralTensor, + offsets: Optional[Tensor] = None, + max_norm: Optional[float] = None, + norm_type: float = 2, + scale_grad_by_freq: bool = False, + mode: str = "mean", + sparse: bool = False, + per_sample_weights: Optional[Tensor] = None, + include_last_offset: bool = False, + padding_idx: Optional[int] = None): + """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding_bag``. + This method looks up an embedding table. + """ + assert isinstance(weight, ColoTensor) + input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group()) + + # Handle different parallel actions. + + if not weight.has_compute_spec(): # No Model Parallel Applied + assert weight.is_replicate(), 'Invalid weight spec for native embedding op' + return ColoTensor.from_torch_tensor(tensor=F.embedding_bag(input_tensor, + weight, + offsets=offsets, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + mode=mode, + sparse=sparse, + per_sample_weights=per_sample_weights, + include_last_offset=include_last_offset, + padding_idx=padding_idx), + spec=ColoTensorSpec(weight.get_process_group())) + elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied + if weight.is_shard_1dcol(): + tp_mode = 'col' + else: + raise NotImplementedError + return colo_embedding_bag_1d(tp_mode, + input_tensor, + weight, + offsets=offsets, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + mode=mode, + sparse=sparse, + per_sample_weights=per_sample_weights, + include_last_offset=include_last_offset, + padding_idx=padding_idx) + else: + raise NotImplementedError diff --git a/colossalai/nn/_ops/layernorm.py b/colossalai/nn/_ops/layernorm.py new file mode 100644 index 0000000000000000000000000000000000000000..2b761b84e3ee8aa9dcaf7ef1ba054b6857aa4b49 --- /dev/null +++ b/colossalai/nn/_ops/layernorm.py @@ -0,0 +1,25 @@ +from typing import List, Optional +import torch.nn.functional as F +from colossalai.tensor.op_wrapper import colo_op_impl +from colossalai.tensor import ColoTensor, distspec, ColoTensorSpec, ReplicaSpec +from ._utils import GeneralTensor, convert_to_colo_tensor + + +@colo_op_impl(F.layer_norm) +def colo_layernorm( + input_tensor: GeneralTensor, + normalized_shape: List[int], + weight: Optional[GeneralTensor] = None, + bias: Optional[GeneralTensor] = None, + eps: float = 1e-5, +): + assert isinstance(weight, ColoTensor) + input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group()) + bias = convert_to_colo_tensor(bias, weight.get_process_group()) + input_tensor = input_tensor.redistribute(ReplicaSpec()) + + output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps) + output = ColoTensor.from_torch_tensor(tensor=output, + spec=ColoTensorSpec(pg=input_tensor.get_process_group(), + dist_attr=input_tensor.dist_spec)) + return output diff --git a/colossalai/nn/_ops/linear.py b/colossalai/nn/_ops/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..2f2088c61fa842b7eff6eecd348e6ff7d42916cf --- /dev/null +++ b/colossalai/nn/_ops/linear.py @@ -0,0 +1,171 @@ +from copy import deepcopy +from typing import Optional + +import torch.nn.functional as F + +from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec +from colossalai.tensor.op_wrapper import colo_op_impl +from colossalai.tensor.sharding_spec import ShardingSpec + +from ._utils import GeneralTensor, convert_to_colo_tensor, reduce_grad, reduce_input + + +def colo_linear_1drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor': + # Input:S[1] x Weight:S[0] = Output:P + # All-Reduce(Output) + bias = res + # Input:S[1] + pg = weight.get_process_group() + input_tensor = input_tensor.redistribute(ShardSpec([-1], [weight.get_tp_world_size()]), pg) + + # Output:P + partial_output = F.linear(input_tensor, weight) + # Reduce(Output) + + output = reduce_input(partial_output, pg) + # Bias + if bias is not None: + assert not bias.has_compute_spec(), 'Invalid bias spec for 1Drow Linear op' + output = output + bias + + output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(pg, ReplicaSpec())) + return output + + +def colo_linear_1dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor': + # Input:B x Weight:S[1] + Bias:S[1] = Output:S[1] + # All-Gather(Output) + # Input:B + compute_spec = weight.compute_spec + input_tensor = input_tensor.redistribute(ReplicaSpec()) + input_parallel = reduce_grad(input_tensor, weight.get_process_group()) + + output_parallel = F.linear(input_parallel, weight, bias) + output = ColoTensor.from_torch_tensor(output_parallel, + spec=ColoTensorSpec(weight.get_process_group(), + ShardSpec([-1], [weight.get_tp_world_size()]), + ComputeSpec(ComputePattern.TP1D))) + if compute_spec.output_replicate: + return output.to_replicate() + else: + return output + + +def colo_linear_1d(mode: str, input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor': + assert mode in ('row', 'col') + funcs = {'row': colo_linear_1drow, 'col': colo_linear_1dcol} + return funcs[mode](input_tensor, weight, bias) + + +# @register_colo_graph(input_pos=[1], param_pos=[2, 3]) +def colo_linear_imp(input_tensor: GeneralTensor, + weight: GeneralTensor, + bias: Optional[GeneralTensor] = None) -> 'ColoTensor': + """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``. + This method computes a linear. + """ + assert isinstance(weight, ColoTensor) + pg = weight.get_process_group() + assert pg + input_tensor = convert_to_colo_tensor(input_tensor, pg) + bias = convert_to_colo_tensor(bias, pg) + # input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias))) + + # Add communication logic before and after linear call. + ret_tensor = None + if not weight.has_compute_spec(): # No Model Parallel Applied + assert weight.is_replicate(), 'Invalid weight spec for native Linear op' + assert bias is None or bias.is_replicate(), 'Invalid bias spec for native Linear op' + ret_tensor = ColoTensor.from_torch_tensor(F.linear(input_tensor, weight, bias), spec=ColoTensorSpec(pg)) + elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied + if weight.is_shard_1dcol() and (bias is None or bias.is_replicate()): + mode = 'row' + elif weight.is_shard_1drow() and (bias is None or bias.is_shard_1drow() or bias.is_shard_1dcol()): + mode = 'col' + else: + raise RuntimeError(f"the weight or bias tensor spec is not valid, weight {weight}, bias {bias}") + ret_tensor = colo_linear_1d(mode, input_tensor, weight, bias) + else: + raise NotImplementedError + + return ret_tensor + + +def _new_colo_linear_imp(input_tensor: GeneralTensor, + weight: GeneralTensor, + bias: Optional[GeneralTensor] = None) -> 'ColoTensor': + """ + A tentative function to compute the distributed linear layer with the latest sharding spec. + This function is subject to future change as the current sharding API is not stable. + """ + # get mesh info + input_sharding_seq = input_tensor.sharding_spec.sharding_sequence + weight_sharding_seq = weight.sharding_spec.sharding_sequence + if bias is not None: + bias_sharding_seq = bias.sharding_spec.sharding_sequence + device_mesh = weight.sharding_spec.device_mesh + pg_axis0 = weight.pg_axis0 + pg_axis1 = weight.pg_axis1 + + # the last dim of input should have the same spec as the first dim of weight + # the weight is transposed, so we look at the second dimension + assert input_sharding_seq[-1] == weight_sharding_seq[1] + + if bias is not None: + assert bias_sharding_seq[0] == weight_sharding_seq[0] + + # compute the output sharding sequence + # as weight is transposed, so we look at the first dimension + output_shard_seq = input_sharding_seq[:-1] + weight_sharding_seq[:1] + output_shard_seq = deepcopy(output_shard_seq) + + # TODO: add reduce grad logic + + # handle column and row parallel linear + # by reusing the implementation above + out = F.linear(input_tensor, weight) + + # run all reduce if necessary + last_dim_spec = input_sharding_seq[-1] + if last_dim_spec.is_replica: + pass + elif last_dim_spec.shard_list is not None: + for dim in last_dim_spec.shard_list: + if dim == 0: + reduce_input(out, pg_axis0) + elif dim == 1: + reduce_input(out, pg_axis1) + else: + raise RuntimeError("Found invalid sharding axis {dim}, only 0 or 1 is expected") + # add bias + if bias is not None: + out += bias + + # convert shard seq to partition dict + output_partition_dict = {} + for index, dim_spec in enumerate(output_shard_seq): + if not dim_spec.is_replica: + if index not in output_partition_dict: + output_partition_dict[index] = [] + output_partition_dict[index].extend(dim_spec.shard_list) + + entire_shape = out.shape + output_sharding_spec = ShardingSpec(device_mesh, entire_shape, output_partition_dict) + ret_tensor = ColoTensor.from_torch_tensor(out) + setattr(ret_tensor, 'sharding_spec', output_sharding_spec) + return ret_tensor + + +def _has_sharding_spec(tensor): + """ + A tentative function to check whether the tensor is using the new sharding spec API. We assume that the sharding spec object is + set as the attribute `sharding_spec` on a tensor. + """ + return hasattr(tensor, 'sharding_spec') + + +@colo_op_impl(F.linear) +def colo_linear(input: GeneralTensor, weight: GeneralTensor, bias: Optional[GeneralTensor] = None) -> 'ColoTensor': + if _has_sharding_spec(weight): + return _new_colo_linear_imp(input, weight, bias) + else: + return colo_linear_imp(input, weight, bias) diff --git a/colossalai/nn/_ops/loss.py b/colossalai/nn/_ops/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..1e54f662859ceffa3edb66157c187c6e17658142 --- /dev/null +++ b/colossalai/nn/_ops/loss.py @@ -0,0 +1,48 @@ +import torch +import torch.nn.functional as F +from typing import Optional +from colossalai.tensor.op_wrapper import colo_op_impl +from colossalai.tensor import ColoTensor, ColoTensorSpec +from colossalai.nn.loss.loss_1d import VocabParallelCrossEntropyLoss1D +from ._utils import GeneralTensor, convert_to_colo_tensor + + +@colo_op_impl(F.cross_entropy) +def colo_cross_entropy(input_tensor: GeneralTensor, + target: GeneralTensor, + weight: Optional[GeneralTensor] = None, + size_average: Optional[bool] = None, + ignore_index: int = -100, + reduce: Optional[bool] = None, + reduction: str = "mean", + label_smoothing: float = 0.0): + assert isinstance(weight, ColoTensor) or isinstance(target, ColoTensor) or isinstance(input_tensor, ColoTensor) + pg = input_tensor.get_process_group() if isinstance(input_tensor, ColoTensor) else isinstance(target, ColoTensor) + weight = convert_to_colo_tensor(weight, pg) + target = convert_to_colo_tensor(target, pg) + input_tensor = convert_to_colo_tensor(input_tensor, pg) + + if input_tensor.is_replicate(): # Input is gathered + assert target.is_replicate() and (weight is None or weight.is_replicate()), \ + "Target tensor and weight tensor both should be complete" + output = F.cross_entropy(input_tensor, + target, + weight=weight, + size_average=size_average, + ignore_index=ignore_index, + reduce=reduce, + reduction=reduction, + label_smoothing=label_smoothing) + return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg)) + elif input_tensor.has_compute_spec(): # Single Model Parallel Applied + if input_tensor.is_shard_1dcol(): + assert weight is None, "Current TP cross entropy loss function doesn't support passing weight tensor in" + assert target.is_replicate(), "Target tensor should be complete in TP cross entropy loss function" + output = VocabParallelCrossEntropyLoss1D()(input_tensor, + target, + process_group=input_tensor.process_group.tp_process_group()) + return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg)) + else: + raise NotImplementedError + else: + raise NotImplementedError diff --git a/colossalai/nn/_ops/view.py b/colossalai/nn/_ops/view.py new file mode 100644 index 0000000000000000000000000000000000000000..3c0bc52337ce8f053f04fbc2885fa5aba879990a --- /dev/null +++ b/colossalai/nn/_ops/view.py @@ -0,0 +1,96 @@ +import operator +from functools import reduce +from typing import Optional, Union + +import torch + +from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec +from colossalai.tensor.op_wrapper import colo_op_impl + + +def _all_int(my_iter): + return all(isinstance(i, int) for i in my_iter) + + +def _get_valid_shape(shape): + if isinstance(shape, list): + if _all_int(shape): + return tuple(shape) + else: + raise RuntimeError("expects type(int) but finds an other type") + elif isinstance(shape, tuple): + if _all_int(shape): + return shape + else: + return _get_valid_shape(shape[0]) + else: + raise RuntimeError("expects an iterable array but finds '{}'".format(type(shape))) + + +def _shape_infer(org_sp, tgt_sp): + cnt = 0 + pos = 0 + for idx, dim in enumerate(tgt_sp): + if dim < -1: + raise RuntimeError("invalid shape dimension {}".format(dim)) + elif dim == -1: + cnt += 1 + pos = idx + + if cnt > 1: + raise RuntimeError("only one dimension can be inferred") + + org_prod = reduce(operator.mul, org_sp, 1) + tgt_prod = reduce(operator.mul, tgt_sp, 1) + + if cnt == 0: + if org_prod != tgt_prod: + raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod)) + else: + return tgt_sp + elif org_prod % tgt_prod != 0: + raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod)) + + infer_dim = -(org_prod // tgt_prod) + return tgt_sp[:pos] + (infer_dim,) + tgt_sp[pos + 1:] + + +@colo_op_impl(torch.Tensor.view) +def colo_view(self: ColoTensor, *shape) -> 'ColoTensor': + """Handles ``__torch_function__`` dispatch for ``torch.Tensor.view``. + Changes the shape of the current tensor. + """ + assert isinstance(self, ColoTensor) + # apply original `view` function for replicated colo tensors + if self.is_replicate(): + return self.view(*shape) + + cur_sp = self.size() + org_sp = self.size_global() + # parse the passed arguments + tgt_sp = _get_valid_shape(shape) + # get the correct shape from inference + inf_sp = _shape_infer(org_sp, tgt_sp) + + if self.is_shard_1drow() and org_sp[0] == inf_sp[0]: + new_shape = (cur_sp[0],) + tgt_sp[1:] + res = self.view(*new_shape) + elif self.is_shard_1dcol() and org_sp[-1] == inf_sp[-1]: + new_shape = tgt_sp[:-1] + (cur_sp[-1],) + res = self.view(*new_shape) + else: + replicated_t = self.redistribute(dist_spec=ReplicaSpec()) + return ColoTensor.from_torch_tensor(tensor=replicated_t.view(*shape), + spec=ColoTensorSpec(self.get_process_group())) + + return ColoTensor.from_torch_tensor(tensor=res, + spec=ColoTensorSpec(pg=self.get_process_group(), dist_attr=self.dist_spec)) + + +@colo_op_impl(torch.Tensor.size) +def colo_size(self: ColoTensor, dim: Optional[int] = None) -> Union[torch.Size, int]: + size = self.size_global() + if dim is None: + return size + else: + return size[dim] diff --git a/colossalai/nn/init.py b/colossalai/nn/init.py new file mode 100644 index 0000000000000000000000000000000000000000..559b7038fc352a0ccbea22403cd4b1284bed42e0 --- /dev/null +++ b/colossalai/nn/init.py @@ -0,0 +1,252 @@ +import math +import warnings + +from torch import Tensor +import torch.nn as nn + + +def zeros_(): + """Return the initializer filling the input Tensor with the scalar zeros""" + + def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): + return nn.init.zeros_(tensor) + + return initializer + + +def ones_(): + """Return the initializer filling the input Tensor with the scalar ones""" + + def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): + return nn.init.ones_(tensor) + + return initializer + + +def uniform_(a: float = 0., b: float = 1.): + r"""Return the initializer filling the input Tensor with values drawn from the uniform + distribution :math:`\mathcal{U}(a, b)`. + + Args: + a (float): the lower bound of the uniform distribution. Defaults 0.0. + b (float): the upper bound of the uniform distribution. Defaults 1.0. + """ + + def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): + return nn.init.uniform_(tensor, a, b) + + return initializer + + +def normal_(mean: float = 0., std: float = 1.): + r"""Return the initializer filling the input Tensor with values drawn from the normal distribution + + .. math:: + \mathcal{N}(\text{mean}, \text{std}^2) + + Args: + mean (float): the mean of the normal distribution. Defaults 0.0. + std (float): the standard deviation of the normal distribution. Defaults 1.0. + """ + + def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): + return nn.init.normal_(tensor, mean, std) + + return initializer + + +def trunc_normal_(mean: float = 0., std: float = 1., a: float = -2., b: float = 2.): + r"""Return the initializer filling the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + + Args: + mean (float): the mean of the normal distribution. Defaults 0.0. + std (float): the standard deviation of the normal distribution. Defaults 1.0. + a (float): the minimum cutoff value. Defaults -2.0. + b (float): the maximum cutoff value. Defaults 2.0. + """ + + def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): + return nn.init.trunc_normal_(tensor, mean, std, a, b) + + return initializer + + +def kaiming_uniform_(a=0, mode='fan_in', nonlinearity='leaky_relu'): + r"""Return the initializer filling the input `Tensor` with values according to the method + described in `Delving deep into rectifiers: Surpassing human-level + performance on ImageNet classification` - He, K. et al. (2015), using a + uniform distribution. The resulting tensor will have values sampled from + :math:`\mathcal{U}(-\text{bound}, \text{bound})` where + + .. math:: + \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan_mode}}} + + Also known as 'He initialization'. + + Args: + a (int): the negative slope of the rectifier used after this layer (only used with ``'leaky_relu'``). + mode (str, optional): either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` + preserves the magnitude of the variance of the weights in the + forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the + backwards pass. + nonlinearity (str, optional): the non-linear function (`nn.functional` name), + recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). + """ + + # adapted from torch.nn.init + def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): + if 0 in tensor.shape: + warnings.warn("Initializing zero-element tensors is a no-op") + return tensor + + if mode == 'fan_in': + assert fan_in is not None, 'Fan_in is not provided.' + fan = fan_in + elif mode == 'fan_out': + assert fan_out is not None, 'Fan_out is not provided.' + fan = fan_out + else: + raise ValueError(f'Invalid initialization mode \'{mode}\'') + + std = nn.init.calculate_gain(nonlinearity, a) / math.sqrt(fan) + bound = math.sqrt(3.) * std + return nn.init.uniform_(tensor, -bound, bound) + + return initializer + + +def kaiming_normal_(a=0, mode='fan_in', nonlinearity='leaky_relu'): + r"""Return the initializer filling the input `Tensor` with values according to the method + described in `Delving deep into rectifiers: Surpassing human-level + performance on ImageNet classification` - He, K. et al. (2015), using a + normal distribution. The resulting tensor will have values sampled from + :math:`\mathcal{N}(0, \text{std}^2)` where + + .. math:: + \text{std} = \frac{\text{gain}}{\sqrt{\text{fan_mode}}} + + Also known as 'He initialization'. + + Args: + a (int): the negative slope of the rectifier used after this layer (only used with ``'leaky_relu'``). + mode (str, optional): either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` + preserves the magnitude of the variance of the weights in the + forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the + backwards pass. + nonlinearity (str, optional): the non-linear function (`nn.functional` name), + recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). + """ + + # adapted from torch.nn.init + def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): + if 0 in tensor.shape: + warnings.warn("Initializing zero-element tensors is a no-op") + return tensor + + if mode == 'fan_in': + assert fan_in is not None, 'Fan_in is not provided.' + fan = fan_in + elif mode == 'fan_out': + assert fan_out is not None, 'Fan_out is not provided.' + fan = fan_out + else: + raise ValueError(f'Invalid initialization mode \'{mode}\'') + + std = nn.init.calculate_gain(nonlinearity, a) / math.sqrt(fan) + return nn.init.normal_(tensor, 0, std) + + return initializer + + +def xavier_uniform_(a: float = math.sqrt(3.), scale: float = 2., gain: float = 1.): + r"""Return the initializer filling the input `Tensor` with values according to the method + described in `Understanding the difficulty of training deep feedforward + neural networks` - Glorot, X. & Bengio, Y. (2010), using a uniform + distribution. The resulting tensor will have values sampled from + :math:`\mathcal{U}(-a, a)` where + + .. math:: + a = \text{gain} \times \sqrt{\frac{6}{\text{fan_in} + \text{fan_out}}} + + Also known as 'Glorot initialization'. + + Args: + a (float, optional): an optional scaling factor used to calculate uniform + bounds from standard deviation. Defaults ``math.sqrt(3.)``. + scale (float, optional): an optional scaling factor used to calculate standard deviation. Defaults 2.0. + gain (float, optional): an optional scaling factor. Defaults 1.0. + """ + + # adapted from torch.nn.init + def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): + assert fan_in is not None, 'Fan_in is not provided.' + + fan = fan_in + if fan_out is not None: + fan += fan_out + + std = gain * math.sqrt(scale / float(fan)) + bound = a * std + return nn.init.uniform_(tensor, -bound, bound) + + return initializer + + +def xavier_normal_(scale: float = 2., gain: float = 1.): + r"""Return the initializer filling the input `Tensor` with values according to the method + described in `Understanding the difficulty of training deep feedforward + neural networks` - Glorot, X. & Bengio, Y. (2010), using a normal + distribution. The resulting tensor will have values sampled from + :math:`\mathcal{N}(0, \text{std}^2)` where + + .. math:: + \text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan_in} + \text{fan_out}}} + + Also known as 'Glorot initialization'. + + Args: + scale (float, optional): an optional scaling factor used to calculate standard deviation. Defaults 2.0. + gain (float, optional): an optional scaling factor. Defaults 1.0. + """ + + # adapted from torch.nn.init + def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): + assert fan_in is not None, 'Fan_in is not provided.' + + fan = fan_in + if fan_out is not None: + fan += fan_out + + std = gain * math.sqrt(scale / float(fan)) + + return nn.init.normal_(tensor, 0., std) + + return initializer + + +def lecun_uniform_(): + # adapted from jax.nn.initializers + def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): + assert fan_in is not None, 'Fan_in is not provided.' + + var = 1.0 / fan_in + bound = math.sqrt(3 * var) + return nn.init.uniform_(tensor, -bound, bound) + + return initializer + + +def lecun_normal_(): + # adapted from jax.nn.initializers + def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): + assert fan_in is not None, 'Fan_in is not provided.' + + std = math.sqrt(1.0 / fan_in) + return nn.init.trunc_normal_(tensor, std=std / .87962566103423978) + + return initializer diff --git a/colossalai/nn/layer/__init__.py b/colossalai/nn/layer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b705632f80407ceda019f4b905c724d1b2254f66 --- /dev/null +++ b/colossalai/nn/layer/__init__.py @@ -0,0 +1,10 @@ +from .colossalai_layer import * +from .parallel_1d import * +from .parallel_2d import * +from .parallel_2p5d import * +from .parallel_3d import * +from .parallel_sequence import * +from .moe import * +from .utils import * +from .vanilla import * +from .wrapper import * diff --git a/colossalai/nn/layer/base_layer.py b/colossalai/nn/layer/base_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..c85f53cc44c3660e39a1cf2995ca3e44a70b4c04 --- /dev/null +++ b/colossalai/nn/layer/base_layer.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch.nn as nn + +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from contextlib import contextmanager + + +class ParallelLayer(nn.Module): + global_state_dict: bool = True + + def __init__(self): + super().__init__() + self.data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank( + ParallelMode.DATA) + self.data_parallel_size = 1 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_world_size( + ParallelMode.DATA) + + self.tensor_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.TENSOR) else gpc.get_local_rank( + ParallelMode.TENSOR) + self.tensor_parallel_size = 1 if not gpc.is_initialized(ParallelMode.TENSOR) else gpc.get_world_size( + ParallelMode.TENSOR) + + self.pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank( + ParallelMode.PIPELINE) + self.pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( + ParallelMode.PIPELINE) + + def _load_from_global_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs): + return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + return super()._save_to_state_dict(destination, prefix, keep_vars) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs): + if self.global_state_dict: + if gpc.get_local_rank(ParallelMode.TENSOR) != 0: + missing_keys.clear() + unexpected_keys.clear() + return self._load_from_global_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, + unexpected_keys, error_msgs) + return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + if self.global_state_dict: + return self._save_to_global_state_dict(destination, prefix, keep_vars) + return super()._save_to_state_dict(destination, prefix, keep_vars) + + @classmethod + @contextmanager + def use_local_state_dict(cls): + try: + cls.global_state_dict = False + yield + finally: + cls.global_state_dict = True diff --git a/colossalai/nn/layer/colossalai_layer/__init__.py b/colossalai/nn/layer/colossalai_layer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2ae1b07a75b2e7a231fc3512e8f46bccd0e9d4c6 --- /dev/null +++ b/colossalai/nn/layer/colossalai_layer/__init__.py @@ -0,0 +1,7 @@ +from ._utils import partition_batch +from .dropout import Dropout +from .embedding import Embedding, PatchEmbedding +from .linear import Classifier, Linear +from .normalization import LayerNorm + +__all__ = ['Linear', 'Classifier', 'Embedding', 'PatchEmbedding', 'LayerNorm', 'Dropout', 'partition_batch'] diff --git a/colossalai/nn/layer/colossalai_layer/_utils.py b/colossalai/nn/layer/colossalai_layer/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..677cb0e7ac428856c6888ae195ab32c7c70d5758 --- /dev/null +++ b/colossalai/nn/layer/colossalai_layer/_utils.py @@ -0,0 +1,41 @@ +import torch.nn as nn +from torch import Tensor + +from ..parallel_2d._operation import split_batch_2d +from ..parallel_2p5d._operation import split_batch_2p5d +from ..parallel_3d._operation import split_batch_3d +from ..utils import get_tensor_parallel_mode + +_parallel_split_batch = {'2d': split_batch_2d, '2.5d': split_batch_2p5d, '3d': split_batch_3d} + + +def partition_batch(input_) -> Tensor: + tensor_parallel_mode = get_tensor_parallel_mode() + if tensor_parallel_mode in _parallel_split_batch: + if isinstance(input_, dict): + return {k: _parallel_split_batch[tensor_parallel_mode](v) for k, v in input_.items()} + else: + return _parallel_split_batch[tensor_parallel_mode](input_) + else: + return input_ + + +class ColossalaiModule(nn.Module): + + def __init__(self, module: nn.Module, **kwargs): + super().__init__() + self.module = module + for k, v in kwargs.items(): + setattr(self, k, v) + + def __getattr__(self, name: str): + if name == 'module': + return super().__getattr__(name) + elif hasattr(self.module, name): + return getattr(self.module, name) + elif name in self.__dict__: + return self.__dict__[name] + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, name)) + + def forward(self, *args): + return self.module(*args) diff --git a/colossalai/nn/layer/colossalai_layer/dropout.py b/colossalai/nn/layer/colossalai_layer/dropout.py new file mode 100644 index 0000000000000000000000000000000000000000..0c049cb3f408e22eb9ce2b67d354da305aeadbae --- /dev/null +++ b/colossalai/nn/layer/colossalai_layer/dropout.py @@ -0,0 +1,31 @@ +import torch.nn as nn + +from colossalai.context import ParallelMode, seed + +from ..parallel_1d import * +from ..utils import get_tensor_parallel_mode +from ._utils import ColossalaiModule + + +class Dropout(ColossalaiModule): + """Dropout layer of colossalai. + + Args: + p (float, optional): probability of an element to be zeroed, defaults 0.5. + inplace (bool, optional): whether to do dropout in-place, default to be False. + """ + + def __init__(self, p: float = 0.5, inplace: bool = False) -> None: + tensor_parallel = get_tensor_parallel_mode() + if tensor_parallel == "1d": + drop = Dropout1D(p, inplace) + else: + drop = nn.Dropout(p, inplace) + super().__init__(drop, tensor_parallel=tensor_parallel) + + def forward(self, *args): + if self.tensor_parallel in [None, '1d']: + return super().forward(*args) + else: + with seed(ParallelMode.TENSOR): + return super().forward(*args) diff --git a/colossalai/nn/layer/colossalai_layer/embedding.py b/colossalai/nn/layer/colossalai_layer/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..e5c9c46e0ff1e6fc3413809a493e7a1012273315 --- /dev/null +++ b/colossalai/nn/layer/colossalai_layer/embedding.py @@ -0,0 +1,151 @@ +import math +from typing import Callable + +from colossalai.utils import get_current_device +from torch import dtype, nn + +from ... import init as init +from ..parallel_1d import Embedding1D, PatchEmbedding1D, VocabParallelEmbedding1D +from ..parallel_2d import Embedding2D, PatchEmbedding2D, VocabParallelEmbedding2D +from ..parallel_2p5d import Embedding2p5D, PatchEmbedding2p5D, VocabParallelEmbedding2p5D +from ..parallel_3d import Embedding3D, PatchEmbedding3D, VocabParallelEmbedding3D +from ..utils import get_tensor_parallel_mode +from ..vanilla import VanillaPatchEmbedding +from ._utils import ColossalaiModule + +_parallel_embedding = { + '1d': Embedding1D, + '2d': Embedding2D, + '2.5d': Embedding2p5D, + '3d': Embedding3D, +} + +_vocab_parallel_embedding = { + '1d': VocabParallelEmbedding1D, + '2d': VocabParallelEmbedding2D, + '2.5d': VocabParallelEmbedding2p5D, + '3d': VocabParallelEmbedding3D +} + +_parallel_patchembedding = { + None: VanillaPatchEmbedding, + '1d': PatchEmbedding1D, + '2d': PatchEmbedding2D, + '2.5d': PatchEmbedding2p5D, + '3d': PatchEmbedding3D +} + + +class Embedding(ColossalaiModule): + r"""Embedding for colossalai. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed “pad”, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:`torch.nn.functional.embedding` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about ``initializer`` please refer to + `init `_ + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: dtype = None, + weight_initializer: Callable = init.normal_(), + vocab_parallel_limit: int = 2048, + *args, + **kwargs) -> None: + tensor_parallel = get_tensor_parallel_mode() + if tensor_parallel is None: + embed = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx, *args, + **kwargs).to(dtype).to(get_current_device()) + weight_initializer(embed.weight, fan_in=num_embeddings, fan_out=embedding_dim) + elif num_embeddings <= vocab_parallel_limit: + embed = _parallel_embedding[tensor_parallel]( + num_embeddings, + embedding_dim, + padding_idx=padding_idx, + dtype=dtype, + weight_initializer=weight_initializer, + *args, + **kwargs, + ) + else: + embed = _vocab_parallel_embedding[tensor_parallel]( + num_embeddings, + embedding_dim, + padding_idx=padding_idx, + dtype=dtype, + weight_initializer=weight_initializer, + *args, + **kwargs, + ) + super().__init__(embed) + + +class PatchEmbedding(ColossalaiModule): + """2D Image to Patch Embedding. + + Args: + img_size (int): image size. + patch_size (int): patch size. + in_chans (int): number of channels of input image. + embed_size (int): size of embedding. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + flatten (bool, optional): whether to flatten output tensor, defaults to True. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + position_embed_initializer (:class:`typing.Callable`, optional): + The initializer of position embedding, defaults to zeros initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__( + self, + img_size: int, + patch_size: int, + in_chans: int, + embed_size: int, + dtype: dtype = None, + flatten: bool = True, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + position_embed_initializer: Callable = init.zeros_() + ) -> None: + tensor_parallel = get_tensor_parallel_mode() + embed = _parallel_patchembedding[tensor_parallel]( + img_size, + patch_size, + in_chans, + embed_size, + dtype=dtype, + flatten=flatten, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + position_embed_initializer=position_embed_initializer, + ) + super().__init__(embed) diff --git a/colossalai/nn/layer/colossalai_layer/linear.py b/colossalai/nn/layer/colossalai_layer/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..3e0c6e285c1c64ccfc25bf5eca6f636cc8744aea --- /dev/null +++ b/colossalai/nn/layer/colossalai_layer/linear.py @@ -0,0 +1,141 @@ +import inspect +import math +from typing import Callable + +from torch import dtype, nn + +from colossalai.utils import get_current_device + +from ... import init as init +from ..parallel_1d import * +from ..parallel_2d import * +from ..parallel_2p5d import * +from ..parallel_3d import * +from ..utils import get_tensor_parallel_mode +from ..vanilla import * +from ._utils import ColossalaiModule + +_parallel_linear = {None: VanillaLinear, '1d': Linear1D, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D} + +_parallel_classifier = { + None: VanillaClassifier, + '1d': Classifier1D, + '2d': Classifier2D, + '2.5d': Classifier2p5D, + '3d': Classifier3D +} + +_vocab_parallel_classifier = { + '1d': VocabParallelClassifier1D, + '2d': VocabParallelClassifier2D, + '2.5d': VocabParallelClassifier2p5D, + '3d': VocabParallelClassifier3D +} + + +class Linear(ColossalaiModule): + """Linear layer of colossalai. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + Note: ``kwargs`` would contain different parameters when you use different parallelisms. + + The ``kwargs`` should contain parameters below: + :: + + Linear1D: + gather_output: bool (optional, default to be false) + skip_bias_add: bool (optional, default to be false) + Linear2D: + skip_bias_add: bool (optional, default to be false) + Linear2p5D: + skip_bias_add: bool (optional, default to be false) + Linear3D: + None + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + **kwargs) -> None: + tensor_parallel = get_tensor_parallel_mode() + linear_cls = _parallel_linear[tensor_parallel] + gather_output = kwargs.pop('gather_output', None) + if 'gather_output' in inspect.signature(linear_cls.__init__).parameters.keys(): # gather_out arg is available + kwargs['gather_output'] = gather_output + layer = linear_cls( + in_features, + out_features, + bias=bias, + dtype=dtype, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + **kwargs, + ) + super().__init__(layer) + + +class Classifier(ColossalaiModule): + """Classifier layer of colossalai. + + Args: + in_features (int): size of each input sample. + num_classes (int): number of classes. + weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + num_classes: int, + weight: nn.Parameter = None, + bias: bool = True, + dtype: dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + vocab_parallel_limit: int = 2048) -> None: + tensor_parallel = get_tensor_parallel_mode() + if num_classes <= vocab_parallel_limit or tensor_parallel is None: + layer = _parallel_classifier[tensor_parallel]( + in_features, + num_classes, + weight=weight, + bias=bias, + dtype=dtype, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + ) + else: + layer = _vocab_parallel_classifier[tensor_parallel]( + in_features, + num_classes, + weight=weight, + bias=bias, + dtype=dtype, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + ) + super().__init__(layer) diff --git a/colossalai/nn/layer/colossalai_layer/normalization.py b/colossalai/nn/layer/colossalai_layer/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..86861d30214a43a95192d0f179be6ba705e002c8 --- /dev/null +++ b/colossalai/nn/layer/colossalai_layer/normalization.py @@ -0,0 +1,41 @@ +from colossalai.utils import get_current_device +from torch import nn + +from ..parallel_1d import LayerNorm1D +from ..parallel_2d import LayerNorm2D +from ..parallel_2p5d import LayerNorm2p5D +from ..parallel_3d import LayerNorm3D +from ..utils import get_tensor_parallel_mode +from ..vanilla import VanillaLayerNorm +from ._utils import ColossalaiModule + +_parallel_layernorm = { + None: VanillaLayerNorm, + "1d": LayerNorm1D, + "2d": LayerNorm2D, + "2.5d": LayerNorm2p5D, + "3d": LayerNorm3D, +} + + +class LayerNorm(ColossalaiModule): + r"""Layer Normalization for colossalai. + + Args: + normalized_shape (int): input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] + \times \ldots \times \text{normalized_shape}[-1]]` + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps (float): a value added to the denominator for numerical stability, defaults to 1e-05. + bias (bool, optional): Whether to add a bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + """ + + def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None) -> None: + tensor_parallel = get_tensor_parallel_mode() + if tensor_parallel is None: + norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_current_device()) + else: + norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype) + super().__init__(norm) diff --git a/colossalai/nn/layer/moe/__init__.py b/colossalai/nn/layer/moe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..05333fe965f1a4efff779550cb5e85946940fd7b --- /dev/null +++ b/colossalai/nn/layer/moe/__init__.py @@ -0,0 +1,10 @@ +from .checkpoint import load_moe_model, save_moe_model +from .experts import Experts, FFNExperts, TPExperts +from .layers import MoeLayer, MoeModule +from .routers import MoeRouter, Top1Router, Top2Router +from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts + +__all__ = [ + 'Experts', 'FFNExperts', 'TPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator', + 'UniformNoiseGenerator', 'build_ffn_experts', 'MoeModule', 'MoeRouter', 'save_moe_model', 'load_moe_model' +] diff --git a/colossalai/nn/layer/moe/_operation.py b/colossalai/nn/layer/moe/_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..37f31c16709b94bb6d2cebbabd850c03a85895eb --- /dev/null +++ b/colossalai/nn/layer/moe/_operation.py @@ -0,0 +1,175 @@ +from typing import Any, Optional, Tuple + +import torch +import torch.distributed as dist +from torch import Tensor +from torch.distributed import ProcessGroup + +COL_MOE_KERNEL_FLAG = False + +try: + from colossalai._C import moe +except: + moe = None + + +def build_moe_if_not_prebuilt(): + # load moe kernel during runtime if not pre-built + global moe + if moe is None: + from colossalai.kernel.op_builder import MOEBuilder + moe = MOEBuilder().load() + + +class AllGather(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: + + global moe + + if moe is None: + from colossalai.kernel.op_builder import MOEBuilder + moe = MOEBuilder().load() + + if ctx is not None: + ctx.comm_grp = group + + comm_size = dist.get_world_size(group) + if comm_size == 1: + return inputs.unsqueeze(0) + + buffer_shape = (comm_size,) + inputs.shape + outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device) + buffer_list = list(torch.chunk(outputs, comm_size, dim=0)) + dist.all_gather(buffer_list, inputs, group=group) + return outputs + + @staticmethod + def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]: + return ReduceScatter.forward(None, grad_outputs, ctx.comm_grp), None + + +class ReduceScatter(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: + if ctx is not None: + ctx.comm_grp = group + + comm_size = dist.get_world_size(group) + if comm_size == 1: + return inputs.squeeze(0) + + if not inputs.is_contiguous(): + inputs = inputs.contiguous() + + output_shape = inputs.shape[1:] + outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device) + buffer_list = list(torch.chunk(inputs, comm_size, dim=0)) + dist.reduce_scatter(outputs, buffer_list, group=group) + return outputs + + @staticmethod + def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]: + return AllGather.forward(None, grad_outputs, ctx.comm_grp), None + + +class AllToAll(torch.autograd.Function): + """Dispatches input tensor [e, c, h] to all experts by all_to_all_single + operation in torch.distributed. + """ + + @staticmethod + def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: + if ctx is not None: + ctx.comm_grp = group + if not inputs.is_contiguous(): + inputs = inputs.contiguous() + if dist.get_world_size(group) == 1: + return inputs + output = torch.empty_like(inputs) + dist.all_to_all_single(output, inputs, group=group) + return output + + @staticmethod + def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]: + return AllToAll.forward(None, *grad_outputs, ctx.comm_grp), None + + +class MoeDispatch(torch.autograd.Function): + + @staticmethod + def forward(ctx, tokens, mask, dest_idx, ec): + s = tokens.size(0) + h = tokens.size(1) + + # load moe kernel during runtime if not pre-built + build_moe_if_not_prebuilt() + + expert_input = moe.dispatch_forward(s, ec, h, tokens, mask, dest_idx) + + ctx.save_for_backward(mask, dest_idx) + ctx.s = s + ctx.h = h + ctx.ec = ec + + return expert_input + + @staticmethod + def backward(ctx, output_grad): + mask, dest_idx = ctx.saved_tensors + d_tokens = moe.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx) + return d_tokens, None, None, None + + +class MoeCombine(torch.autograd.Function): + + @staticmethod + def forward(ctx, expert_tokens, logits, mask, dest_idx, ec): + assert logits.dtype == torch.float32 + + s = logits.size(0) + e = logits.size(1) + c = ec // e + h = expert_tokens.size(-1) + + # load moe kernel during runtime if not pre-built + build_moe_if_not_prebuilt() + + fp16_flag = (expert_tokens.dtype == torch.float16) + cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens + ctokens = moe.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx) + output = ctokens.to(torch.float16) if fp16_flag else ctokens + + ctx.save_for_backward(expert_tokens, logits, mask, dest_idx) + ctx.s = s + ctx.e = e + ctx.c = c + ctx.h = h + ctx.fp16_flag = fp16_flag + + return output + + @staticmethod + def backward(ctx, tokens_grad): + expert_tokens, logits, mask, dest_idx = ctx.saved_tensors + + cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 \ + else tokens_grad + cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens + d_expert, d_logits = moe.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, cb_grad, cb_input, logits, mask, dest_idx) + d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert + + return d_expert, d_logits, None, None, None + + +def moe_cumsum(inputs: Tensor): + dim0 = inputs.size(0) + flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0) + if flag and COL_MOE_KERNEL_FLAG: + # load moe kernel during runtime if not pre-built + build_moe_if_not_prebuilt() + return moe.cumsum_sub_one(inputs) + else: + return torch.cumsum(inputs, dim=0) - 1 diff --git a/colossalai/nn/layer/moe/checkpoint.py b/colossalai/nn/layer/moe/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..efda1f22252d97db97219c9cc5a0a6c62c182bcd --- /dev/null +++ b/colossalai/nn/layer/moe/checkpoint.py @@ -0,0 +1,40 @@ +import torch +import torch.distributed as dist +import torch.nn as nn + +from .experts import MoeExperts + + +def save_moe_model(model: nn.Module, save_path: str): + state_dict = model.state_dict() + if dist.get_rank() == 0: + torch.save(state_dict, save_path) + dist.barrier() + + +def load_moe_model(model: nn.Module, load_path: str): + state_dict = torch.load(load_path) + + for prefix, module in model.named_modules(): + if prefix.endswith('.moe_layer.experts'): + # this module should be an Experts instance + assert isinstance(module, MoeExperts) + + ep_rank = dist.get_rank(module.dist_info.ep_group) + num_local = module.num_local_experts + for i in range(num_local): + expert_id = ep_rank * num_local + i + for name, _ in module.experts[i].named_parameters(): + cur_key = f'{prefix}.experts.{i}.{name}' + param_key = f'{prefix}.experts.{expert_id}.{name}' + load_param = state_dict[param_key] + state_dict[cur_key] = load_param + + for name, _ in module.experts[0].named_parameters(): + pop_pre = f'{prefix}.experts.' + pop_suf = f'.{name}' + for i in range(num_local, module.num_total_experts): + pop_key = f'{pop_pre}{i}{pop_suf}' + state_dict.pop(pop_key) + + model.load_state_dict(state_dict) diff --git a/colossalai/nn/layer/moe/experts.py b/colossalai/nn/layer/moe/experts.py new file mode 100644 index 0000000000000000000000000000000000000000..56b11f4d9e08b7556268ea427a3fdda92d52ef61 --- /dev/null +++ b/colossalai/nn/layer/moe/experts.py @@ -0,0 +1,203 @@ +import math +from copy import deepcopy +from typing import Type + +import torch +import torch.distributed as dist +import torch.nn as nn + +from colossalai.context import ParallelMode, seed +from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.utils import get_current_device +from colossalai.zero.legacy.init_ctx import no_shard_zero_decrator + + +class MoeExperts(nn.Module): + """Basic class for experts in MoE. It stores what kind of communication experts use + to exchange tokens, how many experts in a single GPU and parallel information such as + expert parallel size, data parallel size and their distributed communication groups. + """ + + def __init__(self, comm_name: str, num_experts: int): + super().__init__() + assert comm_name in {"all_to_all", "all_gather"}, \ + "This kind of communication has not been implemented yet.\n Please use Experts build function." + self.comm_name = comm_name + self.num_total_experts = num_experts + # Get the configuration of experts' deployment and parallel information from moe context + self.num_local_experts, self.dist_info = MOE_CONTEXT.get_info(num_experts) + + +@no_shard_zero_decrator(is_replicated=False) +class Experts(MoeExperts): + """A wrapper class to create experts. It will create E experts across the + moe model parallel group, where E is the number of experts. Every expert + is a instance of the class, 'expert' in initialization parameters. + + Args: + expert_cls (:class:`torch.nn.Module`): The class of all experts + num_experts (int): The number of experts + expert_args: Args used to initialize experts, the args could be found in corresponding expert class + """ + + def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args): + super().__init__("all_to_all", num_experts) + + # Use seed to make every expert different from others + with seed(ParallelMode.TENSOR): + self.experts = nn.ModuleList([expert_cls(**expert_args) for _ in range(self.num_local_experts)]) + + # Attach parallel information for all parameters in Experts + for exp in self.experts: + for param in exp.parameters(): + param.__setattr__('moe_info', self.dist_info) + + def forward(self, inputs: torch.Tensor): + # Split inputs for each expert + expert_input = torch.chunk(inputs, self.num_local_experts, dim=1) + expert_output = [] + + # Get outputs from each expert + for i in range(self.num_local_experts): + expert_output.append(self.experts[i](expert_input[i])) + + # Concatenate all outputs together + output = torch.cat(expert_output, dim=1).contiguous() + return output + + def state_dict(self, destination=None, prefix='', keep_vars=False): + assert keep_vars == False, "Only support keep_vars=False now" + dp_rank = dist.get_rank(self.dist_info.dp_group) + ep_rank = dist.get_rank(self.dist_info.ep_group) + submodule_dict = dict() + example_submodule = None + for name, subm in self.experts.named_modules(): + if subm is self.experts: + continue + module_number = self.num_local_experts * ep_rank + int(name) + submodule_dict[module_number] = subm + example_submodule = subm + + if dp_rank == 0: + local_prefix = prefix + 'experts.' + buffer_module = deepcopy(example_submodule) + for i in range(self.num_total_experts): + source_rank = i // self.num_local_experts + current_prefix = local_prefix + str(i) + '.' + comm_module = submodule_dict.get(i, buffer_module) + for name, param in comm_module.named_parameters(): + dist.broadcast(param.data, src=source_rank, group=self.dist_info.ep_group) + if ep_rank == 0: + destination[current_prefix + name] = param.data.cpu() + + dist.barrier() + + +class FFNExperts(MoeExperts): + """Use torch.bmm to speed up for multiple experts. + """ + + def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): + super().__init__("all_to_all", num_experts) + + self.w1 = nn.Parameter(torch.empty(self.num_local_experts, d_model, d_ff, device=get_current_device())) + self.b1 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_ff, device=get_current_device())) + + self.w2 = nn.Parameter(torch.empty(self.num_local_experts, d_ff, d_model, device=get_current_device())) + self.b2 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_model, device=get_current_device())) + + s1 = math.sqrt(0.1 / d_model) + s2 = math.sqrt(0.1 / d_ff) + + with seed(ParallelMode.TENSOR): + nn.init.trunc_normal_(self.w1, std=s1) + nn.init.trunc_normal_(self.b1, std=s1) + nn.init.trunc_normal_(self.w2, std=s2) + nn.init.trunc_normal_(self.b2, std=s2) + + self.act = nn.GELU() if activation is None else activation + self.drop = nn.Dropout(p=drop_rate) + + for param in self.parameters(): + param.__setattr__('moe_info', self.dist_info) + + def forward(self, inputs): # inputs [g, el, c, h] + + el = inputs.size(1) + h = inputs.size(-1) + + inputs = inputs.transpose(0, 1) + inshape = inputs.shape + inputs = inputs.reshape(el, -1, h) + + out_ff = torch.baddbmm(self.b1, inputs, self.w1) + out_act = self.act(out_ff) + with seed(ParallelMode.TENSOR): + out_inter = self.drop(out_act) + + out_model = torch.baddbmm(self.b2, out_inter, self.w2) + with seed(ParallelMode.TENSOR): + outputs = self.drop(out_model) # outputs [el, gc, h] + + outputs = outputs.reshape(inshape) + outputs = outputs.transpose(0, 1).contiguous() + return outputs + + +class TPExperts(MoeExperts): + """Use tensor parallelism to split each expert evenly, which can deploy experts in + case that the number of experts can't be divide by maximum expert parallel size or + maximum expert parallel size can't be divide by the number of experts. + """ + + def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): + super().__init__("all_gather", MOE_CONTEXT.max_ep_size) + + assert d_ff % MOE_CONTEXT.max_ep_size == 0, \ + "d_ff should be divide by maximum expert parallel size" + + p_ff = d_ff // MOE_CONTEXT.max_ep_size + + self.w1 = nn.Parameter(torch.empty(num_experts, d_model, p_ff, device=get_current_device())) + self.b1 = nn.Parameter(torch.empty(num_experts, 1, p_ff, device=get_current_device())) + + self.w2 = nn.Parameter(torch.empty(num_experts, p_ff, d_model, device=get_current_device())) + self.b2 = nn.Parameter(torch.empty(num_experts, 1, d_model, device=get_current_device())) + + s1 = math.sqrt(0.1 / d_model) + s2 = math.sqrt(0.1 / d_ff) + + with seed(ParallelMode.TENSOR): + nn.init.trunc_normal_(self.w1, std=s1) + nn.init.trunc_normal_(self.b1, std=s1) + nn.init.trunc_normal_(self.w2, std=s2) + + nn.init.trunc_normal_(self.b2, std=s2) + + self.act = nn.GELU() if activation is None else activation + self.drop = nn.Dropout(p=drop_rate) + + self.w1.__setattr__('moe_info', self.dist_info) + self.w2.__setattr__('moe_info', self.dist_info) + self.b1.__setattr__('moe_info', self.dist_info) + + def forward(self, inputs): # inputs [g, e, c, h] + + e = inputs.size(1) + h = inputs.size(-1) + + inputs = inputs.transpose(0, 1) + inshape = inputs.shape + inputs = inputs.reshape(e, -1, h) + + out_ff = torch.baddbmm(self.b1, inputs, self.w1) + out_act = self.act(out_ff) + with seed(ParallelMode.TENSOR): + out_inter = self.drop(out_act) + + out_model = torch.baddbmm(self.b2, out_inter, self.w2) + outputs = self.drop(out_model) # outputs [e, gc, h] + + outputs = outputs.reshape(inshape) + outputs = outputs.transpose(0, 1).contiguous() + return outputs # outputs [g, e, c, h] diff --git a/colossalai/nn/layer/moe/layers.py b/colossalai/nn/layer/moe/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..03f55d91f3a861fd552c6df63dbc6b171d43e53e --- /dev/null +++ b/colossalai/nn/layer/moe/layers.py @@ -0,0 +1,210 @@ +import math +from typing import Optional, Tuple, Type + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.nn.layer.moe._operation import ( + COL_MOE_KERNEL_FLAG, + AllGather, + AllToAll, + MoeCombine, + MoeDispatch, + ReduceScatter, +) +from colossalai.nn.layer.moe.experts import Experts, MoeExperts +from colossalai.nn.layer.moe.routers import MoeRouter, Top1Router, Top2Router +from colossalai.nn.layer.moe.utils import NormalNoiseGenerator, UniformNoiseGenerator +from colossalai.utils import get_current_device +from colossalai.zero.legacy.init_ctx import no_shard_zero_context, no_shard_zero_decrator + + +@no_shard_zero_decrator(is_replicated=True) +class MoeLayer(nn.Module): + """A MoE layer, that puts its input tensor to its gate and uses the output logits + to router all tokens, is mainly used to exchange all tokens for every expert across + the moe tensor group by all to all communication. Then it will get the output of all + experts and exchange the output. At last returns the output of the moe system. + + Args: + dim_model (int): Dimension of model. + num_experts (int): The number of experts. + router (MoeRouter): Instance of router used in routing. + experts (MoeExperts): Instance of experts generated by Expert. + """ + + def __init__(self, dim_model: int, num_experts: int, router: MoeRouter, experts: MoeExperts): + super().__init__() + self.d_model = dim_model + self.num_experts = num_experts + self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, dim_model)) + self.router: MoeRouter = router + self.experts: MoeExperts = experts + self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False + self.ep_group = experts.dist_info.ep_group + self.ep_size = experts.dist_info.ep_size + self.num_local_experts = experts.num_local_experts + + nn.init.trunc_normal_(self.gate_weight, std=math.sqrt(0.1 / dim_model)) + + def a2a_process(self, dispatch_data: torch.Tensor): + expert_input = AllToAll.apply(dispatch_data, self.ep_group) + input_shape = expert_input.shape + expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.d_model) + expert_output = self.experts(expert_input) + expert_output = expert_output.reshape(input_shape) + expert_output = AllToAll.apply(expert_output, self.ep_group) + return expert_output + + def tp_process(self, dispatch_data: torch.Tensor): + expert_in = AllGather.apply(dispatch_data, self.ep_group) + expert_out = self.experts(expert_in) + expert_out = ReduceScatter.apply(expert_out, self.ep_group) + return expert_out + + def forward(self, inputs: torch.Tensor) -> Tuple: + # reshape the input tokens + tokens = inputs.reshape(-1, self.d_model) + + # the data type of the inputs in the gating should be fp32 + fp32_input = tokens.to(torch.float) + fp32_weight = self.gate_weight.to(torch.float) + gate_output = F.linear(fp32_input, fp32_weight) + + # the result from the router + route_result_list = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group) + + if self.use_kernel: + dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:]) + dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.d_model) + else: + sec_mask_f = route_result_list[1].type_as(inputs) + dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens) + + # dispatch_data [e, c, h] + if self.experts.comm_name == "all_to_all": + expert_output = self.a2a_process(dispatch_data) + elif self.experts.comm_name == "all_gather": + expert_output = self.tp_process(dispatch_data) + else: + raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts " + "build function.") + # expert_output [e, c, h] + if self.use_kernel: + expert_output = expert_output.reshape(-1, self.d_model) + ans = MoeCombine.apply(expert_output, *route_result_list) + else: + combine_weights = route_result_list[0].type_as(inputs) + combine_weights = combine_weights.view(combine_weights.shape[0], -1) + expert_output = expert_output.view(-1, expert_output.shape[-1]) + ans = torch.matmul(combine_weights, expert_output) + + ans = ans.reshape(inputs.shape) + l_aux = self.router.pop_routing_loss() + return ans, l_aux + + +class MoeModule(nn.Module): + """A class for users to create MoE modules in their models. + + Args: + dim_model (int): Hidden dimension of training model + num_experts (int): The number experts + top_k (int, optional): The number of experts for dispatchment of each token + capacity_factor_train (float, optional): Capacity factor in routing during training + capacity_factor_eval (float, optional): Capacity factor in routing during evaluation + min_capacity (int, optional): The minimum number of the capacity of each expert + noisy_policy (str, optional): The policy of noisy function. Now we have 'Jitter' and 'Gaussian'. + 'Jitter' can be found in `Switch Transformer paper`_. + 'Gaussian' can be found in `ViT-MoE paper`_. + drop_tks (bool, optional): Whether drops tokens in evaluation + use_residual (bool, optional): Makes this MoE layer a Residual MoE. + More information can be found in `Microsoft paper`_. + residual_instance (nn.Module, optional): The instance of residual module in Residual MoE + expert_instance (MoeExperts, optional): The instance of experts module in MoeLayer + expert_cls (Type[nn.Module], optional): The class of each expert when no instance is given + expert_args (optional): The args of expert when no instance is given + + .. _Switch Transformer paper: + https://arxiv.org/abs/2101.03961 + .. _ViT-MoE paper: + https://arxiv.org/abs/2106.05974 + .. _Microsoft paper: + https://arxiv.org/abs/2201.05596 + """ + + def __init__(self, + dim_model: int, + num_experts: int, + top_k: int = 1, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + noisy_policy: Optional[str] = None, + drop_tks: bool = True, + use_residual: bool = False, + residual_instance: Optional[nn.Module] = None, + expert_instance: Optional[MoeExperts] = None, + expert_cls: Optional[Type[nn.Module]] = None, + **expert_args): + super().__init__() + + noisy_func = None + if noisy_policy is not None: + if noisy_policy == 'Jitter': + noisy_func = UniformNoiseGenerator() + elif noisy_policy == 'Gaussian': + noisy_func = NormalNoiseGenerator(num_experts) + else: + raise NotImplementedError("Unsupported input noisy policy") + + if top_k == 1: + moe_router_cls = Top1Router + elif top_k == 2: + moe_router_cls = Top2Router + else: + raise NotImplementedError("top_k > 2 is not supported yet") + + self.moe_router = moe_router_cls(capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + min_capacity=min_capacity, + noisy_func=noisy_func, + drop_tks=drop_tks) + self.use_residual = use_residual + if use_residual: + if residual_instance is not None: + self.residual_module = residual_instance + else: + assert expert_cls is not None, \ + "Expert class can't be None when residual instance is not given" + self.residual_module = expert_cls(**expert_args) + + with no_shard_zero_context(): + self.residual_combine = nn.Linear(dim_model, 2, device=get_current_device()) + + if expert_instance is not None: + my_experts = expert_instance + else: + assert expert_cls is not None, \ + "Expert class can't be None when experts instance is not given" + my_experts = Experts(expert_cls, num_experts, **expert_args) + + self.moe_layer = MoeLayer(dim_model=dim_model, + num_experts=num_experts, + router=self.moe_router, + experts=my_experts) + + def forward(self, inputs: torch.Tensor): + moe_output, l_aux = self.moe_layer(inputs) + + if self.use_residual: + residual_output = self.residual_module(inputs) + combine_coef = self.residual_combine(inputs) + combine_coef = F.softmax(combine_coef, dim=-1) + output = moe_output * combine_coef[..., 0:1] + residual_output * combine_coef[..., 1:] + else: + output = moe_output + + return output, l_aux diff --git a/colossalai/nn/layer/moe/routers.py b/colossalai/nn/layer/moe/routers.py new file mode 100644 index 0000000000000000000000000000000000000000..c5b8390bf0472302d10ef96080a23007d02c9eb7 --- /dev/null +++ b/colossalai/nn/layer/moe/routers.py @@ -0,0 +1,226 @@ +import math +from abc import ABC + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist +from colossalai.utils import get_current_device +from colossalai.context import MOE_CONTEXT +from colossalai.nn.layer.moe._operation import moe_cumsum +from typing import Callable, Optional +from torch.distributed import ProcessGroup + + +class MoeRouter(nn.Module, ABC): + """Base class for all MoE routers. + Args: + k_value (int): The value of top_k. + capacity_factor_train (float): Capacity factor in routing of training. + capacity_factor_eval (float): Capacity factor in routing of evaluation. + min_capacity (int): The minimum number of the capacity of each expert. + noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. + drop_tks (bool, optional): Whether drops tokens in evaluation + """ + + def __init__(self, + k_value: int, + capacity_factor_train: float, + capacity_factor_eval: float, + min_capacity: int, + noisy_func: Callable = None, + drop_tks: bool = True): + super().__init__() + self.k_value = k_value + self.capacity_factor_train = capacity_factor_train + self.capacity_factor_eval = capacity_factor_eval + self.min_capacity = min_capacity + self.noisy_func = noisy_func + self.drop_tks = drop_tks + self._routing_loss = None + + def get_capacity(self, logits_shape): + capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval + capacity = math.floor(self.k_value * capacity_factor * logits_shape[-2] / logits_shape[-1]) + capacity += capacity % 2 + capacity = max(capacity, self.min_capacity) + assert capacity > 0 + return capacity + + def set_routing_loss(self, aux_loss: torch.Tensor) -> None: + assert self._routing_loss is None + self._routing_loss = aux_loss + + def pop_routing_loss(self) -> torch.Tensor: + assert self._routing_loss is not None + reservation = self._routing_loss + self._routing_loss = None + return reservation + + +class Top1Router(MoeRouter): + """Top1 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c] + for routing usage. More detailed function can be found in the paper about Switch Transformer + of Google. + Args: + capacity_factor_train (float, optional): Capacity factor in routing of training. + capacity_factor_eval (float, optional): Capacity factor in routing of evaluation. + min_capacity (int, optional): The minimum number of the capacity of each expert. + select_policy (str, optional): The policy about tokens selection. + noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. + drop_tks (bool, optional): Whether drops tokens in evaluation + """ + + def __init__(self, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + select_policy: str = "first", + noisy_func: Callable = None, + drop_tks: bool = True): + super().__init__(k_value=1, + capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + min_capacity=min_capacity, + noisy_func=noisy_func, + drop_tks=drop_tks) + self.select_policy = select_policy + assert select_policy in {"first", "random"} + if select_policy == "random": + self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=get_current_device()), + high=torch.tensor(1.0, + device=get_current_device())).rsample + + def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None): + + if self.noisy_func is not None and self.training: + inputs = self.noisy_func(inputs) + + assert inputs.dtype == torch.float + logits = F.softmax(inputs, dim=-1) + num_experts = logits.size(-1) + capacity = self.get_capacity(logits.shape) + + top1_idx = torch.argmax(inputs, dim=-1) + mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) + + # caculate the auxiliary loss + me = torch.mean(logits, dim=0) + ce = torch.mean(mask.float(), dim=0) + l_aux = num_experts * torch.sum(me * ce) + self.set_routing_loss(l_aux) + + if not self.training and not self.drop_tks: + max_num = torch.max(torch.sum(mask, dim=0)) + dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) + capacity = max_num.item() + + if self.select_policy == "random": + rand_mask = mask * self.uniform(mask.shape) + _, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0) + mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1) + ranks = moe_cumsum(mask) + elif self.select_policy == "first": + ranks = moe_cumsum(mask) + mask = mask * torch.lt(ranks, capacity) + else: + raise NotImplementedError("Not support such select policy yet.") + + ranks = torch.sum(mask * ranks, dim=-1) + + if use_kernel: + mask = torch.sum(mask, dim=-1) + mask = torch.stack([mask], dim=0).to(torch.int32) + dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32) + return logits, mask, dest_idx, num_experts * capacity + else: + ranks = F.one_hot(ranks, num_classes=capacity) + weight = mask * logits.type_as(inputs) + combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1) + sec_mask = combine_weights.bool() + return combine_weights, sec_mask + + +class Top2Router(MoeRouter): + """Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c] + for routing usage. More detailed function can be found in the paper about ViT-MoE. + Args: + capacity_factor_train (float, optional): Capacity factor in routing of training. + capacity_factor_eval (float, optional): Capacity factor in routing of evaluation. + min_capacity (int, optional): The minimum number of the capacity of each expert + noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. + drop_tks (bool, optional): Whether drops tokens in evaluation. + """ + + def __init__(self, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + noisy_func: Callable = None, + drop_tks: bool = True): + super().__init__(k_value=2, + capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + min_capacity=min_capacity, + noisy_func=noisy_func, + drop_tks=drop_tks) + + def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None): + # inputs: [s, h] + if self.noisy_func is not None and self.training: + inputs = self.noisy_func(inputs) + + assert inputs.dtype == torch.float + logits = F.softmax(inputs, dim=-1) # logits: [s, e] + num_experts = logits.size(-1) + capacity = self.get_capacity(logits.shape) + + top1_idx = torch.argmax(logits, dim=-1) + mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) + logits_except1 = logits.masked_fill(mask1.bool(), float("-inf")) + top2_idx = torch.argmax(logits_except1, dim=-1) + mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32) + + cmask = (mask1 + mask2) # loss: [s, e] + + # caculate the auxiliary loss + me = torch.mean(logits, dim=0) + ce = torch.mean(cmask.float(), dim=0) + l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1 + self.set_routing_loss(l_aux) + + if not self.training and not self.drop_tks: + max_num = torch.max(torch.sum(cmask, dim=0)) + dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) + capacity = max_num.item() + + rank1 = moe_cumsum(mask1) # rank1: [s, e] + rank2 = moe_cumsum(mask2) + rank2 += torch.sum(mask1, dim=-2, keepdim=True) + + mask1 *= torch.lt(rank1, capacity) + mask2 *= torch.lt(rank2, capacity) + + rank1 = torch.sum(mask1 * rank1, dim=-1) + rank2 = torch.sum(mask2 * rank2, dim=-1) + + if use_kernel: + mask1 = torch.sum(mask1, dim=-1) + mask2 = torch.sum(mask2, dim=-1) + + mask = torch.stack([mask1, mask2], dim=0).to(torch.int32) + dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32) + + return logits, mask, dest_idx, num_experts * capacity + else: + weight1 = mask1 * logits.type_as(inputs) + weight2 = mask2 * logits.type_as(inputs) + rank1_sc = F.one_hot(rank1, num_classes=capacity) + rank2_sc = F.one_hot(rank2, num_classes=capacity) + + cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1) + cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1) + cb_weight = cb_weight1 + cb_weight2 + sec_mask = cb_weight.bool() + + return cb_weight, sec_mask diff --git a/colossalai/nn/layer/moe/utils.py b/colossalai/nn/layer/moe/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4ca8bd7033868706cde62eb45a98f4d88ffa2a67 --- /dev/null +++ b/colossalai/nn/layer/moe/utils.py @@ -0,0 +1,68 @@ +import torch +import torch.nn.functional as F +from colossalai.utils import get_current_device +from colossalai.context.moe_context import MOE_CONTEXT +from .experts import FFNExperts, TPExperts + + +class ForceFP32Parameter(torch.nn.Parameter): + + def half(self, memory_format=None): + return self.data.clone() + + +class NormalNoiseGenerator: + """Generates a random noisy mask for logits tensor. + + All noise is generated from a normal distribution :math:`(0, 1 / E^2)`, where + `E = the number of experts`. + + Args: + num_experts (int): The number of experts. + """ + + def __init__(self, num_experts: int): + self.normal = torch.distributions.normal.Normal(loc=torch.tensor(0.0, device=get_current_device()), + scale=torch.tensor(1.0 / num_experts**2, + device=get_current_device())).rsample + + def __call__(self, inputs: torch.Tensor): + noisy = self.normal(inputs.shape) + return inputs + noisy + + +class UniformNoiseGenerator: + """Generates a random noisy mask for logits tensor. + copied from mesh tensorflow: + Multiply values by a random number between :math:`1-epsilon` and :math:`1+epsilon`. + Makes models more resilient to rounding errors introduced by bfloat16. + This seems particularly important for logits. + + Args: + eps (float, optional): Epsilon in generator, defaults 1e-2. + """ + + def __init__(self, eps: float = 1e-2): + self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(1.0 - eps, device=get_current_device()), + high=torch.tensor(1.0 + eps, + device=get_current_device())).rsample + + def __call__(self, inputs: torch.Tensor): + noisy = self.uniform(inputs.shape) + return inputs * noisy + + +def autocast_softmax(logit: torch.Tensor, dim: int): + if logit.dtype != torch.float32: + logit = logit.float() + return F.softmax(logit, dim=dim) + + +def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): + mep_size = MOE_CONTEXT.max_ep_size + if num_experts % mep_size == 0 or mep_size % num_experts == 0: + return FFNExperts(num_experts, d_model, d_ff, activation, drop_rate) + elif d_ff % mep_size == 0: + return TPExperts(num_experts, d_model, d_ff, activation, drop_rate) + else: + raise NotImplementedError(f"Can not build {num_experts} experts in {mep_size} GPUS.") diff --git a/colossalai/nn/layer/parallel_1d/__init__.py b/colossalai/nn/layer/parallel_1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2353851df665246251bb7ef0d884dd5f961b7aac --- /dev/null +++ b/colossalai/nn/layer/parallel_1d/__init__.py @@ -0,0 +1,7 @@ +from .layers import (Classifier1D, Dropout1D, Embedding1D, LayerNorm1D, Linear1D, Linear1D_Col, Linear1D_Row, + PatchEmbedding1D, VocabParallelClassifier1D, VocabParallelEmbedding1D) + +__all__ = [ + 'Linear1D', 'Linear1D_Col', 'Linear1D_Row', 'Embedding1D', 'Dropout1D', 'Classifier1D', 'VocabParallelClassifier1D', + 'VocabParallelEmbedding1D', 'LayerNorm1D', 'PatchEmbedding1D' +] diff --git a/colossalai/nn/layer/parallel_1d/_operation.py b/colossalai/nn/layer/parallel_1d/_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..3943345582758edaa1298ac4d927eb132887e072 --- /dev/null +++ b/colossalai/nn/layer/parallel_1d/_operation.py @@ -0,0 +1,96 @@ +import torch +import torch.distributed as dist +from colossalai.core import global_context as gpc + +try: + import fused_mix_prec_layer_norm_cuda +except: + fused_mix_prec_layer_norm_cuda = None + + +class FusedLayerNormAffineFunction1D(torch.autograd.Function): + r"""Layernorm + + Args: + input: input matrix. + weight: weight matrix. + bias: bias matrix. + normalized_shape: input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]` + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps: a value added to the denominator for numerical stability + """ + + @staticmethod + def forward(ctx, input, weight, bias, normalized_shape, eps): + ctx.normalized_shape = normalized_shape + ctx.eps = eps + input_ = input.contiguous() + weight_ = weight.contiguous() + bias_ = bias.contiguous() + output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(input_, ctx.normalized_shape, weight_, + bias_, ctx.eps) + ctx.save_for_backward(input_, weight_, bias_, mean, invvar) + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight_, bias_, mean, invvar = ctx.saved_tensors + grad_input = grad_weight = grad_bias = None + grad_input, grad_weight, grad_bias \ + = fused_mix_prec_layer_norm_cuda.backward_affine( + grad_output.contiguous(), mean, invvar, + input_, ctx.normalized_shape, + weight_, bias_, ctx.eps) + + return grad_input, grad_weight, grad_bias, None, None + + +class LinearWithAsyncCommunication(torch.autograd.Function): + """ + Linear layer execution with asynchronous communication in backprop. + """ + + @staticmethod + def forward(ctx, input_, weight, bias, parallel_mode, async_grad_allreduce): + ctx.save_for_backward(input_, weight) + ctx.use_bias = bias is not None + ctx.parallel_mode = parallel_mode + ctx.async_grad_allreduce = async_grad_allreduce + + output = torch.matmul(input_, weight.t()) + if bias is not None: + output = output + bias + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + use_bias = ctx.use_bias + + total_input = input + grad_input = grad_output.matmul(weight) + + # Convert the tensor shapes to 2D for execution compatibility + grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]) + total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2]) + + if ctx.async_grad_allreduce: + # Asynchronous all-reduce + handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True) + # Delay the start of weight gradient computation shortly (3us) to have + # all-reduce scheduled first and have GPU resources allocated + _ = torch.empty(1, device=grad_output.device) + 1 + + grad_weight = grad_output.t().matmul(total_input) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.async_grad_allreduce: + handle.wait() + + return grad_input, grad_weight, grad_bias, None, None, None + + +def linear_with_async_comm(input_, weight, bias, parallel_mode, async_grad_allreduce): + return LinearWithAsyncCommunication.apply(input_, weight, bias, parallel_mode, async_grad_allreduce) diff --git a/colossalai/nn/layer/parallel_1d/_utils.py b/colossalai/nn/layer/parallel_1d/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1212d595635d7c305132edb8c74e01ab165a903b --- /dev/null +++ b/colossalai/nn/layer/parallel_1d/_utils.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch +import torch.distributed as dist +from colossalai.core import global_context as gpc +from colossalai.global_variables import tensor_parallel_env as env + +from ..utils import divide + + +def set_parallel_input(input_parallel: bool): + env.parallel_input_1d = input_parallel + + +def get_parallel_input(): + return env.parallel_input_1d + + +def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank): + index_f = rank * per_partition_vocab_size + index_l = index_f + per_partition_vocab_size + return index_f, index_l + + +def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size): + per_partition_vocab_size = divide(global_vocab_size, world_size) + return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank) + + +def _reduce(input_, parallel_mode): + # skip if only one rank involved + if gpc.get_world_size(parallel_mode) == 1: + return input_ + group = gpc.get_cpu_group(parallel_mode) if input_.device.type == "cpu" else gpc.get_group(parallel_mode) + dist.all_reduce(input_, group=group) + + return input_ + + +def _split(input_, parallel_mode, dim=-1): + # skip if only one rank involved + world_size = gpc.get_world_size(parallel_mode) + if world_size == 1: + return input_ + + # Split along last dimension. + dim_size = input_.size(dim) + assert dim_size % world_size == 0, \ + f'The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), ' \ + f'cannot split tensor evenly' + + tensor_list = torch.split(input_, dim_size // world_size, dim=dim) + rank = gpc.get_local_rank(parallel_mode) + output = tensor_list[rank].contiguous() + + return output + + +def _gather(input_, parallel_mode, dim=-1): + # skip if only one rank involved + world_size = gpc.get_world_size(parallel_mode) + if world_size == 1: + return input_ + + # all gather + rank = gpc.get_local_rank(parallel_mode) + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank] = input_ + group = gpc.get_cpu_group(parallel_mode) if input_.device.type == "cpu" else gpc.get_group(parallel_mode) + torch.distributed.all_gather(tensor_list, input_, group=group) + + # concat + output = torch.cat(tensor_list, dim=dim).contiguous() + + return output + + +class _ReduceGrad(torch.autograd.Function): + """ + Pass the input to the model parallel region. + + Args: + input_: input matrix. + parallel_mode: parallel mode. + """ + + @staticmethod + def symbolic(graph, input_): + return input_ + + @staticmethod + def forward(ctx, input_, parallel_mode): + ctx.mode = parallel_mode + return input_ + + @staticmethod + def backward(ctx, grad_output): + return _reduce(grad_output, ctx.mode), None + + +class _ReduceInput(torch.autograd.Function): + """ + All-reduce the input from the model parallel region. + + Args: + input_: input matrix. + parallel_mode: parallel mode. + """ + + @staticmethod + def symbolic(graph, input_): + return _reduce(input_) + + @staticmethod + def forward(ctx, input_, parallel_mode): + return _reduce(input_, parallel_mode) + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None + + +class _SplitForwardGatherBackward(torch.autograd.Function): + """ + Split the input and keep only the corresponding chuck to the rank. + + Args: + input_: input matrix. + parallel_mode: parallel mode. + dim: dimension + """ + + @staticmethod + def symbolic(graph, input_): + return _split(input_) + + @staticmethod + def forward(ctx, input_, parallel_mode, dim): + ctx.mode = parallel_mode + ctx.dim = dim + return _split(input_, parallel_mode, dim) + + @staticmethod + def backward(ctx, grad_output): + return _gather(grad_output, ctx.mode, ctx.dim), None, None + + +class _GatherForwardSplitBackward(torch.autograd.Function): + """Gather the input from model parallel region and concatenate. + + Args: + input_: input matrix. + parallel_mode: parallel mode. + dim: dimension + """ + + @staticmethod + def symbolic(graph, input_): + return _gather(input_) + + @staticmethod + def forward(ctx, input_, parallel_mode, dim): + ctx.mode = parallel_mode + ctx.dim = dim + return _gather(input_, parallel_mode, dim) + + @staticmethod + def backward(ctx, grad_output): + return _split(grad_output, ctx.mode, ctx.dim), None, None + + +def reduce_grad(input_, parallel_mode): + return _ReduceGrad.apply(input_, parallel_mode) + + +def reduce_input(input_, parallel_mode): + return _ReduceInput.apply(input_, parallel_mode) + + +def split_forward_gather_backward(input_, parallel_mode, dim): + return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim) + + +def gather_forward_split_backward(input_, parallel_mode, dim): + return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim) diff --git a/colossalai/nn/layer/parallel_1d/layers.py b/colossalai/nn/layer/parallel_1d/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..406173a18c6010de0b2004f641ec79df3de32dd3 --- /dev/null +++ b/colossalai/nn/layer/parallel_1d/layers.py @@ -0,0 +1,1040 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import math +from collections import OrderedDict +from typing import Callable, Tuple + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.nn.parameter import Parameter + +from colossalai.communication import broadcast +from colossalai.context import ParallelMode, seed +from colossalai.core import global_context as gpc +from colossalai.global_variables import tensor_parallel_env as env +from colossalai.kernel import LayerNorm +from colossalai.nn import init as init +from colossalai.registry import LAYERS +from colossalai.utils.checkpointing import ( + broadcast_state_dict, + gather_tensor_parallel_state_dict, + partition_tensor_parallel_state_dict, +) +from colossalai.utils.cuda import get_current_device + +from ..base_layer import ParallelLayer +from ..colossalai_layer._utils import ColossalaiModule +from ..utils import divide, set_tensor_parallel_attribute_by_partition +from ..vanilla import VanillaLayerNorm, VanillaPatchEmbedding +from ._operation import linear_with_async_comm +from ._utils import ( + gather_forward_split_backward, + get_parallel_input, + reduce_grad, + reduce_input, + set_parallel_input, + split_forward_gather_backward, +) + +Fast_LN = None +try: + from apex.contrib.layer_norm.layer_norm import FastLayerNorm + Fast_LN = FastLayerNorm +except ImportError: + pass + + +@LAYERS.register_module +class Linear1D(ColossalaiModule): + r"""Linear layer for 1D parallelism. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + gather_output (bool, optional): Whether to call all-gather on output, defaults to False. + skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + gather_output: bool = False, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + parallel_input = get_parallel_input() + if not parallel_input and not gather_output: + layer = Linear1D_Col(in_features, + out_features, + bias=bias, + dtype=dtype, + skip_bias_add=skip_bias_add, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer) + else: + layer = Linear1D_Row(in_features, + out_features, + bias=bias, + dtype=dtype, + parallel_input=parallel_input, + skip_bias_add=skip_bias_add, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer) + super().__init__(layer) + + +@LAYERS.register_module +class LayerNorm1D(ColossalaiModule): + r""" + Layer Normalization for colossalai + + Args: + normalized_shape (int): input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] + \times \ldots \times \text{normalized_shape}[-1]]` + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps (float): a value added to the denominator for numerical stability, defaults to 1e-05. + bias (bool, optional): Whether to add a bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + """ + + _fast_ln_supported_sizes = [ + 1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, + 24576, 25600, 30720, 32768, 40960, 49152, 65536 + ] + + def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None): + if Fast_LN is not None and normalized_shape in self._fast_ln_supported_sizes: + norm = Fast_LN(normalized_shape, eps=eps).to(dtype) + else: + norm = None + try: + from apex.normalization import FusedLayerNorm + norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype) + except ImportError: + norm = LayerNorm(normalized_shape, eps=eps).to(dtype) + super().__init__(norm) + + def _load_from_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D) + super()._load_from_state_dict(local_state, prefix, *args) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + super()._save_to_state_dict(destination, prefix, keep_vars) + + +@LAYERS.register_module +class Classifier1D(ParallelLayer): + r"""RowLinear with given weight. Classifier of 1D parallelism. + + Args: + in_features (int): size of each input sample. + num_classes (int): number of classes. + weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + self.in_features = in_features + self.num_classes = num_classes + self.parallel_input = get_parallel_input() + + # Divide the weight matrix along the last dimension. + self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size) + + # Parameters. + # Initialize weight. + factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + if weight is not None: + self.weight = weight + self.has_weight = False + else: + self.weight = Parameter(torch.empty(self.num_classes, self.input_size_per_partition, **factory_kwargs)) + self.has_weight = True + if bias: + self.bias = Parameter(torch.empty(self.num_classes, **factory_kwargs)) + else: + self.bias = None + with seed(ParallelMode.TENSOR): + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + set_parallel_input(False) + env.vocab_parallel = False + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.num_classes + if self.has_weight: + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], ParallelMode.PARALLEL_1D) + + def _set_tensor_parallel_attributes(self): + if self.has_weight: + num_partition = gpc.get_world_size(ParallelMode.TENSOR) + set_tensor_parallel_attribute_by_partition(self.weight, num_partition) + + def _load_from_global_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + if self.has_weight: + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + local_state = partition_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }) + super()._load_from_global_state_dict(local_state, prefix, *args) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict() + if self.has_weight: + local_state[weight_key] = self.weight + if self.bias is not None: + local_state[bias_key] = self.bias + local_state = gather_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + keep_vars=keep_vars) + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + # Set up backprop all-reduce. + if self.parallel_input: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + input_ = input_ + else: + assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \ + 'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size) + input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1) + + output_parallel = F.linear(input_, self.weight) + output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) + if self.bias is not None: + output = output + self.bias + return output + + +@LAYERS.register_module +class VocabParallelClassifier1D(ParallelLayer): + r"""ColLinear with given weight. Classifier of 1D parallelism. + + Args: + in_features (int): size of each input sample. + num_classes (int): number of classes. + weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + gather_output: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + self.in_features = in_features + self.num_classes = num_classes + self.gather_output = gather_output + self.parallel_input = get_parallel_input() + + # Divide the weight matrix along the last dimension. + self.num_classes_per_partition = divide(num_classes, gpc.tensor_parallel_size) + + # Parameters. + # Initialize weight. + factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + if weight is not None: + self.weight = weight + self.has_weight = False + else: + self.weight = Parameter(torch.empty(self.num_classes_per_partition, self.in_features, **factory_kwargs)) + self.has_weight = True + if bias: + self.bias = Parameter(torch.empty(self.num_classes_per_partition, **factory_kwargs)) + else: + self.bias = None + with seed(ParallelMode.TENSOR): + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + set_parallel_input(False) + env.vocab_parallel = True + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.num_classes + if self.has_weight: + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def _set_tensor_parallel_attributes(self): + num_partition = gpc.get_world_size(ParallelMode.TENSOR) + if self.has_weight: + set_tensor_parallel_attribute_by_partition(self.weight, num_partition) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, num_partition) + + def _load_from_global_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + if self.has_weight: + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + local_state = partition_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }) + super()._load_from_global_state_dict(local_state, prefix, *args) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict() + if self.has_weight: + local_state[weight_key] = self.weight + if self.bias is not None: + local_state[bias_key] = self.bias + local_state = gather_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars) + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in VocabParallelClassifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + # Set up backprop all-reduce. + input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) + # Matrix multiply. + output_parallel = F.linear(input_parallel, self.weight, self.bias) + if self.gather_output: + # All-gather across the partitions. + output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) + else: + output = output_parallel + return output + + +@LAYERS.register_module +class Linear1D_Col(ParallelLayer): + r"""Linear layer with column parallelism. + + The linear layer is defined as :math:`Y = XA + b`. A is parallelized along + its second dimension as :math:`A = [A_1, ..., A_p]`. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + gather_output (bool, optional): If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is :math:`Y_i = XA_i`, defaults to False + skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + gather_output: bool = False, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.gather_output = gather_output + self.skip_bias_add = skip_bias_add + + if skip_bias_add and not bias: + raise ValueError('cannot skip bias addition if bias is None') + + self.out_features_per_partition = divide(out_features, gpc.tensor_parallel_size) + + # Parameters. + # Initialize weight. + factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs)) + + if bias: + self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs)) + else: + self.bias = None + with seed(ParallelMode.TENSOR): + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + is_parallel_output = not self.gather_output + set_parallel_input(is_parallel_output) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def _set_tensor_parallel_attributes(self): + num_partition = gpc.get_world_size(ParallelMode.TENSOR) + set_tensor_parallel_attribute_by_partition(self.weight, num_partition) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, num_partition) + + def _load_from_global_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + local_state = partition_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }) + super()._load_from_global_state_dict(local_state, prefix, *args) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict({weight_key: self.weight}) + if self.bias is not None: + local_state[bias_key] = self.bias + local_state = gather_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars) + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + # Set up backprop all-reduce. + # input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) + input_parallel = input_ + # Matrix multiply. + bias = self.bias if not self.skip_bias_add else None + # output_parallel = F.linear(input_parallel, self.weight, bias) + output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, ParallelMode.PARALLEL_1D, True) + if self.gather_output: + # All-gather across the partitions. + output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) + else: + output = output_parallel + + if self.skip_bias_add: + return output, self.bias + else: + return output + + +@LAYERS.register_module +class Linear1D_Row(ParallelLayer): + r""" Linear layer with row parallelism + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + parallel_input (bool, optional): If set to ``True``, it's assumed that the input is split, defaults to False. + skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + parallel_input: bool = True, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + stream_chunk_num: int = 1): + super().__init__() + + self.stream_chunk_num = stream_chunk_num + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.parallel_input = parallel_input + self.skip_bias_add = skip_bias_add + + if skip_bias_add and not bias: + raise ValueError('cannot skip bias addition if bias is None') + + # Divide the weight matrix along the last dimension. + self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size) + + # Parameters. + # Initialize weight. + factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs)) + + if self.stream_chunk_num > 1: + # TODO() work for inference only + self.chunk_weight() + if bias: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + self.bias = None + with seed(ParallelMode.TENSOR): + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + set_parallel_input(False) + + def chunk_weight(self): + self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], ParallelMode.PARALLEL_1D) + + def _set_tensor_parallel_attributes(self): + num_partition = gpc.get_world_size(ParallelMode.TENSOR) + set_tensor_parallel_attribute_by_partition(self.weight, num_partition) + + def _load_from_global_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + local_state = partition_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }) + super()._load_from_global_state_dict(local_state, prefix, *args) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict({weight_key: self.weight}) + if self.bias is not None: + local_state[bias_key] = self.bias + local_state = gather_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + keep_vars=keep_vars) + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + # Set up backprop all-reduce. + if self.parallel_input: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + input_ = input_ + else: + assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size) + input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1) + + if self.stream_chunk_num > 1: + if self.training: + raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!") + with torch.no_grad(): + output_parallel_list = [None for i in range(self.stream_chunk_num)] + handle_list = [] + for i in range(self.stream_chunk_num): + output_parallel_list[i] = F.linear(input_, self.weight_list[i]) + handle = torch.distributed.all_reduce(output_parallel_list[i], + group=gpc.get_group(ParallelMode.PARALLEL_1D), + async_op=True) + handle_list.append(handle) + # output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D) + for handle in handle_list: + handle.wait() + output = torch.cat(output_parallel_list, dim=-1) + else: + output_parallel = F.linear(input_, self.weight) + # output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False) + output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) + if not self.skip_bias_add: + if self.bias is not None: + output = output + self.bias + return output + else: + return output, self.bias + + +@LAYERS.register_module +class Embedding1D(ParallelLayer): + r"""Embedding for 1D parallelism. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed “pad”, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:`torch.nn.functional.embedding` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about ``initializer`` please refer to + `init `_ + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs): + super().__init__() + + self.num_embeddings = num_embeddings + self.embed_dim = embedding_dim + embed_dim_per_partition = divide(embedding_dim, gpc.tensor_parallel_size) + + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + + self.weight = Parameter( + torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)) + + self.reset_parameters(weight_initializer) + self._set_tensor_parallel_attributes() + set_parallel_input(False) + + def _set_tensor_parallel_attributes(self): + set_tensor_parallel_attribute_by_partition(self.weight, gpc.tensor_parallel_size) + + def reset_parameters(self, weight_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.num_embeddings, self.embed_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def _load_from_global_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + 'weight' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + + local_state = partition_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: -1}, + partition_states={weight_key: True}) + super()._load_from_global_state_dict(local_state, prefix, *args) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + local_state = OrderedDict({weight_key: self.weight}) + local_state = gather_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: -1}, + partition_states={weight_key: True}, + keep_vars=keep_vars) + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + + output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) + + output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) + + return output + + +@LAYERS.register_module +class VocabParallelEmbedding1D(ParallelLayer): + r"""Embedding parallelized in the vocabulary dimension. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed “pad”, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about initializer please refer to + `init `_. + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs): + super().__init__() + self.num_embeddings = num_embeddings + self.embed_dim = embedding_dim + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + + tensor_parallel_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) + tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size) + self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition + self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition + + self.weight = Parameter( + torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=get_current_device(), dtype=dtype)) + + self.reset_parameters(weight_initializer) + self._set_tensor_parallel_attributes() + set_parallel_input(False) + env.vocab_parallel = True + + def _set_tensor_parallel_attributes(self): + set_tensor_parallel_attribute_by_partition(self.weight, gpc.tensor_parallel_size) + + def reset_parameters(self, weight_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.num_embeddings, self.embed_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None and \ + self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: + with torch.no_grad(): + self.weight[self.padding_idx - self.vocab_start_index].fill_(0) + + def _load_from_global_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + 'weight' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + + local_state = partition_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: 0}, + partition_states={weight_key: True}) + super()._load_from_global_state_dict(local_state, prefix, *args) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + local_state = OrderedDict({weight_key: self.weight}) + local_state = gather_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: 0}, + partition_states={weight_key: True}, + keep_vars=keep_vars) + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + # Build the mask. + input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) + # Mask the input. + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + + output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args, + **self.embed_kwargs) + + # Mask the output embedding. + output_parallel[input_mask, :] = 0. + # Reduce across all the model parallel GPUs. + output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) + return output + + +@LAYERS.register_module +class Dropout1D(ParallelLayer): + """Dropout layer of 1D parallelism. + + Args: + p (float, optional): probability of an element to be zeroed, defaults 0.5. + inplace (bool, optional): whether to do dropout in-place, default to be False. + """ + + def __init__(self, p: float = 0.5, inplace: bool = False): + super().__init__() + self.parallel_input = get_parallel_input() + self.p = p + self.inplace = inplace + + def forward(self, input_: Tensor) -> Tensor: + if self.parallel_input: + with seed(ParallelMode.TENSOR): + output = F.dropout(input_, self.p, self.training, self.inplace) + else: + output = F.dropout(input_, self.p, self.training, self.inplace) + return output + + +@LAYERS.register_module +class PatchEmbedding1D(ColossalaiModule): + """ + 2D Image to Patch Embedding + + :param img_size: image size + :type img_size: int + :param patch_size: patch size + :type patch_size: int + :param in_chans: number of channels of input image + :type in_chans: int + :param embed_size: size of embedding + :type embed_size: int + :param dtype: The dtype of parameters, defaults to None + :type dtype: torch.dtype, optional + :param flatten: whether to flatten output tensor, defaults to True + :type flatten: bool, optional + :param weight_initializer: The initializer of weight, defaults to kaiming uniform initializer + :type weight_initializer: typing.Callable, optional + :param bias_initializer: The initializer of bias, defaults to xavier uniform initializer + :type bias_initializer: typing.Callable, optional + :param position_embed_initializer: The initializer of position embedding, defaults to zero + :type position_embed_initializer: typing.Callable, optional + """ + + def __init__(self, + img_size: int, + patch_size: int, + in_chans: int, + embed_size: int, + dtype: torch.dtype = None, + flatten: bool = True, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + position_embed_initializer: Callable = init.zeros_()): + embed = VanillaPatchEmbedding(img_size, + patch_size, + in_chans, + embed_size, + dtype=dtype, + flatten=flatten, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + position_embed_initializer=position_embed_initializer) + super().__init__(embed) + + def _load_from_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + param_keys = [prefix + 'weight', prefix + 'bias', prefix + 'cls_token', prefix + 'pos_embed'] + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + for key in param_keys: + param = state_dict.pop(key, None) + if param is not None: + local_state[key] = param + + local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D) + super()._load_from_state_dict(local_state, prefix, *args) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + super()._save_to_state_dict(destination, prefix, keep_vars) diff --git a/colossalai/nn/layer/parallel_2d/__init__.py b/colossalai/nn/layer/parallel_2d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5562d1a700361c23bd8238848b5141c05c9b25aa --- /dev/null +++ b/colossalai/nn/layer/parallel_2d/__init__.py @@ -0,0 +1,8 @@ +from ._operation import reduce_by_batch_2d, split_batch_2d +from .layers import (Classifier2D, Embedding2D, LayerNorm2D, Linear2D, PatchEmbedding2D, VocabParallelClassifier2D, + VocabParallelEmbedding2D) + +__all__ = [ + 'split_batch_2d', 'reduce_by_batch_2d', 'Linear2D', 'LayerNorm2D', 'Classifier2D', 'PatchEmbedding2D', + 'Embedding2D', 'VocabParallelEmbedding2D', 'VocabParallelClassifier2D' +] diff --git a/colossalai/nn/layer/parallel_2d/_operation.py b/colossalai/nn/layer/parallel_2d/_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..306577dbd9333987bb181d70ef21fcff1d548b7c --- /dev/null +++ b/colossalai/nn/layer/parallel_2d/_operation.py @@ -0,0 +1,849 @@ +from typing import Any, Optional, Tuple + +import torch +import torch.distributed as dist +from colossalai.communication.collective import (all_gather, all_reduce, reduce, reduce_scatter) +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.utils import get_current_device +from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd +from colossalai.global_variables import tensor_parallel_env as env + + +def matmul_2d( + a, + b, + summa_dim, + out_shape, + row_rank=None, + col_rank=None, + row_parallel_mode=ParallelMode.PARALLEL_2D_ROW, + col_parallel_mode=ParallelMode.PARALLEL_2D_COL, +): + r"""Matrix multiplication for 2D parallelism. + + Args: + a (:class:`torch.tensor`): matrix :math:`A`. + b (:class:`torch.tensor`): matrix :math:`B`. + summa_dim (int): dimension of SUMMA fo 2D parallelism. + out_shape (:class:`torch.size`): shape of output tensor. + row_rank (int, optional): the rank of row, defaults to None. + col_rank (int, optional): the rank of column, defaults to None. + row_parallel_mode (:class:`colossalai.context.ParallelMode`, optional): + row parallel mode, defaults to ParallelMode.PARALLEL_2D_ROW. + col_parallel_mode (:class:`colossalai.context.ParallelMode`, optional): + column parallel mode, defaults to ParallelMode.PARALLEL_2D_COL. + + Returns: + :class:`torch.tensor`: :math:`C = AB`. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + if row_rank is None: + row_rank = gpc.get_local_rank(col_parallel_mode) + if col_rank is None: + col_rank = gpc.get_local_rank(row_parallel_mode) + + data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) + pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank( + ParallelMode.PIPELINE) + pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( + ParallelMode.PIPELINE) + tensor_parallel_size = summa_dim**2 + return Matmul_AB_2D(a, b, summa_dim, out_shape, row_rank, col_rank, row_parallel_mode, col_parallel_mode, + data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size) + + +class _Classifier2D(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward( + ctx: Any, + A: Tensor, + B: Tensor, + bias: Optional[Tensor], + summa_dim: int, + out_shape: Tuple[int, ...], + row_rank: int, + col_rank: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, + ) -> Tensor: + A = A.clone().detach() + A_shape = A.shape + A = A.reshape((-1, A_shape[-1])) + B_shape = B.shape + B = B.reshape((-1, B_shape[-1])) + B_temp = all_gather(B, -1, col_parallel_mode) + if ctx: + ctx.save_for_backward(A, B_temp) + + C = torch.matmul(A, B_temp.transpose(0, 1)) + + C = all_reduce(C, row_parallel_mode) + + ctx.use_bias = bias is not None + if bias is not None: + C = C + bias + + out = C.reshape(out_shape) + + if ctx: + ctx.summa_dim = summa_dim + ctx.row_rank = row_rank + ctx.col_rank = col_rank + ctx.row_parallel_mode = row_parallel_mode + ctx.col_parallel_mode = col_parallel_mode + ctx.A_shape = A_shape + ctx.B_shape = B_shape + ctx.data_parallel_rank = data_parallel_rank + ctx.pipeline_parallel_rank = pipeline_parallel_rank + ctx.pipeline_parallel_size = pipeline_parallel_size + ctx.tensor_parallel_size = tensor_parallel_size + + return out + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + A, B = ctx.saved_tensors + + with torch.no_grad(): + A_grad = torch.matmul(output_grad, B) + A_grad = A_grad.reshape(ctx.A_shape) + B_grad = torch.matmul(output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), A) + B_grad = reduce_scatter(B_grad, -1, ctx.col_parallel_mode) + B_grad = B_grad.reshape(ctx.B_shape) + if ctx.use_bias: + bias_grad = torch.sum(output_grad, dim=tuple(range(output_grad.ndim - 1))) + bias_grad = all_reduce(bias_grad, ctx.col_parallel_mode) + else: + bias_grad = None + + return A_grad, B_grad, bias_grad, None, None, None, None, None, None, None, None, None, None + + +def classifier_2d(A: Tensor, B: Tensor, bias: Optional[Tensor], summa_dim: int, out_shape: Tuple[int, ...], + row_rank: int, col_rank: int, row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode, + data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int, + tensor_parallel_size: int) -> Tensor: + r"""2D parallel classifier. + + Args: + A (:class:`torch.tensor`): matrix :math:`A`. + B (:class:`torch.tensor`): matrix :math:`B`. + bias (:class:`torch.tensor`, optional): matrix of bias. + summa_dim (int): dimension of SUMMA fo 2D parallelism. + out_shape (:class:`torch.size`): shape of output tensor. + row_rank (int, optional): the rank of row, defaults to None. + col_rank (int, optional): the rank of column, defaults to None. + row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. + col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. + data_parallel_rank (int): data parallel rank. + pipeline_parallel_rank (int): pipeline parallel rank + pipeline_parallel_size (int): pipeline parallel size. + tensor_parallel_size (int): tensor parallel size. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + return _Classifier2D.apply(A, B, bias, summa_dim, out_shape, row_rank, col_rank, row_parallel_mode, + col_parallel_mode, data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, + tensor_parallel_size) + + +class Matmul_AB_2D(torch.autograd.Function): + r"""Matrix multiplication for :math:`C = AB`. + + Args: + A (:class:`torch.tensor`): matrix :math:`A`. + B (:class:`torch.tensor`): matrix :math:`B`. + summa_dim (int): dimension of SUMMA fo 2D parallelism. + out_shape (:class:`torch.size`): shape of output tensor. + row_rank (int, optional): the rank of row, defaults to None. + col_rank (int, optional): the rank of column, defaults to None. + row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. + col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. + data_parallel_rank (int): data parallel rank. + pipeline_parallel_rank (int): pipeline parallel rank + pipeline_parallel_size (int): pipeline parallel size. + tensor_parallel_size (int): tensor parallel size. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward( + ctx: Any, + A: Tensor, + B: Tensor, + summa_dim: int, + out_shape: Tuple[int, ...], + row_rank: int, + col_rank: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, + ) -> Tensor: + # A: [b / q, s, h / q] -> [(b * s) / q, h / q] + # B: [h / q, s / q] + # C: [b / q, s, s / q] -> [(b * s) / q, s / q] + + assert A.shape[-1] == B.shape[-2], \ + 'Invalid shapes: A={}, B={} for AB.'.format(A.shape, B.shape) + + if ctx: + ctx.save_for_backward(A, B) + + A_shape = A.shape + A = A.reshape((-1, A_shape[-1])) + B_shape = B.shape + B = B.reshape((-1, B_shape[-1])) + C_shape = (A.shape[0], B.shape[-1]) + C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device()) + + # use circular buffer to store the communication tensor + # 2 is enough for all cases + A_list = [torch.empty_like(A) for _ in range(2)] + B_list = [torch.empty_like(B) for _ in range(2)] + + row_group = gpc.get_group(row_parallel_mode) + col_group = gpc.get_group(col_parallel_mode) + + src_a = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + src_b = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + + opa = [None] * 2 + opb = [None] * 2 + + A_list[0].copy_(A) + B_list[0].copy_(B) + opa[0] = dist.broadcast(A_list[0], src=src_a, group=row_group, async_op=True) + opb[0] = dist.broadcast(B_list[0], src=src_b, group=col_group, async_op=True) + cur = 0 + + for i in range(summa_dim): + if i != summa_dim - 1: + A_list[1 - cur].copy_(A) + opa[1 - cur] = dist.broadcast(A_list[1 - cur], src=src_a + 1, group=row_group, async_op=True) + B_list[1 - cur].copy_(B) + opb[1 - cur] = dist.broadcast(B_list[1 - cur], src=src_b + summa_dim, group=col_group, async_op=True) + + if opa[cur] is not None: + opa[cur].wait() + if opb[cur] is not None: + opb[cur].wait() + + torch.addmm(C, A_list[cur], B_list[cur], out=C) + cur = 1 - cur + src_a += 1 + src_b += summa_dim + + out = C.reshape(out_shape) + + if ctx: + ctx.summa_dim = summa_dim + ctx.row_rank = row_rank + ctx.col_rank = col_rank + ctx.row_parallel_mode = row_parallel_mode + ctx.col_parallel_mode = col_parallel_mode + ctx.A_shape = A_shape + ctx.B_shape = B_shape + ctx.data_parallel_rank = data_parallel_rank + ctx.pipeline_parallel_rank = pipeline_parallel_rank + ctx.pipeline_parallel_size = pipeline_parallel_size + ctx.tensor_parallel_size = tensor_parallel_size + return out + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + A, B = ctx.saved_tensors + with torch.no_grad(): + A_grad = Matmul_ABT_2D.apply(output_grad, B, ctx.summa_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank, + ctx.row_parallel_mode, ctx.col_parallel_mode, ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size, + ctx.tensor_parallel_size) + B_grad = Matmul_ATB_2D.apply(A, output_grad, ctx.summa_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank, + ctx.row_parallel_mode, ctx.col_parallel_mode, ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size, + ctx.tensor_parallel_size) + return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None + + +class Matmul_ABT_2D(torch.autograd.Function): + r"""Matrix multiplication for :math:`C = AB^T` + + Args: + A (:class:`torch.tensor`): matrix :math:`A`. + B (:class:`torch.tensor`): matrix :math:`B`. + summa_dim (int): dimension of SUMMA fo 2D parallelism. + out_shape (:class:`torch.size`): shape of output tensor. + row_rank (int, optional): the rank of row, defaults to None. + col_rank (int, optional): the rank of column, defaults to None. + row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. + col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. + column parallel mode, defaults to ParallelMode.PARALLEL_2D_COL. + data_parallel_rank (int): data parallel rank. + pipeline_parallel_rank (int): pipeline parallel rank + pipeline_parallel_size (int): pipeline parallel size. + tensor_parallel_size (int): tensor parallel size. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward( + ctx: Any, + A: Tensor, + B: Tensor, + summa_dim: int, + out_shape: Tuple[int, ...], + row_rank: int, + col_rank: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, + ) -> Tensor: + + assert A.shape[-1] == B.shape[-1], \ + 'Invalid shapes: A={}, B={} for ABT.'.format(A.shape, B.shape) + + if ctx: + ctx.save_for_backward(A, B) + + A_shape = A.shape + A = A.reshape((-1, A_shape[-1])) + B_shape = B.shape + B = B.reshape((-1, B_shape[-1])) + C_shape = (A.shape[0], B.shape[0]) + C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) + + # use circular buffer to store the communication tensor + # 2 is enough for all cases + B_list = [torch.empty_like(B) for _ in range(2)] + C_list = [torch.empty_like(C) for _ in range(2)] + + row_group = gpc.get_group(row_parallel_mode) + col_group = gpc.get_group(col_parallel_mode) + + src_b = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + src_c = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + + opb = [None] * 2 + opr = [None] * 2 + + B_list[0].copy_(B) + opb[0] = dist.broadcast(B_list[0], src=src_b, group=col_group, async_op=True) + cur = 0 + + for i in range(summa_dim): + if i != summa_dim - 1: + B_list[1 - cur].copy_(B) + opb[1 - cur] = dist.broadcast(B_list[1 - cur], src=src_b + summa_dim, group=col_group, async_op=True) + + if opr[cur] is not None: + opr[cur].wait() + if i - 2 == col_rank: + C.copy_(C_list[cur]) + + if opb[cur] is not None: + opb[cur].wait() + + torch.matmul(A, B_list[cur].transpose(0, 1), out=C_list[cur]) + opr[cur] = dist.reduce(C_list[cur], dst=src_c, group=row_group, async_op=True) + cur = 1 - cur + src_b += summa_dim + src_c += 1 + + for op in opr: + op.wait() + + if summa_dim - 2 == col_rank: + C.copy_(C_list[cur]) + if summa_dim - 1 == col_rank: + C.copy_(C_list[1 - cur]) + out = C.reshape(out_shape) + + if ctx: + ctx.summa_dim = summa_dim + ctx.row_rank = row_rank + ctx.col_rank = col_rank + ctx.row_parallel_mode = row_parallel_mode + ctx.col_parallel_mode = col_parallel_mode + ctx.A_shape = A_shape + ctx.B_shape = B_shape + ctx.data_parallel_rank = data_parallel_rank + ctx.pipeline_parallel_rank = pipeline_parallel_rank + ctx.pipeline_parallel_size = pipeline_parallel_size + ctx.tensor_parallel_size = tensor_parallel_size + + return out + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + A, B = ctx.saved_tensors + + with torch.no_grad(): + A_grad = Matmul_AB_2D.apply(output_grad, B, ctx.summa_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank, + ctx.row_parallel_mode, ctx.col_parallel_mode, ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size, + ctx.tensor_parallel_size) + B_grad = Matmul_ATB_2D.apply(output_grad, A, ctx.summa_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank, + ctx.row_parallel_mode, ctx.col_parallel_mode, ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size, + ctx.tensor_parallel_size) + return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None + + +class Matmul_ATB_2D(torch.autograd.Function): + r"""Matrix multiplication for :math:`C = A^TB`. + + Args: + A (:class:`torch.tensor`): matrix :math:`A`. + B (:class:`torch.tensor`): matrix :math:`B`. + summa_dim (int): dimension of SUMMA fo 2D parallelism. + out_shape (:class:`torch.size`): shape of output tensor. + row_rank (int, optional): the rank of row, defaults to None. + col_rank (int, optional): the rank of column, defaults to None. + row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. + col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. + data_parallel_rank (int): data parallel rank. + pipeline_parallel_rank (int): pipeline parallel rank + pipeline_parallel_size (int): pipeline parallel size. + tensor_parallel_size (int): tensor parallel size. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward( + ctx: Any, + A: Tensor, + B: Tensor, + summa_dim: int, + out_shape: Tuple[int, ...], + row_rank: int, + col_rank: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, + ) -> Tensor: + + assert A.shape[-2] == B.shape[-2], \ + 'Invalid shapes: A={}, B={} for ATB.'.format(A.shape, B.shape) + + if ctx: + ctx.save_for_backward(A, B) + + A_shape = A.shape + A = A.reshape((-1, A_shape[-1])) + B_shape = B.shape + B = B.reshape((-1, B_shape[-1])) + C_shape = (A.shape[-1], B.shape[-1]) + C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) + + # use circular buffer to store the communication tensor + # 2 is enough for all cases + A_list = [torch.empty_like(A) for _ in range(2)] + C_list = [torch.empty_like(C) for _ in range(2)] + + row_group = gpc.get_group(row_parallel_mode) + col_group = gpc.get_group(col_parallel_mode) + + src_a = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + src_c = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + + opa = [None] * 2 + opr = [None] * 2 + + A_list[0].copy_(A) + opa[0] = dist.broadcast(A_list[0], src=src_a, group=row_group, async_op=True) + cur = 0 + + for i in range(summa_dim): + if i != summa_dim - 1: + A_list[1 - cur].copy_(A) + opa[1 - cur] = dist.broadcast(A_list[1 - cur], src=src_a + 1, group=row_group, async_op=True) + + if opr[cur] is not None: + opr[cur].wait() + if i - 2 == row_rank: + C.copy_(C_list[cur]) + + if opa[cur] is not None: + opa[cur].wait() + + torch.matmul(A_list[cur].transpose(0, 1), B, out=C_list[cur]) + opr[cur] = dist.reduce(C_list[cur], dst=src_c, group=col_group, async_op=True) + cur = 1 - cur + src_a += 1 + src_c += summa_dim + + for op in opr: + op.wait() + + if summa_dim - 2 == row_rank: + C.copy_(C_list[cur]) + if summa_dim - 1 == row_rank: + C.copy_(C_list[1 - cur]) + out = C.reshape(out_shape) + + if ctx: + ctx.summa_dim = summa_dim + ctx.row_rank = row_rank + ctx.col_rank = col_rank + ctx.row_parallel_mode = row_parallel_mode + ctx.col_parallel_mode = col_parallel_mode + ctx.A_shape = A_shape + ctx.B_shape = B_shape + ctx.data_parallel_rank = data_parallel_rank + ctx.pipeline_parallel_rank = pipeline_parallel_rank + ctx.pipeline_parallel_size = pipeline_parallel_size + ctx.tensor_parallel_size = tensor_parallel_size + + return out + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + A, B = ctx.saved_tensors + + with torch.no_grad(): + A_grad = Matmul_ABT_2D.apply(B, output_grad, ctx.summa_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank, + ctx.row_parallel_mode, ctx.col_parallel_mode, ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size, + ctx.tensor_parallel_size) + B_grad = Matmul_AB_2D.apply(A, output_grad, ctx.summa_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank, + ctx.row_parallel_mode, ctx.col_parallel_mode, ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size, + ctx.tensor_parallel_size) + return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None + + +class _Add_Bias_2D(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward( + ctx: Any, + input_: Tensor, + bias: Tensor, + output_size_per_partition: int, + row_rank: int, + col_rank: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, + skip_bias_add: bool, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, + ) -> Tensor: + bias_temp = all_gather(bias, -1, col_parallel_mode) + + ctx.row_rank = row_rank + ctx.col_rank = col_rank + ctx.row_parallel_mode = row_parallel_mode + ctx.col_parallel_mode = col_parallel_mode + ctx.bias = skip_bias_add + ctx.data_parallel_rank = data_parallel_rank + ctx.pipeline_parallel_rank = pipeline_parallel_rank + ctx.pipeline_parallel_size = pipeline_parallel_size + ctx.tensor_parallel_size = tensor_parallel_size + + if skip_bias_add: + return bias_temp + else: + output = input_ + bias_temp + return output + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + col_parallel_mode = ctx.col_parallel_mode + + if ctx.bias: + grad = reduce_scatter(output_grad, -1, col_parallel_mode) + return None, grad, None, None, None, None, None, None, None, None, None, None + else: + reduce_dim = tuple(range(output_grad.ndim - 1)) + reduce = torch.sum(output_grad, dim=reduce_dim) + grad = reduce_scatter(reduce, -1, col_parallel_mode) + return output_grad, grad, None, None, None, None, None, None, None, None, None, None + + +def add_bias_2d(input_: Tensor, bias: Tensor, output_size_per_partition: int, row_rank: int, col_rank: int, + row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode, skip_bias_add: bool, + data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int, + tensor_parallel_size: int) -> Tensor: + r"""Matrix add bias: :math:`C = A + b`. + + Args: + input_ (:class:`torch.tensor`): matrix :math:`A`. + bias (:class:`torch.tensor`): matrix :math:`B`. + output_size_per_partition (int): size of output per partition. + row_rank (int, optional): the rank of row, defaults to None. + col_rank (int, optional): the rank of column, defaults to None. + row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. + col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. + skip_bias_add (bool): + If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion. + data_parallel_rank (int): data parallel rank. + pipeline_parallel_rank (int): pipeline parallel rank + pipeline_parallel_size (int): pipeline parallel size. + tensor_parallel_size (int): tensor parallel size. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + return _Add_Bias_2D.apply(input_, bias, output_size_per_partition, row_rank, col_rank, row_parallel_mode, + col_parallel_mode, skip_bias_add, data_parallel_rank, pipeline_parallel_rank, + pipeline_parallel_size, tensor_parallel_size) + + +class _Layernorm_2D(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx: Any, input_: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode) -> Tensor: + input_ = input_ - E_x + # in here, input = x - E[x], Var_x = 1 / sqrt(Var[x] + eps) + ctx.normalized_shape = hidden_size + output = input_ * Var_x + ctx.save_for_backward(output, Var_x) + ctx.row_parallel_mode = row_parallel_mode + ctx.col_parallel_mode = col_parallel_mode + return output + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + row_parallel_mode = ctx.row_parallel_mode + col_parallel_mode = ctx.col_parallel_mode + x, Var_x = ctx.saved_tensors + # in here, Var_x = 1 / sqrt(Var[x] + eps), x = (x - E[x]) * Var_x + output_grad_sum = torch.sum(output_grad, dim=-1, keepdim=True) + torch.distributed.all_reduce(output_grad_sum, group=gpc.get_group(row_parallel_mode)) + output_grad_sum /= ctx.normalized_shape + + output_grad_mul_x_sum = torch.sum(output_grad * x, dim=-1, keepdim=True) + torch.distributed.all_reduce(output_grad_mul_x_sum, group=gpc.get_group(row_parallel_mode)) + output_grad_mul_x_sum /= ctx.normalized_shape + + input_grad = output_grad.clone() + input_grad -= x * output_grad_mul_x_sum + input_grad -= output_grad_sum + input_grad *= Var_x + + return input_grad, None, None, None, None, None + + +def layernorm_2d(input_: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode) -> Tensor: + r"""Layernorm. + + Args: + input_ (:class:`torch.tensor`): input matrix. + E_x (:class:`torch.tensor`): mean. + Var_x (:class:`torch.tensor`): variance. + hidden_size (int): hidden size. + row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. + col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + return _Layernorm_2D.apply(input_, E_x, Var_x, hidden_size, row_parallel_mode, col_parallel_mode) + + +class _AllGatherTensor2D(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx: Any, inputs: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: + ctx.dim = dim + ctx.parallel_mode = parallel_mode + + outputs = all_gather(inputs, dim, parallel_mode) + return outputs + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + grad = reduce_scatter(output_grad, ctx.dim, ctx.parallel_mode) + return grad.contiguous(), None, None + + +def all_gather_tensor_2d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: + r"""All gather the tensor of 2D parallelism. + + Args: + tensor (:class:`torch.tensor`): Input tensor. + dim (int): Dimension to gather. + parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode tensor used. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + return _AllGatherTensor2D.apply(tensor, dim, parallel_mode) + + +def split_batch_2d(input_: Tensor, dim: int = 0) -> Tensor: + """Splits 2D tensor in specified dimension across cols. + + Args: + input_ (:class:`torch.tensor`): Input tensor. + dim (int): Specified dimension in which to split. + + Returns: + :class:`torch.tensor`: The tensor has been split. + """ + dim_size = input_.size(dim) + world_size = gpc.get_world_size(ParallelMode.PARALLEL_2D_COL) + + if world_size <= 1: + return input_ + + assert dim_size % world_size == 0, \ + f'The batch size ({dim_size}) is not a multiple of 2D size ({world_size}).' + + return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2D_COL), + dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)].contiguous() + + +class _ReduceTensor2D(torch.autograd.Function): + + @staticmethod + def forward(ctx, input_, parallel_mode): + return all_reduce(input_, parallel_mode) + + @staticmethod + def backward(ctx, output_grad): + return output_grad, None + + +def reduce_tensor_2d(input_: Tensor, parallel_mode: ParallelMode) -> Tensor: + r"""All-reduce the input. + + Args: + input_ (:class:`torch.tensor`): Input tensor. + parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode tensor used. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + return _ReduceTensor2D.apply(input_, parallel_mode) + + +class _ReduceScatterTensor2D(torch.autograd.Function): + + @staticmethod + def forward(ctx, input_, dim, parallel_mode): + ctx.dim = dim + ctx.parallel_mode = parallel_mode + return reduce_scatter(input_, dim, parallel_mode) + + @staticmethod + def backward(ctx, output_grad): + return all_gather(output_grad, ctx.dim, ctx.parallel_mode), None, None + + +def reduce_scatter_tensor_2d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: + r"""Reduce-scatter the input. + + Args: + tensor (:class:`torch.tensor`): Input tensor. + dim (int): Dimension to reduce. + parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode tensor used. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + dim_size = tensor.size(dim) + world_size = gpc.get_world_size(parallel_mode) + assert dim_size % world_size == 0, \ + f'The batch size ({dim_size}) is not a multiple of 2D size ({world_size}).' + + return _ReduceScatterTensor2D.apply(tensor, dim, parallel_mode) + + +class _ReduceByBatch2D(torch.autograd.Function): + + @staticmethod + def symbolic(graph, input_, reduce_mean: bool = False): + output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL) + if reduce_mean: + reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2D_COL) + return output / reduce_size + return output + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, input_, reduce_mean: bool = False): + output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL) + ctx.reduce_mean = reduce_mean + if reduce_mean: + reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2D_COL) + ctx.reduce_size = reduce_size + return output.clone() / reduce_size + return output.clone() + + @staticmethod + @custom_bwd + def backward(ctx, output_grad): + if ctx.reduce_mean: + return output_grad / ctx.reduce_size, None + else: + return output_grad, None + + +def reduce_by_batch_2d(input_, reduce_mean: bool = False) -> Tensor: + r"""All-reduce the input from the model parallel region. + + Args: + input_ (:class:`torch.tensor`): input matrix. + reduce_mean (bool, optional): + If set to ``True``, it will divide the output by column parallel size, default to False. + """ + return _ReduceByBatch2D.apply(input_, reduce_mean) diff --git a/colossalai/nn/layer/parallel_2d/_utils.py b/colossalai/nn/layer/parallel_2d/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..012fec41c80231165ceb92e57e2f449e61fdb8b2 --- /dev/null +++ b/colossalai/nn/layer/parallel_2d/_utils.py @@ -0,0 +1,20 @@ +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.global_variables import tensor_parallel_env as env + + +def get_summa_dim_from_env() -> int: + try: + summa_dim = env.summa_dim + assert summa_dim > 0, 'SUMMA_DIM must be larger than zero' + return summa_dim + + except KeyError as e: + raise EnvironmentError('SUMMA_DIM is not found in the current environment, ' + 'please make sure that you have used the correct process group initializer') + + +def assert_summa_initialization(): + assert gpc.is_initialized(ParallelMode.PARALLEL_2D_COL) and \ + gpc.is_initialized(ParallelMode.PARALLEL_2D_ROW), \ + 'Both TWO_DIMENSION_COL and TWO_DIMENSION_ROW must be initialized by the process group initializer' diff --git a/colossalai/nn/layer/parallel_2d/layers.py b/colossalai/nn/layer/parallel_2d/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..f3a4d2bbbc32f8f13815ea36a7c47268217446d5 --- /dev/null +++ b/colossalai/nn/layer/parallel_2d/layers.py @@ -0,0 +1,1201 @@ +import math +from collections import OrderedDict +from typing import Callable + +import torch +import torch.nn as nn +import torch.nn.functional as F +from colossalai.communication import broadcast +from colossalai.context import ParallelMode, seed +from colossalai.core import global_context as gpc +from colossalai.global_variables import tensor_parallel_env as env +from colossalai.nn import init as init +from colossalai.registry import LAYERS +from colossalai.utils.checkpointing import gather_tensor_parallel_state_dict, partition_tensor_parallel_state_dict +from colossalai.utils.cuda import get_current_device +from torch import Tensor +from torch.nn import Parameter + +from ..base_layer import ParallelLayer +from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple +from ._operation import (Matmul_AB_2D, Matmul_ABT_2D, add_bias_2d, all_gather_tensor_2d, classifier_2d, layernorm_2d, + reduce_scatter_tensor_2d, split_batch_2d) +from ._utils import assert_summa_initialization, get_summa_dim_from_env + + +@LAYERS.register_module +class Linear2D(ParallelLayer): + r"""Linear layer for 2D parallelism + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + + self.in_features = in_features + self.out_features = out_features + self.skip_bias_add = skip_bias_add + + # parallel settings + assert_summa_initialization() + self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + self.summa_dim = get_summa_dim_from_env() + + # partitioning dimension + self.input_size_per_partition = divide(self.in_features, self.summa_dim) + self.hidden_size_per_partition = divide(self.out_features, self.summa_dim) + + # create weight, shape: [k/q, h/q] + factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + self.weight = Parameter( + torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs)) + + # create bias, shape: [h/q] + if bias: + self.bias = Parameter(torch.empty(divide(self.out_features, self.summa_dim**2), **factory_kwargs)) + else: + self.register_parameter('bias', None) + + # initialize parameters + with seed(ParallelMode.TENSOR): + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + + def _set_tensor_parallel_attributes(self): + set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight.transpose(0, 1) + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict({weight_key: self.weight}) + if self.bias is not None: + local_state[bias_key] = self.bias + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + local_state[weight_key] = local_state[weight_key].transpose(0, 1) + destination.update(local_state) + + def forward(self, x: Tensor) -> Tensor: + # input: [m/q, n/q, k/q] + # output: [m/q, n/q, h/q] + out_shape = x.shape[:-1] + (self.hidden_size_per_partition,) + + output = Matmul_AB_2D.apply(x, self.weight, self.summa_dim, out_shape, self.row_rank, self.col_rank, + ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, self.data_parallel_rank, + self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size) + + if self.bias is not None: + if self.skip_bias_add: + bias = add_bias_2d(None, self.bias, self.hidden_size_per_partition, self.row_rank, self.col_rank, + ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, + self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, + self.tensor_parallel_size) + return output, bias + else: + output = add_bias_2d(output, self.bias, self.hidden_size_per_partition, self.row_rank, self.col_rank, + ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, False, + self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, + self.tensor_parallel_size) + return output + else: + return output + + +@LAYERS.register_module +class LayerNorm2D(ParallelLayer): + r"""Layer Normalization for 2D parallelism. + + Args: + normalized_shape (int): input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] + \times \ldots \times \text{normalized_shape}[-1]]` + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05. + bias (bool, optional): Whether to add a bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + """ + + def __init__(self, normalized_shape: int, eps: float = 1e-05, bias=True, dtype=None): + super().__init__() + + # layer norm config + self.normalized_shape = normalized_shape + self.variance_epsilon = eps + + # parallel setting + assert_summa_initialization() + self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + self.summa_dim = get_summa_dim_from_env() + + # partitioning dimension + self.partitioned_partition = divide(normalized_shape, self.summa_dim**2) + + # create parameters + factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + + self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs)) + if bias: + self.bias = Parameter(torch.zeros(self.partitioned_partition, **factory_kwargs)) + else: + self.bias = None + + self._set_tensor_parallel_attributes() + + def _set_tensor_parallel_attributes(self): + set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict({weight_key: self.weight}) + if self.bias is not None: + local_state[bias_key] = self.bias + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, x: Tensor) -> Tensor: + with torch.no_grad(): + E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1] + torch.distributed.all_reduce(E_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW)) + E_x /= self.normalized_shape + + # Var_x in the block below is the sum of input^2 + Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1] + torch.distributed.all_reduce(Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW)) + Var_x /= self.normalized_shape + + Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1] + # this time 1/sqrt(Var_x + epsilon) + Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon) + + output = layernorm_2d(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2D_ROW, + ParallelMode.PARALLEL_2D_COL) + scale = add_bias_2d(None, self.weight, self.partitioned_partition, self.row_rank, self.col_rank, + ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, self.data_parallel_rank, + self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size) + if self.bias is not None: + bias = add_bias_2d(None, self.bias, self.partitioned_partition, self.row_rank, self.col_rank, + ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, + self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, + self.tensor_parallel_size) + output = torch.addcmul(bias, scale, output) + else: + output = torch.mul(scale, output) + return output + + +@LAYERS.register_module +class PatchEmbedding2D(ParallelLayer): + r"""2D Image to Patch Embedding. + + Args: + img_size (int): image size. + patch_size (int): patch size. + in_chans (int): number of channels of input image. + embed_size (int): size of embedding. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + flatten (bool, optional): whether to flatten output tensor, defaults to True. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + position_embed_initializer (:class:`typing.Callable`, optional): + The initializer of position embedding, defaults to zeros initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + img_size: int, + patch_size: int, + in_chans: int, + embed_size: int, + flatten: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + position_embed_initializer: Callable = init.zeros_()): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + assert_summa_initialization() + self.summa_dim = get_summa_dim_from_env() + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + self.embed_size = embed_size + self.embed_size_per_partition = embed_size // (self.summa_dim**2) + + with seed(ParallelMode.TENSOR): + self.weight = Parameter( + torch.empty((self.embed_size_per_partition, in_chans, *self.patch_size), + device=get_current_device(), + dtype=dtype)) + self.bias = Parameter(torch.empty(self.embed_size_per_partition, device=get_current_device(), dtype=dtype)) + + self.cls_token = Parameter( + torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype)) + self.pos_embed = Parameter( + torch.zeros((1, self.num_patches + 1, self.embed_size_per_partition), + device=get_current_device(), + dtype=dtype)) + + self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) + self._set_tensor_parallel_attribute() + + def _set_tensor_parallel_attribute(self): + set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2) + set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2) + set_tensor_parallel_attribute_by_partition(self.cls_token, self.summa_dim**2) + set_tensor_parallel_attribute_by_partition(self.pos_embed, self.summa_dim**2) + + def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer): + with seed(ParallelMode.TENSOR): + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + fan_out = self.embed_size + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + bias_initializer(self.bias, fan_in=fan_in) + position_embed_initializer(self.pos_embed) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + cls_token_key = prefix + 'cls_token' + pos_embed_key = prefix + 'pos_embed' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + # cls token + cls_token = state_dict.pop(cls_token_key, None) + if cls_token is not None: + local_state[cls_token_key] = cls_token + # pos embed + pos_embed = state_dict.pop(pos_embed_key, None) + if pos_embed is not None: + local_state[pos_embed_key] = pos_embed + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={ + weight_key: 0, + bias_key: 0, + cls_token_key: -1, + pos_embed_key: -1 + }, + partition_states={ + weight_key: True, + bias_key: True, + cls_token_key: True, + pos_embed_key: True + }, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={ + weight_key: 0, + bias_key: 0, + cls_token_key: -1, + pos_embed_key: -1 + }, + partition_states={ + weight_key: True, + bias_key: True, + cls_token_key: True, + pos_embed_key: True + }, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + cls_token_key = prefix + 'cls_token' + pos_embed_key = prefix + 'pos_embed' + local_state = OrderedDict({ + weight_key: self.weight, + bias_key: self.bias, + cls_token_key: self.cls_token, + pos_embed_key: self.pos_embed + }) + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={ + weight_key: 0, + bias_key: 0, + cls_token_key: -1, + pos_embed_key: -1 + }, + partition_states={ + weight_key: True, + bias_key: True, + cls_token_key: True, + pos_embed_key: True + }, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={ + weight_key: 0, + bias_key: 0, + cls_token_key: -1, + pos_embed_key: -1 + }, + partition_states={ + weight_key: True, + bias_key: True, + cls_token_key: True, + pos_embed_key: True + }, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + input_ = split_batch_2d(input_) + + B, C, H, W = input_.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + + weight = all_gather_tensor_2d(self.weight, 0, ParallelMode.PARALLEL_2D_COL) + bias = all_gather_tensor_2d(self.bias, 0, ParallelMode.PARALLEL_2D_COL) + + output = F.conv2d(input_, weight, bias, stride=self.patch_size) + if self.flatten: + output = output.flatten(2).transpose(1, 2) # BCHW -> BNC + + cls_token = all_gather_tensor_2d(self.cls_token, -1, ParallelMode.PARALLEL_2D_COL) + pos_embed = all_gather_tensor_2d(self.pos_embed, -1, ParallelMode.PARALLEL_2D_COL) + cls_token = cls_token.expand(output.shape[0], -1, -1) + output = torch.cat((cls_token, output), dim=1) + output = output + pos_embed + + return output + + +@LAYERS.register_module +class Embedding2D(ParallelLayer): + r"""Embedding for 2D parallelism. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed “pad”, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about initializer please refer to + `init `_ + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs): + super().__init__() + + assert_summa_initialization() + self.summa_dim = get_summa_dim_from_env() + self.num_embeddings = num_embeddings + self.embed_dim = embedding_dim + embed_dim_per_partition = divide(embedding_dim, self.summa_dim**2) + + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + + self.weight = Parameter( + torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)) + + self.reset_parameters(weight_initializer) + self._set_tensor_parallel_attributes() + + def _set_tensor_parallel_attributes(self): + set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2) + + def reset_parameters(self, weight_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.num_embeddings, self.embed_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={weight_key: -1}, + partition_states={weight_key: True}, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={weight_key: -1}, + partition_states={weight_key: True}, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + local_state = OrderedDict({weight_key: self.weight}) + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={weight_key: -1}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={weight_key: -1}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + input_ = split_batch_2d(input_) + + weight = all_gather_tensor_2d(self.weight, -1, ParallelMode.PARALLEL_2D_COL) + output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) + + return output + + +@LAYERS.register_module +class VocabParallelEmbedding2D(ParallelLayer): + r"""Embedding parallelized in the vocabulary dimension. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed “pad”, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about initializer please refer to + `init `_. + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs): + super().__init__() + self.num_embeddings = num_embeddings + self.embed_dim = embedding_dim + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + + assert_summa_initialization() + self.summa_dim = get_summa_dim_from_env() + self.num_embeddings_per_partition = divide(self.num_embeddings, self.summa_dim) + self.embed_dim_per_partition = divide(self.embed_dim, self.summa_dim) + tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition + self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition + + self.weight = Parameter( + torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition), + device=get_current_device(), + dtype=dtype)) + + self.reset_parameters(weight_initializer) + self._set_tensor_parallel_attributes() + env.vocab_parallel = True + + def _set_tensor_parallel_attributes(self): + set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2) + + def reset_parameters(self, weight_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.num_embeddings, self.embed_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None and \ + self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: + with torch.no_grad(): + self.weight[self.padding_idx - self.vocab_start_index].fill_(0) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={weight_key: -1}, + partition_states={weight_key: True}, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={weight_key: 0}, + partition_states={weight_key: True}, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + local_state = OrderedDict({weight_key: self.weight}) + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={weight_key: 0}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={weight_key: -1}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + + output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args, + **self.embed_kwargs) + + output_parallel[input_mask, :] = 0. + output = reduce_scatter_tensor_2d(output_parallel, 0, ParallelMode.PARALLEL_2D_COL) + return output + + +@LAYERS.register_module +class Classifier2D(ParallelLayer): + r"""Classifier for 2D parallelism. + + Args: + in_features (int): size of each input sample. + num_classes (int): number of classes. + weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + self.in_features = in_features + self.num_classes = num_classes + assert_summa_initialization() + self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + self.summa_dim = get_summa_dim_from_env() + + # partitioning dimension + self.input_size_per_partition = divide(self.in_features, self.summa_dim**2) + + if weight is not None: + self.weight = weight + self.has_weight = False + else: + self.weight = Parameter( + torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype)) + self.has_weight = True + if bias: + self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) + else: + self.bias = None + + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + + def _set_tensor_parallel_attributes(self): + if self.has_weight: + set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.in_features, self.num_classes + col_src_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2D_COL)[0] + row_src_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2D_ROW)[0] + + if self.has_weight: + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + broadcast(self.bias, col_src_rank, ParallelMode.PARALLEL_2D_COL) + broadcast(self.bias, row_src_rank, ParallelMode.PARALLEL_2D_ROW) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + if self.has_weight: + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict() + if self.has_weight: + local_state[weight_key] = self.weight + if self.bias is not None: + local_state[bias_key] = self.bias + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + out_shape = input_.shape[:-1] + (self.num_classes,) + + return classifier_2d(input_, self.weight, self.bias, self.summa_dim, out_shape, self.row_rank, self.col_rank, + ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, self.data_parallel_rank, + self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size) + + +@LAYERS.register_module +class VocabParallelClassifier2D(ParallelLayer): + r"""Vocab parallel classifier layer for 2D parallelism. + + Args: + in_features (int): size of each input sample. + num_classes (int): number of classes. + weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + + self.in_features = in_features + self.num_classes = num_classes + + # parallel setting + assert_summa_initialization() + self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + self.summa_dim = get_summa_dim_from_env() + + # partitioning dimension + self.input_size_per_partition = divide(in_features, self.summa_dim) + self.output_size_per_partition = divide(num_classes, self.summa_dim) + + # create weight, shape: [k/q, h/q] + factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + if weight is not None: + self.weight = weight + self.has_weight = False + else: + self.weight = Parameter( + torch.empty(self.output_size_per_partition, self.input_size_per_partition, **factory_kwargs)) + self.has_weight = True + # create bias, shape: [h/q] + if bias: + self.bias = Parameter(torch.empty(divide(self.num_classes, self.summa_dim**2), **factory_kwargs)) + else: + self.bias = None + + # initialize parameters + with seed(ParallelMode.TENSOR): + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + env.vocab_parallel = True + + def _set_tensor_parallel_attributes(self): + if self.has_weight: + set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.num_classes + if self.has_weight: + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + if self.has_weight: + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict() + if self.has_weight: + local_state[weight_key] = self.weight + if self.bias is not None: + local_state[bias_key] = self.bias + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + local_state[weight_key] = local_state[weight_key].transpose(0, 1) + destination.update(local_state) + + def forward(self, x: Tensor) -> Tensor: + # input: [m/q, n/q, k/q] + # output: [m/q, n/q, h/q] + out_shape = x.shape[:-1] + (self.output_size_per_partition,) + + output = Matmul_ABT_2D.apply(x, self.weight, self.summa_dim, out_shape, self.row_rank, self.col_rank, + ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, + self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, + self.tensor_parallel_size) + + if self.bias is not None: + output = add_bias_2d(output, self.bias, self.output_size_per_partition, self.row_rank, self.col_rank, + ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, False, + self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, + self.tensor_parallel_size) + return output diff --git a/colossalai/nn/layer/parallel_2p5d/__init__.py b/colossalai/nn/layer/parallel_2p5d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bec3b1c4b0b87e8db497b627207ca6f30b1fff49 --- /dev/null +++ b/colossalai/nn/layer/parallel_2p5d/__init__.py @@ -0,0 +1,8 @@ +from ._operation import reduce_by_batch_2p5d, split_batch_2p5d +from .layers import (Classifier2p5D, Embedding2p5D, LayerNorm2p5D, Linear2p5D, PatchEmbedding2p5D, + VocabParallelClassifier2p5D, VocabParallelEmbedding2p5D) + +__all__ = [ + 'split_batch_2p5d', 'reduce_by_batch_2p5d', 'Linear2p5D', 'LayerNorm2p5D', 'Classifier2p5D', 'PatchEmbedding2p5D', + 'Embedding2p5D', 'VocabParallelClassifier2p5D', 'VocabParallelEmbedding2p5D' +] diff --git a/colossalai/nn/layer/parallel_2p5d/_operation.py b/colossalai/nn/layer/parallel_2p5d/_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..5a0f537cd6d9ca0d6104390a8cdb4634f25d204c --- /dev/null +++ b/colossalai/nn/layer/parallel_2p5d/_operation.py @@ -0,0 +1,880 @@ +from typing import Any, Tuple + +import torch +import torch.distributed as dist +from colossalai.communication.collective import (all_gather, all_reduce, reduce_scatter) +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.utils import get_current_device +from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd + + +def get_parallel_group(parallel_mode: ParallelMode): + return gpc.get_group(parallel_mode) + + +def get_global_rank(): + return gpc.get_global_rank() + + +def get_parallel_rank(parallel_mode: ParallelMode): + return gpc.get_local_rank(parallel_mode) + + +class _Classifier2p5D(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward( + ctx: Any, + A: Tensor, + B: Tensor, + bias, + tesseract_dim: int, + out_shape: Tuple[int, ...], + row_rank: int, + col_rank: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, + ) -> Tensor: + A = A.clone().detach() + A_shape = A.shape + A = A.reshape((-1, A_shape[-1])) + B_shape = B.shape + B = B.reshape((-1, B_shape[-1])) + B_temp = all_gather(B, -1, col_parallel_mode) + if ctx: + ctx.save_for_backward(A, B_temp) + + C = torch.matmul(A, B_temp.transpose(0, 1)) + + C = all_reduce(C, row_parallel_mode) + + ctx.use_bias = bias is not None + if bias is not None: + C = C + bias + + out = C.reshape(out_shape) + + if ctx: + ctx.tesseract_dim = tesseract_dim + ctx.row_rank = row_rank + ctx.col_rank = col_rank + ctx.row_parallel_mode = row_parallel_mode + ctx.col_parallel_mode = col_parallel_mode + ctx.A_shape = A_shape + ctx.B_shape = B_shape + ctx.data_parallel_rank = data_parallel_rank + ctx.pipeline_parallel_rank = pipeline_parallel_rank + ctx.pipeline_parallel_size = pipeline_parallel_size + ctx.tensor_parallel_size = tensor_parallel_size + + return out + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + A, B = ctx.saved_tensors + + with torch.no_grad(): + A_grad = torch.matmul(output_grad, B) + A_grad = A_grad.reshape(ctx.A_shape) + B_grad = torch.matmul(output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), A) + B_grad = reduce_scatter(B_grad, -1, ctx.col_parallel_mode) + B_grad = B_grad.reshape(ctx.B_shape) + + if ctx.use_bias: + bias_grad = torch.sum(output_grad, dim=tuple(range(output_grad.ndim - 1))) + bias_grad = all_reduce(bias_grad, ctx.col_parallel_mode) + else: + bias_grad = None + + return A_grad, B_grad, bias_grad, None, None, None, None, None, None, None, None, None, None + + +def classifier_2p5d(A: Tensor, B: Tensor, bias, tesseract_dim: int, out_shape: Tuple[int, + ...], row_rank: int, col_rank: int, + row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode, data_parallel_rank: int, + pipeline_parallel_rank: int, pipeline_parallel_size: int, tensor_parallel_size: int) -> Tensor: + r"""Classifier. + + Args: + A (:class:`torch.tensor`): matrix :math:`A`. + B (:class:`torch.tensor`): matrix :math:`B`. + bias (:class:`torch.tensor`): matrix of bias. + tesseract_dim (int): dimension of TESSERACT fo 2.5D parallelism. + out_shape (:class:`torch.size`): shape of output tensor. + row_rank (int): the rank of row. + col_rank (int): the rank of column. + row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. + col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. + data_parallel_rank (int): data parallel rank. + pipeline_parallel_rank (int): pipeline parallel rank + pipeline_parallel_size (int): pipeline parallel size. + tensor_parallel_size (int): tensor parallel size. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + return _Classifier2p5D.apply(A, B, bias, tesseract_dim, out_shape, row_rank, col_rank, row_parallel_mode, + col_parallel_mode, data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, + tensor_parallel_size) + + +class Matmul_AB_2p5D(torch.autograd.Function): + r"""Matrix multiplication for :math:`C = AB`. + + Args: + A (:class:`torch.tensor`): matrix :math:`A`. + B (:class:`torch.tensor`): matrix :math:`B`. + tesseract_dim (int): dimension of TESSERACT fo 2.5D parallelism. + out_shape (:class:`torch.size`): shape of output tensor. + row_rank (int): the rank of row. + col_rank (int): the rank of column. + dep_rank (int): the rank of depth. + row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. + col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. + data_parallel_rank (int): data parallel rank. + pipeline_parallel_rank (int): pipeline parallel rank + pipeline_parallel_size (int): pipeline parallel size. + tensor_parallel_size (int): tensor parallel size. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple[int, ...], row_rank: int, + col_rank: int, dep_rank: int, row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode, + data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int, + tensor_parallel_size: int) -> Tensor: + # A: [b / dq, s, h / q] -> [(b * s) / dq, h / q] + # B: [h / dq, s / q] + # C: [b / dq, s, s / q] -> [(b * s) / dq, s / q] + + assert A.shape[-1] == B.shape[-2], \ + 'Invalid shapes: A={}, B={} for AB.'.format(A.shape, B.shape) + + if ctx: + ctx.save_for_backward(A, B) + + A_shape = A.shape + A = A.reshape((-1, A_shape[-1])) + B_shape = B.shape + B = B.reshape((-1, B_shape[-1])) + C_shape = (A.shape[0], B.shape[-1]) + C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device()) + + # use circular buffer to store the communication tensor + # 2 is enough for all cases + A_list = [torch.empty_like(A) for _ in range(2)] + B_list = [torch.empty_like(B) for _ in range(2)] + + row_group = gpc.get_group(row_parallel_mode) + col_group = gpc.get_group(col_parallel_mode) + + src_a = \ + tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + \ + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + src_b = \ + col_rank + tesseract_dim ** 2 * dep_rank + \ + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + + opa = [None] * 2 + opb = [None] * 2 + + A_list[0].copy_(A) + B_list[0].copy_(B) + opa[0] = dist.broadcast(A_list[0], src=src_a, group=row_group, async_op=True) + opb[0] = dist.broadcast(B_list[0], src=src_b, group=col_group, async_op=True) + cur = 0 + + for i in range(tesseract_dim): + if i != tesseract_dim - 1: + A_list[1 - cur].copy_(A) + opa[1 - cur] = dist.broadcast(A_list[1 - cur], src=src_a + 1, group=row_group, async_op=True) + B_list[1 - cur].copy_(B) + opb[1 - cur] = dist.broadcast(B_list[1 - cur], + src=src_b + tesseract_dim, + group=col_group, + async_op=True) + + if opa[cur] is not None: + opa[cur].wait() + if opb[cur] is not None: + opb[cur].wait() + + torch.addmm(C, A_list[cur], B_list[cur], out=C) + cur = 1 - cur + src_a += 1 + src_b += tesseract_dim + out = C.reshape(out_shape) + + if ctx: + ctx.tesseract_dim = tesseract_dim + ctx.row_rank = row_rank + ctx.col_rank = col_rank + ctx.dep_rank = dep_rank + ctx.row_parallel_mode = row_parallel_mode + ctx.col_parallel_mode = col_parallel_mode + ctx.A_shape = A_shape + ctx.B_shape = B_shape + ctx.data_parallel_rank = data_parallel_rank + ctx.pipeline_parallel_rank = pipeline_parallel_rank + ctx.pipeline_parallel_size = pipeline_parallel_size + ctx.tensor_parallel_size = tensor_parallel_size + + return out + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + A, B = ctx.saved_tensors + with torch.no_grad(): + A_grad = Matmul_ABT_2p5D.apply(output_grad, B, ctx.tesseract_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank, + ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode, + ctx.data_parallel_rank, ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, ctx.tensor_parallel_size) + B_grad = Matmul_ATB_2p5D.apply(A, output_grad, ctx.tesseract_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank, + ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode, + ctx.data_parallel_rank, ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, ctx.tensor_parallel_size) + return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None + + +class Matmul_ABT_2p5D(torch.autograd.Function): + r"""Matrix multiplication for :math:`C = AB^T`. + + Args: + A (:class:`torch.tensor`): matrix :math:`A`. + B (:class:`torch.tensor`): matrix :math:`B`. + tesseract_dim (int): dimension of TESSERACT fo 2.5D parallelism. + out_shape (:class:`torch.size`): shape of output tensor. + row_rank (int): the rank of row. + col_rank (int): the rank of column. + dep_rank (int): the rank of depth. + row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. + col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. + data_parallel_rank (int): data parallel rank. + pipeline_parallel_rank (int): pipeline parallel rank + pipeline_parallel_size (int): pipeline parallel size. + tensor_parallel_size (int): tensor parallel size. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple[int, ...], row_rank: int, + col_rank: int, dep_rank: int, row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode, + data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int, + tensor_parallel_size: int) -> Tensor: + + assert A.shape[-1] == B.shape[-1], \ + 'Invalid shapes: A={}, B={} for ABT.'.format(A.shape, B.shape) + + if ctx: + ctx.save_for_backward(A, B) + + A_shape = A.shape + A = A.reshape((-1, A_shape[-1])) + B_shape = B.shape + B = B.reshape((-1, B_shape[-1])) + C_shape = (A.shape[0], B.shape[0]) + C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) + + # use circular buffer to store the communication tensor + # 2 is enough for all cases + B_list = [torch.empty_like(B) for _ in range(2)] + C_list = [torch.empty_like(C) for _ in range(2)] + + row_group = gpc.get_group(row_parallel_mode) + col_group = gpc.get_group(col_parallel_mode) + + src_b = \ + col_rank + tesseract_dim ** 2 * dep_rank + \ + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + src_c = \ + tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + \ + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + + opb = [None] * 2 + opr = [None] * 2 + + B_list[0].copy_(B) + opb[0] = dist.broadcast(B_list[0], src=src_b, group=col_group, async_op=True) + cur = 0 + + for i in range(tesseract_dim): + if i != tesseract_dim - 1: + B_list[1 - cur].copy_(B) + opb[1 - cur] = dist.broadcast(B_list[1 - cur], + src=src_b + tesseract_dim, + group=col_group, + async_op=True) + + if opr[cur] is not None: + opr[cur].wait() + if i - 2 == col_rank: + C.copy_(C_list[cur]) + + if opb[cur] is not None: + opb[cur].wait() + + torch.matmul(A, B_list[cur].transpose(0, 1), out=C_list[cur]) + opr[cur] = dist.reduce(C_list[cur], dst=src_c, group=row_group, async_op=True) + cur = 1 - cur + src_b += tesseract_dim + src_c += 1 + + for op in opr: + op.wait() + + if tesseract_dim - 2 == col_rank: + C.copy_(C_list[cur]) + if tesseract_dim - 1 == col_rank: + C.copy_(C_list[1 - cur]) + out = C.reshape(out_shape) + + if ctx: + ctx.tesseract_dim = tesseract_dim + ctx.row_rank = row_rank + ctx.col_rank = col_rank + ctx.dep_rank = dep_rank + ctx.row_parallel_mode = row_parallel_mode + ctx.col_parallel_mode = col_parallel_mode + ctx.A_shape = A_shape + ctx.B_shape = B_shape + ctx.data_parallel_rank = data_parallel_rank + ctx.pipeline_parallel_rank = pipeline_parallel_rank + ctx.pipeline_parallel_size = pipeline_parallel_size + ctx.tensor_parallel_size = tensor_parallel_size + + return out + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + A, B = ctx.saved_tensors + with torch.no_grad(): + A_grad = Matmul_AB_2p5D.apply(output_grad, B, ctx.tesseract_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank, + ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode, + ctx.data_parallel_rank, ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, ctx.tensor_parallel_size) + B_grad = Matmul_ATB_2p5D.apply(output_grad, A, ctx.tesseract_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank, + ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode, + ctx.data_parallel_rank, ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, ctx.tensor_parallel_size) + return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None + + +class Matmul_ATB_2p5D(torch.autograd.Function): + r"""Matrix multiplication for :math:`C = A^TB` + + Args: + A (:class:`torch.tensor`): matrix :math:`A`. + B (:class:`torch.tensor`): matrix :math:`B`. + tesseract_dim (int): dimension of TESSERACT fo 2.5D parallelism. + out_shape (:class:`torch.size`): shape of output tensor. + row_rank (int): the rank of row. + col_rank (int): the rank of column. + dep_rank (int): the rank of depth. + row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. + col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. + data_parallel_rank (int): data parallel rank. + pipeline_parallel_rank (int): pipeline parallel rank + pipeline_parallel_size (int): pipeline parallel size. + tensor_parallel_size (int): tensor parallel size. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple[int, ...], row_rank: int, + col_rank: int, dep_rank: int, row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode, + data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int, + tensor_parallel_size: int): + + assert A.shape[-2] == B.shape[-2], \ + 'Invalid shapes: A={}, B={} for ATB.'.format(A.shape, B.shape) + + if ctx: + ctx.save_for_backward(A, B) + + A_shape = A.shape + A = A.reshape((-1, A_shape[-1])) + B_shape = B.shape + B = B.reshape((-1, B_shape[-1])) + C_shape = (A.shape[-1], B.shape[-1]) + C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) + + # use circular buffer to store the communication tensor + # 2 is enough for all cases + A_list = [torch.empty_like(A) for _ in range(2)] + C_list = [torch.empty_like(C) for _ in range(2)] + + row_group = gpc.get_group(row_parallel_mode) + col_group = gpc.get_group(col_parallel_mode) + + src_a = \ + tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + \ + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + src_c = \ + col_rank + tesseract_dim ** 2 * dep_rank + \ + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + + opa = [None] * 2 + opr = [None] * 2 + + A_list[0].copy_(A) + opa[0] = dist.broadcast(A_list[0], src=src_a, group=row_group, async_op=True) + cur = 0 + + for i in range(tesseract_dim): + if i != tesseract_dim - 1: + A_list[1 - cur].copy_(A) + opa[1 - cur] = dist.broadcast(A_list[1 - cur], src=src_a + 1, group=row_group, async_op=True) + + if opr[cur] is not None: + opr[cur].wait() + if i - 2 == row_rank: + C.copy_(C_list[cur]) + + if opa[cur] is not None: + opa[cur].wait() + + torch.matmul(A_list[cur].transpose(0, 1), B, out=C_list[cur]) + opr[cur] = dist.reduce(C_list[cur], dst=src_c, group=col_group, async_op=True) + cur = 1 - cur + src_a += 1 + src_c += tesseract_dim + + for op in opr: + op.wait() + + if tesseract_dim - 2 == row_rank: + C.copy_(C_list[cur]) + if tesseract_dim - 1 == row_rank: + C.copy_(C_list[1 - cur]) + out = C.reshape(out_shape) + + if ctx: + ctx.tesseract_dim = tesseract_dim + ctx.row_rank = row_rank + ctx.col_rank = col_rank + ctx.dep_rank = dep_rank + ctx.row_parallel_mode = row_parallel_mode + ctx.col_parallel_mode = col_parallel_mode + ctx.A_shape = A_shape + ctx.B_shape = B_shape + ctx.data_parallel_rank = data_parallel_rank + ctx.pipeline_parallel_rank = pipeline_parallel_rank + ctx.pipeline_parallel_size = pipeline_parallel_size + ctx.tensor_parallel_size = tensor_parallel_size + + return out + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + A, B = ctx.saved_tensors + with torch.no_grad(): + A_grad = Matmul_ABT_2p5D.apply(B, output_grad, ctx.tesseract_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank, + ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode, + ctx.data_parallel_rank, ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, ctx.tensor_parallel_size) + B_grad = Matmul_AB_2p5D.apply(A, output_grad, ctx.tesseract_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank, + ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode, + ctx.data_parallel_rank, ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, ctx.tensor_parallel_size) + return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None + + +class _Add_Bias_2p5D(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx: Any, input: Tensor, bias: Tensor, output_size_per_partition: int, tesseract_dim: int, + row_rank: int, col_rank: int, dep_rank: int, col_parallel_mode: ParallelMode, skip_bias_add: bool, + data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int, + tensor_parallel_size: int) -> Tensor: + if row_rank == 0: + bias_temp = bias.clone() + else: + bias_temp = torch.zeros(output_size_per_partition, dtype=bias.dtype, device=get_current_device()) + src_rank = \ + col_rank + dep_rank * tesseract_dim ** 2 + \ + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + dist.broadcast(bias_temp, src=src_rank, group=get_parallel_group(col_parallel_mode)) + + ctx.row_rank = row_rank + ctx.col_rank = col_rank + ctx.dep_rank = dep_rank + ctx.tesseract_dim = tesseract_dim + ctx.col_parallel_mode = col_parallel_mode + ctx.bias = skip_bias_add + ctx.data_parallel_rank = data_parallel_rank + ctx.pipeline_parallel_rank = pipeline_parallel_rank + ctx.pipeline_parallel_size = pipeline_parallel_size + ctx.tensor_parallel_size = tensor_parallel_size + + if skip_bias_add: + return bias_temp + else: + output = input + bias_temp + return output + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + row_rank = ctx.row_rank + col_rank = ctx.col_rank + dep_rank = ctx.dep_rank + tesseract_dim = ctx.tesseract_dim + col_parallel_mode = ctx.col_parallel_mode + data_parallel_rank = ctx.data_parallel_rank + pipeline_parallel_rank = ctx.pipeline_parallel_rank + pipeline_parallel_size = ctx.pipeline_parallel_size + tensor_parallel_size = ctx.tensor_parallel_size + + if ctx.bias: + dst_rank = \ + col_rank + dep_rank * (tesseract_dim ** 2) + \ + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + dist.reduce(output_grad, dst=dst_rank, group=get_parallel_group(col_parallel_mode)) + if row_rank == 0: + return \ + None, output_grad, None, None, None, None, None, None, \ + None, None, None, None, None, None, None, None + else: + grad_tmp = torch.zeros_like(output_grad) + return \ + None, grad_tmp, None, None, None, None, None, None, \ + None, None, None, None, None, None, None, None + else: + reduce_dim = tuple(range(output_grad.ndim - 1)) + reduce = torch.sum(output_grad, dim=reduce_dim) + dst_rank = \ + col_rank + dep_rank * (tesseract_dim ** 2) + \ + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + dist.reduce(reduce, dst=dst_rank, group=get_parallel_group(col_parallel_mode)) + if row_rank == 0: + return \ + output_grad, reduce, None, None, None, None, None, None, None, \ + None, None, None, None, None, None, None, None + else: + reduce_tmp = torch.zeros_like(reduce) + return \ + output_grad, reduce_tmp, None, None, None, None, None, None, \ + None, None, None, None, None, None, None, None, None + + +def add_bias_2p5d(input: Tensor, bias: Tensor, output_size_per_partition: int, tesseract_dim: int, row_rank: int, + col_rank: int, dep_rank: int, col_parallel_mode: ParallelMode, skip_bias_add: bool, + data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int, + tensor_parallel_size: int) -> Tensor: + r"""Matrix add bias: :math:`C = A + b`. + + Args: + input (:class:`torch.tensor`): matrix :math:`A`. + bias (:class:`torch.tensor`): matrix :math:`B`. + tesseract_dim (int): dimension of TESSERACT fo 2.5D parallelism. + output_size_per_partition (int): output size in each partition. + row_rank (int): the rank of row. + col_rank (int): the rank of column. + dep_rank (int): the rank of depth. + col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion. + data_parallel_rank (int): data parallel rank. + pipeline_parallel_rank (int): pipeline parallel rank + pipeline_parallel_size (int): pipeline parallel size. + tensor_parallel_size (int): tensor parallel size. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + return _Add_Bias_2p5D.apply(input, bias, output_size_per_partition, tesseract_dim, row_rank, col_rank, dep_rank, + col_parallel_mode, skip_bias_add, data_parallel_rank, pipeline_parallel_rank, + pipeline_parallel_size, tensor_parallel_size) + + +class _Layernorm2p5D(torch.autograd.Function): + r"""Layernorm. + + Args: + input (:class:`torch.tensor`): input matrix. + E_x (:class:`torch.tensor`): mean. + Var_x (:class:`torch.tensor`): variance. + hidden_size (int): hidden size. + row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx: Any, input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, + row_parallel_mode: ParallelMode) -> Tensor: + input = input - E_x + # in here, input = x - E[x], Var_x = 1 / sqrt(Var[x] + eps) + ctx.hidden_size = hidden_size + output = input * Var_x + ctx.save_for_backward(output, Var_x) + ctx.row_parallel_mode = row_parallel_mode + return output + + @staticmethod + @custom_bwd + def backward(ctx, output_grad): + row_parallel_mode = ctx.row_parallel_mode + x, Var_x = ctx.saved_tensors + # in here, Var_x = 1 / sqrt(Var[x] + eps), x = (x - E[x]) * Var_x + with torch.no_grad(): + output_grad_sum = torch.sum(output_grad, dim=-1, keepdim=True) + torch.distributed.all_reduce(output_grad_sum, group=get_parallel_group(row_parallel_mode)) + output_grad_sum /= ctx.hidden_size + + output_grad_mul_x_sum = torch.sum(output_grad * x, dim=-1, keepdim=True) + torch.distributed.all_reduce(output_grad_mul_x_sum, group=get_parallel_group(row_parallel_mode)) + output_grad_mul_x_sum /= ctx.hidden_size + + input_grad = output_grad.clone() + input_grad -= x * output_grad_mul_x_sum + input_grad -= output_grad_sum + input_grad *= Var_x + + return input_grad, None, None, None, None, None, None + + +def layernorm_2p5d(input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, + row_parallel_mode: ParallelMode) -> Tensor: + r"""Layernorm. + + Args: + input (:class:`torch.tensor`): input matrix. + E_x (:class:`torch.tensor`): mean. + Var_x (:class:`torch.tensor`): variance. + hidden_size (int): hidden size. + row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + return _Layernorm2p5D.apply(input, E_x, Var_x, hidden_size, row_parallel_mode) + + +class _AllGatherTensor2p5D(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx: Any, inputs: Tensor, dim: int, col_parallel_mode: ParallelMode) -> Tensor: + ctx.dim = dim + ctx.col_parallel_mode = col_parallel_mode + + outputs = all_gather(inputs, dim, col_parallel_mode) + return outputs + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + grad = reduce_scatter(output_grad, ctx.dim, ctx.col_parallel_mode) + return grad.contiguous(), None, None + + +def all_gather_tensor_2p5d(inputs: Tensor, dim: int, col_parallel_mode: ParallelMode) -> Tensor: + r"""all gather the weight of 2.5D parallelism. + + Args: + inputs (:class:`torch.tensor`): input tensor. + dim (int): dimension of all-gather. + col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + return _AllGatherTensor2p5D.apply(inputs, dim, col_parallel_mode) + + +class SplitFirst(torch.autograd.Function): + r""" + + Args: + inputs (:class:`torch.tensor`): input tensor. + tesseract_dim (int): dimension of TESSERACT fo 2.5D parallelism + col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx: Any, inputs: Tensor, tesseract_dim: int, col_parallel_mode: ParallelMode) -> Tensor: + ctx.tesseract_dim = tesseract_dim + ctx.batch_size = inputs.size(0) + ctx.para_mode = col_parallel_mode + row_rank = gpc.get_local_rank(col_parallel_mode) + + outputs = inputs.chunk(tesseract_dim, dim=0)[row_rank] + return outputs + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + grad_shape = (ctx.batch_size,) + output_grad.shape[1:] + grad = torch.empty(grad_shape, dtype=output_grad.dtype, device=get_current_device()) + dist.all_gather(list(grad.chunk(ctx.tesseract_dim, dim=0)), + output_grad.contiguous(), + group=gpc.get_group(ctx.para_mode)) + return grad, None, None + + +def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor: + """Splits 2P5D tensor in specified dimension across cols. + + Args: + input_ (:class:`torch.tensor`): Input tensor. + dim (int): Specified dimension in which to split. + + Returns: + :class:`torch.tensor`: The tensor has been split. + """ + dim_size = input_.size(dim) + world_size = gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL) + + if world_size <= 1: + return input_ + + assert dim_size % world_size == 0, \ + f'The batch size ({dim_size}) is not a multiple of 2.5D size * depth ({world_size}).' + + return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL), + dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous() + + +class _ReduceTensor2p5D(torch.autograd.Function): + + @staticmethod + def forward(ctx, input_, parallel_mode): + return all_reduce(input_, parallel_mode) + + @staticmethod + def backward(ctx, output_grad): + return output_grad, None + + +def reduce_tensor_2p5d(input_: Tensor, parallel_mode: ParallelMode) -> Tensor: + r"""All-reduce the input. + + Args: + input_ (:class:`torch.tensor`): Input tensor. + parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode tensor used. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + return _ReduceTensor2p5D.apply(input_, parallel_mode) + + +class _ReduceScatterTensor2p5D(torch.autograd.Function): + + @staticmethod + def forward(ctx, input_, dim, parallel_mode): + ctx.dim = dim + ctx.parallel_mode = parallel_mode + return reduce_scatter(input_, dim, parallel_mode) + + @staticmethod + def backward(ctx, output_grad): + return all_gather(output_grad, ctx.dim, ctx.parallel_mode), None, None + + +def reduce_scatter_tensor_2p5d(input_: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: + r"""Reduce-scatter the input. + + Args: + input_ (:class:`torch.tensor`): Input tensor. + dim (int): Dimension to reduce. + parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode tensor used. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + dim_size = input_.size(dim) + world_size = gpc.get_world_size(parallel_mode) + assert dim_size % world_size == 0, \ + f'The batch size ({dim_size}) is not a multiple of 2.5D size * depth ({world_size}).' + + return _ReduceScatterTensor2p5D.apply(input_, dim, parallel_mode) + + +class _RreduceByBatch2p5D(torch.autograd.Function): + + @staticmethod + def symbolic(graph, input_, reduce_mean: bool = False): + output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL) + if reduce_mean: + reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL) + return output / reduce_size + return output + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, input_, reduce_mean: bool = False): + output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL) + ctx.reduce_mean = reduce_mean + if reduce_mean: + reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL) + ctx.reduce_size = reduce_size + return output.clone() / reduce_size + return output.clone() + + @staticmethod + @custom_bwd + def backward(ctx, output_grad): + if ctx.reduce_mean: + return output_grad / ctx.reduce_size, None + else: + return output_grad, None + + +def reduce_by_batch_2p5d(input_, reduce_mean: bool = False) -> Tensor: + r"""All-reduce the input from the model parallel region. + + Args: + input_ (:class:`torch.tensor`): input matrix. + reduce_mean (bool, optional): + If set to ``True``, it will divide the output by column parallel size, default to False. + """ + return _RreduceByBatch2p5D.apply(input_, reduce_mean) diff --git a/colossalai/nn/layer/parallel_2p5d/_utils.py b/colossalai/nn/layer/parallel_2p5d/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1478b25de618978ef2c7e060da33edcb47ecff0b --- /dev/null +++ b/colossalai/nn/layer/parallel_2p5d/_utils.py @@ -0,0 +1,25 @@ +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.global_variables import tensor_parallel_env as env + + +def get_tesseract_dim_dep_from_env(): + try: + tesseract_dim = env.tesseract_dim + tesseract_dep = env.tesseract_dep + assert tesseract_dim > 0, 'TESSERACT_DIM must be larger than zero' + assert tesseract_dep > 0, 'TESSERACT_DEP must be larger than zero' + return tesseract_dim, tesseract_dep + + except KeyError as e: + raise EnvironmentError('TESSERACT_DIM or TESSERACT_DEP is not found in the current environment, ' + 'please make sure that you have used the correct process group initializer') + + +def assert_tesseract_initialization(): + assert gpc.is_initialized(ParallelMode.PARALLEL_2P5D_COL) and \ + gpc.is_initialized(ParallelMode.PARALLEL_2P5D_ROW) and \ + gpc.is_initialized(ParallelMode.PARALLEL_2P5D_DEP) and \ + gpc.is_initialized(ParallelMode.PARALLEL_2P5D_XZ), \ + 'Both PARALLEL_2P5D_COL, PARALLEL_2P5D_ROW, PARALLEL_2P5D_DEP and PARALLEL_2P5D_XZ ' \ + 'must be initialized by the process group initializer' diff --git a/colossalai/nn/layer/parallel_2p5d/layers.py b/colossalai/nn/layer/parallel_2p5d/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..f849cbbe7b0d0f27067941a1c74ccae2ac6fa6f8 --- /dev/null +++ b/colossalai/nn/layer/parallel_2p5d/layers.py @@ -0,0 +1,1198 @@ +import math +from collections import OrderedDict +from typing import Callable + +import torch +import torch.nn as nn +import torch.nn.functional as F +from colossalai.communication import broadcast +from colossalai.context import ParallelMode, seed +from colossalai.core import global_context as gpc +from colossalai.global_variables import tensor_parallel_env as env +from colossalai.nn import init as init +from colossalai.registry import LAYERS +from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_parallel_state_dict, + partition_tensor_parallel_state_dict) +from colossalai.utils.cuda import get_current_device +from torch import Tensor +from torch.nn import Parameter + +from ..base_layer import ParallelLayer +from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple +from ._operation import (Matmul_AB_2p5D, Matmul_ABT_2p5D, add_bias_2p5d, all_gather_tensor_2p5d, classifier_2p5d, + layernorm_2p5d, reduce_scatter_tensor_2p5d, split_batch_2p5d) +from ._utils import assert_tesseract_initialization, get_tesseract_dim_dep_from_env + + +@LAYERS.register_module +class Linear2p5D(ParallelLayer): + r"""Linear layer for 2.5D parallelism. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + + self.in_features = in_features + self.out_features = out_features + self.skip_bias_add = skip_bias_add + + # parallel setting + assert_tesseract_initialization() + self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + self.dep_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + self.tesseract_dim, _ = get_tesseract_dim_dep_from_env() + + # partitioning dimension + self.input_size_per_partition = divide(in_features, self.tesseract_dim) + self.hidden_size_per_partition = divide(out_features, self.tesseract_dim) + + # create weight, shape: [k/q, h/q] + factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + self.weight = Parameter( + torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs)) + + # create bias, shape: [h/q] + if bias: + self.bias = Parameter(torch.empty(self.hidden_size_per_partition, **factory_kwargs)) + else: + self.register_parameter('bias', None) + + # initialize parameters + with seed(ParallelMode.TENSOR): + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + + def _set_tensor_parallel_attributes(self): + set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight.transpose(0, 1) + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # broadcast in dep groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0 and \ + gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) == 0: + broadcast_state_dict(local_state, ParallelMode.PARALLEL_2P5D_DEP) + # partition in column groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_COL, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + ) + # partition in row groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_ROW, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) == 0: + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict({weight_key: self.weight}) + if self.bias is not None: + local_state[bias_key] = self.bias + + # gather in row groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_ROW, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars, + ) + # gather in column groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_COL, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + local_state[weight_key] = local_state[weight_key].transpose(0, 1) + destination.update(local_state) + + def forward(self, x: Tensor) -> Tensor: + # input: [m/dq, n/q, k/q] + # output: [m/dq, n/q, h/q] + out_shape = x.shape[:-1] + (self.hidden_size_per_partition,) + + output = Matmul_AB_2p5D.apply( + x, + self.weight, + self.tesseract_dim, + out_shape, + self.row_rank, + self.col_rank, + self.dep_rank, + ParallelMode.PARALLEL_2P5D_ROW, + ParallelMode.PARALLEL_2P5D_COL, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) + + if self.bias is not None: + if self.skip_bias_add: + bias = add_bias_2p5d(None, self.bias, self.hidden_size_per_partition, self.tesseract_dim, self.row_rank, + self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True, + self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, + self.tensor_parallel_size) + return output, bias + else: + output = add_bias_2p5d(output, self.bias, self.hidden_size_per_partition, self.tesseract_dim, + self.row_rank, self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, + False, self.data_parallel_rank, self.pipeline_parallel_rank, + self.pipeline_parallel_size, self.tensor_parallel_size) + return output + else: + return output + + +@LAYERS.register_module +class LayerNorm2p5D(ParallelLayer): + r"""Layer Normalization for 2.5D parallelism. + + Args: + normalized_shape (int): input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] + \times \ldots \times \text{normalized_shape}[-1]]` + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05. + bias (bool, optional): Whether to add a bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + """ + + def __init__(self, normalized_shape: int, eps: float = 1e-05, bias=True, dtype=None): + super().__init__() + + # layer norm config + self.normalized_shape = normalized_shape + self.variance_epsilon = eps + + # parallel setting + assert_tesseract_initialization() + self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + self.dep_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + self.tesseract_dim, _ = get_tesseract_dim_dep_from_env() + + # partitioning dimension + self.partitioned_partition = divide(normalized_shape, self.tesseract_dim) # * + + # create parameters + factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + + self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs)) + if bias: + self.bias = Parameter(torch.zeros(self.partitioned_partition, **factory_kwargs)) + else: + self.bias = None + + self._set_tensor_parallel_attribute() + + def _set_tensor_parallel_attribute(self): + set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_ROW, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_COL, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict({weight_key: self.weight}) + if self.bias is not None: + local_state[bias_key] = self.bias + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_COL, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_ROW, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, x: Tensor) -> Tensor: + with torch.no_grad(): + E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1] + torch.distributed.all_reduce(E_x, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW)) + E_x /= self.normalized_shape + + # Var_x in the block below is the sum of input^2 + Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1] + torch.distributed.all_reduce(Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW)) + Var_x /= self.normalized_shape + + Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1] + # this time 1/sqrt(Var_x + epsilon) + Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon) + + output = layernorm_2p5d(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2P5D_ROW) + scale = add_bias_2p5d(None, self.weight, self.partitioned_partition, self.tesseract_dim, self.row_rank, + self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True, + self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, + self.tensor_parallel_size) + if self.bias is not None: + bias = add_bias_2p5d(None, self.bias, self.partitioned_partition, self.tesseract_dim, self.row_rank, + self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True, + self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, + self.tensor_parallel_size) + output = torch.addcmul(bias, scale, output) + else: + output = torch.mul(scale, output) + return output + + +@LAYERS.register_module +class PatchEmbedding2p5D(ParallelLayer): + r"""2D Image to Patch Embedding. + + Args: + img_size (int): image size. + patch_size (int): patch size. + in_chans (int): number of channels of input image. + embed_size (int): size of embedding. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + flatten (bool, optional): whether to flatten output tensor, defaults to True. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + position_embed_initializer (:class:`typing.Callable`, optional): + The initializer of position embedding, defaults to zeros initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + img_size: int, + patch_size: int, + in_chans: int, + embed_size: int, + flatten: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + position_embed_initializer: Callable = init.zeros_()): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + assert_tesseract_initialization() + self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env() + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + self.embed_size = embed_size + self.embed_size_per_partition = embed_size // self.tesseract_dim**2 + + with seed(ParallelMode.TENSOR): + self.weight = Parameter( + torch.empty((self.embed_size_per_partition, in_chans, *self.patch_size), + device=get_current_device(), + dtype=dtype)) + self.bias = Parameter(torch.empty(self.embed_size_per_partition, device=get_current_device(), dtype=dtype)) + + self.cls_token = Parameter( + torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype)) + self.pos_embed = Parameter( + torch.zeros((1, self.num_patches + 1, self.embed_size_per_partition), + device=get_current_device(), + dtype=dtype)) + + self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) + self._set_tensor_parallel_attribute() + + def _set_tensor_parallel_attribute(self): + set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2) + set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim**2) + set_tensor_parallel_attribute_by_partition(self.cls_token, self.tesseract_dim**2) + set_tensor_parallel_attribute_by_partition(self.pos_embed, self.tesseract_dim**2) + + def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer): + with seed(ParallelMode.TENSOR): + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + fan_out = self.embed_size + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + bias_initializer(self.bias, fan_in=fan_in) + position_embed_initializer(self.pos_embed) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + cls_token_key = prefix + 'cls_token' + pos_embed_key = prefix + 'pos_embed' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + # cls token + cls_token = state_dict.pop(cls_token_key, None) + if cls_token is not None: + local_state[cls_token_key] = cls_token + # pos embed + pos_embed = state_dict.pop(pos_embed_key, None) + if pos_embed is not None: + local_state[pos_embed_key] = pos_embed + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_ROW, + dims={ + weight_key: 0, + bias_key: 0, + cls_token_key: -1, + pos_embed_key: -1 + }, + partition_states={ + weight_key: True, + bias_key: True, + cls_token_key: True, + pos_embed_key: True + }, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_COL, + dims={ + weight_key: 0, + bias_key: 0, + cls_token_key: -1, + pos_embed_key: -1 + }, + partition_states={ + weight_key: True, + bias_key: True, + cls_token_key: True, + pos_embed_key: True + }, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + cls_token_key = prefix + 'cls_token' + pos_embed_key = prefix + 'pos_embed' + local_state = OrderedDict({ + weight_key: self.weight, + bias_key: self.bias, + cls_token_key: self.cls_token, + pos_embed_key: self.pos_embed + }) + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_COL, + dims={ + weight_key: 0, + bias_key: 0, + cls_token_key: -1, + pos_embed_key: -1 + }, + partition_states={ + weight_key: True, + bias_key: True, + cls_token_key: True, + pos_embed_key: True + }, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_ROW, + dims={ + weight_key: 0, + bias_key: 0, + cls_token_key: -1, + pos_embed_key: -1 + }, + partition_states={ + weight_key: True, + bias_key: True, + cls_token_key: True, + pos_embed_key: True + }, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + input_ = split_batch_2p5d(input_, 0) + + B, C, H, W = input_.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + + weight = all_gather_tensor_2p5d(self.weight, 0, ParallelMode.PARALLEL_2P5D_COL) + bias = all_gather_tensor_2p5d(self.bias, 0, ParallelMode.PARALLEL_2P5D_COL) + + output = F.conv2d(input_, weight, bias, stride=self.patch_size) + if self.flatten: + output = output.flatten(2).transpose(1, 2) # BCHW -> BNC + + cls_token = all_gather_tensor_2p5d(self.cls_token, -1, ParallelMode.PARALLEL_2P5D_COL) + pos_embed = all_gather_tensor_2p5d(self.pos_embed, -1, ParallelMode.PARALLEL_2P5D_COL) + cls_token = cls_token.expand(output.shape[0], -1, -1) + output = torch.cat((cls_token, output), dim=1) + output = output + pos_embed + + return output + + +@LAYERS.register_module +class Embedding2p5D(ParallelLayer): + r"""Embedding for 2.5D parallelism. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed “pad”, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about initializer please refer to + `init `_ + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs): + super().__init__() + + assert_tesseract_initialization() + self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env() + self.num_embeddings = num_embeddings + self.embed_dim = embedding_dim + embed_dim_per_partition = embedding_dim // self.tesseract_dim**2 + + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + + self.weight = Parameter( + torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)) + + self.reset_parameters(weight_initializer) + self._set_tensor_parallel_attributes() + + def _set_tensor_parallel_attributes(self): + set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2) + + def reset_parameters(self, weight_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.num_embeddings, self.embed_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_ROW, + dims={weight_key: -1}, + partition_states={weight_key: True}, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_COL, + dims={weight_key: -1}, + partition_states={weight_key: True}, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + local_state = OrderedDict({weight_key: self.weight}) + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_COL, + dims={weight_key: -1}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_ROW, + dims={weight_key: -1}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + input_ = split_batch_2p5d(input_, 0) + + weight = all_gather_tensor_2p5d(self.weight, -1, ParallelMode.PARALLEL_2P5D_COL) + + output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) + + return output + + +@LAYERS.register_module +class VocabParallelEmbedding2p5D(ParallelLayer): + """Embedding parallelized in the vocabulary dimension. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed “pad”, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about initializer please refer to + `init `_. + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs): + super().__init__() + self.num_embeddings = num_embeddings + self.embed_dim = embedding_dim + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + + assert_tesseract_initialization() + self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env() + self.num_embeddings_per_partition = divide(self.num_embeddings, self.tesseract_dim) + self.embed_dim_per_partition = divide(self.embed_dim, self.tesseract_dim) + tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition + self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition + + self.weight = Parameter( + torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition), + device=get_current_device(), + dtype=dtype)) + + self.reset_parameters(weight_initializer) + self._set_tensor_parallel_attributes() + env.vocab_parallel = True + + def _set_tensor_parallel_attributes(self): + set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2) + + def reset_parameters(self, weight_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.num_embeddings, self.embed_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None and \ + self.vocab_start_index <= self.padding_idx < self.vocab_end_index: + with torch.no_grad(): + self.weight[self.padding_idx - self.vocab_start_index].fill_(0) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_ROW, + dims={weight_key: -1}, + partition_states={weight_key: True}, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_COL, + dims={weight_key: 0}, + partition_states={weight_key: True}, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + local_state = OrderedDict({weight_key: self.weight}) + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_COL, + dims={weight_key: 0}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_ROW, + dims={weight_key: -1}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + # Build the mask. + input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) + # Mask the input. + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + + output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args, + **self.embed_kwargs) + + # Mask the output embedding. + output_parallel[input_mask, :] = 0. + # Reduce across all the model parallel GPUs. + output = reduce_scatter_tensor_2p5d(output_parallel, 0, ParallelMode.PARALLEL_2P5D_COL) + return output + + +@LAYERS.register_module +class Classifier2p5D(ParallelLayer): + r"""Classifier for 2.5D parallelism. + + Args: + in_features (int): size of each input sample. + num_classes (int): number of classes. + weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + self.in_features = in_features + self.num_classes = num_classes + assert_tesseract_initialization() + self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + self.dep_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env() + + # partitioning dimension + self.input_size_per_partition = divide(self.in_features, self.tesseract_dim**2) + + if weight is not None: + self.weight = weight + self.has_weight = False + else: + self.weight = Parameter( + torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype)) + self.has_weight = True + if bias: + self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) + else: + self.bias = None + + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + + def _set_tensor_parallel_attributes(self): + if self.has_weight: + set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.in_features, self.num_classes + col_src_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2P5D_COL)[0] + row_src_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2P5D_ROW)[0] + + if self.has_weight: + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + broadcast(self.bias, col_src_rank, ParallelMode.PARALLEL_2P5D_COL) + broadcast(self.bias, row_src_rank, ParallelMode.PARALLEL_2P5D_ROW) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + if self.has_weight: + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_ROW, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_COL, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict() + if self.has_weight: + local_state[weight_key] = self.weight + if self.bias is not None: + local_state[bias_key] = self.bias + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_COL, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_ROW, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + out_shape = input_.shape[:-1] + (self.num_classes,) + + return classifier_2p5d(input_, self.weight, self.bias, self.tesseract_dim, out_shape, self.row_rank, + self.col_rank, ParallelMode.PARALLEL_2P5D_ROW, ParallelMode.PARALLEL_2P5D_COL, + self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, + self.tensor_parallel_size) + + +@LAYERS.register_module +class VocabParallelClassifier2p5D(ParallelLayer): + r"""Vocab parallel classifier layer for 2.5D parallelism. + + Args: + in_features (int): size of each input sample. + num_classes (int): number of classes. + weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + + self.in_features = in_features + self.num_classes = num_classes + + # parallel setting + assert_tesseract_initialization() + self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + self.dep_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + self.tesseract_dim, _ = get_tesseract_dim_dep_from_env() + + # partitioning dimension + self.input_size_per_partition = divide(in_features, self.tesseract_dim) + self.hidden_size_per_partition = divide(num_classes, self.tesseract_dim) + + # create weight, shape: [k/q, h/q] + factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + if weight is not None: + self.weight = weight + self.has_weight = False + else: + self.weight = Parameter( + torch.empty(self.hidden_size_per_partition, self.input_size_per_partition, **factory_kwargs)) + self.has_weight = True + # create bias, shape: [h/q] + if bias: + self.bias = Parameter(torch.empty(self.hidden_size_per_partition, **factory_kwargs)) + else: + self.bias = None + + # initialize parameters + with seed(ParallelMode.TENSOR): + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + env.vocab_parallel = True + + def _set_tensor_parallel_attributes(self): + if self.has_weight: + set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.num_classes + if self.has_weight: + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + if self.has_weight: + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_ROW, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_COL, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def forward(self, x: Tensor) -> Tensor: + # input: [m/dq, n/q, k/q] + # output: [m/dq, n/q, h/q] + out_shape = x.shape[:-1] + (self.hidden_size_per_partition,) + + output = Matmul_ABT_2p5D.apply( + x, + self.weight, + self.tesseract_dim, + out_shape, + self.row_rank, + self.col_rank, + self.dep_rank, + ParallelMode.PARALLEL_2P5D_ROW, + ParallelMode.PARALLEL_2P5D_COL, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) + + if self.bias is not None: + output = add_bias_2p5d(output, self.bias, self.hidden_size_per_partition, self.tesseract_dim, self.row_rank, + self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, False, + self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, + self.tensor_parallel_size) + return output diff --git a/colossalai/nn/layer/parallel_3d/__init__.py b/colossalai/nn/layer/parallel_3d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9ae255b449ee7f57a08a3bb596102860bb1b60d3 --- /dev/null +++ b/colossalai/nn/layer/parallel_3d/__init__.py @@ -0,0 +1,8 @@ +from ._operation import reduce_by_batch_3d, split_batch_3d, split_tensor_3d +from .layers import (Classifier3D, Embedding3D, LayerNorm3D, Linear3D, PatchEmbedding3D, VocabParallelClassifier3D, + VocabParallelEmbedding3D) + +__all__ = [ + 'reduce_by_batch_3d', 'split_tensor_3d', 'split_batch_3d', 'Linear3D', 'LayerNorm3D', 'PatchEmbedding3D', + 'Classifier3D', 'Embedding3D', 'VocabParallelEmbedding3D', 'VocabParallelClassifier3D' +] diff --git a/colossalai/nn/layer/parallel_3d/_operation.py b/colossalai/nn/layer/parallel_3d/_operation.py new file mode 100755 index 0000000000000000000000000000000000000000..5dc9a242851fa79af244f018e4c9e4d2e57e84fc --- /dev/null +++ b/colossalai/nn/layer/parallel_3d/_operation.py @@ -0,0 +1,590 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from typing import Optional, Tuple + +import torch +from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd + +from colossalai.communication import all_gather, all_reduce, broadcast, reduce, reduce_scatter +from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc + +from ._utils import get_parallel_mode_from_env, push_async_grad + + +class _Linear3D(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward( + ctx, + input_: Tensor, + weight: Tensor, + weight_id: int, + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + output_parallel_mode: ParallelMode, + ) -> Tensor: + ctx.weight_id = weight_id + ctx.input_parallel_mode = input_parallel_mode + ctx.weight_parallel_mode = weight_parallel_mode + ctx.output_parallel_mode = output_parallel_mode + + input_ = all_gather(input_, 0, input_parallel_mode) + weight = all_gather(weight, 0, weight_parallel_mode) + ctx.save_for_backward(input_, weight) + + output = torch.matmul(input_, weight) + output = reduce_scatter(output, 0, output_parallel_mode) + + return output + + @staticmethod + @custom_bwd + def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: + input_, weight = ctx.saved_tensors + output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode) + + input_grad = torch.matmul(output_grad, weight.transpose(0, 1)) + input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True) + + weight_grad = torch.matmul( + input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1])) + weight_grad, op = reduce_scatter(weight_grad, 0, ctx.weight_parallel_mode, async_op=True) + weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) + + input_op.wait() + + return input_grad, weight_grad, None, None, None, None + + +def linear_3d( + input_: Tensor, + weight: Tensor, + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + output_parallel_mode: ParallelMode, +) -> Tensor: + r"""Linear layer for 3D parallelism. + + Args: + input_ (:class:`torch.tensor`): input matrix. + weight (:class:`torch.tensor`): matrix of weight. + input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode. + weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode. + output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + return _Linear3D.apply( + input_, + weight, + id(weight), + input_parallel_mode, + weight_parallel_mode, + output_parallel_mode, + ) + + +class _Classifier3D(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward( + ctx, + input_: Tensor, + weight: Tensor, + bias: Optional[Tensor], + weight_id: int, + bias_id: Optional[int], + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + output_parallel_mode: ParallelMode, + ) -> Tensor: + ctx.use_bias = bias is not None + ctx.weight_id = weight_id + + src_rank = gpc.get_ranks_in_group(input_parallel_mode)[gpc.get_local_rank(output_parallel_mode)] + weight = broadcast(weight, src_rank, input_parallel_mode) + ctx.save_for_backward(input_, weight) + + output = torch.matmul(input_, weight.transpose(0, 1)) + output = all_reduce(output, output_parallel_mode) + + if bias is not None: + ctx.bias_id = bias_id + output += bias + + ctx.src_rank = src_rank + ctx.input_parallel_mode = input_parallel_mode + ctx.weight_parallel_mode = weight_parallel_mode + ctx.output_parallel_mode = output_parallel_mode + return output + + @staticmethod + @custom_bwd + def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: + input_, weight = ctx.saved_tensors + weight_grad = torch.matmul( + output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), input_.reshape(-1, input_.shape[-1])) + weight_grad = reduce(weight_grad, ctx.src_rank, ctx.input_parallel_mode) + if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode): + weight_grad, op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True) + weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) + else: + weight_grad = None + + if ctx.use_bias: + bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1])) + bias_grad = all_reduce(bias_grad, ctx.input_parallel_mode) + bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True) + bias_grad = push_async_grad(op, bias_grad, ctx.bias_id) + else: + bias_grad = None + + input_grad = torch.matmul(output_grad, weight) + + return input_grad, weight_grad, bias_grad, None, None, None, None, None + + +def classifier_3d( + input_: Tensor, + weight: Tensor, + bias: Optional[Tensor], + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + output_parallel_mode: ParallelMode, +) -> Tensor: + r"""3D parallel classifier. + + Args: + input_ (:class:`torch.tensor`): input matrix. + weight (:class:`torch.tensor`): matrix of weight. + bias (:class:`torch.tensor`): matrix of bias. + input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode. + weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode. + output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + return _Classifier3D.apply( + input_, + weight, + bias, + id(weight), + id(bias) if bias is not None else None, + input_parallel_mode, + weight_parallel_mode, + output_parallel_mode, + ) + + +class _VocabParallelClassifier3D(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward( + ctx, + input_: Tensor, + weight: Tensor, + bias: Optional[Tensor], + weight_id: int, + bias_id: Optional[int], + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + output_parallel_mode: ParallelMode, + ) -> Tensor: + ctx.use_bias = bias is not None + ctx.weight_id = weight_id + + input_ = all_gather(input_, 0, input_parallel_mode) + weight = all_gather(weight, 0, weight_parallel_mode).transpose(0, 1) + ctx.save_for_backward(input_, weight) + + output = torch.matmul(input_, weight) + output = reduce_scatter(output, 0, output_parallel_mode) + + if bias is not None: + ctx.bias_id = bias_id + output += bias + + ctx.input_parallel_mode = input_parallel_mode + ctx.weight_parallel_mode = weight_parallel_mode + ctx.output_parallel_mode = output_parallel_mode + return output + + @staticmethod + @custom_bwd + def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: + input_, weight = ctx.saved_tensors + output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode) + + input_grad = torch.matmul(output_grad, weight.transpose(0, 1)) + input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True) + + weight_grad = torch.matmul( + input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1])) + weight_grad, op = reduce_scatter(weight_grad.transpose(0, 1), 0, ctx.weight_parallel_mode, async_op=True) + weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) + + if ctx.use_bias: + bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1])) + bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True) + bias_grad = push_async_grad(op, bias_grad, ctx.bias_id) + else: + bias_grad = None + + input_op.wait() + + return input_grad, weight_grad, bias_grad, None, None, None, None, None + + +def vocab_parallel_classifier_3d( + input_: Tensor, + weight: Tensor, + bias: Optional[Tensor], + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + output_parallel_mode: ParallelMode, +) -> Tensor: + r"""3D vocab parallel classifier. + + Args: + input_ (:class:`torch.tensor`): input matrix. + weight (:class:`torch.tensor`): matrix of weight. + bias (:class:`torch.tensor`): matrix of bias. + input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode. + weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode. + output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + return _VocabParallelClassifier3D.apply( + input_, + weight, + bias, + id(weight), + id(bias) if bias is not None else None, + input_parallel_mode, + weight_parallel_mode, + output_parallel_mode, + ) + + +@torch.jit.script +def norm_forward(x: Tensor, mean: Tensor, sqr_mean: Tensor, weight: Tensor, bias: Tensor, eps: float): + mu = x - mean + var = sqr_mean - mean**2 + sigma = torch.sqrt(var + eps) + z = mu / sigma + output = weight * z + bias + + return output, mu, sigma + + +@torch.jit.script +def norm_backward(grad: Tensor, mu: Tensor, sigma: Tensor, weight: Tensor): + # dbias, dweight = grad, grad * mu / sigma + dz = grad * weight + dmu = dz / sigma + dvar = dz * mu * (-0.5) * sigma**(-3) + dmean = -dmu + dvar = torch.sum(dvar, -1, keepdim=True) + dmean = torch.sum(dmean, -1, keepdim=True) + + return dmu, dmean, dvar + + +class _Layernorm3D(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward( + ctx, + input_: Tensor, + weight: Tensor, + bias: Tensor, + weight_id: int, + bias_id: int, + normalized_shape: int, + eps: float, + output_parallel_mode: ParallelMode, + input_x_weight_parallel_mode: ParallelMode, + ) -> Tensor: + ctx.weight_id = weight_id + ctx.bias_id = bias_id + + sum_ = torch.sum(input_, dim=-1, keepdim=True) + sqr_sum = torch.sum(input_**2, dim=-1, keepdim=True) + mean, sqr_mean = all_reduce(torch.stack((sum_, sqr_sum)), output_parallel_mode) / normalized_shape + + output, mu, sigma = norm_forward(input_, mean, sqr_mean, weight, bias, eps) + + ctx.save_for_backward(mu, sigma, weight) + + ctx.normalized_shape = normalized_shape + ctx.output_parallel_mode = output_parallel_mode + ctx.input_x_weight_parallel_mode = input_x_weight_parallel_mode + + return output + + @staticmethod + @custom_bwd + def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: + mu, sigma, weight = ctx.saved_tensors + + bias_grad, weight_grad = output_grad, output_grad * mu / sigma + bias_grad = torch.sum(bias_grad, dim=tuple(range(len(bias_grad.shape))[:-1])) + bias_grad, op = all_reduce(bias_grad, ctx.input_x_weight_parallel_mode, async_op=True) + bias_grad = push_async_grad(op, bias_grad, ctx.bias_id) + weight_grad = torch.sum(weight_grad, dim=tuple(range(len(weight_grad.shape))[:-1])) + weight_grad, op = all_reduce(weight_grad, ctx.input_x_weight_parallel_mode, async_op=True) + weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) + + dmu, dmean, dvar = norm_backward(output_grad, mu, sigma, weight) + dvar, dmean = all_reduce(torch.stack((dvar, dmean)), ctx.output_parallel_mode) + input_grad = dmu + (dmean + 2 * dvar * mu) / ctx.normalized_shape + + return input_grad, weight_grad, bias_grad, None, None, None, None, None, None, None, None + + +def layernorm_3d( + input_: Tensor, + weight: Tensor, + bias: Tensor, + normalized_shape: int, + eps: float, + output_parallel_mode: ParallelMode, + input_x_weight_parallel_mode: ParallelMode, +) -> Tensor: + r"""3D parallel Layernorm. + + Args: + input_ (:class:`torch.tensor`): input matrix. + weight (:class:`torch.tensor`): matrix of weight. + bias (:class:`torch.tensor`): matrix of bias. + normalized_shape (int): input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] + \times \ldots \times \text{normalized_shape}[-1]]` + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps (float): a value added to the denominator for numerical stability + output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode. + input_x_weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input x weight parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + return _Layernorm3D.apply( + input_, + weight, + bias, + id(weight), + id(bias), + normalized_shape, + eps, + output_parallel_mode, + input_x_weight_parallel_mode, + ) + + +def split_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: + r"""Splits 3D parallel tensor in specified dimension. + + Args: + tensor (:class:`torch.tensor`): Input tensor. + dim (int): Specified dimension in which to split. + parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`, optional): Parallel mode. + + Returns: + :class:`torch.tensor`: The tensor has been split. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + dim_size = tensor.size(dim) + world_size = gpc.get_world_size(parallel_mode) + assert dim_size % world_size == 0, \ + f'The dimension {dim} to split, size ({dim_size}) is not a multiple of world size ({world_size}), ' \ + f'cannot split tensor evenly' + if tensor.size(dim) <= 1: + return tensor + output = torch.chunk(tensor, gpc.get_world_size(parallel_mode), + dim=dim)[gpc.get_local_rank(parallel_mode)].contiguous() + return output + + +def split_batch_3d(input_: Tensor, + dim: int = 0, + input_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_INPUT, + weight_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_WEIGHT) -> Tensor: + r"""Splits 3D tensor in batch. + + Args: + input_ (:class:`torch.tensor`): Input tensor. + dim (int): Specified dimension in which to split. + input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`, optional): input parallel mode. + weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`, optional): weight parallel mode. + + Returns: + :class:`torch.tensor`: The tensor has been split. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + if input_.size(dim) <= 1: + return input_ + weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + weight_world_size = gpc.get_world_size(weight_parallel_mode) + input_world_size = gpc.get_world_size(input_parallel_mode) + output = torch.chunk(input_, weight_world_size, dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous() + output = torch.chunk(output, input_world_size, dim=dim)[gpc.get_local_rank(input_parallel_mode)].contiguous() + return output + + +class _ReduceTensor3D(torch.autograd.Function): + + @staticmethod + def forward(ctx, input_, parallel_mode): + return all_reduce(input_, parallel_mode) + + @staticmethod + def backward(ctx, output_grad): + return output_grad, None + + +def reduce_tensor_3d(tensor: Tensor, parallel_mode: ParallelMode) -> Tensor: + r"""All-reduce the input + + Args: + tensor (:class:`torch.tensor`): Input tensor. + parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): Parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + return _ReduceTensor3D.apply(tensor, parallel_mode) + + +class _AllGatherTensor3D(torch.autograd.Function): + + @staticmethod + def forward(ctx, input_, dim, parallel_mode): + ctx.dim = dim + ctx.parallel_mode = parallel_mode + output = all_gather(input_, dim, parallel_mode) + return output + + @staticmethod + def backward(ctx, output_grad): + input_grad = reduce_scatter(output_grad, ctx.dim, ctx.parallel_mode) + return input_grad, None, None + + +def all_gather_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: + r"""All-reduce the gradient in backward pass. + + Args: + tensor (:class:`torch.tensor`): Input tensor. + dim (int): Dimension to gather. + parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): Parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + return _AllGatherTensor3D.apply(tensor, dim, parallel_mode) + + +class _ReduceScatterTensor3D(torch.autograd.Function): + + @staticmethod + def forward(ctx, input_, dim, parallel_mode): + ctx.dim = dim + ctx.parallel_mode = parallel_mode + return reduce_scatter(input_, dim, parallel_mode) + + @staticmethod + def backward(ctx, output_grad): + input_grad = all_gather(output_grad, ctx.dim, ctx.parallel_mode) + return input_grad, None, None + + +def reduce_scatter_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: + r"""Reduce-scatter the input. + + Args: + tensor (:class:`torch.tensor`): Input tensor. + dim (int): Dimension to scatter. + parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): Parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + dim_size = tensor.size(dim) + world_size = gpc.get_world_size(parallel_mode) + assert dim_size % world_size == 0, \ + f'The batch size ({dim_size}) is not a multiple of square of 3D depth ({world_size}).' + + return _ReduceScatterTensor3D.apply(tensor, dim, parallel_mode) + + +class _ReduceByBatch3D(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, + input_: Tensor, + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + reduce_mean: bool = False) -> Tensor: + output = all_reduce(input_, input_parallel_mode) + output = all_reduce(output, weight_parallel_mode) + ctx.reduce_mean = reduce_mean + if reduce_mean: + reduce_size = gpc.get_world_size(input_parallel_mode) * gpc.get_world_size(weight_parallel_mode) + ctx.reduce_size = reduce_size + return output.clone() / reduce_size + return output.clone() + + @staticmethod + @custom_bwd + def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: + if ctx.reduce_mean: + return output_grad / ctx.reduce_size, None, None, None + else: + return output_grad, None, None, None + + +def reduce_by_batch_3d(tensor: Tensor, + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + reduce_mean: bool = False) -> Tensor: + r"""All-reduce the input from the model parallel region. + + Args: + input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode. + weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode. + reduce_mean (bool, optional): If set to ``True``, it will divide the output by + (input parallel size * weight parallel size), default to False. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + return _ReduceByBatch3D.apply(tensor, input_parallel_mode, weight_parallel_mode, reduce_mean) diff --git a/colossalai/nn/layer/parallel_3d/_utils.py b/colossalai/nn/layer/parallel_3d/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..364191a79f88450ca8701d96258f8e34b7b9b784 --- /dev/null +++ b/colossalai/nn/layer/parallel_3d/_utils.py @@ -0,0 +1,99 @@ +from collections import OrderedDict +from functools import partial + +import torch +from torch import Tensor + +from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP_3D, OUTPUT_X_WEIGHT_3D, WEIGHT_GROUP_3D +from colossalai.core import global_context as gpc +from colossalai.global_variables import tensor_parallel_env as env + + +def get_depth_from_env() -> int: + try: + depth = env.depth_3d + assert depth > 0, 'DEPTH must be greater than zero' + return depth + + except KeyError as e: + raise EnvironmentError('DEPTH is not found in the current environment, ' + 'please make sure that you have used the correct process group initializer') + + +def get_parallel_mode_from_env(group): + assert group in [INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_X_WEIGHT_3D], \ + f'{group} is not valid for 3D tensor parallelism.' + return getattr(env, group) + + +def swap_in_out_group(): + env.input_group_3d, env.output_group_3d = env.output_group_3d, env.input_group_3d + env.input_x_weight_group_3d, env.output_x_weight_group_3d = ( + env.output_x_weight_group_3d, + env.input_x_weight_group_3d, + ) + + +def dbg_check_shape(tensor: Tensor, shape: tuple): + rank = gpc.get_global_rank() + if rank == 0: + print(tensor.shape) + assert tensor.shape == shape, \ + '{} does not match {}'.format(tensor.shape, shape) + + +class AsyncGradientBucket(object): + + def __init__(self): + self.bucket = OrderedDict() + + def __len__(self): + return len(self.bucket) + + def push(self, async_op, grad_tensor, param_id): + self.bucket[param_id] = tuple((async_op, grad_tensor)) + return torch.zeros_like(grad_tensor, dtype=grad_tensor.dtype, device=grad_tensor.device) + + def pop(self, param_id): + grad = None + if param_id in self.bucket: + op, grad = self.bucket.pop(param_id) + if op is not None: + op.wait() + return grad + + def synchronize(self, params): + for p in params: + i = id(p) + if i in self.bucket: + op, grad = self.bucket.pop(i) + if op is not None: + op.wait() + p.grad.add_(grad) + + +_async_grad_bucket = AsyncGradientBucket() + + +def push_async_grad(op, grad, param_id): + return _async_grad_bucket.push(op, grad, param_id) + + +def pop_async_grad(param_id): + return _async_grad_bucket.pop(param_id) + + +def _async_grad_hook(grad, param_id): + grad.add_(pop_async_grad(param_id)) + return grad + + +def register_async_grad_hook(param): + param.register_hook(partial(_async_grad_hook, param_id=id(param))) + + +def synchronize(params=list()): + _async_grad_bucket.synchronize(params) + torch.cuda.default_stream().synchronize() + if len(_async_grad_bucket) > 0: + raise RuntimeError(f"{len(_async_grad_bucket)} asynchronous gradient(s) not collected.") diff --git a/colossalai/nn/layer/parallel_3d/layers.py b/colossalai/nn/layer/parallel_3d/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..99b0c3f8b7ec339190b28d69842a694cc318fbc7 --- /dev/null +++ b/colossalai/nn/layer/parallel_3d/layers.py @@ -0,0 +1,1218 @@ +import math +from collections import OrderedDict +from typing import Callable + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.nn import Parameter + +from colossalai.communication import all_reduce, broadcast +from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP_3D, OUTPUT_X_WEIGHT_3D, WEIGHT_GROUP_3D +from colossalai.context import ParallelMode, seed +from colossalai.core import global_context as gpc +from colossalai.global_variables import tensor_parallel_env as env +from colossalai.nn import init as init +from colossalai.nn.layer.base_layer import ParallelLayer +from colossalai.registry import LAYERS +from colossalai.utils.checkpointing import ( + broadcast_state_dict, + gather_tensor_parallel_state_dict, + partition_tensor_parallel_state_dict, +) +from colossalai.utils.cuda import get_current_device + +from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple +from ._operation import ( + all_gather_tensor_3d, + classifier_3d, + layernorm_3d, + linear_3d, + reduce_scatter_tensor_3d, + split_batch_3d, + split_tensor_3d, + vocab_parallel_classifier_3d, +) +from ._utils import get_depth_from_env, get_parallel_mode_from_env, register_async_grad_hook, swap_in_out_group + + +@LAYERS.register_module +class LayerNorm3D(ParallelLayer): + r"""Layer Normalization for 3D parallelism. + + Args: + normalized_shape (int): input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] + \times \ldots \times \text{normalized_shape}[-1]]` + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-12. + bias (bool, optional): Whether to add a bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + """ + + def __init__(self, normalized_shape: int, eps: float = 1e-12, bias=True, dtype=None): + + super().__init__() + self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + self.input_x_weight_parallel_mode = get_parallel_mode_from_env(INPUT_X_WEIGHT_3D) + self.depth = get_depth_from_env() + self.normalized_shape = normalized_shape + self.normalized_shape_per_partition = divide(normalized_shape, self.depth) + + self.weight = Parameter( + torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype)) + if bias: + self.bias = Parameter( + torch.zeros(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype)) + else: + self.bias = None + self.variance_epsilon = eps + self.reset_parameters() + self._set_tensor_parallel_attributes() + + def _set_tensor_parallel_attributes(self) -> None: + set_tensor_parallel_attribute_by_partition(self.weight, self.depth) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, self.depth) + + def reset_parameters(self) -> None: + init.ones_()(self.weight) + register_async_grad_hook(self.weight) + if self.bias is not None: + init.zeros_()(self.bias) + register_async_grad_hook(self.bias) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight.transpose(0, 1) + # bias + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # partition in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ + gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True, + }, + ) + # broadcast in input groups + if gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = broadcast_state_dict(local_state, self.input_parallel_mode) + # broadcast in weight groups + local_state = broadcast_state_dict(local_state, self.weight_parallel_mode) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict({weight_key: self.weight}) + if self.bias is not None: + local_state[bias_key] = self.bias + + # gather in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ + gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + return layernorm_3d( + input_, + self.weight, + self.bias, + self.normalized_shape, + self.variance_epsilon, + self.output_parallel_mode, + self.input_x_weight_parallel_mode, + ) + + +@LAYERS.register_module +class Linear3D(ParallelLayer): + r"""Linear layer for 3D parallelism. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + self.output_x_weight_parallel_mode = get_parallel_mode_from_env(OUTPUT_X_WEIGHT_3D) + self.depth = get_depth_from_env() + self.skip_bias_add = skip_bias_add + self.in_features_per_partition = divide(in_features, self.depth**2) + self.out_features_per_partition = divide(out_features, self.depth) + self.bias_features_per_partition = divide(out_features, self.depth) + + self.weight = Parameter( + torch.empty(self.in_features_per_partition, + self.out_features_per_partition, + device=get_current_device(), + dtype=dtype)) + if bias: + self.bias = Parameter( + torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype)) + else: + self.bias = None + + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + swap_in_out_group() + + def _set_tensor_parallel_attributes(self) -> None: + set_tensor_parallel_attribute_by_partition(self.weight, self.depth**3) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, self.depth) + + def _sync_grad_hook(self, grad) -> Tensor: + grad = all_reduce(grad.clone(), self.output_x_weight_parallel_mode) + return grad + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.in_features, self.out_features + + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + register_async_grad_hook(self.weight) + + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + broadcast(self.bias, + gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0], + self.output_x_weight_parallel_mode) + self.bias.register_hook(self._sync_grad_hook) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight.transpose(0, 1) + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # partition in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ + gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + ) + # partition in input groups + if gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + self.input_parallel_mode, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + ) + # partition in weight groups + local_state = partition_tensor_parallel_state_dict( + local_state, + self.weight_parallel_mode, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict({weight_key: self.weight}) + if self.bias is not None: + local_state[bias_key] = self.bias + + # gather in weight groups + local_state = gather_tensor_parallel_state_dict( + local_state, + self.weight_parallel_mode, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + keep_vars=keep_vars, + ) + # gather in input groups + if gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + self.input_parallel_mode, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars, + ) + # gather in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ + gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + local_state[weight_key] = local_state[weight_key].transpose(0, 1) + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + output = linear_3d( + input_, + self.weight, + self.input_parallel_mode, + self.weight_parallel_mode, + self.output_parallel_mode, + ) + + if not self.skip_bias_add: + if self.bias is not None: + output = output + self.bias + return output + else: + return output, self.bias + + +@LAYERS.register_module +class Classifier3D(ParallelLayer): + r"""Classifier for 3D parallelism. + + Args: + in_features (int): size of each input sample. + num_classes (int): number of classes. + weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + self.in_features = in_features + self.num_classes = num_classes + self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + self.depth = get_depth_from_env() + self.in_features_per_partition = divide(in_features, self.depth) + + if weight is not None: + self.weight = weight + self.has_weight = False + else: + self.weight = Parameter( + torch.empty(self.num_classes, self.in_features_per_partition, device=get_current_device(), dtype=dtype)) + self.has_weight = True + if bias: + self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) + else: + self.bias = None + + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + + def _set_tensor_parallel_attributes(self) -> None: + if self.has_weight: + set_tensor_parallel_attribute_by_partition(self.weight, self.depth) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.in_features, self.num_classes + + if self.has_weight: + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + broadcast(self.weight, gpc.get_ranks_in_group(self.weight_parallel_mode)[0], self.weight_parallel_mode) + + register_async_grad_hook(self.weight) + + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], ParallelMode.TENSOR) + register_async_grad_hook(self.bias) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + if self.has_weight: + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # partition in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ + gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + ) + # broadcast in input groups + if gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = broadcast_state_dict(local_state, self.input_parallel_mode) + # broadcast in weight groups + local_state = broadcast_state_dict(local_state, self.weight_parallel_mode) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict() + if self.has_weight: + local_state[weight_key] = self.weight + if self.bias is not None: + local_state[bias_key] = self.bias + + # gather in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ + gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + return classifier_3d( + input_, + self.weight, + self.bias, + self.input_parallel_mode, + self.weight_parallel_mode, + self.output_parallel_mode, + ) + + +@LAYERS.register_module +class VocabParallelClassifier3D(ParallelLayer): + r"""Vocab parallel classifier layer for 3D parallelism. + + Args: + in_features (int): size of each input sample. + num_classes (int): number of classes. + weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + self.in_features = in_features + self.num_classes = num_classes + self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + self.output_x_weight_parallel_mode = get_parallel_mode_from_env(OUTPUT_X_WEIGHT_3D) + self.depth = get_depth_from_env() + self.in_features_per_partition = divide(in_features, self.depth) + self.out_features_per_partition = divide(num_classes, self.depth**2) + self.bias_features_per_partition = divide(num_classes, self.depth) + + if weight is not None: + self.weight = weight + self.has_weight = False + else: + self.weight = Parameter( + torch.empty(self.out_features_per_partition, + self.in_features_per_partition, + device=get_current_device(), + dtype=dtype)) + self.has_weight = True + if bias: + self.bias = Parameter( + torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype)) + else: + self.bias = None + + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + swap_in_out_group() + env.vocab_parallel = True + + def _set_tensor_parallel_attributes(self) -> None: + if self.has_weight: + set_tensor_parallel_attribute_by_partition(self.weight, self.depth**3) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, self.depth) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.in_features, self.num_classes + + if self.has_weight: + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + + register_async_grad_hook(self.weight) + + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + broadcast(self.bias, + gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0], + self.output_x_weight_parallel_mode) + register_async_grad_hook(self.bias) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + if self.has_weight: + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # partition in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ + gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + ) + # partition in input groups + if gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + self.input_parallel_mode, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + ) + # partition in weight groups + local_state = partition_tensor_parallel_state_dict( + local_state, + self.weight_parallel_mode, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict({weight_key: self.weight}) + if self.bias is not None: + local_state[bias_key] = self.bias + + # gather in weight groups + local_state = gather_tensor_parallel_state_dict( + local_state, + self.weight_parallel_mode, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + keep_vars=keep_vars, + ) + # gather in input groups + if gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + self.input_parallel_mode, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars, + ) + # gather in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ + gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + return vocab_parallel_classifier_3d( + input_, + self.weight, + self.bias, + self.input_parallel_mode, + self.weight_parallel_mode, + self.output_parallel_mode, + ) + + +@LAYERS.register_module +class PatchEmbedding3D(ParallelLayer): + r"""2D Image to Patch Embedding. + + Args: + img_size (int): image size. + patch_size (int): patch size. + in_chans (int): number of channels of input image. + embed_size (int): size of embedding. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + flatten (bool, optional): whether to flatten output tensor, defaults to True. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + position_embed_initializer (:class:`typing.Callable`, optional): + The initializer of position embedding, defaults to zeros initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + img_size: int, + patch_size: int, + in_chans: int, + embed_size: int, + flatten: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + position_embed_initializer: Callable = init.zeros_()): + super().__init__() + self.depth = get_depth_from_env() + self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + self.input_x_weight_parallel_mode = get_parallel_mode_from_env(INPUT_X_WEIGHT_3D) + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.embed_size = embed_size + embed_size_per_partition = embed_size // self.depth + self.flatten = flatten + + self.weight = nn.Parameter( + torch.empty((embed_size_per_partition, in_chans, *self.patch_size), + device=get_current_device(), + dtype=dtype)) + self.bias = nn.Parameter(torch.empty(embed_size_per_partition, device=get_current_device(), dtype=dtype)) + + self.cls_token = nn.Parameter( + torch.zeros((1, 1, embed_size_per_partition), device=get_current_device(), dtype=dtype)) + self.pos_embed = nn.Parameter( + torch.zeros((1, self.num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype)) + + self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) + self._set_tensor_parallel_attributes() + + def _set_tensor_parallel_attributes(self) -> None: + set_tensor_parallel_attribute_by_partition(self.weight, self.depth) + set_tensor_parallel_attribute_by_partition(self.bias, self.depth) + set_tensor_parallel_attribute_by_partition(self.cls_token, self.depth) + set_tensor_parallel_attribute_by_partition(self.pos_embed, self.depth) + + def _sync_grad_hook(self, grad) -> Tensor: + grad = all_reduce(grad.clone(), self.input_x_weight_parallel_mode) + return grad + + def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + fan_out = self.embed_size + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + bias_initializer(self.bias, fan_in=fan_in) + position_embed_initializer(self.pos_embed) + + src_rank = gpc.get_ranks_in_group(self.input_x_weight_parallel_mode)[0] + broadcast(self.weight, src_rank, self.input_x_weight_parallel_mode) + broadcast(self.bias, src_rank, self.input_x_weight_parallel_mode) + broadcast(self.pos_embed, src_rank, self.input_x_weight_parallel_mode) + + self.weight.register_hook(self._sync_grad_hook) + self.bias.register_hook(self._sync_grad_hook) + self.cls_token.register_hook(self._sync_grad_hook) + self.pos_embed.register_hook(self._sync_grad_hook) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + cls_token_key = prefix + 'cls_token' + pos_embed_key = prefix + 'pos_embed' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + # cls token + cls_token = state_dict.pop(cls_token_key, None) + if cls_token is not None: + local_state[cls_token_key] = cls_token + # pos embed + pos_embed = state_dict.pop(pos_embed_key, None) + if pos_embed is not None: + local_state[pos_embed_key] = pos_embed + + # partition in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ + gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={ + weight_key: 0, + bias_key: 0, + cls_token_key: -1, + pos_embed_key: -1 + }, + partition_states={ + weight_key: True, + bias_key: True, + cls_token_key: True, + pos_embed_key: True + }, + ) + # broadcast in input groups + if gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = broadcast_state_dict(local_state, self.input_parallel_mode) + # broadcast in weight groups + local_state = broadcast_state_dict(local_state, self.weight_parallel_mode) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + cls_token_key = prefix + 'cls_token' + pos_embed_key = prefix + 'pos_embed' + local_state = OrderedDict({ + weight_key: self.weight, + bias_key: self.bias, + cls_token_key: self.cls_token, + pos_embed_key: self.pos_embed + }) + + # gather in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ + gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={ + weight_key: 0, + bias_key: 0, + cls_token_key: -1, + pos_embed_key: -1 + }, + partition_states={ + weight_key: True, + bias_key: True, + cls_token_key: True, + pos_embed_key: True + }, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + input_ = split_batch_3d(input_, + input_parallel_mode=self.input_parallel_mode, + weight_parallel_mode=self.weight_parallel_mode) + output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size) + if self.flatten: + output = output.flatten(2).transpose(1, 2) # BCHW -> BNC + + cls_token = self.cls_token.expand(output.shape[0], -1, -1) + output = torch.cat((cls_token, output), dim=1) + output = output + self.pos_embed + + return output + + +@LAYERS.register_module +class Embedding3D(ParallelLayer): + r"""Embedding for 3D parallelism. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed “pad”, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about initializer please refer to + `init `_ + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs): + super().__init__() + self.depth = get_depth_from_env() + self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + self.input_x_weight_parallel_mode = get_parallel_mode_from_env(INPUT_X_WEIGHT_3D) + + self.num_embeddings = num_embeddings + self.embed_dim = embedding_dim + embed_dim_per_partition = divide(embedding_dim, self.depth) + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + + self.weight = nn.Parameter( + torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)) + + self.reset_parameters(weight_initializer) + self._set_tensor_parallel_attributes() + + def _set_tensor_parallel_attributes(self) -> None: + set_tensor_parallel_attribute_by_partition(self.weight, self.depth) + + def _sync_grad_hook(self, grad) -> Tensor: + grad = all_reduce(grad.clone(), self.input_x_weight_parallel_mode) + return grad + + def reset_parameters(self, weight_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.num_embeddings, self.embed_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + broadcast(self.weight, + gpc.get_ranks_in_group(self.input_x_weight_parallel_mode)[0], self.input_x_weight_parallel_mode) + self.weight.register_hook(self._sync_grad_hook) + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + + # partition in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ + gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={weight_key: 0}, + partition_states={weight_key: True}, + ) + # broadcast in input groups + if gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = broadcast_state_dict(local_state, self.input_parallel_mode) + # broadcast in weight groups + local_state = broadcast_state_dict(local_state, self.weight_parallel_mode) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + local_state = OrderedDict({weight_key: self.weight}) + + # gather in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ + gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={weight_key: 0}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + input_ = split_batch_3d(input_, + input_parallel_mode=self.input_parallel_mode, + weight_parallel_mode=self.weight_parallel_mode) + output = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) + + return output + + +@LAYERS.register_module +class VocabParallelEmbedding3D(ParallelLayer): + r"""Embedding parallelized in the vocabulary dimension. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed “pad”, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about initializer please refer to + `init `_. + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs): + super().__init__() + self.num_embeddings = num_embeddings + self.embed_dim = embedding_dim + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + + self.depth = get_depth_from_env() + self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + self.num_embeddings_per_partition = divide(self.num_embeddings, self.depth**2) + self.embed_dim_per_partition = divide(self.embed_dim, self.depth) + vocab_parallel_rank = gpc.get_local_rank(self.input_parallel_mode) + self.vocab_start_index = vocab_parallel_rank * self.num_embeddings_per_partition * self.depth + self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition * self.depth + + self.weight = Parameter( + torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition), + device=get_current_device(), + dtype=dtype)) + + self.reset_parameters(weight_initializer) + self._set_tensor_parallel_attributes() + env.vocab_parallel = True + + def _set_tensor_parallel_attributes(self): + set_tensor_parallel_attribute_by_partition(self.weight, self.depth**3) + + def reset_parameters(self, weight_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.num_embeddings, self.embed_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None and \ + self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: + with torch.no_grad(): + self.weight[self.padding_idx - self.vocab_start_index].fill_(0) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + + # partition in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ + gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={weight_key: -1}, + partition_states={weight_key: True}, + ) + # partition in input groups + if gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + self.input_parallel_mode, + dims={weight_key: 0}, + partition_states={weight_key: True}, + ) + # partition in weight groups + local_state = partition_tensor_parallel_state_dict( + local_state, + self.weight_parallel_mode, + dims={weight_key: 0}, + partition_states={weight_key: True}, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + local_state = OrderedDict({weight_key: self.weight}) + + # gather in weight groups + local_state = gather_tensor_parallel_state_dict( + local_state, + self.weight_parallel_mode, + dims={weight_key: 0}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + # gather in input groups + if gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + self.input_parallel_mode, + dims={weight_key: 0}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + # gather in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ + gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={weight_key: -1}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode) + + input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + + weight = all_gather_tensor_3d(self.weight, 0, self.weight_parallel_mode) + + output_parallel = F.embedding(masked_input, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) + + output_parallel[input_mask, :] = 0. + output = reduce_scatter_tensor_3d(output_parallel, 0, self.input_parallel_mode) + + return output diff --git a/colossalai/nn/layer/parallel_sequence/__init__.py b/colossalai/nn/layer/parallel_sequence/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4fa9eed6f34b8ccdcf03935337bc96ba705530d0 --- /dev/null +++ b/colossalai/nn/layer/parallel_sequence/__init__.py @@ -0,0 +1,4 @@ +from ._operation import RingQK, RingAV +from .layers import TransformerSelfAttentionRing + +__all__ = ['TransformerSelfAttentionRing', 'RingAV', 'RingQK'] diff --git a/colossalai/nn/layer/parallel_sequence/_operation.py b/colossalai/nn/layer/parallel_sequence/_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..fc80494224c6d2ec3176f40c47733133a422b88c --- /dev/null +++ b/colossalai/nn/layer/parallel_sequence/_operation.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch +from torch import distributed as dist + +from colossalai.communication import ring_forward +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.nn.layer.parallel_sequence._utils import _calc_incoming_device_range, _calc_current_device_range +from colossalai.utils import get_current_device +from torch.cuda.amp import custom_bwd, custom_fwd + + +class RingQK(torch.autograd.Function): + """ + Calculate QK in a ring-exchange style + """ + + @staticmethod + @custom_fwd + def forward(ctx, sub_q, sub_k, batch_size, num_attention_heads, sub_seq_length): + # save tensor for backward + ctx.save_for_backward(sub_q, sub_k) + ctx.sub_seq_length = sub_seq_length + + # create local segment of attention score + attention_score = torch.empty(batch_size * num_attention_heads, + sub_seq_length, + sub_seq_length * gpc.get_world_size(ParallelMode.SEQUENCE), + dtype=sub_q.dtype, + device=get_current_device()) + + # compute local QK^T + part_a = torch.matmul(sub_q, sub_k.transpose(2, 1)) + local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE) + local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE) + start_idx = local_rank * sub_seq_length + end_idx = (local_rank + 1) * sub_seq_length + attention_score[:, :, start_idx:end_idx] = part_a + + # compute QK^T in ring-all-reduce style + for i in range(local_world_size - 1): + sub_k = ring_forward(sub_k, ParallelMode.SEQUENCE) + start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, sub_seq_length) + part_a = torch.matmul(sub_q, sub_k.transpose(2, 1)) + attention_score[:, :, start_idx:end_idx] = part_a + + return attention_score + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + sub_q, sub_k, = ctx.saved_tensors + local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE) + local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE) + + # calculate gradient of sub_k + grad_k = torch.matmul(grad_output.transpose(2, 1), sub_q) + + dist.all_reduce(grad_k, group=gpc.get_group(ParallelMode.SEQUENCE)) + grad_k = grad_k[:, local_rank * ctx.sub_seq_length:(local_rank + 1) * ctx.sub_seq_length] + grad_k /= local_world_size + + # calculate gradient for sub_q + grad_q = torch.zeros_like( + sub_q, + dtype=sub_q.dtype, + device=get_current_device(), + ) + + # compute with local sub_k + start_idx, end_idx = _calc_current_device_range(local_rank, ctx.sub_seq_length) + grad_q += torch.matmul(grad_output[:, :, start_idx:end_idx], sub_k) + + # compute QK^T in ring-all-reduce style + for i in range(local_world_size - 1): + sub_k = ring_forward(sub_k, ParallelMode.SEQUENCE) + start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, ctx.sub_seq_length) + grad_q += torch.matmul(grad_output[:, :, start_idx:end_idx], sub_k) + + grad_q /= local_world_size + + return grad_q, grad_k, None, None, None + + +class RingAV(torch.autograd.Function): + """ + Calculate AV in a ring-exchange style + """ + + @staticmethod + @custom_fwd + def forward(ctx, attention_score, sub_v, batch_size, num_attention_heads, attention_head_size, sub_seq_length): + local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE) + local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE) + local_start_idx, local_end_idx = _calc_current_device_range(local_rank, sub_seq_length) + + sub_attention_result = torch.zeros(batch_size * num_attention_heads, + sub_seq_length, + attention_head_size, + device=get_current_device(), + dtype=attention_score.dtype) + + # save tensors for backward + ctx.save_for_backward(attention_score, sub_v) + ctx.sub_seq_length = sub_seq_length + + # compute local AV + part_av = torch.matmul(attention_score[:, :, local_start_idx:local_end_idx], sub_v) + sub_attention_result += part_av + + # compute AV in ring - all - reduce style + for i in range(local_world_size - 1): + sub_v = ring_forward(sub_v, ParallelMode.SEQUENCE) + start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, sub_seq_length) + + # compute QK^T + part_av = torch.matmul(attention_score[:, :, start_idx:end_idx], sub_v) + sub_attention_result += part_av + return sub_attention_result + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE) + local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE) + local_start_idx, local_end_idx = _calc_current_device_range(local_rank, ctx.sub_seq_length) + attention_scores, sub_v = ctx.saved_tensors + + # calculate gradient of v + grad_v = torch.matmul(attention_scores.transpose(2, 1), grad_output) + dist.all_reduce(grad_v, group=gpc.get_group(ParallelMode.SEQUENCE)) + grad_v = grad_v[:, local_start_idx:local_end_idx] + grad_v /= local_world_size + + # calculate gradient for attention score + grad_attention_score = torch.zeros_like(attention_scores, dtype=grad_output.dtype, device=get_current_device()) + + # compute with local sub_k + grad_attention_score[:, :, local_start_idx:local_end_idx] += torch.matmul(grad_output, sub_v.transpose(2, 1)) + + # compute QK^T in ring-all-reduce style + for i in range(local_world_size - 1): + sub_v = ring_forward(sub_v, ParallelMode.SEQUENCE) + start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, ctx.sub_seq_length) + + # compute grad_q + grad_attention_score[:, :, start_idx:end_idx] += torch.matmul(grad_output, sub_v.transpose(2, 1)) + + return grad_attention_score, grad_v, None, None, None, None diff --git a/colossalai/nn/layer/parallel_sequence/_utils.py b/colossalai/nn/layer/parallel_sequence/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9fad8fab23d2e89d70ef2d82789107db78ebaf08 --- /dev/null +++ b/colossalai/nn/layer/parallel_sequence/_utils.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + + +def _calc_incoming_device_range(i, rank, world_size, sub_seq_length): + device_of_incoming_k = (rank - i - 1) % world_size + start_idx = sub_seq_length * device_of_incoming_k + end_idx = sub_seq_length * (device_of_incoming_k + 1) + return start_idx, end_idx + + +def _calc_current_device_range(rank, sub_seq_length): + start_idx = sub_seq_length * rank + end_idx = sub_seq_length * (rank + 1) + return start_idx, end_idx diff --git a/colossalai/nn/layer/parallel_sequence/layers.py b/colossalai/nn/layer/parallel_sequence/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..d9486217bbc93a9f75cfda809701cafaa15957cc --- /dev/null +++ b/colossalai/nn/layer/parallel_sequence/layers.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import math +import colossalai + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Parameter + +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.nn.layer.parallel_sequence._operation import RingQK, RingAV +from colossalai.registry import LAYERS +from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType +from colossalai.kernel import FusedScaleMaskSoftmax +from colossalai.context import seed + + +@LAYERS.register_module +class TransformerSelfAttentionRing(nn.Module): + """Parallel self-attention layer abstract class. + Self-attention layer takes input with size [b, s, h] + and returns output of the same size. + + Args: + hidden_size (int): hidden size. + num_attention_heads (int): number of attention heads. + attention_dropout (float): dropout probability for attention layer. + attention_mask_func (:class:`typing.Callable`): Mask function to be applied. + layer_number (int): number of layers. + + """ + + def __init__(self, + hidden_size, + num_attention_heads, + attention_dropout, + attention_mask_func, + layer_number, + apply_query_key_layer_scaling: bool = False, + convert_fp16_to_fp32_in_softmax: bool = False, + attn_mask_type=AttnMaskType.padding, + masked_softmax_fusion=True, + fp16=False, + bf16=False): + super().__init__() + self.convert_fp16_to_fp32_in_softmax = convert_fp16_to_fp32_in_softmax + self.apply_query_key_layer_scaling = apply_query_key_layer_scaling + self.attention_mask_func = attention_mask_func + self.layer_number = layer_number + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.attn_mask_type = attn_mask_type + assert self.layer_number > 0 + self.attention_dropout = attention_dropout + + if self.apply_query_key_layer_scaling: + self.convert_fp16_to_fp32_in_softmax = True + + assert self.hidden_size % self.num_attention_heads == 0, \ + 'hidden size is not divisible by the number of attention heads' + + self.hidden_size_per_attention_head = self.hidden_size // num_attention_heads + + self.world_size = gpc.get_world_size(ParallelMode.SEQUENCE) + + # Strided linear layer. + self.query_key_value = _Linear( + hidden_size, + 3 * self.hidden_size, + ) + + self.coeff = None + self.norm_factor = math.sqrt(self.hidden_size) + + if self.apply_query_key_layer_scaling: + self.coeff = layer_number + self.norm_factor *= self.coeff + + self.scale_mask_softmax = FusedScaleMaskSoftmax(fp16, bf16, self.attn_mask_type, masked_softmax_fusion, + self.attention_mask_func, self.convert_fp16_to_fp32_in_softmax, + self.coeff) + + self.attention_dropout = nn.Dropout(attention_dropout) + + # Output. + self.dense = _Linear(hidden_size, hidden_size, bias=True, skip_bias_add=True) + + def forward(self, hidden_states, attention_mask): + # hidden_states: [sub_seq_len, batch_size, hidden_size] + # attention_mask: [batch_size, 1, sub_seq_len, seq_len] + sub_seq_length, batch_size, hidden_size = hidden_states.size() + + # ===================== + # Query, Key, and Value + # ===================== + + # Attention heads shape change: + # [sub_seq_len, batch_size, hidden_size] --> [sub_seq_len, batch_size, (3 * head_size * num_heads)] + mixed_x_layer = self.query_key_value(hidden_states) + + # [sub_seq_len, batch_size, num_heads, 3 * head_size] --> 3 [sub_seq_len, batch_size, num_heads, head_size] + new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_attention_heads, + 3 * self.hidden_size_per_attention_head) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # split into query, key and value + last_dim = mixed_x_layer.dim() - 1 + last_dim_value = mixed_x_layer.size(-1) + assert last_dim_value % 3 == 0, 'the last dimension is not a multiple of 3, ' \ + 'cannot be divided into query, key and value' + partition_size = last_dim_value // 3 + (query_layer, key_layer, value_layer) = torch.split(mixed_x_layer, partition_size, dim=last_dim) + + # attention scores: [batch_size, num_heads, sub_seq_len, seq_len] + output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), + key_layer.size(0) * self.world_size) + + # [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size * num_heads, head_size] + query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) + # [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size * num_heads, head_size] + key_layer = key_layer.view(key_layer.size(0), output_size[0] * output_size[1], -1) + + # attention_scores: [batch_size * num_heads, sub_seq_len, seq_len] + attention_scores = RingQK.apply( + query_layer.transpose(0, 1).contiguous(), # [batch_size * num_heads, sub_seq_len, head_size] + key_layer.transpose(0, 1).contiguous(), # [batch_size * num_heads, sub_seq_len, head_size], + batch_size, + self.num_attention_heads, + sub_seq_length) + + attention_scores /= self.norm_factor + + # change view to [batch_size, num_heads, sub_seq_len, seq_len] + attention_scores = attention_scores.view(*output_size) + + # change shape to [batch_size, num_heads, sub_seq_len, seq_len] + attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + with seed(ParallelMode.TENSOR): + attention_probs = self.attention_dropout(attention_probs) + + # context layer shape: [batch_size, num_heads, sub_seq_len, head_size] + output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) + + # change view [sub_seq_len, batch_size * num_heads, head_size] + value_layer = value_layer.contiguous().view(value_layer.size(0), output_size[0] * output_size[1], -1) + + # # change view [b * num_heads, sub_seq_len, seq_len] + attention_probs = attention_probs.view( + attention_probs.size(0) * attention_probs.size(1), attention_probs.size(2), attention_probs.size(3)) + + # matmul: [batch_size * num_heads, sub_seq_len, head_size] + context_layer = RingAV.apply(attention_probs, + value_layer.transpose(0, 1).contiguous(), batch_size, self.num_attention_heads, + self.hidden_size_per_attention_head, sub_seq_length) + + # change view [batch_size, num_heads, sub_seq_len, head_size] + context_layer = context_layer.view(*output_size) + + # [batch_size, num_heads, sub_seq_len, head_size] -> [sub_seq_len, batch_size, num_heads, head_size] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + + # [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size, hidden_size] + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_attention_head * + self.num_attention_heads,) + context_layer = context_layer.view(*new_context_layer_shape) + + output, bias = self.dense(context_layer) + + return output, bias + + def __repr__(self): + return f'TransformerSelfAttentionRing(apply_query_key_layer_scaling={self.apply_query_key_layer_scaling}, ' \ + f'layer_number={self.layer_number}, hidden_size:{self.hidden_size}, attention_dropout={self.attention_dropout}, ' \ + f'attn_mask_type={self.attn_mask_type}, num_attention_heads={self.num_attention_heads}, ' \ + f'hidden_size_per_attention_head={self.hidden_size_per_attention_head}, coeff={self.coeff}, norm_factor={self.norm_factor}, ' \ + f'convert_fp16_to_fp32_in_softmax={self.convert_fp16_to_fp32_in_softmax})' + + +class _Linear(nn.Module): + """Linear layer with column parallelism. + The linear layer is defined as Y = XA + b. A is parallelized along + its second dimension as A = [A_1, ..., A_p]. + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias + init_method: method to initialize weights. Note that bias is always set + to zero. + stride: For the strided linear layers. + keep_master_weight_for_test: This was added for testing and should be + set to False. It returns the master weights + used for initialization. + skip_bias_add: This was added to enable performance optimations where bias + can be fused with other elementwise operations. we skip + adding bias but instead return it. + """ + + def __init__(self, input_size, output_size, bias=True, skip_bias_add=False): + super(_Linear, self).__init__() + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.skip_bias_add = skip_bias_add + + self.weight = Parameter(torch.empty( + self.output_size, + self.input_size, + )) + nn.init.xavier_normal_(self.weight) + + if bias: + self.bias = Parameter(torch.empty(self.output_size)) + # Always initialize bias to zero. + with torch.no_grad(): + self.bias.zero_() + else: + self.register_parameter('bias', None) + + def forward(self, input_): + # Matrix multiply. + bias = self.bias if not self.skip_bias_add else None + output = F.linear(input_, self.weight, bias) + + if self.skip_bias_add: + return output, self.bias + else: + return output + + def __repr__(self): + return f'Linear(in_features={self.input_size}, out_features={self.output_size}, ' + \ + f'bias={self.bias is not None}, skip_bias_add={self.skip_bias_add})' diff --git a/colossalai/nn/layer/utils/__init__.py b/colossalai/nn/layer/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7e999ee8214916d9d2b5465333262d05cad198ec --- /dev/null +++ b/colossalai/nn/layer/utils/__init__.py @@ -0,0 +1,7 @@ +from .common import (ACT2FN, CheckpointModule, _ntuple, divide, get_tensor_parallel_mode, + set_tensor_parallel_attribute_by_partition, set_tensor_parallel_attribute_by_size, to_2tuple) + +__all__ = [ + 'CheckpointModule', 'divide', 'ACT2FN', 'set_tensor_parallel_attribute_by_size', + 'set_tensor_parallel_attribute_by_partition', 'get_tensor_parallel_mode', '_ntuple', 'to_2tuple' +] diff --git a/colossalai/nn/layer/utils/common.py b/colossalai/nn/layer/utils/common.py new file mode 100644 index 0000000000000000000000000000000000000000..f2297304fdc939c6a57e9eaa0f52ee268ead46fd --- /dev/null +++ b/colossalai/nn/layer/utils/common.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import collections.abc +from itertools import repeat + +import numpy as np +import torch +from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS +from colossalai.global_variables import tensor_parallel_env as env +from colossalai.utils import checkpoint +from torch import Tensor, nn + + +class CheckpointModule(nn.Module): + + def __init__(self, checkpoint: bool = True, offload: bool = False): + super().__init__() + self.checkpoint = checkpoint + self._use_checkpoint = checkpoint + self._offload = offload + + def _forward(self, *args, **kwargs): + raise NotImplementedError('CheckpointModule should implement _forward method instead of origin forward') + + def forward(self, *args, **kwargs): + if self._use_checkpoint: + return checkpoint(self._forward, self._offload, *args, **kwargs) + else: + return self._forward(*args, **kwargs) + + def train(self, mode: bool = True): + self._use_checkpoint = self.checkpoint + return super().train(mode=mode) + + def eval(self): + self._use_checkpoint = False + return super().eval() + + +def divide(numerator, denominator): + """Only allow exact division. + + Args: + numerator (int): Numerator of the division. + denominator (int): Denominator of the division. + + Returns: + int: the result of exact division. + """ + assert denominator != 0, 'denominator can not be zero' + assert numerator % denominator == 0, \ + '{} is not divisible by {}'.format(numerator, denominator) + return numerator // denominator + + +def swish(x: Tensor) -> Tensor: + return x * torch.sigmoid(x) + + +ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish} + + +def set_tensor_parallel_attribute_by_size(param, size): + setattr(param, IS_TENSOR_PARALLEL, True) + setattr(param, NUM_PARTITIONS, size // np.prod(param.shape)) + + +def set_tensor_parallel_attribute_by_partition(param, num_partitions): + setattr(param, IS_TENSOR_PARALLEL, True) + setattr(param, NUM_PARTITIONS, num_partitions) + + +def get_tensor_parallel_mode(): + return env.mode + + +# From PyTorch internals + + +def _ntuple(n): + + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_2tuple = _ntuple(2) diff --git a/colossalai/nn/layer/vanilla/__init__.py b/colossalai/nn/layer/vanilla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d767b8886f53fe5ddb697fd9fb4ed261cfd05c3 --- /dev/null +++ b/colossalai/nn/layer/vanilla/__init__.py @@ -0,0 +1,14 @@ +from .layers import ( + DropPath, + VanillaClassifier, + VanillaLayerNorm, + VanillaLinear, + VanillaPatchEmbedding, + WrappedDropout, + WrappedDropPath, +) + +__all__ = [ + "VanillaLayerNorm", "VanillaPatchEmbedding", "VanillaClassifier", "DropPath", "WrappedDropout", "WrappedDropPath", + "VanillaLinear" +] diff --git a/colossalai/nn/layer/vanilla/layers.py b/colossalai/nn/layer/vanilla/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..225aed3916a6dbacaa8876ea994c2d2441713338 --- /dev/null +++ b/colossalai/nn/layer/vanilla/layers.py @@ -0,0 +1,341 @@ +import math +from typing import Callable + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch import nn as nn +from torch.nn.parameter import Parameter + +from colossalai.context import seed +from colossalai.nn import init as init +from colossalai.registry import LAYERS +from colossalai.utils.cuda import get_current_device + +from ..utils import to_2tuple + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + Args: + drop_prob (float, optional): probability of dropping path, defaults 0.0. + training (bool, optional): whether in training progress, defaults False. + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py + + Args: + drop_prob (float, optional): probability of dropping path, defaults None. + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class WrappedDropout(nn.Module): + r"""Same as torch.nn.Dropout. But it is wrapped with the context of seed manager. During training, randomly zeroes + some elements of the input tensor with probability p using samples from a Bernoulli distribution. Each + channel will be zeroed out independently on every forward call. Furthermore, the outputs are scaled by a factor of + 1/(1-p) during training. This means that during evaluation the module simply computes an identity function. + + Args: + p (float, optional): probability of an element to be zeroed, defaults 0.5. + inplace (bool, optional): whether to do dropout in-place, default to be False. + mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + + def __init__(self, p: float = 0.5, inplace: bool = False, mode=None): + super().__init__() + if p < 0 or p > 1: + raise ValueError("dropout probability has to be between 0 and 1, " + "but got {}".format(p)) + self.p = p + self.inplace = inplace + if mode is None: + self.func = self.nonefunc + else: + self.func = self.normalfunc + self.mode = mode + + def nonefunc(self, inputs): + return F.dropout(inputs, self.p, self.training, self.inplace) + + def normalfunc(self, inputs): + with seed(self.mode): + return F.dropout(inputs, self.p, self.training, self.inplace) + + def forward(self, inputs): + return self.func(inputs) + + +class WrappedDropPath(nn.Module): + r"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + Here, it is wrapped with the context of seed manager. + + Args: + p (float, optional): probability of dropping path, defaults 0.0. + mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + + def __init__(self, p: float = 0., mode=None): + super().__init__() + self.p = p + self.mode = mode + if self.mode is None: + self.func = self.nonefunc + else: + self.func = self.normalfunc + self.mode = mode + + def nonefunc(self, inputs): + return drop_path(inputs, self.p, self.training) + + def normalfunc(self, inputs): + with seed(self.mode): + return drop_path(inputs, self.p, self.training) + + def forward(self, inputs): + return self.func(inputs) + + +@LAYERS.register_module +class VanillaPatchEmbedding(nn.Module): + r""" + 2D Image to Patch Embedding + + Args: + img_size (int): image size. + patch_size (int): patch size. + in_chans (int): number of channels of input image. + embed_size (int): size of embedding. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + flatten (bool, optional): whether to flatten output tensor, defaults to True. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + position_embed_initializer (:class:`typing.Callable`, optional): + The initializer of position embedding, defaults to zeros initializer. + + More details about initializer please refer to + `init `_. + """ + + def __init__(self, + img_size: int, + patch_size: int, + in_chans: int, + embed_size: int, + flatten: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + position_embed_initializer: Callable = init.zeros_()): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + self.weight = nn.Parameter( + torch.empty((embed_size, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype)) + self.bias = nn.Parameter(torch.empty(embed_size, device=get_current_device(), dtype=dtype)) + self.cls_token = nn.Parameter(torch.zeros((1, 1, embed_size), device=get_current_device(), dtype=dtype)) + self.pos_embed = nn.Parameter( + torch.zeros((1, self.num_patches + 1, embed_size), device=get_current_device(), dtype=dtype)) + + self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) + + def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer): + fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(self.weight) + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + bias_initializer(self.bias, fan_in=fan_in) + position_embed_initializer(self.pos_embed) + + def forward(self, input_: Tensor) -> Tensor: + B, C, H, W = input_.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size) + if self.flatten: + output = output.flatten(2).transpose(1, 2) # BCHW -> BNC + + cls_token = self.cls_token.expand(output.shape[0], -1, -1) + output = torch.cat((cls_token, output), dim=1) + output = output + self.pos_embed + return output + + +@LAYERS.register_module +class VanillaClassifier(nn.Module): + r"""Dense linear classifier. + + Args: + in_features (int): size of each input sample. + num_classes (int): number of classes. + weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + flatten (bool, optional): whether to flatten output tensor, defaults to True. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about initializer please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + num_classes: int, + weight: nn.Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + self.in_features = in_features + self.num_classes = num_classes + + if weight is not None: + self.weight = weight + self.has_weight = False + else: + self.weight = nn.Parameter( + torch.empty(self.num_classes, self.in_features, device=get_current_device(), dtype=dtype)) + self.has_weight = True + if bias: + self.bias = nn.Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) + else: + self.bias = None + + self.reset_parameters(weight_initializer, bias_initializer) + + def reset_parameters(self, weight_initializer, bias_initializer): + fan_in, fan_out = self.in_features, self.num_classes + + if self.has_weight: + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def forward(self, input_: Tensor) -> Tensor: + return F.linear(input_, self.weight, self.bias) + + +@LAYERS.register_module +class VanillaLayerNorm(nn.Module): + r""" + Layer Normalization for colossalai + + Args: + normalized_shape (int): input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] + \times \ldots \times \text{normalized_shape}[-1]]` + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps (float): a value added to the denominator for numerical stability, defaults to 1e-05. + bias (bool, optional): Whether to add a bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + """ + + def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None): + super().__init__() + + self.normalized_shape = (normalized_shape,) + self.variance_epsilon = eps + + factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + + self.weight = nn.Parameter(torch.ones(normalized_shape, **factory_kwargs)) + if bias: + self.bias = nn.Parameter(torch.zeros(normalized_shape, **factory_kwargs)) + else: + self.bias = None + + def forward(self, x: Tensor) -> Tensor: + return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.variance_epsilon) + + +@LAYERS.register_module +class VanillaLinear(nn.Module): + """Linear layer. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + skip_bias_add: bool (optional, default to be false). + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + **kwargs) -> None: + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.skip_bias_add = skip_bias_add + factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs)) + if bias: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + self.bias = None + weight_initializer(self.weight, fan_in=in_features, fan_out=out_features) + if self.bias is not None: + bias_initializer(self.bias, fan_in=in_features) + + def forward(self, input: Tensor) -> Tensor: + if not self.skip_bias_add: + return F.linear(input, self.weight, self.bias) + else: + return F.linear(input, self.weight), self.bias diff --git a/colossalai/nn/layer/wrapper/__init__.py b/colossalai/nn/layer/wrapper/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c7d90d887ec6612e351713e508d96b106b767a81 --- /dev/null +++ b/colossalai/nn/layer/wrapper/__init__.py @@ -0,0 +1,3 @@ +from .pipeline_wrapper import PipelineSharedModuleWrapper + +__all__ = ['PipelineSharedModuleWrapper'] diff --git a/colossalai/nn/layer/wrapper/pipeline_wrapper.py b/colossalai/nn/layer/wrapper/pipeline_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..ef1d794cc68f15a1885776016601e1b850d933b8 --- /dev/null +++ b/colossalai/nn/layer/wrapper/pipeline_wrapper.py @@ -0,0 +1,46 @@ +import torch.nn as nn +import torch.distributed as dist +from typing import List, Tuple, Union +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc + + +class PipelineSharedModuleWrapper: + + def __init__(self, pipeline_ranks: Union[List[int], Tuple[int]]) -> None: + assert len(pipeline_ranks) > 1, f'Expect len(pipeline_ranks) > 1, got {len(pipeline_ranks)}' + self.pipeline_ranks = pipeline_ranks + self.group = None + self.ranks_in_group = None + self._init_group() + + def _init_group(self): + world_size = gpc.get_world_size(ParallelMode.GLOBAL) + dp_size = gpc.get_world_size(ParallelMode.DATA) + pp_size = gpc.get_world_size(ParallelMode.PIPELINE) + rank = gpc.get_global_rank() + num_dp_groups = world_size // dp_size + num_pp_stages = num_dp_groups // pp_size + for i in range(dp_size): + for j in range(num_pp_stages): + pipeline_ranks = list(range(i * num_dp_groups + j, (i + 1) * num_dp_groups, num_pp_stages)) + sub_ranks = [pipeline_ranks[idx] for idx in self.pipeline_ranks] + group = dist.new_group(sub_ranks) + if rank in sub_ranks: + self.group = group + self.ranks_in_group = sub_ranks + + def register_module(self, module: nn.Module): + assert self.ranks_in_group is not None,\ + f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}' + src = self.ranks_in_group[self.pipeline_ranks[0]] + for p in module.parameters(): + setattr(p, 'pipeline_shared_module_pg', self.group) + dist.broadcast(p, src, group=self.group) + + def register_parameter(self, param: nn.Parameter): + assert self.ranks_in_group is not None,\ + f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}' + src = self.ranks_in_group[self.pipeline_ranks[0]] + setattr(param, 'pipeline_shared_module_pg', self.group) + dist.broadcast(param, src, group=self.group) diff --git a/colossalai/nn/loss/__init__.py b/colossalai/nn/loss/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..373e4ec9468bc13317d74c19b5922073a5cb8c0c --- /dev/null +++ b/colossalai/nn/loss/__init__.py @@ -0,0 +1,41 @@ +from colossalai.global_variables import tensor_parallel_env as env +from colossalai.nn.layer.utils import get_tensor_parallel_mode +from torch import nn +from torch.nn.modules.loss import * +from torch.nn.modules.loss import _Loss + +from .loss_1d import VocabParallelCrossEntropyLoss1D +from .loss_2d import CrossEntropyLoss2D, VocabParallelCrossEntropyLoss2D +from .loss_2p5d import CrossEntropyLoss2p5D, VocabParallelCrossEntropyLoss2p5D +from .loss_3d import CrossEntropyLoss3D, VocabParallelCrossEntropyLoss3D +from .loss_moe import MoeCrossEntropyLoss, MoeLoss + +_parallel_cross_entropy = { + '2d': CrossEntropyLoss2D, + '2.5d': CrossEntropyLoss2p5D, + '3d': CrossEntropyLoss3D, +} + +_vocab_parallel_cross_entropy = { + '1d': VocabParallelCrossEntropyLoss1D, + '2d': VocabParallelCrossEntropyLoss2D, + '2.5d': VocabParallelCrossEntropyLoss2p5D, + '3d': VocabParallelCrossEntropyLoss3D, +} + + +class CrossEntropyLoss(_Loss): + + def __init__(self, reduction: bool = True, *args, **kwargs): + super().__init__() + tensor_parallel = get_tensor_parallel_mode() + if tensor_parallel is not None and env.vocab_parallel: + self.loss = _vocab_parallel_cross_entropy[tensor_parallel](reduction=reduction, *args, **kwargs) + elif tensor_parallel is None or tensor_parallel == '1d': + reduction = 'mean' if reduction else 'none' + self.loss = nn.CrossEntropyLoss(reduction=reduction, *args, **kwargs) + else: + self.loss = _parallel_cross_entropy[tensor_parallel](reduction=reduction, *args, **kwargs) + + def forward(self, *args): + return self.loss(*args) diff --git a/colossalai/nn/loss/loss_1d.py b/colossalai/nn/loss/loss_1d.py new file mode 100644 index 0000000000000000000000000000000000000000..2fabd954f8fb7daeb7b457f0c2b216d191e6ed69 --- /dev/null +++ b/colossalai/nn/loss/loss_1d.py @@ -0,0 +1,105 @@ +import torch +import torch.distributed as dist +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.registry import LOSSES +from torch.cuda.amp import custom_bwd, custom_fwd +from torch.nn.modules.loss import _Loss + + +class _VocabParallelCrossEntropy1D(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, vocab_parallel_logits, targets, process_group): + if process_group is None: + process_group = gpc.get_group(ParallelMode.PARALLEL_1D) + + # Maximum value along vocab dimension across all GPUs. + logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] + torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=process_group) + # Subtract the maximum value. + vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) + + # Get the partition's vocab indecies + partition_vocab_size = vocab_parallel_logits.size()[-1] + rank = dist.get_rank(process_group) + vocab_start_index = partition_vocab_size * rank + vocab_end_index = vocab_start_index + partition_vocab_size + + # Create a mask of valid vocab ids (1 means it needs to be masked). + target_mask = (targets < vocab_start_index) | (targets >= vocab_end_index) + masked_target = targets.clone() - vocab_start_index + masked_target[target_mask] = 0 + + # Get predicted-logits = logits[target]. + # For Simplicity, we convert logits to a 2-D tensor with size + # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. + logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) + masked_target_1d = masked_target.view(-1) + arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) + predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] + predicted_logits_1d = predicted_logits_1d.clone().contiguous() + predicted_logits = predicted_logits_1d.view_as(targets) + predicted_logits[target_mask] = 0.0 + # All reduce is needed to get the chunks from other GPUs. + torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=process_group) + + # Sum of exponential of logits along vocab dimension across all GPUs. + exp_logits = torch.exp(vocab_parallel_logits) + sum_exp_logits = exp_logits.sum(dim=-1) + torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=process_group) + + # Loss = log(sum(exp(logits))) - predicted-logit. + loss = torch.log(sum_exp_logits) - predicted_logits + # Store softmax, target-mask and masked-target for backward pass. + exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) + return loss + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + + # Retreive tensors from the forward path. + softmax, target_mask, masked_target_1d = ctx.saved_tensors + + # All the inputs have softmax as thier gradient. + grad_input = softmax + # For simplicity, work with the 2D gradient. + partition_vocab_size = softmax.size()[-1] + grad_2d = grad_input.view(-1, partition_vocab_size) + + # Add the gradient from matching classes. + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) + grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float()) + + # Finally elementwise multiplication with the output gradients. + grad_input.mul_(grad_output.unsqueeze(dim=-1)) + + return grad_input, None, None + + +@LOSSES.register_module +class VocabParallelCrossEntropyLoss1D(_Loss): + """Vocab parallel cross entropy loss for 1D parallelism. + + Args: + reduction (bool, optional): whether to average the loss, defaults to True. + """ + + def __init__(self, reduction=True): + super().__init__() + self.reduction_mean = reduction + + def forward(self, logits, targets, process_group=None): + """Calculate loss between logits and targets. + + Args: + logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). + targets (:class:`torch.tensor`): Ground truth class indices or class probabilities. + """ + loss = _VocabParallelCrossEntropy1D.apply(logits, targets, process_group) + if self.reduction_mean: + loss = loss.mean() + return loss diff --git a/colossalai/nn/loss/loss_2d.py b/colossalai/nn/loss/loss_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..cb12e723c3232446bf2b911730ff31f7274ea6e2 --- /dev/null +++ b/colossalai/nn/loss/loss_2d.py @@ -0,0 +1,156 @@ +import torch +import torch.distributed as dist +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d +from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization +from colossalai.registry import LOSSES +from colossalai.utils import get_current_device +from torch.cuda.amp import custom_bwd, custom_fwd +from torch.nn.functional import cross_entropy +from torch.nn.modules.loss import _Loss + + +@LOSSES.register_module +class CrossEntropyLoss2D(_Loss): + r"""Cross entropy loss for 2D parallelism + + Args: + reduction (bool, optional): whether to average the loss, defaults to True. + + The ``args`` and ``kwargs`` should include parameters below: + :: + + weight (Tensor, optional) + size_average (bool, optional) + ignore_index (int, optional) + reduce (bool, optional) + label_smoothing (float, optional) + + More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in + `Cross_entropy `_. + """ + + def __init__(self, reduction=True, *args, **kwargs): + super().__init__() + assert_summa_initialization() + self.reduction_mean = reduction + self.loss_args = args + self.loss_kwargs = kwargs + + def forward(self, logits, targets): + """Calculate loss between logits and targets. + + Args: + logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). + targets (:class:`torch.tensor`): Ground truth class indices or class probabilities. + + Returns: + float: the loss between logits and targets. + """ + targets = split_batch_2d(targets) + loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs) + if self.reduction_mean: + loss = loss.mean() + loss = reduce_by_batch_2d(loss, True) + return loss + + +class _VocabParallelCrossEntropy2D(torch.autograd.Function): + ### Modified based on megatron.mpu.cross_entropy ### + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, logits, targets): + # logits: [b/q, h/q] + # labels: [b/q] + # loss: [b/q] + # vocab_parallel_logits: [b/q, s, v/q] + # target: [b/q, s] + logits_max = torch.max(logits, dim=-1)[0] + torch.distributed.all_reduce(logits_max, + op=torch.distributed.ReduceOp.MAX, + group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW)) + # Subtract the maximum value. + # vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) + logits = logits - logits_max.unsqueeze(dim=-1) + + vocab_size = logits.size(-1) + rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + vocab_start = rank * (vocab_size) + vocab_end = (rank + 1) * (vocab_size) - 1 + + target_mask = (targets < vocab_start) | (targets > vocab_end) + + masked_target = targets.clone() - vocab_start + masked_target[target_mask] = 0 + arange_1d = torch.arange( + start=0, + end=logits.size()[0], + ) + predicted_logits = logits[arange_1d, masked_target] + predicted_logits[target_mask] = 0. + dist.all_reduce(predicted_logits, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW)) + + exp_logits = torch.exp(logits) + sum_exp_logits = exp_logits.sum(dim=1) + dist.all_reduce(sum_exp_logits, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW)) + + loss = torch.log(sum_exp_logits) - predicted_logits + + exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + ctx.save_for_backward(exp_logits, target_mask, masked_target) + + return loss + + @staticmethod + @custom_bwd + def backward(ctx, output_grad): + # Retreive tensors from the forward path. + softmax, target_mask, masked_target = ctx.saved_tensors + + # All the inputs have softmax as their gradient. + grad_input = softmax + + # For simplicity, work with the 2D gradient. + partition_vocab_size = softmax.size()[-1] + grad_2d = grad_input.view(-1, partition_vocab_size) + + # Add the gradient from matching classes. + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device()) + grad_2d[arange_1d, masked_target] -= (1.0 - target_mask.view(-1).float()) + + # Finally elementwise multiplication with the output gradients. + grad_input.mul_(output_grad.unsqueeze(dim=-1)) + + return grad_input, None + + +@LOSSES.register_module +class VocabParallelCrossEntropyLoss2D(_Loss): + """Vocab parallel cross entropy loss for 2D parallelism. + + Args: + reduction (bool, optional): whether to average the loss, defaults to True. + """ + + def __init__(self, reduction=True): + super().__init__() + self.reduction_mean = reduction + + def forward(self, logits, targets): + """Calculate loss between logits and targets. + + Args: + logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). + targets (:class:`torch.tensor`): Ground truth class indices or class probabilities. + """ + targets = split_batch_2d(targets) + loss = _VocabParallelCrossEntropy2D.apply( + logits, + targets, + ) + if self.reduction_mean: + loss = loss.mean() + loss = reduce_by_batch_2d(loss, True) + return loss diff --git a/colossalai/nn/loss/loss_2p5d.py b/colossalai/nn/loss/loss_2p5d.py new file mode 100644 index 0000000000000000000000000000000000000000..f8e3324fc5ff8fe3d28ea25798e33ff9aeb26d25 --- /dev/null +++ b/colossalai/nn/loss/loss_2p5d.py @@ -0,0 +1,149 @@ +import torch +import torch.distributed as dist +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d +from colossalai.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization +from colossalai.registry import LOSSES +from colossalai.utils import get_current_device +from torch.cuda.amp import custom_bwd, custom_fwd +from torch.nn.functional import cross_entropy +from torch.nn.modules.loss import _Loss + + +@LOSSES.register_module +class CrossEntropyLoss2p5D(_Loss): + r"""Cross entropy loss for 2.5D parallelism + + Args: + reduction (bool, optional): whether to average the loss, defaults to True. + + The ``args`` and ``kwargs`` should include parameters below: + :: + + weight (Tensor, optional) + size_average (bool, optional) + ignore_index (int, optional) + reduce (bool, optional) + label_smoothing (float, optional) + + More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in + `Cross_entropy `_. + """ + + def __init__(self, reduction=True, *args, **kwargs): + super().__init__() + assert_tesseract_initialization() + self.reduction_mean = reduction + self.loss_args = args + self.loss_kwargs = kwargs + + def forward(self, logits, targets): + """Calculate loss between logits and targets. + + Args: + logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). + targets (:class:`torch.tensor`): Ground truth class indices or class probabilities. + """ + targets = split_batch_2p5d(targets) + loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs) + if self.reduction_mean: + loss = loss.mean() + loss = reduce_by_batch_2p5d(loss, True) + return loss + + +class _VocabParallelCrossEntropy2p5D(torch.autograd.Function): + ### Modified based on megatron.mpu.cross_entropy ### + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, logits, targets): + # logits: [b/dq, h/q] + # loss: [b/dq] + # targets: [b/dq, h/q] + logits_max = torch.max(logits, dim=-1)[0] + torch.distributed.all_reduce(logits_max, + op=torch.distributed.ReduceOp.MAX, + group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW)) + # Subtract the maximum value. + logits = logits - logits_max.unsqueeze(dim=-1) + + vocab_size = logits.size(-1) + rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + vocab_start = rank * (vocab_size) + vocab_end = (rank + 1) * (vocab_size) - 1 + + target_mask = (targets < vocab_start) | (targets > vocab_end) + + masked_target = targets.clone() - vocab_start + masked_target[target_mask] = 0 + arange_1d = torch.arange( + start=0, + end=logits.size()[0], + ) + predicted_logits = logits[arange_1d, masked_target] + predicted_logits[target_mask] = 0. + dist.all_reduce(predicted_logits, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW)) + + exp_logits = torch.exp(logits) + sum_exp_logits = exp_logits.sum(dim=1) + dist.all_reduce(sum_exp_logits, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW)) + + loss = torch.log(sum_exp_logits) - predicted_logits + + exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + ctx.save_for_backward(exp_logits, target_mask, masked_target) + + return loss + + @staticmethod + @custom_bwd + def backward(ctx, output_grad): + # Retreive tensors from the forward path. + softmax, target_mask, masked_target = ctx.saved_tensors + + # All the inputs have softmax as their gradient. + grad_input = softmax + + # For simplicity, work with the 2D gradient. + partition_vocab_size = softmax.size()[-1] + grad_2d = grad_input.view(-1, partition_vocab_size) + + # Add the gradient from matching classes. + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device()) + grad_2d[arange_1d, masked_target] -= (1.0 - target_mask.view(-1).float()) + + # Finally elementwise multiplication with the output gradients. + grad_input.mul_(output_grad.unsqueeze(dim=-1)) + + return grad_input, None + + +@LOSSES.register_module +class VocabParallelCrossEntropyLoss2p5D(_Loss): + """ + Vocab parallel cross entropy loss for 2.5D parallelism + + Args: + reduction (bool, optional): whether to average the loss, defaults to True. + """ + + def __init__(self, reduction=True): + super().__init__() + self.reduction_mean = reduction + + def forward(self, logits, targets): + """Calculate loss between logits and targets. + + Args: + logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). + targets (:class:`torch.tensor`): Ground truth class indices or class probabilities. + """ + targets = split_batch_2p5d(targets) + loss = _VocabParallelCrossEntropy2p5D.apply(logits, targets) + if self.reduction_mean: + loss = loss.mean() + loss = reduce_by_batch_2p5d(loss, True) + + return loss diff --git a/colossalai/nn/loss/loss_3d.py b/colossalai/nn/loss/loss_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..e76439191fdbc8fc31737e21b5c2b17e6685059b --- /dev/null +++ b/colossalai/nn/loss/loss_3d.py @@ -0,0 +1,147 @@ +import torch +import torch.distributed as dist +from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D +from colossalai.core import global_context as gpc +from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d +from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env +from colossalai.registry import LOSSES +from colossalai.utils import get_current_device +from torch.cuda.amp import custom_bwd, custom_fwd +from torch.nn.functional import cross_entropy +from torch.nn.modules.loss import _Loss + + +@LOSSES.register_module +class CrossEntropyLoss3D(_Loss): + r"""Cross entropy loss for 3D parallelism. + + Args: + reduction (bool, optional): whether to average the loss, defaults to True. + + The ``args`` and ``kwargs`` should include parameters below: + :: + + weight (Tensor, optional) + size_average (bool, optional) + ignore_index (int, optional) + reduce (bool, optional) + label_smoothing (float, optional) + + More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in + `Cross_entropy `_. + """ + + def __init__(self, reduction=True, *args, **kwargs): + super().__init__() + self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + self.reduction_mean = reduction + self.loss_args = args + self.loss_kwargs = kwargs + + def forward(self, logits, targets): + """Calculate loss between logits and targets. + + Args: + logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). + targets (:class:`torch.tensor`): Ground truth class indices or class probabilities. + """ + targets = split_tensor_3d(targets, 0, self.weight_parallel_mode) + targets = split_tensor_3d(targets, 0, self.input_parallel_mode) + loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs) + if self.reduction_mean: + loss = loss.mean() + loss = reduce_by_batch_3d(loss, self.input_parallel_mode, self.weight_parallel_mode, True) + return loss + + +class _VocabParallelCrossEntropy3D(torch.autograd.Function): + # Adapted from megatron.mpu.cross_entropy + # loss[i] = -logits[i][targets] + log(sum(exp(logits[i]))) + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, logits, targets, output_parallel_mode): + # logits: [b/q^2, c/q] + # labels: [b/q^2] + # loss: [b/q^2] + logits_max = torch.max(logits, dim=-1)[0] + dist.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=gpc.get_group(output_parallel_mode)) + # Subtract the maximum value. + logits = logits - logits_max.unsqueeze(dim=-1) + + vocab_size_per_partition = logits.size()[-1] + rank = gpc.get_local_rank(output_parallel_mode) + vocab_start = rank * vocab_size_per_partition + vocab_end = (rank + 1) * vocab_size_per_partition - 1 + + # loss[i] = 0 if targets[i] < vocab_start or targets[i] > vocab_end + target_mask = (targets < vocab_start) | (targets > vocab_end) + masked_target = targets.clone() - vocab_start + masked_target[target_mask] = 0 + arange_1d = torch.arange(start=0, end=logits.size()[0], device=get_current_device()) + predicted_logits = logits[arange_1d, masked_target] + predicted_logits = predicted_logits.clone().contiguous().view_as(targets) + predicted_logits[target_mask] = 0. + dist.all_reduce(predicted_logits, group=gpc.get_group(output_parallel_mode)) + + # Loss = log(sum(exp(logits))) - predicted-logit. + exp_logits = torch.exp(logits) + sum_exp_logits = exp_logits.sum(dim=-1) + dist.all_reduce(sum_exp_logits, group=gpc.get_group(output_parallel_mode)) + loss = torch.log(sum_exp_logits) - predicted_logits + + exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + ctx.save_for_backward(exp_logits, target_mask, masked_target) + + return loss + + @staticmethod + @custom_bwd + def backward(ctx, output_grad): + # Retreive tensors from the forward path. + softmax, target_mask, masked_target = ctx.saved_tensors + + # All the inputs have softmax as thier gradient. + input_grad = softmax + # For simplicity, work with the 2D gradient. + partition_vocab_size = softmax.size()[-1] + grad_2d = input_grad.view(-1, partition_vocab_size) + + # Add the gradient from matching classes. + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device()) + grad_2d[arange_1d, masked_target] -= (1.0 - target_mask.view(-1).float()) + input_grad.mul_(output_grad.unsqueeze(dim=-1)) + + return input_grad, None, None, None + + +@LOSSES.register_module +class VocabParallelCrossEntropyLoss3D(_Loss): + """Vocab parallel cross entropy loss for 2D parallelism. + + Args: + reduction (bool, optional): whether to average the loss, defaults to True. + """ + + def __init__(self, reduction=True): + super().__init__() + self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + self.reduction_mean = reduction + + def forward(self, logits, targets): + """Calculate loss between logits and targets. + + Args: + logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). + targets (:class:`torch.tensor`): Ground truth class indices or class probabilities. + """ + targets = split_tensor_3d(targets, 0, self.weight_parallel_mode) + targets = split_tensor_3d(targets, 0, self.input_parallel_mode) + loss = _VocabParallelCrossEntropy3D.apply(logits, targets, self.output_parallel_mode) + if self.reduction_mean: + loss = loss.mean() + loss = reduce_by_batch_3d(loss, self.input_parallel_mode, self.weight_parallel_mode, True) + return loss diff --git a/colossalai/nn/loss/loss_moe.py b/colossalai/nn/loss/loss_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..a8b18a3e37ee20821df21fa7b2e309c4dc9f02e5 --- /dev/null +++ b/colossalai/nn/loss/loss_moe.py @@ -0,0 +1,80 @@ +import torch.nn as nn +from colossalai.registry import LOSSES +from torch.nn.modules.loss import _Loss +from colossalai.context.moe_context import MOE_CONTEXT + + +@LOSSES.register_module +class MoeCrossEntropyLoss(_Loss): + r"""torch.nn.CrossEntropyLoss added with auxiliary loss. + + Args: + input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). + target (:class:`torch.tensor`): Ground truth class indices or class probabilities. + aux_weight (float, optional): Weight of auxiliary loss in total loss.Defaults 0.01. + + The ``args`` and ``kwargs`` should include parameters below: + :: + + weight (Tensor, optional) + size_average (bool, optional) + ignore_index (int, optional) + reduce (bool, optional) + reduction (str, optional) + label_smoothing (float, optional) + + More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in + `Cross_entropy `_. + """ + + def __init__(self, aux_weight: float = 0.01, *args, **kwargs): + super().__init__() + self.loss = nn.CrossEntropyLoss(*args, **kwargs) + self.aux_weight = aux_weight + + def forward(self, *args): + """ + The ``args`` should at least include parameters below: + :: + + input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). + target (:class:`torch.tensor`): Ground truth class indices or class probabilities. + + More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in + `Cross_entropy `_. + """ + main_loss = self.loss(*args) + aux_loss = MOE_CONTEXT.get_loss() + return main_loss + self.aux_weight * aux_loss + + +@LOSSES.register_module +class MoeLoss(_Loss): + """A wrapper class for any loss module to add with auxiliary loss. + + Args: + aux_weight (float): Weight of auxiliary loss in total loss. + loss_fn (``Callable``): Loss function. + args (list): Args in loss function. + kwargs (dict): Kwargs in loss function + """ + + def __init__(self, aux_weight: float, loss_fn, *args, **kwargs): + super().__init__() + self.loss_fn = loss_fn(*args, **kwargs) + self.aux_weight = aux_weight + + def forward(self, *args, **kwargs): + """ + The ``args`` and ``kwargs`` should at least include parameters below: + :: + + input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). + target (:class:`torch.tensor`): Ground truth class indices or class probabilities. + + Note: + The ``args`` and ``kwargs`` may include different parameters varying with different loss function. + """ + main_loss = self.loss_fn(*args, **kwargs) + aux_loss = MOE_CONTEXT.get_loss() + return main_loss + self.aux_weight * aux_loss diff --git a/colossalai/nn/lr_scheduler/__init__.py b/colossalai/nn/lr_scheduler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..34731ee901a0d37a3296e57915df08c38f0b648b --- /dev/null +++ b/colossalai/nn/lr_scheduler/__init__.py @@ -0,0 +1,12 @@ +from .cosine import CosineAnnealingLR, CosineAnnealingWarmupLR, FlatAnnealingLR, FlatAnnealingWarmupLR +from .linear import LinearWarmupLR +from .multistep import MultiStepLR, MultiStepWarmupLR +from .onecycle import OneCycleLR +from .poly import PolynomialLR, PolynomialWarmupLR +from .torch import LambdaLR, MultiplicativeLR, StepLR, ExponentialLR + +__all__ = [ + 'CosineAnnealingLR', 'CosineAnnealingWarmupLR', 'FlatAnnealingLR', 'FlatAnnealingWarmupLR', 'LinearWarmupLR', + 'MultiStepLR', 'MultiStepWarmupLR', 'OneCycleLR', 'PolynomialLR', 'PolynomialWarmupLR', 'LambdaLR', + 'MultiplicativeLR', 'StepLR', 'ExponentialLR' +] diff --git a/colossalai/nn/lr_scheduler/cosine.py b/colossalai/nn/lr_scheduler/cosine.py new file mode 100644 index 0000000000000000000000000000000000000000..aab523bef8b30dafc65f60a9475a8b9a70326738 --- /dev/null +++ b/colossalai/nn/lr_scheduler/cosine.py @@ -0,0 +1,121 @@ +from torch.optim.lr_scheduler import CosineAnnealingLR as _CosineAnnealingLR + +from colossalai.registry import LR_SCHEDULERS +from .delayed import DelayerScheduler, WarmupDelayerScheduler, WarmupScheduler + + +@LR_SCHEDULERS.register_module +class CosineAnnealingLR(_CosineAnnealingLR): + r"""Set the learning rate of each parameter group using a cosine annealing + schedule, where :math:`\eta_{max}` is set to the initial lr and + :math:`T_{cur}` is the number of epochs since the last restart in SGDR: + + .. math:: + \begin{aligned} + \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right), + & T_{cur} \neq (2k+1)T_{max}; \\ + \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min}) + \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right), + & T_{cur} = (2k+1)T_{max}. + \end{aligned} + + When last_epoch=-1, sets initial lr as lr. Notice that because the schedule + is defined recursively, the learning rate can be simultaneously modified + outside this scheduler by other operators. If the learning rate is set + solely by this scheduler, the learning rate at each step becomes: + + .. math:: + \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right) + + It has been proposed in + `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only + implements the cosine annealing part of SGDR, and not the restarts. + + .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: + https://arxiv.org/abs/1608.03983 + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + total_steps (int): Number of total training steps. + eta_min (int, optional): Minimum learning rate, defaults to 0. + last_epoch (int, optional): The index of last epoch, defaults to -1. When last_epoch=-1, + the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. + """ + + def __init__(self, optimizer, total_steps: int, eta_min: int = 0, last_epoch: int = -1, **kwargs): + super().__init__(optimizer, total_steps, eta_min=eta_min, last_epoch=last_epoch) + + +@LR_SCHEDULERS.register_module +class CosineAnnealingWarmupLR(WarmupScheduler): + """Cosine annealing learning rate scheduler with learning rate warmup. A linear warmup schedule will be applied. + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + total_steps (int): Number of total training steps. + warmup_steps (int, optional): Number of warmup steps, defaults to 0. + eta_min (int, optional): Minimum learning rate, defaults to 0. + last_epoch (int, optional): The index of last epoch, defaults to -1. When last_epoch=-1, + the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. + """ + + def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min: float = 0., last_epoch: int = -1): + base_scheduler = _CosineAnnealingLR(optimizer, + total_steps - warmup_steps, + eta_min=eta_min, + last_epoch=last_epoch) + super().__init__(optimizer, warmup_steps, base_scheduler) + + +@LR_SCHEDULERS.register_module +class FlatAnnealingLR(DelayerScheduler): + """Flat and cosine annealing learning rate scheduler. The learning rate will be a fixed value before starting decay. + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + total_steps (int): Number of total training steps. + pct_start (float, optional): Percent of steps before starting learning rate decay, defaults to -0.72. + last_epoch (int, optional): The index of last epoch, defaults to -1. When last_epoch=-1, + the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. + """ + + def __init__(self, optimizer, total_steps: int, pct_start: float = 0.72, last_epoch: int = -1, **kwargs): + if not (0.0 <= pct_start <= 1.0): + raise ValueError(f'pct_start must >= 0.0 and <= 1.0, got {pct_start}') + flat_steps = int(total_steps * pct_start) + anneal_steps = total_steps - flat_steps + base_scheduler = _CosineAnnealingLR(optimizer, anneal_steps) + super().__init__(optimizer, flat_steps, base_scheduler, last_epoch=last_epoch) + + +@LR_SCHEDULERS.register_module +class FlatAnnealingWarmupLR(WarmupDelayerScheduler): + """Flat and cosine annealing learning rate scheduler with learning rate warmup. A linear warmup schedule will be + applied, and then the learning rate will be a fixed value before starting decay. + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + total_steps (int): Number of total training steps. + warmup_steps (int, optional): Number of warmup steps, defaults to 0. + pct_start (float, optional): Percent of steps before starting learning rate decay, defaults to -0.72. + eta_min (int, optional): Minimum learning rate, defaults to 0. + last_epoch (int, optional): The index of last epoch, defaults to -1. When last_epoch=-1, + the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. + """ + + def __init__(self, + optimizer, + total_steps: int, + warmup_steps: int = 0, + pct_start: float = 0.72, + eta_min: int = 0, + last_epoch: int = -1, + **kwargs): + if not (0.0 <= pct_start <= 1.0): + raise ValueError(f'pct_start must >= 0.0 and <= 1.0, got {pct_start}') + flat_steps = int((total_steps - warmup_steps) * pct_start) + anneal_steps = total_steps - warmup_steps - flat_steps + base_scheduler = _CosineAnnealingLR(optimizer, anneal_steps, eta_min=eta_min) + super().__init__(optimizer, warmup_steps, flat_steps, base_scheduler, last_epoch=last_epoch) diff --git a/colossalai/nn/lr_scheduler/delayed.py b/colossalai/nn/lr_scheduler/delayed.py new file mode 100644 index 0000000000000000000000000000000000000000..a73ff8ae37ace4f82433c55048dac9cf729389fe --- /dev/null +++ b/colossalai/nn/lr_scheduler/delayed.py @@ -0,0 +1,176 @@ +from torch.optim.lr_scheduler import _LRScheduler + + +class _enable_get_lr_call: + + def __init__(self, o): + self.o = o + + def __enter__(self): + self.o._get_lr_called_within_step = True + return self + + def __exit__(self, type, value, traceback): + self.o._get_lr_called_within_step = False + + +class DelayerScheduler(_LRScheduler): + """Starts with a flat lr schedule until it reaches N epochs then applies + the specific scheduler (For example: ReduceLROnPlateau) + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + delay_epochs (int): Number of epochs to keep the initial lr until starting applying the scheduler. + after_scheduler (:class:`torch.optim.lr_scheduler`): After target_epoch, use this scheduler. + last_epoch (int, optional): The index of last epoch, defaults to -1. When last_epoch=-1, + the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. + """ + + def __init__(self, optimizer, delay_epochs, after_scheduler, last_epoch=-1): + if delay_epochs < 0: + raise ValueError(f'delay_epochs must >= 0, got {delay_epochs}') + self.delay_epochs = delay_epochs + self.after_scheduler = after_scheduler + self.finished = False + super().__init__(optimizer, last_epoch) + + def state_dict(self): + state_dict = {key: value for key, value in self.__dict__.items() if key not in 'optimizer'} + if isinstance(state_dict['after_scheduler'], _LRScheduler): + state_dict['after_scheduler_type'] = type(state_dict['after_scheduler']).__name__ + state_dict['after_scheduler_dict'] = state_dict['after_scheduler'].state_dict() + del state_dict['after_scheduler'] + else: + raise NotImplementedError() + return state_dict + + def get_lr(self): + if self.last_epoch >= self.delay_epochs: + if not self.finished: + self.after_scheduler.base_lrs = self.base_lrs + self.finished = True + with _enable_get_lr_call(self.after_scheduler): + return self.after_scheduler.get_lr() + + return self.base_lrs + + def step(self, epoch=None): + if self.finished: + if epoch is None: + self.after_scheduler.step(None) + self._last_lr = self.after_scheduler.get_last_lr() + else: + self.after_scheduler.step(epoch - self.delay_epochs) + self._last_lr = self.after_scheduler.get_last_lr() + else: + return super(DelayerScheduler, self).step(epoch) + + +class WarmupScheduler(_LRScheduler): + """Starts with a linear warmup lr schedule until it reaches N epochs then applies + the specific scheduler (For example: ReduceLROnPlateau). + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + warmup_epochs (int): Number of epochs to linearly warmup lr until starting applying the scheduler. + after_scheduler (:class:`torch.optim.lr_scheduler`): After target_epoch, use this scheduler. + last_epoch (int, optional): The index of last epoch, defaults to -1. When last_epoch=-1, + the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. + """ + + def __init__(self, optimizer, warmup_epochs, after_scheduler, last_epoch=-1): + self.warmup_epochs = int(warmup_epochs) + self.after_scheduler = after_scheduler + self.finished = False + super().__init__(optimizer, last_epoch) + + def state_dict(self): + state_dict = {key: value for key, value in self.__dict__.items() if key not in 'optimizer'} + if isinstance(state_dict['after_scheduler'], _LRScheduler): + state_dict['after_scheduler_type'] = type(state_dict['after_scheduler']).__name__ + state_dict['after_scheduler_dict'] = state_dict['after_scheduler'].state_dict() + del state_dict['after_scheduler'] + else: + raise NotImplementedError() + return state_dict + + def get_lr(self): + if self.last_epoch >= self.warmup_epochs: + if not self.finished: + self.after_scheduler.base_lrs = self.base_lrs + self.finished = True + return self.after_scheduler.get_lr() + + return [(self.last_epoch + 1) / self.warmup_epochs * lr for lr in self.base_lrs] + + def step(self, epoch=None): + if self.finished: + if epoch is None: + self.after_scheduler.step(None) + self._last_lr = self.after_scheduler.get_last_lr() + else: + self.after_scheduler.step(epoch - self.warmup_epochs) + self._last_lr = self.after_scheduler.get_last_lr() + else: + return super().step(epoch) + + +class WarmupDelayerScheduler(_LRScheduler): + """Starts with a linear warmup lr schedule until it reaches N epochs and a flat lr schedule + until it reaches M epochs then applies the specific scheduler (For example: ReduceLROnPlateau). + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + warmup_epochs (int): Number of epochs to linearly warmup lr until starting applying the scheduler. + delay_epochs (int): Number of epochs to keep the initial lr until starting applying the scheduler. + after_scheduler (:class:`torch.optim.lr_scheduler`): After target_epoch, use this scheduler. + last_epoch (int, optional): The index of last epoch, defaults to -1. When last_epoch=-1, + the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. + """ + + def __init__(self, optimizer, warmup_epochs, delay_epochs, after_scheduler, last_epoch=-1): + if delay_epochs < 0: + raise ValueError(f'delay_epochs must >= 0, got {delay_epochs}') + if warmup_epochs < 0: + raise ValueError(f'warmup_epochs must >= 0, got {warmup_epochs}') + self.warmup_epochs = warmup_epochs + self.delay_epochs = delay_epochs + self.after_scheduler = after_scheduler + self.finished = False + super().__init__(optimizer, last_epoch) + + def state_dict(self): + state_dict = {key: value for key, value in self.__dict__.items() if key not in 'optimizer'} + if isinstance(state_dict['after_scheduler'], _LRScheduler): + state_dict['after_scheduler_type'] = type(state_dict['after_scheduler']).__name__ + state_dict['after_scheduler_dict'] = state_dict['after_scheduler'].state_dict() + del state_dict['after_scheduler'] + else: + raise NotImplementedError() + return state_dict + + def get_lr(self): + if self.last_epoch >= self.warmup_epochs + self.delay_epochs: + if not self.finished: + self.after_scheduler.base_lrs = self.base_lrs + # reset lr to base_lr + for group, base_lr in zip(self.optimizer.param_groups, self.base_lrs): + group['lr'] = base_lr + self.finished = True + with _enable_get_lr_call(self.after_scheduler): + return self.after_scheduler.get_lr() + elif self.last_epoch >= self.warmup_epochs: + return self.base_lrs + + return [(self.last_epoch + 1) / self.warmup_epochs * lr for lr in self.base_lrs] + + def step(self, epoch=None): + if self.finished: + if epoch is None: + self.after_scheduler.step(None) + self._last_lr = self.after_scheduler.get_last_lr() + else: + self.after_scheduler.step(epoch - self.warmup_epochs) + self._last_lr = self.after_scheduler.get_last_lr() + else: + return super().step(epoch) diff --git a/colossalai/nn/lr_scheduler/linear.py b/colossalai/nn/lr_scheduler/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..556938b8a60c8ae5eaf116e17710f5253f4bce28 --- /dev/null +++ b/colossalai/nn/lr_scheduler/linear.py @@ -0,0 +1,28 @@ +from torch.optim.lr_scheduler import _LRScheduler + +from colossalai.registry import LR_SCHEDULERS + + +@LR_SCHEDULERS.register_module +class LinearWarmupLR(_LRScheduler): + """Linearly warmup learning rate and then linearly decay. + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + total_steps (int): Number of total training steps. + warmup_steps (int, optional): Number of warmup steps, defaults to 0 + last_epoch (int, optional): The index of last epoch, defaults to -1. When last_epoch=-1, + the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. + """ + + def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, last_epoch: int = -1, **kwargs): + self.warmup_steps = warmup_steps + self.total_steps = total_steps + super().__init__(optimizer, last_epoch=last_epoch) + + def get_lr(self): + if self.last_epoch < self.warmup_steps: + return [(self.last_epoch + 1) / (self.warmup_steps + 1) * lr for lr in self.base_lrs] + else: + return [(self.total_steps - self.last_epoch) / (self.total_steps - self.warmup_steps) * lr + for lr in self.base_lrs] diff --git a/colossalai/nn/lr_scheduler/multistep.py b/colossalai/nn/lr_scheduler/multistep.py new file mode 100644 index 0000000000000000000000000000000000000000..29531a9e385524913b9b6bfb4c89c7ce40b05792 --- /dev/null +++ b/colossalai/nn/lr_scheduler/multistep.py @@ -0,0 +1,62 @@ +from typing import List + +from torch.optim.lr_scheduler import MultiStepLR as _MultiStepLR + +from colossalai.registry import LR_SCHEDULERS +from .delayed import WarmupScheduler + + +@LR_SCHEDULERS.register_module +class MultiStepLR(_MultiStepLR): + """Decays the learning rate of each parameter group by gamma once the + number of epoch reaches one of the milestones. Notice that such decay can + happen simultaneously with other changes to the learning rate from outside + this scheduler. When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + total_steps (int): Number of total training steps. + milestones (List[int], optional): List of epoch indices. Must be increasing, defaults to None. + gamma (float, optional): Multiplicative factor of learning rate decay, defaults to 0.1. + last_epoch (int, optional): The index of last epoch, defaults to -1. When last_epoch=-1, + the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. + """ + + def __init__(self, + optimizer, + total_steps: int, + milestones: List[int] = None, + gamma: float = 0.1, + last_epoch: int = -1, + **kwargs): + super().__init__(optimizer, milestones, gamma=gamma, last_epoch=last_epoch) + + +@LR_SCHEDULERS.register_module +class MultiStepWarmupLR(WarmupScheduler): + """Multistep learning rate scheduler with warmup. + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + total_steps (int): Number of total training steps. + warmup_steps (int, optional): Number of warmup steps, defaults to 0. + milestones (List[int], optional): List of epoch indices. Must be increasing, defaults to None. + gamma (float, optional): Multiplicative factor of learning rate decay, defaults to 0.1. + num_steps_per_epoch (int, optional): Number of steps per epoch, defaults to -1. + last_epoch (int, optional): The index of last epoch, defaults to -1. When last_epoch=-1, + the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. + """ + + def __init__(self, + optimizer, + total_steps: int, + warmup_steps: int = 0, + milestones: List[int] = None, + gamma: float = 0.1, + last_epoch: int = -1, + **kwargs): + if len(milestones) == 0: + raise ValueError('milestones cannot be empty') + milestones = [v - warmup_steps for v in milestones if v >= warmup_steps] + base_scheduler = _MultiStepLR(optimizer, milestones=milestones, gamma=gamma) + super().__init__(optimizer, warmup_steps, base_scheduler, last_epoch=last_epoch) diff --git a/colossalai/nn/lr_scheduler/onecycle.py b/colossalai/nn/lr_scheduler/onecycle.py new file mode 100644 index 0000000000000000000000000000000000000000..8007fd36008ea01830500a1f3272d0407d2ca3f2 --- /dev/null +++ b/colossalai/nn/lr_scheduler/onecycle.py @@ -0,0 +1,94 @@ +from torch.optim.lr_scheduler import OneCycleLR as _OneCycleLR + +from colossalai.registry import LR_SCHEDULERS + + +@LR_SCHEDULERS.register_module +class OneCycleLR(_OneCycleLR): + r"""Sets the learning rate of each parameter group according to the + 1cycle learning rate policy. The 1cycle policy anneals the learning + rate from an initial learning rate to some maximum learning rate and then + from that maximum learning rate to some minimum learning rate much lower + than the initial learning rate. + This policy was initially described in the paper `Super-Convergence: + Very Fast Training of Neural Networks Using Large Learning Rates`_. + The 1cycle learning rate policy changes the learning rate after every batch. + `step` should be called after a batch has been used for training. + This scheduler is not chainable. + Note also that the total number of steps in the cycle can be determined in one + of two ways (listed in order of precedence): + + * A value for total_steps is explicitly provided. + * A number of epochs (epochs) and a number of steps per epoch (steps_per_epoch) are provided. + In this case, the number of total steps is inferred by total_steps = epochs * steps_per_epoch + + You must either provide a value for total_steps or provide a value for both + epochs and steps_per_epoch. + The default behaviour of this scheduler follows the fastai implementation of 1cycle, which + claims that "unpublished work has shown even better results by using only two phases". To + mimic the behaviour of the original paper instead, set ``three_phase=True``. + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + total_steps (int): Number of total training steps. + pct_start (float, optional): + The percentage of the cycle (in number of steps) spent increasing the learning rate, defaults to 0.3. + anneal_strategy (str, optional): {'cos', 'linear'}, Specifies the annealing strategy: + "cos" for cosine annealing, "linear" for linear annealing, defaults to 'cos'. + cycle_momentum (bool, optional): If ``True``, momentum is cycled inversely + to learning rate between 'base_momentum' and 'max_momentum', defaults to True. + base_momentum (float, optional): Lower momentum boundaries in the cycle for each parameter group. + Note that momentum is cycled inversely to learning rate; at the peak of a cycle, momentum is + 'base_momentum' and learning rate is 'max_lr', defaults to 0.85. + max_momentum (float, optional): Upper momentum boundaries in the cycle for each parameter group. + Functionally, it defines the cycle amplitude (max_momentum - base_momentum). + Note that momentum is cycled inversely to learning rate; at the start of a cycle, momentum is 'max_momentum' + and learning rate is 'base_lr', defaults to 0.95. + div_factor (float, optional): Determines the initial learning rate via + initial_lr = max_lr/div_factor, defaults to 25.0. + final_div_factor (float, optional): Determines the minimum learning rate via + min_lr = initial_lr/final_div_factor, defaults to 10000.0. + last_epoch (int, optional): The index of the last batch. This parameter is used when resuming a training job. + Since `step()` should be invoked after each batch instead of after each epoch, this number represents + the total number of *batches* computed, not the total number of epochs computed. + When last_epoch=-1, the schedule is started from the beginning, defaults to -1 + + The ``kwargs`` for initializing torch.optim.lr_scheduler.OneCycleLR should include parameters below: + :: + + epochs (int, optional, default=None) + steps_per_epoch (int, optional, default=None) + three_phase (bool, optional, default=False) + verbose (bool, optional, default=False) + + More details about kwargs could be found in + `OneCycleLR `_. + + .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates: + https://arxiv.org/abs/1708.07120 + """ + + def __init__(self, + optimizer, + total_steps: int, + pct_start=0.3, + anneal_strategy='cos', + cycle_momentum=True, + base_momentum=0.85, + max_momentum=0.95, + div_factor=25.0, + final_div_factor=10000.0, + last_epoch=-1, + **kwargs): + max_lrs = list(map(lambda group: group['lr'], optimizer.param_groups)) + super().__init__(optimizer, + max_lrs, + total_steps=total_steps, + pct_start=pct_start, + anneal_strategy=anneal_strategy, + cycle_momentum=cycle_momentum, + base_momentum=base_momentum, + max_momentum=max_momentum, + div_factor=div_factor, + final_div_factor=final_div_factor, + last_epoch=last_epoch) diff --git a/colossalai/nn/lr_scheduler/poly.py b/colossalai/nn/lr_scheduler/poly.py new file mode 100644 index 0000000000000000000000000000000000000000..16352bc5175ff022f3111bec8503443cdae772b9 --- /dev/null +++ b/colossalai/nn/lr_scheduler/poly.py @@ -0,0 +1,66 @@ +from torch.optim.lr_scheduler import _LRScheduler + +from colossalai.registry import LR_SCHEDULERS +from .delayed import WarmupScheduler + + +@LR_SCHEDULERS.register_module +class PolynomialLR(_LRScheduler): + """Polynomial learning rate scheduler. + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + total_steps (int): Number of total training steps. + end_lr (float, optional): Minimum learning rate, defaults to 0.0001. + power (float, optional): The power of polynomial, defaults to 1.0. + last_epoch (int, optional): The index of last epoch, defaults to -1. When last_epoch=-1, + the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. + """ + + def __init__(self, + optimizer, + total_steps: int, + end_lr: float = 0.0001, + power: float = 1.0, + last_epoch: int = -1, + **kwargs): + if end_lr < 0: + raise ValueError(f'end_lr must >= 0, got {end_lr}') + self.total_steps = total_steps + self.end_lr = end_lr + self.power = power + super().__init__(optimizer, last_epoch=last_epoch) + + def get_lr(self): + return self._get_closed_form_lr() + + def _get_closed_form_lr(self): + return [(base_lr - self.end_lr) * + ((1 - min(self.last_epoch, self.total_steps) / self.total_steps)**self.power) + self.end_lr + for base_lr in self.base_lrs] + + +@LR_SCHEDULERS.register_module +class PolynomialWarmupLR(WarmupScheduler): + """Polynomial learning rate scheduler with warmup. + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + total_steps (int): Number of total training steps. + warmup_steps (int, optional): Number of warmup steps, defaults to 0. + end_lr (float, optional): Minimum learning rate, defaults to 0.0001. + power (float, optional): The power of polynomial, defaults to 1.0. + last_epoch (int, optional): The index of last epoch, defaults to -1. When last_epoch=-1, + the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. + """ + + def __init__(self, + optimizer, + total_steps: int, + warmup_steps: int = 0, + end_lr: float = 0.0001, + power: float = 1.0, + last_epoch: int = -1, + **kwargs): + base_scheduler = PolynomialLR(optimizer, total_steps - warmup_steps, end_lr=end_lr, power=power) + super().__init__(optimizer, warmup_steps, base_scheduler, last_epoch=last_epoch) diff --git a/colossalai/nn/lr_scheduler/torch.py b/colossalai/nn/lr_scheduler/torch.py new file mode 100644 index 0000000000000000000000000000000000000000..05d2a49c1ea5b0363f82dc61bc3f737c5c2a845a --- /dev/null +++ b/colossalai/nn/lr_scheduler/torch.py @@ -0,0 +1,77 @@ +from torch.optim.lr_scheduler import LambdaLR as _LambdaLR +from torch.optim.lr_scheduler import MultiplicativeLR as _MultiplicativeLR +from torch.optim.lr_scheduler import StepLR as _StepLR +from torch.optim.lr_scheduler import ExponentialLR as _ExponentialLR + +from colossalai.registry import LR_SCHEDULERS + + +@LR_SCHEDULERS.register_module +class LambdaLR(_LambdaLR): + """Sets the learning rate of each parameter group to the initial lr + times a given function. When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + total_steps (int): Number of total training steps. + lr_lambda (Union[``function``, ``list[function]``]): A function which computes a multiplicative + factor given an integer parameter epoch, or a list of such functions, + one for each group in optimizer.param_groups, defaults to None. + last_epoch (int, optional): The index of last epoch, defaults to -1. + """ + + def __init__(self, optimizer, total_steps, lr_lambda=None, last_epoch: int = -1) -> None: + super().__init__(optimizer, lr_lambda, last_epoch=last_epoch) + + +@LR_SCHEDULERS.register_module +class MultiplicativeLR(_MultiplicativeLR): + """Multiply the learning rate of each parameter group by the factor given + in the specified function. When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + total_steps (int): Number of total training steps. + lr_lambda (Union[``function``, ``list[function]``]): A function which computes a multiplicative + factor given an integer parameter epoch, or a list of such functions, + one for each group in optimizer.param_groups, defaults to None. + last_epoch (int, optional): The index of last epoch, defaults to -1. + """ + + def __init__(self, optimizer, total_steps, lr_lambda=None, last_epoch: int = -1) -> None: + super().__init__(optimizer, lr_lambda, last_epoch=last_epoch) + + +@LR_SCHEDULERS.register_module +class StepLR(_StepLR): + """Decays the learning rate of each parameter group by gamma every + step_size epochs. Notice that such decay can happen simultaneously with + other changes to the learning rate from outside this scheduler. When + last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + total_steps (int): Number of total training steps. + step_size (int, optional): Period of learning rate decay, defaults to 1. + gamma (float, optional): Multiplicative factor of learning rate decay, defaults to 0.1. + last_epoch (int, optional): The index of last epoch, defaults to -1. + """ + + def __init__(self, optimizer, total_steps, step_size: int = 1, gamma: float = 0.1, last_epoch: int = -1) -> None: + super().__init__(optimizer, step_size, gamma=gamma, last_epoch=last_epoch) + + +@LR_SCHEDULERS.register_module +class ExponentialLR(_ExponentialLR): + """Decays the learning rate of each parameter group by gamma every epoch. + When last_epoch=-1, sets initial lr as lr + + Args: + optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Wrapped optimizer. + total_steps (int): Number of total training steps. + gamma (float, optional): Multiplicative factor of learning rate decay, defaults to 1.0. + last_epoch (int, optional): The index of last epoch, defaults to -1. + """ + + def __init__(self, optimizer, total_steps, gamma: float = 1.0, last_epoch: int = -1) -> None: + super().__init__(optimizer, gamma, last_epoch=last_epoch) diff --git a/colossalai/nn/metric/__init__.py b/colossalai/nn/metric/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..00833b6119c161be0bd2855a1b44b333f2b93f66 --- /dev/null +++ b/colossalai/nn/metric/__init__.py @@ -0,0 +1,26 @@ +from torch import nn + +from ._utils import calc_acc +from .accuracy_2d import Accuracy2D +from .accuracy_2p5d import Accuracy2p5D +from .accuracy_3d import Accuracy3D +from colossalai.nn.layer.utils import get_tensor_parallel_mode + +_parallel_accuracy = { + '2d': Accuracy2D, + '2.5d': Accuracy2p5D, + '3d': Accuracy3D, +} + + +class Accuracy(nn.Module): + def __init__(self): + super().__init__() + tensor_parallel = get_tensor_parallel_mode() + if tensor_parallel not in _parallel_accuracy: + self.acc = calc_acc + else: + self.acc = _parallel_accuracy[tensor_parallel]() + + def forward(self, *args): + return self.acc(*args) diff --git a/colossalai/nn/metric/_utils.py b/colossalai/nn/metric/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..eac591b64c65cd835d4212698b1809c29710b425 --- /dev/null +++ b/colossalai/nn/metric/_utils.py @@ -0,0 +1,7 @@ +import torch + + +def calc_acc(logits, targets): + preds = torch.argmax(logits, dim=-1) + correct = torch.sum(targets == preds) + return correct diff --git a/colossalai/nn/metric/accuracy_2d.py b/colossalai/nn/metric/accuracy_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..a86832973cfda4ffe89f414d9ae7342191293aac --- /dev/null +++ b/colossalai/nn/metric/accuracy_2d.py @@ -0,0 +1,29 @@ +import torch +from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d +from torch import nn + +from ._utils import calc_acc + + +class Accuracy2D(nn.Module): + """Accuracy for 2D parallelism + """ + + def __init__(self): + super().__init__() + + def forward(self, logits, targets): + """Calculate the accuracy of predicted labels. + + Args: + logits (:class:`torch.tensor`): Predicted labels. + targets (:class:`torch.tensor`): True labels from data. + + Returns: + float: the accuracy of prediction. + """ + with torch.no_grad(): + targets = split_batch_2d(targets) + correct = calc_acc(logits, targets) + correct = reduce_by_batch_2d(correct) + return correct diff --git a/colossalai/nn/metric/accuracy_2p5d.py b/colossalai/nn/metric/accuracy_2p5d.py new file mode 100644 index 0000000000000000000000000000000000000000..3044da065de136b5e2d5f73b2690893f3dc8e240 --- /dev/null +++ b/colossalai/nn/metric/accuracy_2p5d.py @@ -0,0 +1,29 @@ +import torch +from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d +from torch import nn + +from ._utils import calc_acc + + +class Accuracy2p5D(nn.Module): + """Accuracy for 2p5D parallelism + """ + + def __init__(self): + super().__init__() + + def forward(self, logits, targets): + """Calculate the accuracy of predicted labels. + + Args: + logits (:class:`torch.tensor`): Predicted labels. + targets (:class:`torch.tensor`): True labels from data. + + Returns: + float: the accuracy of prediction. + """ + with torch.no_grad(): + targets = split_batch_2p5d(targets) + correct = calc_acc(logits, targets) + correct = reduce_by_batch_2p5d(correct) + return correct diff --git a/colossalai/nn/metric/accuracy_3d.py b/colossalai/nn/metric/accuracy_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..5506fc1d2ffcf918cbb9f079bfe0f07f92b9bc7f --- /dev/null +++ b/colossalai/nn/metric/accuracy_3d.py @@ -0,0 +1,33 @@ +import torch +from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D +from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d +from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env +from torch import nn + +from ._utils import calc_acc + + +class Accuracy3D(nn.Module): + """Accuracy for 3D parallelism + """ + def __init__(self): + super().__init__() + self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + + def forward(self, logits, targets): + """Calculate the accuracy of predicted labels. + + Args: + logits (:class:`torch.tensor`): Predicted labels. + targets (:class:`torch.tensor`): True labels from data. + + Returns: + float: the accuracy of prediction. + """ + with torch.no_grad(): + targets = split_tensor_3d(targets, 0, self.weight_parallel_mode) + targets = split_tensor_3d(targets, 0, self.input_parallel_mode) + correct = calc_acc(logits, targets) + correct = reduce_by_batch_3d(correct, self.input_parallel_mode, self.weight_parallel_mode) + return correct diff --git a/colossalai/nn/optimizer/README.md b/colossalai/nn/optimizer/README.md new file mode 100644 index 0000000000000000000000000000000000000000..09395d08b93e9d80d3e966772e7197f022f7a851 --- /dev/null +++ b/colossalai/nn/optimizer/README.md @@ -0,0 +1,82 @@ +# Colossal-AI Optimization Techniques + +## Introduction + +Welcome to the large-scale deep learning optimization techniques of [Colossal-AI](https://github.com/hpcaitech/ColossalAI), +which has been accepted as official tutorials by top conference [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), [PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/), etc. + + +[Colossal-AI](https://github.com/hpcaitech/ColossalAI), a unified deep learning system for the big model era, integrates +many advanced technologies such as multi-dimensional tensor parallelism, sequence parallelism, heterogeneous memory management, +large-scale optimization, adaptive task scheduling, etc. By using Colossal-AI, we could help users to efficiently and +quickly deploy large AI model training and inference, reducing large AI model training budgets and scaling down the labor cost of learning and deployment. + +### 🚀 Quick Links + +[**Colossal-AI**](https://github.com/hpcaitech/ColossalAI) | +[**Paper**](https://arxiv.org/abs/2110.14883) | +[**Documentation**](https://www.colossalai.org/) | +[**Forum**](https://github.com/hpcaitech/ColossalAI/discussions) | +[**Slack**](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w) + + +## Table of Content + +Large transformer models display promising performance on a wide spectrum of AI applications. +Both academia and industry are scaling DL training on larger clusters. However, degrading generalization performance, non-negligible communication overhead, and increasing model size prevent DL researchers and engineers from exploring large-scale AI models. + +We aim to provide a clear sketch of the optimizations for large-scale deep learning with regard to model accuracy and model efficiency. +One way to achieve the goal of maintaining or improving the model accuracy in the large-scale setting while maintaining compute efficiency is to design algorithms that +are less communication and memory hungry. Notably, they are not mutually exclusive but can +be optimized jointly to further speed up training. + +1. Model Accuracy + - Gradient Descent Optimization + - Gradient Descent Variants + - Momentum + - Adaptive Gradient + - Large Batch Training Optimization + - LARS + - LAMB + - Generalization Gap + - Second-Order Optimization + - Hessian-Free + - K-FAC + - Shampoo + +2. Model Accuracy + - Communication Efficiency + - Reduce Volumn of Comm. + - Reduce Frequency of Comm. + - Memory Efficiency + - Mix-Precision Training + - Memory-Efficient Methods, e.g. ZeRO, Gemini, etc. + +Some of the above are still under development. **If you wish to make a contribution to this repository, please read the `Contributing` section below.** + +## Discussion + +Discussion about the Colossal-AI project is always welcomed! We would love to exchange ideas with the community to better help this project grow. +If you think there is a need to discuss anything, you may jump to our [Slack](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w). + +If you encounter any problem while running these optimizers, you may want to raise an issue in this repository. + +## Contributing + +This project welcomes constructive ideas and implementations from the community. + +### Update an Optimizer + +If you find that an optimizer is broken (not working) or not user-friendly, you may put up a pull request to this repository and update this optimizer. + +### Add a New Optimizer + +If you wish to add an optimizer for a specific application, please follow the steps below. + +1. create the new optimizer file in the current folder +2. Prepare the corresponding example files in the [Examples](https://github.com/hpcaitech/ColossalAI-Examples) repository to prove effectiveness of the new optimizer +3. Prepare a detailed readme on environment setup, dataset preparation, code execution, etc. in your example folder +4. Update the table of content (last section above) in this readme file + + +If your PR is accepted, we may invite you to put up a tutorial or blog in [ColossalAI Documentation](https://colossalai.org/). diff --git a/colossalai/nn/optimizer/__init__.py b/colossalai/nn/optimizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..06072648beba974c2e6f638bdc5f3d1162141056 --- /dev/null +++ b/colossalai/nn/optimizer/__init__.py @@ -0,0 +1,10 @@ +from .colossalai_optimizer import ColossalaiOptimizer +from .fused_adam import FusedAdam +from .fused_lamb import FusedLAMB +from .fused_sgd import FusedSGD +from .lamb import Lamb +from .lars import Lars +from .cpu_adam import CPUAdam +from .hybrid_adam import HybridAdam + +__all__ = ['ColossalaiOptimizer', 'FusedLAMB', 'FusedAdam', 'FusedSGD', 'Lamb', 'Lars', 'CPUAdam', 'HybridAdam'] diff --git a/colossalai/nn/optimizer/colossalai_optimizer.py b/colossalai/nn/optimizer/colossalai_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..34f5a9541975aa3029032d0346947c3524cd5a69 --- /dev/null +++ b/colossalai/nn/optimizer/colossalai_optimizer.py @@ -0,0 +1,44 @@ +import torch +import torch.nn as nn +from torch import Tensor +from torch.optim import Optimizer +from colossalai.utils import clip_grad_norm_fp32 + + +class ColossalaiOptimizer(Optimizer): + + def __init__(self, optim: Optimizer): + self.optim = optim + + @property + def param_groups(self): + return self.optim.param_groups + + @property + def defaults(self): + return self.optim.defaults + + def add_param_group(self, *args, **kwargs): + return self.optim.add_param_group(*args, **kwargs) + + def step(self, *args, **kwargs): + return self.optim.step(*args, **kwargs) + + def zero_grad(self, *args, **kwargs): + self.optim.zero_grad(*args, **kwargs) + + def load_state_dict(self, *args, **kwargs): + self.optim.load_state_dict(*args, **kwargs) + + def state_dict(self): + return self.optim.state_dict() + + def backward(self, loss: Tensor): + loss.backward() + + def backward_by_grad(self, tensor: Tensor, grad: Tensor): + torch.autograd.backward(tensors=tensor, grad_tensors=grad) + + def clip_grad_norm(self, model: nn.Module, max_norm: float): + if max_norm > 0.0: + clip_grad_norm_fp32(model.parameters(), max_norm) diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py new file mode 100644 index 0000000000000000000000000000000000000000..54036973e1e31441f594f1a774693b5f93b01dd3 --- /dev/null +++ b/colossalai/nn/optimizer/cpu_adam.py @@ -0,0 +1,169 @@ +import math +from typing import Optional + +import torch + +from colossalai.kernel.op_builder import CPUAdamBuilder +from colossalai.registry import OPTIMIZERS + +from .nvme_optimizer import NVMeOptimizer + + +@OPTIMIZERS.register_module +class CPUAdam(NVMeOptimizer): + """Implements Adam algorithm. + + Supports parameters updating on both GPU and CPU, depanding on the device of paramters. + But the parameters and gradients should on the same device: + * Parameters on CPU and gradients on CPU is allowed. + * Parameters on GPU and gradients on GPU is allowed. + * Parameters on GPU and gradients on CPU is **not** allowed. + + `CPUAdam` requires CUDA extensions which can be built during installation or runtime. + + This version of CPU Adam accelates parameters updating on CPU with SIMD. + Support of AVX2 or AVX512 is required. + + The GPU part is implemented in an naive way. + + CPU Adam also supports the hybrid precision calculation, eg. fp32 parameters and fp16 gradients. + + :class:`colossalai.nn.optimizer.CPUAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``, + or ``torch.optim.Adam`` with ``adamw_mode=False`` + + Adam was been proposed in `Adam: A Method for Stochastic Optimization`_. + + Arguments: + model_params (iterable): iterable of parameters of dicts defining + parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square. (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) NOT SUPPORTED yet in CPUAdam! + adamw_mode (boolean, optional): Apply L2 regularization or weight decay + True for decoupled weight decay(also known as AdamW) (default: True) + simd_log (boolean, optional): whether to show if you are using SIMD to + accelerate. (default: False) + nvme_offload_fraction (float, optional): Fraction of optimizer states to be offloaded to NVMe. Defaults to 0.0. + nvme_offload_dir (Optional[str], optional): Directory to save NVMe offload files. + If it's ``None``, a random temporary directory will be used. Defaults to None. + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + # Number of fp32 shards for per parameter + # Param weight, grad, momentum and variance + num_fp32_shards_per_param = 4 + + def __init__(self, + model_params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + adamw_mode=True, + nvme_offload_fraction: float = 0.0, + nvme_offload_dir: Optional[str] = None): + + default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction) + super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir) + self.adamw_mode = adamw_mode + cpu_adam = CPUAdamBuilder().load() + self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode) + + def torch_adam_update(self, + data, + grad, + exp_avg, + exp_avg_sq, + lr, + beta1, + beta2, + eps, + weight_decay, + bias_correction1, + bias_correction2, + use_adamw=False): + # FIXME(ver217): remove the below line when replace torch adam with fused adam + grad = grad.float() + + if weight_decay != 0: + if use_adamw: + data.mul_(1 - lr * weight_decay) + else: + grad = grad.add(data, alpha=weight_decay) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # TODO(jiaruifang) dose not support amsgrad + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) + + step_size = lr / bias_correction1 + + data.addcdiv_(exp_avg, denom, value=-step_size) + + @torch.no_grad() + def step(self, closure=None, div_scale: float = -1): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + self._pre_step('exp_avg', 'exp_avg_sq') + for _, group in enumerate(self.param_groups): + for _, p in enumerate(group['params']): + + if p.grad is None: + continue + + state = self.state[p] + + target_device = p.device + if len(state) == 0: + state['step'] = 0 + + # gradient momentums + state['exp_avg'] = torch.zeros_like(p, dtype=torch.float, device=target_device) + # gradient variances + state['exp_avg_sq'] = torch.zeros_like(p, dtype=torch.float, device=target_device) + self._post_state_init(p) + + state['step'] += 1 + beta1, beta2 = group['betas'] + + if target_device.type == 'cpu': + assert p.data.numel() == p.grad.data.numel(), "parameter and gradient should have the same size" + assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu" + assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu" + self._pre_update(p, 'exp_avg', 'exp_avg_sq') + self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'], + group['bias_correction'], p.data, p.grad.data, state['exp_avg'], + state['exp_avg_sq'], div_scale) + self._post_update(p, 'exp_avg', 'exp_avg_sq') + elif target_device.type == 'cuda': + assert div_scale == -1, "div_scale should remain default" + assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda" + assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda" + + bias_correction1 = 1 - beta1**state['step'] + bias_correction2 = 1 - beta2**state['step'] + + # adam on cuda + self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'], + beta1, beta2, group['eps'], group['weight_decay'], bias_correction1, + bias_correction2, self.adamw_mode) + else: + raise RuntimeError + self._post_step() + return loss diff --git a/colossalai/nn/optimizer/fused_adam.py b/colossalai/nn/optimizer/fused_adam.py new file mode 100644 index 0000000000000000000000000000000000000000..987af8a968b7d883178b145c121edd2a0fc1033a --- /dev/null +++ b/colossalai/nn/optimizer/fused_adam.py @@ -0,0 +1,149 @@ +# modified from https://github.com/NVIDIA/apex/blob/master/apex/optimizers/fused_adam.py +''' +Copyright 2020 The Microsoft DeepSpeed Team + +Copyright NVIDIA/apex +This file is adapted from fused adam in NVIDIA/apex, commit a109f85 +Licensed under the MIT License. +''' +import torch + +from colossalai.registry import OPTIMIZERS +from colossalai.utils import multi_tensor_applier + + +@OPTIMIZERS.register_module +class FusedAdam(torch.optim.Optimizer): + """Implements Adam algorithm. + + `FusedAdam` requires CUDA extensions which can be built during installation or runtime. + + This version of fused Adam implements 2 fusions. + + * Fusion of the Adam update's elementwise operations + * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches. + + :class:`colossalai.nn.optimizer.FusedAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``, + or ``torch.optim.Adam`` with ``adamw_mode=False`` + + :class:`colossalai.nn.optimizer.FusedAdam` may be used with or without Amp. + + Adam was been proposed in `Adam: A Method for Stochastic Optimization`_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square. (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) NOT SUPPORTED in FusedAdam! + adamw_mode (boolean, optional): Apply L2 regularization or weight decay + True for decoupled weight decay(also known as AdamW) (default: True) + set_grad_none (bool, optional): whether set grad to None when zero_grad() + method is called. (default: True) + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-8, + adamw_mode=True, + weight_decay=0., + amsgrad=False, + set_grad_none=True): + + if amsgrad: + raise RuntimeError('FusedAdam does not support the AMSGrad variant.') + defaults = dict(lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay) + super(FusedAdam, self).__init__(params, defaults) + self.adamw_mode = 1 if adamw_mode else 0 + self.set_grad_none = set_grad_none + if multi_tensor_applier.available: + from colossalai.kernel.op_builder import FusedOptimBuilder + fused_optim = FusedOptimBuilder().load() + + # Skip buffer + self._dummy_overflow_buf = torch.cuda.IntTensor([0]) + self.multi_tensor_adam = fused_optim.multi_tensor_adam + else: + raise RuntimeError('FusedAdam requires cuda extensions') + + def zero_grad(self, set_to_none=False): + if set_to_none: + for group in self.param_groups: + for p in group['params']: + p.grad = None + else: + super(FusedAdam, self).zero_grad() + + def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None, div_scale: float = -1): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + + The remaining arguments are deprecated, and are only retained (for the moment) for error-checking purposes. + """ + if any(p is not None for p in [grads, output_params, scale, grad_norms]): + raise RuntimeError( + 'FusedAdam has been updated. Simply initialize it identically to torch.optim.Adam, and call step() with no arguments.' + ) + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + bias_correction = 1 if group['bias_correction'] else 0 + beta1, beta2 = group['betas'] + + # assume same step across group now to simplify things + # per parameter step can be easily support by making it tensor, or pass list into kernel + if 'step' in group: + group['step'] += 1 + else: + group['step'] = 1 + + # create lists for multi-tensor apply + g_l, p_l, m_l, v_l = [], [], [], [] + + for p in group['params']: + if p.grad is None: + continue + if p.grad.data.is_sparse: + raise RuntimeError( + 'FusedAdam does not support sparse gradients, please consider SparseAdam instead') + + state = self.state[p] + # State initialization + if len(state) == 0: + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p) + + if p.dtype not in [torch.float16, torch.float32]: + raise RuntimeError('FusedAdam only support fp16 and fp32.') + + g_l.append(p.grad.data) + p_l.append(p.data) + m_l.append(state['exp_avg']) + v_l.append(state['exp_avg_sq']) + + multi_tensor_applier(self.multi_tensor_adam, self._dummy_overflow_buf, [g_l, p_l, m_l, v_l], group['lr'], + beta1, beta2, group['eps'], group['step'], self.adamw_mode, bias_correction, + group['weight_decay'], div_scale) + + return loss diff --git a/colossalai/nn/optimizer/fused_lamb.py b/colossalai/nn/optimizer/fused_lamb.py new file mode 100644 index 0000000000000000000000000000000000000000..72520064e98ba50954189490e7cbdb2f00685340 --- /dev/null +++ b/colossalai/nn/optimizer/fused_lamb.py @@ -0,0 +1,194 @@ +# modified from https://github.com/NVIDIA/apex/blob/master/apex/optimizers/fused_lamb.py +import torch + +from colossalai.registry import OPTIMIZERS +from colossalai.utils import multi_tensor_applier + + +@OPTIMIZERS.register_module +class FusedLAMB(torch.optim.Optimizer): + """Implements LAMB algorithm. + + `FusedLAMB` requires CUDA extensions which can be built during installation or runtime. + + This version of fused LAMB implements 2 fusions. + + * Fusion of the LAMB update's elementwise operations + * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches. + + :class:`colossalai.nn.optimizer.FusedLAMB`'s usage is identical to any ordinary Pytorch optimizer + + :class:`colossalai.nn.optimizer.FusedLAMB` may be used with or without Amp. + + LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its norm. (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-6) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0.01) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + NOT SUPPORTED now! (default: False) + adam_w_mode (boolean, optional): Apply L2 regularization or weight decay + True for decoupled weight decay(also known as AdamW) (default: True) + grad_averaging (bool, optional): whether apply (1-beta2) to grad when + calculating running averages of gradient. (default: True) + set_grad_none (bool, optional): whether set grad to None when zero_grad() + method is called. (default: True) + max_grad_norm (float, optional): value used to clip global grad norm + (default: 1.0) + use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0 + weight decay parameter (default: False) + + .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes: + https://arxiv.org/abs/1904.00962 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-6, + weight_decay=0.01, + amsgrad=False, + adam_w_mode=True, + grad_averaging=True, + set_grad_none=True, + max_grad_norm=1.0, + use_nvlamb=False): + if amsgrad: + raise RuntimeError('FusedLAMB does not support the AMSGrad variant.') + defaults = dict(lr=lr, + bias_correction=bias_correction, + betas=betas, + eps=eps, + weight_decay=weight_decay, + grad_averaging=grad_averaging, + max_grad_norm=max_grad_norm) + super(FusedLAMB, self).__init__(params, defaults) + if multi_tensor_applier.available: + from colossalai.kernel.op_builder import FusedOptimBuilder + fused_optim = FusedOptimBuilder().load() + + self.multi_tensor_l2norm = fused_optim.multi_tensor_l2norm + # Skip buffer + self._dummy_overflow_buf = torch.tensor([0], + dtype=torch.int, + device=self.param_groups[0]["params"][0].device) + self.multi_tensor_lamb = fused_optim.multi_tensor_lamb + else: + raise RuntimeError('FusedLAMB requires cuda extensions') + + self.adam_w_mode = 1 if adam_w_mode else 0 + self.set_grad_none = set_grad_none + self.use_nvlamb = use_nvlamb + + def zero_grad(self): + if self.set_grad_none: + for group in self.param_groups: + for p in group['params']: + p.grad = None + else: + super(FusedLAMB, self).zero_grad() + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + # create separate grad lists for fp32 and fp16 params + g_all_32, g_all_16 = [], [] + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + if p.dtype == torch.float32: + g_all_32.append(p.grad.data) + elif p.dtype == torch.float16: + g_all_16.append(p.grad.data) + else: + raise RuntimeError('FusedLAMB only support fp16 and fp32.') + + device = self.param_groups[0]["params"][0].device + g_norm_32, g_norm_16 = torch.zeros(1, device=device), torch.zeros(1, device=device) + # compute grad norm for two lists + if len(g_all_32) > 0: + g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm, self._dummy_overflow_buf, [g_all_32], False)[0] + if len(g_all_16) > 0: + g_norm_16 = multi_tensor_applier(self.multi_tensor_l2norm, self._dummy_overflow_buf, [g_all_16], False)[0] + + # blend two grad norms to get global grad norm + global_grad_norm = multi_tensor_applier(self.multi_tensor_l2norm, self._dummy_overflow_buf, + [[g_norm_32, g_norm_16]], False)[0] + max_grad_norm = self.defaults['max_grad_norm'] + + for group in self.param_groups: + bias_correction = 1 if group['bias_correction'] else 0 + beta1, beta2 = group['betas'] + grad_averaging = 1 if group['grad_averaging'] else 0 + + # assume same step across group now to simplify things + # per parameter step can be easily support by making it tensor, or pass list into kernel + if 'step' in group: + group['step'] += 1 + else: + group['step'] = 1 + + # create lists for multi-tensor apply + g_16, p_16, m_16, v_16 = [], [], [], [] + g_32, p_32, m_32, v_32 = [], [], [], [] + + for p in group['params']: + if p.grad is None: + continue + if p.grad.data.is_sparse: + raise RuntimeError( + 'FusedLAMB does not support sparse gradients, please consider SparseAdam instead') + + state = self.state[p] + # State initialization + if len(state) == 0: + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of gradient values + state['exp_avg_sq'] = torch.zeros_like(p) + + if p.dtype == torch.float16: + g_16.append(p.grad.data) + p_16.append(p.data) + m_16.append(state['exp_avg']) + v_16.append(state['exp_avg_sq']) + elif p.dtype == torch.float32: + g_32.append(p.grad.data) + p_32.append(p.data) + m_32.append(state['exp_avg']) + v_32.append(state['exp_avg_sq']) + else: + raise RuntimeError('FusedLAMB only support fp16 and fp32.') + + if (len(g_16) > 0): + multi_tensor_applier(self.multi_tensor_lamb, self._dummy_overflow_buf, [g_16, p_16, m_16, v_16], + group['lr'], beta1, beta2, group['eps'], group['step'], bias_correction, + group['weight_decay'], grad_averaging, self.adam_w_mode, global_grad_norm, + max_grad_norm, self.use_nvlamb) + if (len(g_32) > 0): + multi_tensor_applier(self.multi_tensor_lamb, self._dummy_overflow_buf, [g_32, p_32, m_32, v_32], + group['lr'], beta1, beta2, group['eps'], group['step'], bias_correction, + group['weight_decay'], grad_averaging, self.adam_w_mode, global_grad_norm, + max_grad_norm, self.use_nvlamb) + + return loss diff --git a/colossalai/nn/optimizer/fused_sgd.py b/colossalai/nn/optimizer/fused_sgd.py new file mode 100644 index 0000000000000000000000000000000000000000..468713b223c15015ce5e53aa84fda6d7f8afb31c --- /dev/null +++ b/colossalai/nn/optimizer/fused_sgd.py @@ -0,0 +1,149 @@ +# modified from https://github.com/NVIDIA/apex/blob/master/apex/optimizers/fused_sgd.py +import torch +from torch.optim.optimizer import Optimizer, required + +from colossalai.registry import OPTIMIZERS +from colossalai.utils import multi_tensor_applier + + +@OPTIMIZERS.register_module +class FusedSGD(Optimizer): + r"""Implements stochastic gradient descent (optionally with momentum). + + `FusedSGD` requires CUDA extensions which can be built during installation or runtime. + + This version of fused SGD implements 2 fusions. + + * Fusion of the SGD update's elementwise operations + * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches. + + :class:`colossalai.nn.optimizer.FusedSGD` may be used as a drop-in replacement for ``torch.optim.SGD`` + + :class:`colossalai.nn.optimizer.FusedSGD` may be used with or without Amp. + + Nesterov momentum is based on the formula from + `On the importance of initialization and momentum in deep learning`__. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float): learning rate + momentum (float, optional): momentum factor (default: 0) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + dampening (float, optional): dampening for momentum (default: 0) + nesterov (bool, optional): enables Nesterov momentum (default: False) + + __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf + + .. note:: + The implementation of SGD with Momentum/Nesterov subtly differs from + Sutskever et. al. and implementations in some other frameworks. + Considering the specific case of Momentum, the update can be written as + + .. math:: + v = \rho * v + g \\ + p = p - lr * v + + where p, g, v and :math:`\rho` denote the parameters, gradient, + velocity, and momentum respectively. + This is in contrast to Sutskever et. al. and + other frameworks which employ an update of the form + + .. math:: + v = \rho * v + lr * g \\ + p = p - v + + The Nesterov version is analogously modified. + """ + + def __init__(self, + params, + lr=required, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + wd_after_momentum=False): + if lr is not required and lr < 0.0: + raise ValueError("Invalid learning rate: {}".format(lr)) + if momentum < 0.0: + raise ValueError("Invalid momentum value: {}".format(momentum)) + if weight_decay < 0.0: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + + defaults = dict(lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay, nesterov=nesterov) + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError("Nesterov momentum requires a momentum and zero dampening") + super(FusedSGD, self).__init__(params, defaults) + + self.wd_after_momentum = wd_after_momentum + + if multi_tensor_applier.available: + from colossalai.kernel.op_builder import FusedOptimBuilder + fused_optim = FusedOptimBuilder().load() + + # Skip buffer + self._dummy_overflow_buf = torch.tensor([0], + dtype=torch.int, + device=self.param_groups[0]["params"][0].device) + self.multi_tensor_sgd = fused_optim.multi_tensor_sgd + else: + raise RuntimeError('FusedSGD requires cuda extensions') + + def __setstate__(self, state): + super(FusedSGD, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('nesterov', False) + + def get_momentums(self, params): + momentums = [] + first_run = True + for p in params: + param_state = self.state[p] + # torch.optim.SGD initializes momentum in the main loop, we have + # to do it here, and track whether or not we've done so, so that + # momentum application can be skipped in the main kernel. + if 'momentum_buffer' not in param_state: + first_run = True + buf = param_state['momentum_buffer'] = torch.zeros_like(p) + momentums.append(buf) + else: + first_run = False + momentums.append(param_state['momentum_buffer']) + return momentums, first_run + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + weight_decay = group['weight_decay'] + momentum = group['momentum'] + dampening = group['dampening'] + nesterov = group['nesterov'] + + # For each group, there are 3 possible combinations we need to consider: + # grad_type, param_to_update_type, momentum_type + # 1. fp16, fp16, fp16 + # 2. fp32, fp32, fp32 + # 3. fp16, fp32, fp32 + g_l, p_l = [], [] + for p in group['params']: + if p.grad is None: + continue + if p.grad.data.is_sparse: + raise RuntimeError('FusedSGD does not support sparse gradients') + g_l.append(p.grad) + p_l.append(p) + m_l, first_run = self.get_momentums(p_l) + multi_tensor_applier(self.multi_tensor_sgd, self._dummy_overflow_buf, [g_l, p_l, m_l], weight_decay, + momentum, dampening, group['lr'], nesterov, first_run, self.wd_after_momentum, 1.0) + + return loss diff --git a/colossalai/nn/optimizer/hybrid_adam.py b/colossalai/nn/optimizer/hybrid_adam.py new file mode 100644 index 0000000000000000000000000000000000000000..1d0fb92de499ba832af472ae5ad190478bc2d302 --- /dev/null +++ b/colossalai/nn/optimizer/hybrid_adam.py @@ -0,0 +1,149 @@ +from typing import Any, Optional + +import torch + +from colossalai.kernel.op_builder import CPUAdamBuilder, FusedOptimBuilder +from colossalai.registry import OPTIMIZERS +from colossalai.utils import multi_tensor_applier + +from .nvme_optimizer import NVMeOptimizer + + +@OPTIMIZERS.register_module +class HybridAdam(NVMeOptimizer): + """Implements Adam algorithm. + + Supports parameters updating on both GPU and CPU, depanding on the device of paramters. + But the parameters and gradients should on the same device: + * Parameters on CPU and gradients on CPU is allowed. + * Parameters on GPU and gradients on GPU is allowed. + * Parameters on GPU and gradients on CPU is **not** allowed. + + `HybriadAdam` requires CUDA extensions which can be built during installation or runtime. + + This version of Hybrid Adam is an hybrid of CPUAdam and FusedAdam. + + * For parameters updating on CPU, it uses CPUAdam. + * For parameters updating on GPU, it uses FusedAdam. + * Hybird precision calculation of fp16 and fp32 is supported, eg fp32 parameters and fp16 gradients. + + :class:`colossalai.nn.optimizer.HybridAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``, + or ``torch.optim.Adam`` with ``adamw_mode=False`` + + Adam was been proposed in `Adam: A Method for Stochastic Optimization`_. + + Arguments: + model_params (iterable): iterable of parameters of dicts defining + parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square. (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) NOT SUPPORTED yet in CPUAdam! + adamw_mode (boolean, optional): Apply L2 regularization or weight decay + True for decoupled weight decay(also known as AdamW) (default: True) + simd_log (boolean, optional): whether to show if you are using SIMD to + accelerate. (default: False) + nvme_offload_fraction (float, optional): Fraction of optimizer states to be offloaded to NVMe. Defaults to 0.0. + nvme_offload_dir (Optional[str], optional): Directory to save NVMe offload files. + If it's ``None``, a random temporary directory will be used. Defaults to None. + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + # Number of fp32 shards for per parameter + # Param weight, grad, momentum and variance + num_fp32_shards_per_param = 4 + + def __init__(self, + model_params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + adamw_mode=True, + nvme_offload_fraction: float = 0.0, + nvme_offload_dir: Optional[str] = None, + **defaults: Any): + + default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction) + super(HybridAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir) + self.adamw_mode = adamw_mode + + # build during runtime if not found + cpu_optim = CPUAdamBuilder().load() + fused_optim = FusedOptimBuilder().load() + self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode) + + self.gpu_adam_op = fused_optim.multi_tensor_adam + self._dummy_overflow_buf = torch.cuda.IntTensor([0]) + + @torch.no_grad() + def step(self, closure=None, div_scale: float = -1): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + self._pre_step('exp_avg', 'exp_avg_sq') + for _, group in enumerate(self.param_groups): + g_l, p_l, m_l, v_l = [], [], [], [] + group_step = 0 + for _, p in enumerate(group['params']): + + if p.grad is None: + continue + + state = self.state[p] + + target_device = p.device + if len(state) == 0: + state['step'] = 0 + + # gradient momentums + state['exp_avg'] = torch.zeros_like(p, dtype=torch.float, device=target_device) + # gradient variances + state['exp_avg_sq'] = torch.zeros_like(p, dtype=torch.float, device=target_device) + self._post_state_init(p) + + state['step'] += 1 + group_step = state['step'] + beta1, beta2 = group['betas'] + + if target_device.type == 'cpu': + assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu" + assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu" + self._pre_update(p, 'exp_avg', 'exp_avg_sq') + self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'], + group['bias_correction'], p.data, p.grad.data, state['exp_avg'], + state['exp_avg_sq'], div_scale) + self._post_update(p, 'exp_avg', 'exp_avg_sq') + + elif target_device.type == 'cuda': + assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda" + assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda" + + # record the state by gruop and update at once + g_l.append(p.grad.data) + p_l.append(p.data) + m_l.append(state['exp_avg']) + v_l.append(state['exp_avg_sq']) + + else: + raise RuntimeError + if len(g_l) > 0: + adamw_mode = 1 if self.adamw_mode else 0 + bias_correction = 1 if group['bias_correction'] else 0 + multi_tensor_applier(self.gpu_adam_op, self._dummy_overflow_buf, [g_l, p_l, m_l, v_l], group['lr'], + group['betas'][0], group['betas'][1], group['eps'], group_step, adamw_mode, + bias_correction, group['weight_decay'], div_scale) + self._post_step() + return loss diff --git a/colossalai/nn/optimizer/lamb.py b/colossalai/nn/optimizer/lamb.py new file mode 100644 index 0000000000000000000000000000000000000000..7ac2109572a443e004c58f91bfeb03f85dbbdc33 --- /dev/null +++ b/colossalai/nn/optimizer/lamb.py @@ -0,0 +1,111 @@ +""" +Adapted from the pytorch-lamb library at https://github.com/cybertronai/pytorch-lamb +""" + +import torch +from torch.optim import Optimizer + +from colossalai.registry import OPTIMIZERS + + +@OPTIMIZERS.register_module +class Lamb(Optimizer): + r"""Implements Lamb algorithm. + It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-6) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + adam (bool, optional): always use trust ratio = 1, which turns this into + Adam. Useful for comparison purposes. + + .. _Large Batch Optimization for Deep Learning\: Training BERT in 76 minutes: + https://arxiv.org/abs/1904.00962 + """ + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0, adam=False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + self.adam = adam + super(Lamb, self).__init__(params, defaults) + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + + # Decay the first and second moment running average coefficient + # m_t + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + # v_t + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # Paper v3 does not use debiasing. + # bias_correction1 = 1 - beta1 ** state['step'] + # bias_correction2 = 1 - beta2 ** state['step'] + # Apply bias to lr to avoid broadcast. + # * math.sqrt(bias_correction2) / bias_correction1 + step_size = group['lr'] + + weight_norm = p.data.pow(2).sum().sqrt() + + adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps']) + if group['weight_decay'] != 0: + adam_step.add_(p.data, alpha=group['weight_decay']) + + adam_norm = adam_step.pow(2).sum().sqrt() + if weight_norm == 0 or adam_norm == 0: + trust_ratio = 1 + else: + trust_ratio = weight_norm / adam_norm + state['weight_norm'] = weight_norm + state['adam_norm'] = adam_norm + state['trust_ratio'] = trust_ratio + if self.adam: + trust_ratio = 1 + + p.data.add_(adam_step, alpha=-step_size * trust_ratio) + + return loss diff --git a/colossalai/nn/optimizer/lars.py b/colossalai/nn/optimizer/lars.py new file mode 100644 index 0000000000000000000000000000000000000000..212f66671a0db9f580d99758823fdf78e3e54106 --- /dev/null +++ b/colossalai/nn/optimizer/lars.py @@ -0,0 +1,102 @@ +"""Adapted from https://github.com/NUS-HPC-AI-Lab/LARS-ImageNet-PyTorch/blob/main/lars.py""" + +from typing import Iterable + +import torch +from torch.optim import Optimizer + +from colossalai.registry import OPTIMIZERS + + +@OPTIMIZERS.register_module +class Lars(Optimizer): + r"""Implements the LARS optimizer from `"Large batch training of convolutional networks" + `_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + momentum (float, optional): momentum factor (default: 0) + eeta (float, optional): LARS coefficient as used in the paper (default: 1e-3) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + """ + + def __init__( + self, + params: Iterable[torch.nn.Parameter], + lr=1e-3, + momentum=0, + eeta=1e-3, + weight_decay=0, + epsilon=0.0 + ) -> None: + if not isinstance(lr, float) or lr < 0.0: + raise ValueError("Invalid learning rate: {}".format(lr)) + if momentum < 0.0: + raise ValueError("Invalid momentum value: {}".format(momentum)) + if weight_decay < 0.0: + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay)) + if eeta <= 0 or eeta > 1: + raise ValueError("Invalid eeta value: {}".format(eeta)) + if epsilon < 0: + raise ValueError("Invalid epsilon value: {}".format(epsilon)) + defaults = dict(lr=lr, momentum=momentum, + weight_decay=weight_decay, eeta=eeta, epsilon=epsilon, lars=True) + + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + weight_decay = group['weight_decay'] + momentum = group['momentum'] + eeta = group['eeta'] + lr = group['lr'] + lars = group['lars'] + eps = group['epsilon'] + + for p in group['params']: + if p.grad is None: + continue + decayed_grad = p.grad + scaled_lr = lr + if lars: + w_norm = torch.norm(p) + g_norm = torch.norm(p.grad) + trust_ratio = torch.where( + w_norm > 0 and g_norm > 0, + eeta * w_norm / (g_norm + weight_decay * w_norm + eps), + torch.ones_like(w_norm) + ) + trust_ratio.clamp_(0.0, 50) + scaled_lr *= trust_ratio.item() + if weight_decay != 0: + decayed_grad = decayed_grad.add(p, alpha=weight_decay) + decayed_grad = torch.clamp(decayed_grad, -10.0, 10.0) + + if momentum != 0: + param_state = self.state[p] + if 'momentum_buffer' not in param_state: + buf = param_state['momentum_buffer'] = torch.clone( + decayed_grad).detach() + else: + buf = param_state['momentum_buffer'] + buf.mul_(momentum).add_(decayed_grad) + decayed_grad = buf + + p.add_(decayed_grad, alpha=-scaled_lr) + + return loss diff --git a/colossalai/nn/optimizer/nvme_optimizer.py b/colossalai/nn/optimizer/nvme_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..53e4a46c9741dc8cb5fe3fc2335789beada3e3c8 --- /dev/null +++ b/colossalai/nn/optimizer/nvme_optimizer.py @@ -0,0 +1,165 @@ +import math +import os +import tempfile +from typing import Callable, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + + +class NVMeOptimizer(torch.optim.Optimizer): + """A base class for offloading optimizer states. + + Args: + params: parameters + defaults (dict): default dict + nvme_offload_fraction (float, optional): Fraction of params to be offloaded to NVMe. Defaults to 0.0. + offload_dir (Optional[str], optional): Directory to save NVMe offload files. + If it's ``None``, a random temporary directory will be used. Defaults to None. + + Raises: + ImportError: Raise if ``tensornvme`` is not installed. + """ + + def __init__(self, + params, + defaults: dict, + nvme_offload_fraction: float = 0.0, + offload_dir: Optional[str] = None) -> None: + assert 0.0 <= nvme_offload_fraction <= 1.0 + super().__init__(params, defaults) + self.nvme_offload_fraction = float(nvme_offload_fraction) + if self.nvme_offload_fraction > 0.0: + try: + from tensornvme import DiskOffloader + from tensornvme._C import get_backends + except ModuleNotFoundError: + raise ModuleNotFoundError('Please install tensornvme to use NVMeOptimizer') + self.offload_dir = offload_dir or tempfile.mkdtemp() + backend = 'uring' if 'uring' in get_backends() else 'aio' + self.offloader = DiskOffloader(self.offload_dir, 8, backend=backend) + else: + self.offload_dir = None + self.offloader = None + self.is_on_nvme: Dict[Parameter, bool] = {} + self.offloaded_numel: int = 0 + # As param may be not materialized here, these attributes are initalized when the first step + self.total_numel: Optional[int] = None + self.can_offload_numel: Optional[int] = None + + self.prefetch_params: List[Parameter] = [] + self.param_to_prefetch_idx: Dict[Parameter, int] = {} + + def _get_numel(self) -> int: + numel = 0 + for group in self.param_groups: + for p in group['params']: + numel += p.storage().size() + return numel + + def _post_state_init(self, param: Parameter) -> None: + numel = param.storage().size() + if self.offloader is not None and param.device.type == 'cpu' and numel + self.offloaded_numel <= self.can_offload_numel: + self.is_on_nvme[param] = True + self.offloaded_numel += numel + else: + self.is_on_nvme[param] = False + + def _setup_prefetch_params(self) -> List[Parameter]: + if self.offloader is None: + return + assert len(self.prefetch_params) == 0 and len(self.param_to_prefetch_idx) == 0 + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + if len(self.state[p]) > 0 and self.is_on_nvme[p]: + assert p.device.type == 'cpu' + self.param_to_prefetch_idx[p] = len(self.prefetch_params) + self.prefetch_params.append(p) + + def _pre_step(self, *state_keys: str) -> None: + if self.total_numel is None: + self.total_numel = self._get_numel() + self.can_offload_numel = math.floor(self.total_numel * self.nvme_offload_fraction) + self._setup_prefetch_params() + if self.offloader is None or len(self.prefetch_params) == 0: + return + state = self.state[self.prefetch_params[0]] + for key in state_keys: + self.offloader.async_read(state[key]) + + def _pre_update(self, param: Parameter, *state_keys: str) -> None: + if self.offloader is None or param not in self.param_to_prefetch_idx: + return + self.offloader.sync_read_events() + idx = self.param_to_prefetch_idx[param] + if idx + 1 < len(self.prefetch_params): + state = self.state[self.prefetch_params[idx + 1]] + for key in state_keys: + self.offloader.async_read(state[key]) + + def _post_update(self, param: Parameter, *state_keys: str) -> None: + if self.offloader is None: + return + self.offloader.sync_write_events() + if self.is_on_nvme[param]: + state = self.state[param] + for key in state_keys: + self.offloader.async_write(state[key]) + + def _post_step(self) -> None: + if self.offloader is not None: + self.offloader.synchronize() + self.prefetch_params.clear() + self.param_to_prefetch_idx.clear() + + def step(self, closure: Optional[Callable[[], float]] = ...) -> Optional[float]: + """Performs a single optimization step (parameter update). + + Example: + + >>> self._pre_step('exp_avg', 'exp_avg_sq') + >>> for group in self.param_groups: + >>> for p in group['params']: + >>> if p.grad is None: + >>> continue + >>> state = self.state[p] + >>> if len(state) == 0: + >>> state['exp_avg'] = ... + >>> state['exp_avg_sq'] = ... + >>> self._post_state_init(p) + >>> if p.device.type == 'cpu': + >>> self._pre_update(p, 'exp_avg', 'exp_avg_sq') + >>> adam() + >>> self._post_update(p, 'exp_avg', 'exp_avg_sq') + >>> else: + >>> ... + >>> self._post_step() + + Args: + closure (Optional[Callable[[], float]], optional): A closure that reevaluates the model and + returns the loss. Optional for most optimizers. + """ + raise NotImplementedError + + def state_dict(self) -> dict: + # TODO(ver217): design a new method to save state_dict. When using NVMe offload, this method may lead to OOM. + if self.offloader is not None: + raise NotImplementedError + return super().state_dict() + + def load_state_dict(self, state_dict: dict) -> None: + # TODO(ver217): design a new method to load state_dict. When using NVMe offload, whole state_dict may not be able to fit in memory. + if self.offloader is not None: + raise NotImplementedError + super().load_state_dict(state_dict) + + def __del__(self) -> None: + if getattr(self, 'offloader', None) is not None: + del self.offloader + if os.path.exists(self.offload_dir): + try: + os.rmdir(self.offload_dir) + except OSError: + pass diff --git a/colossalai/nn/parallel/__init__.py b/colossalai/nn/parallel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..17e010f478c92f5c445c32ef2c512ef3007f95b2 --- /dev/null +++ b/colossalai/nn/parallel/__init__.py @@ -0,0 +1,5 @@ +from .data_parallel import ColoDDP + +__all__ = [ + 'ColoDDP', +] diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..f839d6b2844491a870b38a68459c4bedaf9d760a --- /dev/null +++ b/colossalai/nn/parallel/data_parallel.py @@ -0,0 +1,174 @@ +from collections import OrderedDict +from functools import partial +from typing import Iterable, Optional, Set + +import torch +import torch.distributed as dist + +from colossalai.tensor import ProcessGroup as ColoProcessGroup +from colossalai.utils import is_ddp_ignored + +from .reducer import Reducer + + +def free_storage(data: torch.Tensor) -> None: + """Free underlying storage of a Tensor.""" + if data.storage().size() > 0: + # Since we're modifying the Tensor's Storage directly, make sure the Tensor + # is the sole occupant of the Storage. + assert data.storage_offset() == 0 + data.storage().resize_(0) + + +def _cast_float(args, dtype: torch.dtype): + if isinstance(args, torch.Tensor) and torch.is_floating_point(args): + args = args.to(dtype) + elif isinstance(args, (list, tuple)): + args = type(args)(_cast_float(t, dtype) for t in args) + elif isinstance(args, dict): + args = {k: _cast_float(v, dtype) for k, v in args.items()} + return args + + +class ColoDDP(torch.nn.Module): + """Distributed data parallel for ColoTensor. Nested ColoDDP is not supported now. + + Example: + >>> from colossalai.core import global_context as gpc + >>> from colossalai.context import ParallelMode + >>> model = torch.nn.Linear(20, 1) + >>> pg = ProcessGroup(tp_degree = world_size//2) + >>> model = ColoDDP(model, pg) + >>> logits = model(x) + >>> loss = criterion(logits, labels) + >>> model.backward(loss) + + Args: + module (torch.nn.Module): Module to apply DDP. + process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses. + If it's None, the default data parallel group will be used. Defaults to None. + """ + + def __init__(self, + module: torch.nn.Module, + process_group: ColoProcessGroup, + bucket_cap_mb: int = 25, + rebuild_bucket: bool = True) -> None: + assert not isinstance(module, ColoDDP) + super().__init__() + self.module = module + self.comm_stream: torch.cuda.Stream = torch.cuda.Stream() + assert process_group + + self.process_group = process_group + self.dp_world_size = self.process_group.dp_world_size() + + self.reducer = Reducer(bucket_cap_mb) + self.rebuild_bucket = rebuild_bucket + for p in module.parameters(): + if is_ddp_ignored(p): + continue + if p.requires_grad: + p.register_hook(partial(self.grad_handle, p)) + + def parameters(self, recurse: bool = True): + return self.module.parameters(recurse) + + def named_parameters(self, prefix: str = '', recurse: bool = True): + return self.module.named_parameters(prefix, recurse) + + def named_buffers(self, prefix: str = '', recurse: bool = True): + return self.module.named_buffers(prefix, recurse) + + def named_children(self): + return self.module.named_children() + + def named_modules(self, + memo: Optional[Set[torch.nn.Module]] = None, + prefix: str = '', + remove_duplicate: bool = True): + return self.module.named_modules(memo, prefix, remove_duplicate) + + def forward(self, *args, **kwargs): + self.module.zero_grad(set_to_none=True) + return self.module(*args, **kwargs) + + def backward(self, loss: torch.Tensor): + loss.backward() + with torch.cuda.stream(self.comm_stream): + self.reducer.flush() + torch.cuda.current_stream().wait_stream(self.comm_stream) + if self.rebuild_bucket: + self.reducer.free() + for p in self.module.parameters(): + if is_ddp_ignored(p): + continue + if p.grad.device.type != "cpu": + p.grad = p._saved_grad + + def grad_handle(self, p, grad): + if grad.device.type != "cpu": + empty_grad = torch.empty_like(grad) + free_storage(empty_grad) + if self.dp_world_size > 1: + grad = grad / self.dp_world_size + self.comm_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self.comm_stream): + self.reducer.all_reduce_async(grad, + group=self.process_group.dp_process_group(), + callback_fn=partial(self._save_grad, p)) + grad.record_stream(self.comm_stream) + else: + ColoDDP._save_grad(p, grad) + return empty_grad + + else: + # TODO(jiaruifang) fixme + self.process_group.set_cpu_groups() + dist.all_reduce(grad, group=self.process_group.cpu_dp_process_group()) + return grad + + @staticmethod + def _save_grad(p, grad): + if hasattr(p, '_saved_grad'): + p._saved_grad.add_(grad) + else: + p._saved_grad = grad + + def zero_grad(self, set_to_none: bool = False) -> None: + self.module.zero_grad(set_to_none=True) + for p in self.module.parameters(): + if getattr(p, '_saved_grad', None) is not None: + if set_to_none: + p._saved_grad = None + else: + if p._saved_grad.grad_fn is not None: + p._saved_grad.detach_() + else: + p._saved_grad.requires_grad_(False) + p._saved_grad.zero_() + + @staticmethod + def set_params_to_ignore(params_to_ignore: Iterable[torch.Tensor]) -> None: + """Sets parameters to be ignored by DDP. + This method must be called before initializing ColoDDP. + + Example: + >>> params_to_ignore = [] + >>> for p in module.parameters(): + >>> if should_ignore(p): + >>> params_to_ignore.append(p) + >>> ColoDDP.set_params_to_ignore(params_to_ignore) + >>> module = ColoDDP(module) + + Args: + params_to_ignore (Iterable[torch.Tensor]): A list of parameters to be ignored. + """ + for p in params_to_ignore: + p._ddp_to_ignore = True + + def state_dict(self, destination=None, prefix='', keep_vars=False): + return self.module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) + + def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True): + return self.module.load_state_dict(state_dict, strict) diff --git a/colossalai/nn/parallel/layers/__init__.py b/colossalai/nn/parallel/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..29b8353e63c5930950d28d58b0c122c76486c3e1 --- /dev/null +++ b/colossalai/nn/parallel/layers/__init__.py @@ -0,0 +1,14 @@ +from .colo_module import ColoModule +from .linear import ColoLinear +from .embedding import ColoEmbedding +from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module + +from .cache_embedding import CachedEmbeddingBag, ParallelCachedEmbeddingBag, CachedParamMgr, LimitBuffIndexCopyer, EvictionStrategy, \ + ParallelCachedEmbeddingBagTablewise, TablewiseEmbeddingBagConfig, ParallelCachedEmbeddingBagTablewiseSpiltCache + +__all__ = [ + 'ColoModule', 'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module', + 'ColoLinear', 'ColoEmbedding', 'CachedEmbeddingBag', 'ParallelCachedEmbeddingBag', 'CachedParamMgr', + 'LimitBuffIndexCopyer', 'EvictionStrategy', 'ParallelCachedEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig', + 'ParallelCachedEmbeddingBagTablewiseSpiltCache' +] diff --git a/colossalai/nn/parallel/layers/cache_embedding/__init__.py b/colossalai/nn/parallel/layers/cache_embedding/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5bbc931a79dceeffcf827fc3fd3b18ca8bf87dd3 --- /dev/null +++ b/colossalai/nn/parallel/layers/cache_embedding/__init__.py @@ -0,0 +1,13 @@ +from .cache_mgr import CachedParamMgr, EvictionStrategy +from .copyer import LimitBuffIndexCopyer +from .cached_embedding import CachedEmbeddingBag +from .parallel_cached_embedding import ParallelCachedEmbeddingBag +from .embedding_config import TablewiseEmbeddingBagConfig +from .parallel_cached_embedding_tablewise import ParallelCachedEmbeddingBagTablewise +from .parallel_cached_embedding_tablewise_split_cache import ParallelCachedEmbeddingBagTablewiseSpiltCache + +__all__ = [ + 'CachedParamMgr', 'LimitBuffIndexCopyer', 'CachedEmbeddingBag', 'ParallelCachedEmbeddingBag', 'EvictionStrategy', + 'ParallelCachedEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig', + 'ParallelCachedEmbeddingBagTablewiseSpiltCache' +] diff --git a/colossalai/nn/parallel/layers/cache_embedding/base_embedding.py b/colossalai/nn/parallel/layers/cache_embedding/base_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..705835a0ed22ef5c9f6ecc388ddcf8d4e2ea3073 --- /dev/null +++ b/colossalai/nn/parallel/layers/cache_embedding/base_embedding.py @@ -0,0 +1,36 @@ +import abc +import torch.nn as nn + + +class BaseEmbeddingBag(abc.ABC, nn.Module): + + def __init__( + self, + num_embeddings, + embedding_dim, + padding_idx=None, + max_norm=None, + norm_type=2., + scale_grad_by_freq=False, + sparse=False, + mode='mean', + include_last_offset=False, + ): + super(BaseEmbeddingBag, self).__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + if padding_idx is not None: + if padding_idx > 0: + assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' + elif padding_idx < 0: + assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' + padding_idx = self.num_embeddings + padding_idx + self.padding_idx = padding_idx + self.max_norm = max_norm + self.norm_type = norm_type + self.scale_grad_by_freq = scale_grad_by_freq + self.sparse = sparse + + # Specific to embedding bag + self.mode = mode + self.include_last_offset = include_last_offset diff --git a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py new file mode 100644 index 0000000000000000000000000000000000000000..da043df368ae17ed1b95cc42b90e58994c84f9aa --- /dev/null +++ b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py @@ -0,0 +1,583 @@ +import numpy as np +import torch +from torch.profiler import record_function +from typing import List, Optional +from contexttimer import Timer +from .copyer import LimitBuffIndexCopyer +from enum import Enum +import sys +from contextlib import contextmanager + + +class EvictionStrategy(Enum): + LFU = 1 + # dataset aware eviction strategy + DATASET = 2 + + +def _wait_for_data(t, stream: Optional[torch.cuda.streams.Stream]) -> None: + if stream is None: + return + torch.cuda.current_stream().wait_stream(stream) + # As mentioned in https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html, + # PyTorch uses the "caching allocator" for memroy allocation for tensors. When a tensor is + # freed, its memory is likely to be reused by newly constructed tenosrs. By default, + # this allocator traces whether a tensor is still in use by only the CUDA stream where it + # was created. When a tensor is used by additional CUDA streams, we need to call record_stream + # to tell the allocator about all these streams. Otherwise, the allocator might free the + # underlying memory of the tensor once it is no longer used by the creator stream. This is + # a notable programming trick when we write programs using multi CUDA streams. + cur_stream = torch.cuda.current_stream() + assert isinstance(t, torch.Tensor) + t.record_stream(cur_stream) + + +class CachedParamMgr(torch.nn.Module): + """ + Manage Embedding Weights on CPU and CUDA memory uses a software cache. + CPU maintains the entire original weight. + CUDA maintains a fraction of the weights used in the upcoming computation. The row number in CUDA is controlled by `cuda_row_num`. + During training, GPU needs to transmit embedding rows between CPU and GPU. + Args: + weight (torch.Tensor): the weight of the Embedding layer. + cuda_row_num (int, optional): the number of rows cached in CUDA memory. Defaults to 0. + buffer_size (int, optional): the number of rows in a data transmitter buffer. Defaults to 50_000. + pin_weight (bool, optional): use pin memory to store the cpu weight. If set `True`, the cpu memory usage will increase largely. Defaults to False. + evict_strategy (EvictionStrategy, optional): the eviction strategy. There are two options. + `EvictionStrategy.LFU`: use the least frequently used cache. + `EvictionStrategy.DATASET`: use the stats collected from the target dataset. It usually leads to less cpu-gpu communication volume. + Defaults to EvictionStrategy.DATASET. + """ + + def __init__( + self, + weight: torch.Tensor, + cuda_row_num: int = 0, + buffer_size: int = 0, + pin_weight: bool = True, + evict_strategy: EvictionStrategy = EvictionStrategy.DATASET, + async_copy: bool = False, + ) -> None: + super(CachedParamMgr, self).__init__() + self.buffer_size = buffer_size + self.num_embeddings, self.embedding_dim = weight.shape + self.cuda_row_num = cuda_row_num + self._cuda_available_row_num = self.cuda_row_num + self.pin_weight = pin_weight + self.elem_size_in_byte = weight.element_size() + + # weight configure + self._init_weight(weight) + + # Perf log + self.num_hits_history = [] + self.num_miss_history = [] + self.num_write_back_history = [] + + self._evict_strategy = evict_strategy + + self._async_copy = async_copy + + if self._async_copy: + self._memcpy_stream = torch.cuda.Stream() + + print('use async copy') + + if self._evict_strategy == EvictionStrategy.LFU: + # cache_row_idx -> frequency, freq of the cache rows. + # classic lfu cache. evict the minimal freq value row in cuda cache. + self.register_buffer("freq_cnter", + torch.empty(self.cuda_row_num, device=torch.cuda.current_device(), + dtype=torch.long).fill_(sys.maxsize), + persistent=False) + self._elapsed_dict = {} + self._show_cache_miss = True + self._reset_comm_stats() + + def _reset_comm_stats(self): + for k in self._elapsed_dict.keys(): + self._elapsed_dict[k] = 0 + + self._cpu_to_cuda_numel = 0 + self._cuda_to_cpu_numel = 0 + if self._show_cache_miss: + self._cache_miss = 0 + self._total_cache = 0 + + @contextmanager + def timer(self, name): + with Timer() as t: + yield + torch.cuda.synchronize() + + if name not in self._elapsed_dict.keys(): + self._elapsed_dict[name] = 0 + self._elapsed_dict[name] += t.elapsed + + def _find_evict_gpu_idxs(self, evict_num: int) -> torch.Tensor: + """_find_evict_gpu_idxs + Find the gpu idxs to be evicted, according to their freq. + Args: + evict_num (int): how many rows has to be evicted + Returns: + torch.Tensor: a list tensor (1D), contains the gpu_row_idxs. + """ + if self._evict_strategy == EvictionStrategy.LFU: + # find the minimal evict_num freq entries in cached_idx_map + _, evict_gpu_row_idxs = torch.topk(self.freq_cnter, evict_num, largest=False) + return evict_gpu_row_idxs + elif self._evict_strategy == EvictionStrategy.DATASET: + # cached_idx_map itself implies the priority of eviction. + # The value of self.cached_idx_map represents cpu_row_idx. + # The larger it is, the less frequently it will appear in the dataset, + # and the higher its eviction priority will be. + _, evict_gpu_row_idxs = torch.topk(self.cached_idx_map, evict_num, largest=True) + return evict_gpu_row_idxs + else: + raise TypeError + + def _init_weight(self, weight): + if self.cuda_row_num > 0: + # Enable cache with introducing auxiliary data structures + self.cuda_cached_weight = torch.nn.Parameter( + torch.zeros(self.cuda_row_num, + self.embedding_dim, + device=torch.cuda.current_device(), + dtype=weight.dtype)) + + # pin memory cpu for higher CPU-GPU copy bandwidth + self.weight = weight.pin_memory() if self.pin_weight else weight + # map original id to new id with respect to frequency + # id -> cpu_row_idx + self.register_buffer( + "idx_map", + torch.arange(self.num_embeddings, dtype=torch.long, device=torch.cuda.current_device()), + persistent=False, + ) + + # cached_idx_map: gpu_row_idx -> cpu_row_idx + self.register_buffer("cached_idx_map", + torch.empty(self.cuda_row_num, device=torch.cuda.current_device(), + dtype=torch.long).fill_(-1), + persistent=False) + + # cpu_row_id -> gpu_row_idx. + # gpu_row_idx as -1 means cpu_row_id not in CUDA. + self.register_buffer("inverted_cached_idx", + torch.zeros(self.num_embeddings, device=torch.cuda.current_device(), + dtype=torch.long).fill_(-1), + persistent=False) + + self.evict_backlist = torch.tensor([], device=torch.cuda.current_device()) + + # index copy buffer size should less than 10% of cuda weight. + if self.buffer_size > 0: + self.limit_buff_index_copyer = LimitBuffIndexCopyer(self.buffer_size) + + else: + # Disable cache so that FreqCacheEmbedding is compatible with vanilla EmbeddingBag + # self.weight = torch.nn.Parameter(weight) + # self.cuda_cached_weight = self.weight + raise NotImplementedError() + + def cpu_weight_data(self, row_idx: int) -> torch.Tensor: + """ + access a row of CPU weight. + Args: + row_idx (int): the idx of rows + Returns: + torch.Tensor: a piece of memory in CPU weight corresponding to row id's payload. The tensor is 1-D. + """ + + return self.weight.data.view(-1).narrow(0, + int(row_idx) * self.embedding_dim, + self.embedding_dim).view(1, self.embedding_dim) + + @property + def cuda_available_row_num(self): + return self._cuda_available_row_num + + @torch.no_grad() + def reorder(self, ids_freq_mapping: Optional[List[int]] = None, warmup_ratio=0.7): + """reorder + reorder the weight according to ids' frequency in dataset before training. + Execute only once before training, also known as warmup phase. + + Note: + If you would like to use the DATASET as the eviction strategy, you must call this function. + Note: + If you are use the LFU as the eviction strategy, you can skip this function. If you still use this function. It will initialize + The frequency in LFU cache using the dataset statistics. + Args: + ids_freq_mapping (List[int]): a list, whose offset is id number, value is freq. if None then not reorder the cpu weight. + warmup_ratio (float): the amount of chunks preloaded in cuda cache + """ + # reorder phase: reorder the cpu weight according to their freq stats in the target dataset. + # reorder only works for DATASET eviction strategy. + + if ids_freq_mapping is not None and not isinstance(ids_freq_mapping, torch.Tensor): + ids_freq_mapping = torch.tensor(ids_freq_mapping) + + if self._evict_strategy == EvictionStrategy.DATASET: + if ids_freq_mapping is not None: + tmp_idx = torch.argsort(ids_freq_mapping, descending=True) + sorted_idx = torch.argsort(tmp_idx) + self.idx_map.data.copy_(sorted_idx) + + # warmup phase: copy #preload_row_num rows from cpu to gpu. + preload_row_num = min(int(np.ceil(self.cuda_row_num * warmup_ratio)), self.num_embeddings) + if preload_row_num > 0: + with Timer() as timer: + # extract rows from cpu weight + if self._evict_strategy == EvictionStrategy.LFU and ids_freq_mapping is not None: + freq_value, preload_cpu_ids = torch.topk(ids_freq_mapping, preload_row_num, dim=0, largest=True) + preload_cuda_row_idxs = torch.arange(preload_row_num).cuda() + else: + preload_cpu_ids = torch.arange(preload_row_num) + preload_cuda_row_idxs = preload_cpu_ids.cuda() + if self.buffer_size > 0: + self.limit_buff_index_copyer.index_copy(0, + src_index=preload_cpu_ids, + tgt_index=preload_cuda_row_idxs, + src=self.weight.view(self.num_embeddings, -1), + tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1)) + else: + preload_rows = self.weight.view(self.num_embeddings, -1).index_select(0, preload_cpu_ids).cuda() + self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_cuda_row_idxs, + preload_rows) + + # update auxiliary info + self.cached_idx_map[preload_cuda_row_idxs] = preload_cpu_ids.cuda() + self.inverted_cached_idx[preload_cpu_ids] = preload_cuda_row_idxs + self._cuda_available_row_num -= preload_row_num + + if self._evict_strategy == EvictionStrategy.LFU: + # if the ids_freq_mapping is not None, we initialize the embedding row's freq value in LFU as its freq in dataset. + if ids_freq_mapping is None: + self.freq_cnter.index_fill_(0, preload_cuda_row_idxs, 0) + else: + self.freq_cnter[preload_cuda_row_idxs] = freq_value.cuda() + + print(f'Cache warmup finished cost {timer.elapsed} sec.') + + def flush(self): + """flush all CUDA rows to CPU. + The function is usually called after training finished. + """ + slots = torch.nonzero(self.cached_idx_map > -1).squeeze(1) + row_ids = self.cached_idx_map[slots] + rows = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, slots).cpu() + self.weight.view(self.num_embeddings, -1).index_copy_(0, row_ids.cpu(), rows) + self.cached_idx_map.index_fill_(0, slots, -1) + self.inverted_cached_idx.index_fill_(0, row_ids, -1) + self._cuda_available_row_num += slots.numel() + + if self._show_cache_miss: + self._cache_miss = 0 + self._total_cache = 0 + + if self._evict_strategy == EvictionStrategy.LFU: + self.freq_cnter.fill_(sys.maxsize) + assert self._cuda_available_row_num == self.cuda_row_num + assert torch.all(self.inverted_cached_idx == -1).item() + assert torch.all(self.cached_idx_map == -1).item() + + def print_comm_stats(self): + if self._cuda_to_cpu_numel > 0 and "3_evict_out" in self._elapsed_dict: + elapsed = self._elapsed_dict["3_evict_out"] + print( + f"CUDA->CPU BWD {self._cuda_to_cpu_numel * self.elem_size_in_byte / 1e6 / elapsed} MB/s {self._cuda_to_cpu_numel / 1e6} M elem" + ) + print(f'cuda_to_cpu_elapse {elapsed} sec') + if self._cpu_to_cuda_numel > 0 and "5_evict_in" in self._elapsed_dict: + elapsed = self._elapsed_dict["5_evict_in"] + print( + f"CPU->CUDA BWD {self._cpu_to_cuda_numel * self.elem_size_in_byte / 1e6 / elapsed} MB/s {self._cpu_to_cuda_numel / 1e6} M elem" + ) + print(f'cpu_to_cuda_elpase {elapsed} sec') + + for k, v in self._elapsed_dict.items(): + print(f'{k}: {v}') + + print(f'cache miss ratio {self._cache_miss / self._total_cache}') + + @torch.no_grad() + def _id_to_cached_cuda_id(self, ids: torch.Tensor) -> torch.Tensor: + """ + convert ids to indices in self.cuda_cached_weight. + Implemented with parallel operations on GPU. + Args: + ids (torch.Tensor): ids from the dataset + Returns: + torch.Tensor: contains indices in self.cuda_cached_weight + """ + ids = self.idx_map.index_select(0, ids.view(-1)) + ret = self.inverted_cached_idx.index_select(0, ids) + return ret + + @torch.no_grad() + def prepare_ids(self, ids: torch.Tensor) -> torch.Tensor: + """ + move the cpu embedding rows w.r.t. ids into CUDA memory + Args: + ids (torch.Tensor): the ids to be computed + Returns: + torch.Tensor: indices on the cuda_cached_weight. + """ + torch.cuda.synchronize() + with self.timer("cache_op") as gtimer: + # identify cpu rows to cache + with self.timer("1_identify_cpu_row_idxs") as timer: + with record_function("(cache) get unique indices"): + if self._evict_strategy == EvictionStrategy.LFU: + cpu_row_idxs, repeat_times = torch.unique(ids, return_counts=True) + else: + cpu_row_idxs, repeat_times = torch.unique(self.idx_map.index_select(0, ids), return_counts=True) + + assert len(cpu_row_idxs) <= self.cuda_row_num, \ + f"You move {len(cpu_row_idxs)} embedding rows from CPU to CUDA. " \ + f"It is larger than the capacity of the cache, which at most contains {self.cuda_row_num} rows, " \ + f"Please increase cuda_row_num or decrease the training batch size." + self.evict_backlist = cpu_row_idxs + tmp = torch.isin(cpu_row_idxs, self.cached_idx_map, invert=True) + comm_cpu_row_idxs = cpu_row_idxs[tmp] + + if self._show_cache_miss: + self._cache_miss += torch.sum(repeat_times[tmp]) + self._total_cache += ids.numel() + + self.num_hits_history.append(len(cpu_row_idxs) - len(comm_cpu_row_idxs)) + self.num_miss_history.append(len(comm_cpu_row_idxs)) + self.num_write_back_history.append(0) + + # move sure the cuda rows will not be evicted! + with record_function("(cache) prepare_rows_on_cuda"): + with self.timer("prepare_rows_on_cuda") as timer: + self._prepare_rows_on_cuda(comm_cpu_row_idxs) + + self.evict_backlist = torch.tensor([], device=cpu_row_idxs.device, dtype=cpu_row_idxs.dtype) + + with self.timer("6_update_cache") as timer: + with record_function("6_update_cache"): + gpu_row_idxs = self._id_to_cached_cuda_id(ids) + + # update for LFU. + if self._evict_strategy == EvictionStrategy.LFU: + unique_gpu_row_idxs = self.inverted_cached_idx[cpu_row_idxs] + self.freq_cnter.scatter_add_(0, unique_gpu_row_idxs, repeat_times) + + return gpu_row_idxs + + def _row_in_cuda(self, row_id: int) -> bool: + return self.inverted_cached_idx[row_id] != -1 + + @torch.no_grad() + def _prepare_rows_on_cuda(self, cpu_row_idxs: torch.Tensor) -> None: + """prepare rows in cpu_row_idxs on CUDA memory + Args: + cpu_row_idxs (torch.Tensor): the rows to be placed on CUDA + """ + evict_num = cpu_row_idxs.numel() - self.cuda_available_row_num + + cpu_row_idxs_copy = cpu_row_idxs.cpu() + + # move evict in rows to gpu + if self._async_copy: + if self.buffer_size == 0: + evict_in_rows_gpu = self.weight.view(self.num_embeddings, + -1).index_select(0, cpu_row_idxs_copy).pin_memory() + with torch.cuda.stream(self._memcpy_stream): + evict_in_rows_gpu = evict_in_rows_gpu.to(torch.cuda.current_device(), non_blocking=True) + else: + raise NotImplemented + + if evict_num > 0: + with self.timer("2_identify_cuda_row_idxs") as timer: + mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist) + invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1) + if self._evict_strategy == EvictionStrategy.DATASET: + # mask method. + # set cached_idx_map[invalid_idxs] to -2. + # so those idxs will be sorted to end, therefore not being chosen as victim + backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone() + self.cached_idx_map.index_fill_(0, invalid_idxs, -2) + + with self.timer("2_1_find_evict_gpu_idxs") as timer: + evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num) + + # move evict out rows to cpu + if self._async_copy: + evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num, + -1).index_select(0, evict_gpu_row_idxs) + evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True) + with torch.cuda.stream(None): + evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True) + self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs) + + elif self._evict_strategy == EvictionStrategy.LFU: + with self.timer("2_1_backup_freqs") as timer: + backup_freqs = self.freq_cnter[invalid_idxs].clone() + self.freq_cnter.index_fill_(0, invalid_idxs, sys.maxsize) + + with self.timer("2_2_find_evict_gpu_idxs") as timer: + evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num) + + if self._async_copy: + evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num, + -1).index_select(0, evict_gpu_row_idxs) + evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True) + with torch.cuda.stream(None): + evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True) + + with self.timer("2_3_revert_freqs") as timer: + self.freq_cnter.index_copy_(0, invalid_idxs, backup_freqs) + + evict_info = self.cached_idx_map[evict_gpu_row_idxs] + + with self.timer("3_evict_out") as timer: + if self.buffer_size > 0: + self.limit_buff_index_copyer.index_copy(0, + src_index=evict_gpu_row_idxs, + tgt_index=evict_info.cpu(), + src=self.cuda_cached_weight.view(self.cuda_row_num, -1), + tgt=self.weight.view(self.num_embeddings, -1)) + else: + # allocate tmp memory on CPU and copy rows on CUDA to CPU. + # TODO async gpu -> cpu + if self._async_copy: + _wait_for_data(evict_out_rows_cpu, None) + else: + with self.timer("3_1_evict_out_index_select") as timer: + evict_out_rows_cpu = self.cuda_cached_weight.view(self.cuda_row_num, + -1).index_select(0, evict_gpu_row_idxs) + with self.timer("3_2_evict_out_gpu_to_cpu_copy") as timer: + evict_out_rows_cpu = evict_out_rows_cpu.cpu() + + with self.timer("3_2_evict_out_cpu_copy") as timer: + self.weight.view(self.num_embeddings, -1).index_copy_(0, evict_info.cpu(), evict_out_rows_cpu) + + self.cached_idx_map.index_fill_(0, evict_gpu_row_idxs, -1) + self.inverted_cached_idx.index_fill_(0, evict_info, -1) + # self.freq_cnter.index_fill(0, evict_gpu_row_idxs, sys.maxsize) # unnecessary + self._cuda_available_row_num += evict_num + + weight_size = evict_gpu_row_idxs.numel() * self.embedding_dim + self._cuda_to_cpu_numel += weight_size + # print(f"evict embedding weight: {weight_size*self.elem_size_in_byte/1e6:.2f} MB") + + # slots of cuda weight to evict in + with self.timer("4_identify_cuda_slot") as timer: + slots = torch.nonzero(self.cached_idx_map == -1).squeeze(1)[:cpu_row_idxs.numel()] + + # TODO wait for optimize + with self.timer("5_evict_in") as timer: + # Here also allocate extra memory on CUDA. #cpu_row_idxs + if self.buffer_size > 0: + self.limit_buff_index_copyer.index_copy(0, + src_index=cpu_row_idxs_copy, + tgt_index=slots, + src=self.weight.view(self.num_embeddings, -1), + tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1)) + else: + if self._async_copy: + _wait_for_data(evict_in_rows_gpu, self._memcpy_stream) + else: + with self.timer("5_1_evict_in_index_select") as timer: + # narrow index select to a subset of self.weight + # tmp = torch.narrow(self.weight.view(self.num_embeddings, -1), 0, min(cpu_row_idxs).cpu(), max(cpu_row_idxs) - min(cpu_row_idxs) + 1) + # evict_in_rows_gpu = tmp.index_select(0, cpu_row_idxs_copy - min(cpu_row_idxs).cpu()) + evict_in_rows_gpu = self.weight.view(self.num_embeddings, + -1).index_select(0, cpu_row_idxs_copy).pin_memory() + + with self.timer("5_2_evict_in_gpu_to_cpu_copy") as timer: + evict_in_rows_gpu = evict_in_rows_gpu.cuda() + + with self.timer("5_3_evict_in_index_copy") as timer: + self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, slots, evict_in_rows_gpu) + + with self.timer("6_update_cache") as timer: + self.cached_idx_map[slots] = cpu_row_idxs + self.inverted_cached_idx.index_copy_(0, cpu_row_idxs, slots) + if self._evict_strategy == EvictionStrategy.LFU: + self.freq_cnter.index_fill_(0, slots, 0) + self._cuda_available_row_num -= cpu_row_idxs.numel() + + weight_size = cpu_row_idxs.numel() * self.embedding_dim + self._cpu_to_cuda_numel += weight_size + # print(f"admit embedding weight: {weight_size*self.elem_size_in_byte/1e6:.2f} MB") + + def _find_free_cuda_row(self) -> int: + if self._cuda_available_row_num == 0: + return -1 + candidates = torch.nonzero(self.cached_idx_map == -1).squeeze(1) + return candidates[0].item() + + def _evict(self) -> int: + """ + deprecated + evict one row from cuda to cpu. + Returns: + (int) : the slot id be evicted. + """ + mask = torch.logical_or(torch.isin(self.cached_idx_map, self.evict_backlist), self.cached_idx_map == -1) + buf = self.cached_idx_map[mask].clone() + idx = torch.nonzero(mask).squeeze(1) + self.cached_idx_map.index_fill_(0, idx, -1) + max_row, max_cpu_row_idx = torch.max(self.cached_idx_map, dim=0) + max_gpu_row_idx = self.cached_idx_map[max_cpu_row_idx] + + if max_gpu_row_idx == -1: + raise RuntimeError("Can not evict a row") + + max_gpu_row_idx = max_gpu_row_idx.item() + max_offset = self.inverted_cached_idx[max_gpu_row_idx] + # recover + self.cached_idx_map.index_copy_(0, idx, buf) + + with Timer() as timer: + cuda_tensor = torch.narrow(self.cuda_cached_weight.view(-1), 0, max_offset * self.embedding_dim, + self.embedding_dim).view(1, self.embedding_dim) + self.cpu_weight_data(max_gpu_row_idx).data.copy_(cuda_tensor) + + # update inverted_cached_idx, min_slot_id is evicted from cuda + self.cached_idx_map[max_cpu_row_idx] = -1 + if self._evict_strategy == EvictionStrategy.LFU: + self.freq_cnter[max_cpu_row_idx] = sys.maxsize + self.inverted_cached_idx[max_gpu_row_idx] = -1 + + self._cuda_available_row_num += 1 + + self._cuda_to_cpu_numel += self.embedding_dim + # self.num_write_back_history[-1] += 1 + return max_cpu_row_idx + + @torch.no_grad() + def _admit(self, row_id: int): + """ + deprecated + move in row_id to CUDA + Args: + row_id (int): the id of row to be moved in + """ + # find a free slot in partial cuda weight + slot_id = self._find_free_cuda_row() + + if slot_id == -1: + # evict one row + slot_id = self._evict() + slot_offset = slot_id + # copy payload from cpu to cuda + with Timer() as timer: + cuda_tensor = torch.narrow(self.cuda_cached_weight.view(-1), 0, slot_offset * self.embedding_dim, + self.embedding_dim).view(1, self.embedding_dim) + cuda_tensor.data.copy_(self.cpu_weight_data(row_id)) + + # update the inverted_cached_idx + self.cached_idx_map[slot_id] = row_id + if self._evict_strategy == EvictionStrategy.LFU: + self.freq_cnter[slot_id] = 0 + self.inverted_cached_idx[row_id] = slot_offset + + self._cuda_available_row_num -= 1 + + self._cpu_to_cuda_numel += self.embedding_dim diff --git a/colossalai/nn/parallel/layers/cache_embedding/cached_embedding.py b/colossalai/nn/parallel/layers/cache_embedding/cached_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c45d8e80c028637a5c964b41062846b7020c7b --- /dev/null +++ b/colossalai/nn/parallel/layers/cache_embedding/cached_embedding.py @@ -0,0 +1,157 @@ +import torch +import torch.nn.functional as F +from typing import List, Optional, Iterator, Tuple, Union + +from .base_embedding import BaseEmbeddingBag +from .cache_mgr import CachedParamMgr, EvictionStrategy +from torch.nn.parameter import Parameter + + +class CachedEmbeddingBag(BaseEmbeddingBag): + """CachedEmbeddingBag + + Cached Embedding. Apply a GPU-based software cache approaches to dynamically manage the embedding table in the CPU and GPU memory space. + It can leverage the id's frequency statistics of the target dataset, by passing a frequency list to param `ids_freq_mapping`. + You can also apply a navie LFU cache eviction strategy by setting `evict_strategy` as EvictionStrategy.LFU. + + Args: + num_embeddings (int): size of the dictionary of embeddings + embedding_dim (int): the size of each embedding vector + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; therefore, the embedding vector at padding_idx is not updated during training, i.e. it remains as a fixed “pad”. For a newly constructed EmbeddingBag, the embedding vector at padding_idx will default to all zeros, but can be updated to another value to be used as the padding vector. Note that the embedding vector at padding_idx is excluded from the reduction. + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is renormalized to have norm max_norm + norm_type (str, optional): The p of the p-norm to compute for the max_norm option. Defaults to 2.. + scale_grad_by_freq (bool, optional): if given, this will scale gradients by the inverse of frequency of the words in the mini-batch. Default False. Note: this option is not supported when mode="max". Defaults to False. + sparse (bool, optional): if True, gradient w.r.t. weight matrix will be a sparse tensor. See Notes for more details regarding sparse gradients. Note: this option is not supported when mode="max".. Defaults to False. + _weight (torch.Tensor, optional): an embedding weight tensor. Concate multiple tables in a embedding bag as a single one. Defaults to None. + mode (str, optional): "sum", "mean" or "max". Specifies the way to reduce the bag. "sum" computes the weighted sum, taking per_sample_weights into consideration. "mean" computes the average of the values in the bag, "max" computes the max value over each bag. Default: "mean". Defaults to 'mean'. + include_last_offset (bool, optional): if True, offsets has one additional element, where the last element is equivalent to the size of indices. This matches the CSR format.. Defaults to False. + dtype (torch.dtype, optional): data type of the cpu weight initialization. Defaults to None meaning float32. + device (torch.device, optional): device type to the cpu weight. Defaults to None meaning cpu. + cache_ratio (float, float): cache ratio of the #cuda_weight_row / #cpu_weight_row + ids_freq_mapping (Union[List, torch.Tensor], optional): the frequency of each embedding vector occures in dataset. Defaults to None. + warmup_ratio (float, optional): the ratio of cuda cache is warmuped with. Defaults to 0.7. + buffer_size (int, optional): the max number of vectors in transmitter buffer. If set to 0, the buffer is not used. Defaults to 0. + pin_weight (bool, optional): pin the cpu weight. Defaults to False. + evict_strategy (EvictionStrategy, optional): evict strategy of the software cache. Defaults to EvictionStrategy.DATASET. + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + max_norm: float = None, + norm_type: float = 2., + scale_grad_by_freq: bool = False, + sparse: bool = False, + _weight: Optional[torch.Tensor] = None, + mode: str = 'mean', + include_last_offset: bool = False, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + cache_ratio: float = 0.01, + ids_freq_mapping: Optional[Union[List, torch.Tensor]] = None, + warmup_ratio: float = 0.7, + buffer_size: int = 0, + pin_weight: bool = False, + evict_strategy: EvictionStrategy = EvictionStrategy.LFU): + super(CachedEmbeddingBag, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, + scale_grad_by_freq, sparse, mode, include_last_offset) + + assert cache_ratio <= 1.0, f"cache ratio {cache_ratio} must less than 1.0" + self.evict_strategy = evict_strategy + if _weight is None: + _weight = self._weight_alloc(dtype, device) + cuda_row_num = int(num_embeddings * cache_ratio) + # configure weight & cache + self._preprocess(_weight, cuda_row_num, ids_freq_mapping, warmup_ratio, buffer_size, pin_weight) + self.cache_op = True + + def set_cache_mgr_async_copy(self, flag): + self.cache_weight_mgr._async_copy = flag + + def _weight_alloc(self, dtype, device): + weight = torch.empty(self.num_embeddings, self.embedding_dim, dtype=dtype, device=device) + with torch.no_grad(): + weight.data.uniform_(-1 / self.num_embeddings, 1 / self.num_embeddings) + if self.padding_idx is not None: + weight[self.padding_idx].fill_(0) + return weight + + def _preprocess(self, + weight, + cuda_row_num: int, + ids_freq_mapping: Optional[List[int]] = None, + warmup_ratio=0.7, + buffer_size=50_000, + pin_weight=False): + """ + Called after initialized. + Reorder the weight rows according to the ids_freq_mapping. + Then, let the weights of the Module be managed by a CachedParamMgr. + + Args: + cuda_row_num (int): number of rows can be hosted in CUDA memory + ids_freq_mapping (List[int]): a list, idx is id number, value is freq + warmup_ratio (float): the amount of rows preloaded in cuda cache + """ + self.cache_weight_mgr = CachedParamMgr(weight, + cuda_row_num, + buffer_size, + pin_weight, + evict_strategy=self.evict_strategy) + self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio) + + def forward(self, input, offsets=None, per_sample_weights=None, shape_hook=None): + if self.cache_op: + with torch.no_grad(): + input = self.cache_weight_mgr.prepare_ids(input) + + embeddings = F.embedding_bag(input.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, + per_sample_weights, self.include_last_offset, self.padding_idx) + if shape_hook is not None: + embeddings = shape_hook(embeddings) + return embeddings + + @property + def weight(self): + return self.cache_weight_mgr.weight + + def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]: + yield 'weight', self.cache_weight_mgr.cuda_cached_weight + + def parameters(self, recurse: bool = True) -> Iterator[Parameter]: + yield self.cache_weight_mgr.cuda_cached_weight + + def set_cache_op(self, cache_op: bool = True): + self.cache_op = cache_op + + +############################# Perf Log ################################### + + @property + def num_hits_history(self): + return self.cache_weight_mgr.num_hits_history + + @property + def num_miss_history(self): + return self.cache_weight_mgr.num_miss_history + + @property + def num_write_back_history(self): + return self.cache_weight_mgr.num_write_back_history + + @property + def swap_in_bandwidth(self): + if self.cache_weight_mgr._cpu_to_cuda_numel > 0: + return self.cache_weight_mgr._cpu_to_cuda_numel * self.cache_weight_mgr.elem_size_in_byte / 1e6 / \ + self.cache_weight_mgr._cpu_to_cuda_elpase + else: + return 0 + + @property + def swap_out_bandwidth(self): + if self.cache_weight_mgr._cuda_to_cpu_numel > 0: + return self.cache_weight_mgr._cuda_to_cpu_numel * self.cache_weight_mgr.elem_size_in_byte / 1e6 / \ + self.cache_weight_mgr._cuda_to_cpu_elapse + return 0 diff --git a/colossalai/nn/parallel/layers/cache_embedding/copyer.py b/colossalai/nn/parallel/layers/cache_embedding/copyer.py new file mode 100644 index 0000000000000000000000000000000000000000..b586be1dc6d98ed7df2cffd1c326ca25ab837f33 --- /dev/null +++ b/colossalai/nn/parallel/layers/cache_embedding/copyer.py @@ -0,0 +1,49 @@ +import torch +from torch import LongTensor + + +class LimitBuffIndexCopyer(object): + """LimitBuffIndexCopyer + Index Copy using limited temp buffer on CUDA. + + Args: + size (int): buffer size + """ + + def __init__(self, size: int) -> None: + self._buff_size = size + + @torch.no_grad() + def index_copy(self, dim: int, src_index: LongTensor, tgt_index: LongTensor, src: torch.Tensor, tgt: torch.Tensor): + """copy + src tensor[src_index] -(index_select)-> tmp -(index_copy_)-> tgt tensor [tgt_index] + The valid rows in the src tensor are continous, while rows in tgt tensor is scattered. + + Args: + dim (int): dimension along which to index + src_index (int): indices of src tensor to select from + tgt_index (int): indices of tgt tensor to select from + src (torch.Tensor): the tensor containing values to copy + tgt (torch.Tensor): the tensor to be copied + """ + # tgt.index_copy_(dim, index, src) + assert dim == 0, "only support index_copy on dim 0" + assert tgt.dim() == 2 + assert src.dim() == 2 + tgt_device = tgt.device + src_device = src.device + + assert src_index.numel() == tgt_index.numel() + dim_size = src_index.numel() + src_index = src_index.to(src_device) + for begin_pos in range(0, dim_size, self._buff_size): + cur_len = min(self._buff_size, dim_size - begin_pos) + src_idx_piece = src_index.narrow(0, begin_pos, cur_len) + if src_device.type == 'cpu' and tgt_device.type == 'cuda': + cpu_tmp_buffer = src.index_select(dim, src_idx_piece).pin_memory() + tmp_buffer = torch.empty_like(cpu_tmp_buffer, device=tgt_device) + tmp_buffer.copy_(cpu_tmp_buffer) + else: + tmp_buffer = src.index_select(dim, src_idx_piece).to(tgt_device) + tgt_idx_piece = tgt_index.narrow(0, begin_pos, cur_len) + tgt.index_copy_(dim, tgt_idx_piece, tmp_buffer) diff --git a/colossalai/nn/parallel/layers/cache_embedding/embedding_config.py b/colossalai/nn/parallel/layers/cache_embedding/embedding_config.py new file mode 100644 index 0000000000000000000000000000000000000000..36e04c833feb4203d9033a15951b580207890e8b --- /dev/null +++ b/colossalai/nn/parallel/layers/cache_embedding/embedding_config.py @@ -0,0 +1,27 @@ +import torch + + +class TablewiseEmbeddingBagConfig: + ''' + example: + def prepare_tablewise_config(args, cache_ratio, ...): + embedding_bag_config_list: List[TablewiseEmbeddingBagConfig] = [] + ... + return embedding_bag_config_list + ''' + + def __init__(self, + num_embeddings: int, + cuda_row_num: int, + assigned_rank: int = 0, + buffer_size=50_000, + ids_freq_mapping=None, + initial_weight: torch.tensor = None, + name: str = ""): + self.num_embeddings = num_embeddings + self.cuda_row_num = cuda_row_num + self.assigned_rank = assigned_rank + self.buffer_size = buffer_size + self.ids_freq_mapping = ids_freq_mapping + self.initial_weight = initial_weight + self.name = name diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py b/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..d7f77e195f4b480da53e6f1772145d122afb1e69 --- /dev/null +++ b/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py @@ -0,0 +1,141 @@ +import torch +import torch.nn.functional as F +from typing import List, Optional, Iterator, Tuple + +from .cached_embedding import CachedEmbeddingBag +from colossalai.nn._ops._utils import dual_all_to_all + +from colossalai.tensor import ColoParameter, ShardSpec, ComputePattern, ProcessGroup, ColoTensorSpec, ColoTensor +from .cache_mgr import CachedParamMgr, EvictionStrategy + + +def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]: + if world_size == 1: + return 0, embedding_dim, True + + assert embedding_dim >= world_size, \ + f"Embedding dimension {embedding_dim} must be larger than the world size " \ + f"{world_size} of the process group" + chunk_size = embedding_dim // world_size + threshold = embedding_dim % world_size + # if embedding dim is divisible by world size + if threshold == 0: + return rank * chunk_size, (rank + 1) * chunk_size, True + + # align with the split strategy of torch.tensor_split + size_list = [chunk_size + 1 if i < threshold else chunk_size for i in range(world_size)] + offset = sum(size_list[:rank]) + return offset, offset + size_list[rank], False + + +class ParallelCachedEmbeddingBag(CachedEmbeddingBag): + + def __init__(self, + num_embeddings, + embedding_dim, + padding_idx=None, + max_norm=None, + norm_type=2., + scale_grad_by_freq=False, + sparse=False, + _weight=None, + mode='mean', + include_last_offset=False, + dtype=None, + device=None, + cache_ratio=0.01, + ids_freq_mapping=None, + warmup_ratio=0.7, + buffer_size=50_000, + pin_weight=False, + evict_strategy: EvictionStrategy = EvictionStrategy.DATASET): + self.rank = torch.distributed.get_rank() + self.world_size = torch.distributed.get_world_size() + + self.partition_start_index, self.partition_end_index, divisible = get_partition( + embedding_dim, self.rank, self.world_size) + self.embedding_dim_per_partition = self.partition_end_index - self.partition_start_index + + super(ParallelCachedEmbeddingBag, + self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, + sparse, _weight, mode, include_last_offset, dtype, device, cache_ratio, ids_freq_mapping, + warmup_ratio, buffer_size, pin_weight, evict_strategy) + self.cache_op = True + + def _weight_alloc(self, dtype, device): + weight = torch.empty(self.num_embeddings, self.embedding_dim_per_partition, device=device, dtype=dtype) + with torch.no_grad(): + weight.data.uniform_(-1 / self.num_embeddings, 1 / self.num_embeddings) + if self.padding_idx is not None: + weight[self.padding_idx].fill_(0) + colo_tensor_spec = ColoTensorSpec(pg=ProcessGroup(tp_degree=self.world_size), + dist_attr=ShardSpec(dims=[-1], num_partitions=[self.world_size]), + compute_attr=ComputePattern.TP1D) + return ColoTensor.from_torch_tensor(weight, spec=colo_tensor_spec) + + def forward( + self, + indices, + offsets=None, + per_sample_weights=None, + shape_hook=None, + scatter_dim=0, + gather_dim=-1, + ): + if self.cache_op: + with torch.no_grad(): + indices = self.cache_weight_mgr.prepare_ids(indices) + output_shard = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, + per_sample_weights, self.include_last_offset, self.padding_idx) + if shape_hook is not None: + output_shard = shape_hook(output_shard) + output_full = dual_all_to_all(output_shard, + self.weight.get_process_group(), + scatter_dim=scatter_dim, + gather_dim=gather_dim) + return output_full + + def set_cache_op(self, cache_op: bool = True): + self.cache_op = cache_op + + @classmethod + def from_pretrained( + cls, + embedding: torch.Tensor, + freeze: bool = True, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2., + scale_grad_by_freq: bool = False, + sparse: bool = False, + mode: str = 'mean', + include_last_offset: bool = False, + cuda_row_num: int = 100_000, + ids_freq_mapping: Optional[List[int]] = None, + warmup_ratio: float = 0.7, + buffer_size: int = 0, + ) -> 'ParallelCachedEmbeddingBag': + rows, cols = embedding.shape + embedding_bag = cls(rows, + cols, + padding_idx, + max_norm, + norm_type, + scale_grad_by_freq, + sparse, + embedding, + mode, + include_last_offset, + cuda_row_num=cuda_row_num, + ids_freq_mapping=ids_freq_mapping, + warmup_ratio=warmup_ratio, + buffer_size=buffer_size) + embedding_bag.cache_weight_mgr.cuda_cached_weight.requires_grad_ = not freeze + return embedding_bag + + def print_comm_stats_(self): + self.cache_weight_mgr.print_comm_stats() + + def element_size(self): + return self.weight.element_size() diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py b/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py new file mode 100644 index 0000000000000000000000000000000000000000..949f85ad4baf894d4e53f06594bcbd9080f249ae --- /dev/null +++ b/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py @@ -0,0 +1,198 @@ +import torch +import torch.distributed as dist +import torch.nn.functional as F + +from .cached_embedding import CachedEmbeddingBag +from .cache_mgr import EvictionStrategy +from .embedding_config import TablewiseEmbeddingBagConfig +from colossalai.tensor import ProcessGroup +from colossalai.nn._ops._utils import dual_all_to_all_tablewise + +from typing import List +import time + + +class ParallelCachedEmbeddingBagTablewise(CachedEmbeddingBag): + """ + all tables assigned to this class instance are managed by a single CachedEmbeddingBag. + Those parameters in TablewiseEmbeddingBagConfig are ignored: cuda_row_num, buffer_size, initial_weight. + """ + + def __init__(self, + embedding_bag_config_list: List[TablewiseEmbeddingBagConfig], + embedding_dim: int, + padding_idx=None, + max_norm=None, + norm_type=2., + scale_grad_by_freq=False, + sparse=False, + _weight=None, + mode='mean', + include_last_offset=False, + dtype=None, + device=None, + cache_ratio=0.01, + warmup_ratio=0.7, + buffer_size=50_000, + pin_weight=False, + evict_strategy: EvictionStrategy = EvictionStrategy.LFU): + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + self.rank_of_tables = [config.assigned_rank for config in embedding_bag_config_list] + self.global_table_num_embeddings_list = [config.num_embeddings for config in embedding_bag_config_list] + self.global_tables_num = len(embedding_bag_config_list) + self.global_tables_offsets = torch.cumsum(torch.tensor([0] + self.global_table_num_embeddings_list), 0).cuda() + self.assigned_table_list: List[int] = [] + self.pg = ProcessGroup(tp_degree=self.world_size) + self.num_embeddings = 0 + for i, rank in enumerate(self.rank_of_tables): + if rank == self.rank: + self.assigned_table_list.append(i) + self.num_embeddings += self.global_table_num_embeddings_list[i] + self.include_last_offset = include_last_offset + + ids_freq_mapping = [] + for config in embedding_bag_config_list: + if config.assigned_rank == self.rank: + if config.ids_freq_mapping != None: + ids_freq_mapping.extend(config.ids_freq_mapping) + else: + ids_freq_mapping = None + break + self.cache_ratio = cache_ratio + # table-associate cache + cuda_row_num = int(cache_ratio * self.num_embeddings) + super(ParallelCachedEmbeddingBagTablewise, + self).__init__(self.num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, + sparse, _weight, mode, include_last_offset, dtype, device, cache_ratio, ids_freq_mapping, + warmup_ratio, buffer_size, pin_weight, evict_strategy) + + # for assigned tables reconnection: + self.idx_offset_list = [] + offset_cumsum = 0 + for table_i, table_num_embeddings in enumerate(self.global_table_num_embeddings_list): + if self.rank_of_tables[table_i] == self.rank: + self.idx_offset_list.append(offset_cumsum) + else: + offset_cumsum += table_num_embeddings + + # prepare list shape for all_to_all output + self.embedding_dim_per_rank = [0 for i in range(self.world_size)] + for rank in self.rank_of_tables: + self.embedding_dim_per_rank[rank] += embedding_dim + + self.cache_op = True + + def forward( + self, + indices: torch.Tensor, + offsets: torch.Tensor = None, + per_sample_weights=None, + shape_hook=None, + already_split_along_rank=True, + ): + if not already_split_along_rank: + # not recommanded. it takes time. + batch_size = (offsets.shape[0]) // self.global_tables_num + local_indices, local_offsets, local_per_sample_weights = self.split_along_rank( + batch_size, indices, offsets, per_sample_weights) + else: + # recommanded. + batch_size = (offsets.shape[0]) // len(self.assigned_table_list) + local_indices, local_offsets, local_per_sample_weights = indices, offsets, per_sample_weights + if self.cache_op: + with torch.no_grad(): + indices = self.cache_weight_mgr.prepare_ids(local_indices) + local_output = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets, + self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, + local_per_sample_weights, self.include_last_offset, self.padding_idx) + local_output = torch.cat(local_output.split(batch_size), 1) + remains = batch_size % self.world_size + scatter_strides = [batch_size // self.world_size + int(i < remains) for i in range(self.world_size)] + output_full = dual_all_to_all_tablewise(local_output, self.pg, scatter_strides, self.embedding_dim_per_rank) + if shape_hook is not None: + output_full = shape_hook(output_full) + return output_full + + def split_along_rank(self, + batch_size, + indices: torch.Tensor, + offsets: torch.Tensor = None, + per_sample_weights=None): + ''' + if input indices and offsets haven't been splitted along assigned rank, this function will do it. + it takes time. please consider splitting data during batch loading. + ''' + local_indices_list: List(torch.Tensor) = [] + local_offsets_list: List(torch.Tensor) = [] + if per_sample_weights != None: + local_per_sample_weights_list: List(torch.Tensor) = [] + + offset_pre_end = 0 # local_offsets trick + for i, handle_table in enumerate(self.assigned_table_list): + indices_start_position = offsets[batch_size * handle_table] + if (not self.include_last_offset) and (batch_size * (handle_table + 1) >= indices.shape[0]): + # till-the-end special case + indices_end_position = indices.shape[0] + else: + indices_end_position = offsets[batch_size * (handle_table + 1)] + # alternative approach: reduce malloc + ''' + # 1. local_indices_list: + local_indices = indices.narrow(0, indices_start_position, indices_end_position - indices_start_position) + torch.sub(local_indices, self.idx_offset_list[i], out=local_indices) + local_indices_list.append(local_indices) + # 2. local_offsets_list: + if i + 1 == len(self.assigned_table_list): + # till-the-end special case + if not self.include_last_offset: + local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size) + else: + local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + 1) + torch.add(local_offsets, offset_pre_end - offsets[batch_size * handle_table], out=local_offsets) + local_offsets_list.append(local_offsets) + else: + temp_holder = offsets[batch_size * handle_table].item() + local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size) + torch.add(local_offsets, offset_pre_end - offsets[batch_size * handle_table], out=local_offsets) + offset_pre_end = offsets[batch_size * (handle_table + 1)] + offset_pre_end - temp_holder + local_offsets_list.append(local_offsets) + ''' + # 1. local_indices_list: + local_indices_list.append( + indices.narrow(0, indices_start_position, + indices_end_position - indices_start_position).sub(self.idx_offset_list[i])) + # 2. local_offsets_list: + if i + 1 == len(self.assigned_table_list): + # till-the-end special case + if not self.include_last_offset: + local_offsets = offsets.narrow(0, batch_size * handle_table, + batch_size).add(offset_pre_end - offsets[batch_size * + (handle_table)]) + else: + local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + + 1).add(offset_pre_end - offsets[batch_size * (handle_table)]) + local_offsets_list.append(local_offsets) + else: + local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + + 1).add(offset_pre_end - offsets[batch_size * (handle_table)]) + offset_pre_end = local_offsets[-1] + local_offsets_list.append(local_offsets[:-1]) + # 3. local_per_sample_weights_list: + if per_sample_weights != None: + local_per_sample_weights_list.append(per_sample_weights[indices_start_position:indices_end_position]) + local_indices = torch.cat(local_indices_list, 0) + local_offsets = torch.cat(local_offsets_list, 0) + local_per_sample_weights = None + if per_sample_weights != None: + local_per_sample_weights = torch.cat(local_per_sample_weights_list, 0) + return local_indices, local_offsets, local_per_sample_weights + + def set_cache_op(self, cache_op: bool = True): + self.cache_op = cache_op + + def print_comm_stats_(self): + self.cache_weight_mgr.print_comm_stats() + + def element_size(self): + return self.weight.element_size() diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py b/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..cb4647028d477d6c89ddfcb0d68d906ba43aa4c8 --- /dev/null +++ b/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py @@ -0,0 +1,138 @@ +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.profiler import record_function + +from .cached_embedding import CachedEmbeddingBag + +from colossalai.tensor import ProcessGroup +from colossalai.nn._ops._utils import dual_all_to_all_tablewise +from .embedding_config import TablewiseEmbeddingBagConfig +from .cache_mgr import EvictionStrategy + +from typing import List +import abc + + +class ParallelCachedEmbeddingBagTablewiseSpiltCache(abc.ABC, nn.Module): + """ + every table assigned to this class instance is managed by a CachedEmbeddingBag. + """ + + def __init__(self, + embedding_bag_config_list: List[TablewiseEmbeddingBagConfig], + embedding_dim: int, + padding_idx=None, + max_norm=None, + norm_type=2., + scale_grad_by_freq=False, + sparse=False, + mode='mean', + include_last_offset=False, + dtype=None, + device=None, + warmup_ratio=0.7, + pin_weight=False, + evict_strategy: EvictionStrategy = EvictionStrategy.LFU): + super(ParallelCachedEmbeddingBagTablewiseSpiltCache, self).__init__() + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + self.rank_of_tables = [config.assigned_rank for config in embedding_bag_config_list] + self.global_table_num_embeddings_list = [config.num_embeddings for config in embedding_bag_config_list] + self.global_tables_num = len(embedding_bag_config_list) + self.global_tables_offsets = torch.cumsum(torch.tensor([0] + self.global_table_num_embeddings_list), 0).cuda() + + self.assigned_table_list: List[int] = [] + for i, rank in enumerate(self.rank_of_tables): + if rank == self.rank: + self.assigned_table_list.append(i) + self.include_last_offset = include_last_offset + self.pg = ProcessGroup(tp_degree=self.world_size) + + # prepare CachedEmbeddingBag list + + self.cached_embedding_bag_list: nn.ModuleList = nn.ModuleList() + for config in embedding_bag_config_list: + if config.assigned_rank != self.rank: + continue + self.cached_embedding_bag_list.append( + CachedEmbeddingBag(num_embeddings=config.num_embeddings, + embedding_dim=embedding_dim, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse, + _weight=config.initial_weight, + mode=mode, + include_last_offset=include_last_offset, + dtype=dtype, + device=device, + cuda_row_num=config.cuda_row_num, + ids_freq_mapping=config.ids_freq_mapping, + warmup_ratio=warmup_ratio, + buffer_size=config.buffer_size, + pin_weight=pin_weight, + evict_strategy=evict_strategy)) + + # prepare list shape for all_to_all output + self.embedding_dim_per_rank = [0 for i in range(self.world_size)] + for rank in self.rank_of_tables: + self.embedding_dim_per_rank[rank] += embedding_dim + + def forward(self, indices: torch.Tensor, offsets: torch.Tensor = None, per_sample_weights=None, shape_hook=None): + # determine indices to handle + batch_size = (offsets.shape[0]) // self.global_tables_num + local_output_list = [] + for i, handle_table in enumerate(self.assigned_table_list): + with record_function("(tablewise) prepare indices and offsets"): + with record_function("part 1"): + indices_start_position = offsets[batch_size * handle_table] + if (not self.include_last_offset) and (batch_size * (handle_table + 1) >= indices.shape[0]): + # till the end special case + indices_end_position = indices.shape[0] + else: + indices_end_position = offsets[batch_size * (handle_table + 1)] + with record_function("part 2"): + # local_indices = indices[indices_start_position:indices_end_position] - self.global_tables_offsets[handle_table] + local_indices = indices.narrow(0, indices_start_position, indices_end_position - + indices_start_position).sub(self.global_tables_offsets[handle_table]) + if self.include_last_offset: + # local_offsets = offsets[batch_size * handle_table:batch_size * (handle_table + 1) + 1] - offsets[batch_size * (handle_table)] + local_offsets = offsets.narrow(0, batch_size * handle_table, + batch_size + 1).sub(offsets[batch_size * (handle_table)]) + else: + # local_offsets = offsets[batch_size * handle_table:batch_size * (handle_table + 1)] - offsets[batch_size * (handle_table)] + local_offsets = offsets.narrow(0, batch_size * handle_table, + batch_size).sub(offsets[batch_size * (handle_table)]) + local_per_sample_weights = None + if per_sample_weights != None: + local_per_sample_weights = per_sample_weights[indices_start_position:indices_end_position] + with record_function("(tablewise) tablewise forward"): + local_output_list.append(self.cached_embedding_bag_list[i](local_indices, local_offsets, + local_per_sample_weights)) + + # get result of shape = (batch_size, (len(assigned_table_list)*embedding_dim)) + local_output = torch.cat(local_output_list, 1) + # then concatenate those local_output on the second demension. + # use all_to_all + remains = batch_size % self.world_size + scatter_strides = [batch_size // self.world_size + int(i < remains) for i in range(self.world_size)] + output_full = dual_all_to_all_tablewise(local_output, self.pg, scatter_strides, self.embedding_dim_per_rank) + if shape_hook is not None: + output_full = shape_hook(output_full) + return output_full + + def element_size(self): + if len(self.assigned_table_list) == 0: + return 0 + return self.cached_embedding_bag_list[0].cache_weight_mgr.weight.element_size() + + def print_comm_stats_(self): + cuda_to_cpu_elem_num = 0 + cpu_to_cuda_elem_num = 0 + for cached_embedding_bag in self.cached_embedding_bag_list: + cuda_to_cpu_elem_num += cached_embedding_bag.cache_weight_mgr._cuda_to_cpu_numel + cpu_to_cuda_elem_num += cached_embedding_bag.cache_weight_mgr._cpu_to_cuda_numel + print(f"CUDA->CPU num: {cuda_to_cpu_elem_num / 1e6} M elem") + print(f"CPU->CUDA num: {cpu_to_cuda_elem_num / 1e6} M elem") diff --git a/colossalai/nn/parallel/layers/colo_module.py b/colossalai/nn/parallel/layers/colo_module.py new file mode 100644 index 0000000000000000000000000000000000000000..8f0f5d5f520a17c979a21048903b025dc642296f --- /dev/null +++ b/colossalai/nn/parallel/layers/colo_module.py @@ -0,0 +1,46 @@ +from colossalai.tensor.distspec import _DistSpec +from colossalai.tensor import ComputePattern +from typing import List, Dict + + +class ColoModule(object): + + def __init__(self): + self._shard_params: List[str] = [] + self._allowed_patterns: Dict[ComputePattern, Dict[str, Dict[str, _DistSpec]]] = {} + + def _register_shard_params(self, params: List[str]): + self._shard_params = params + + def _register_allowed_patterns(self, + compute_pattern: ComputePattern, + dist_specs: Dict[str, _DistSpec], + mode='default'): + assert list( + dist_specs.keys()).sort() == self._shard_params.sort(), 'Every registered param should have dist_spec.' + if not compute_pattern in self._allowed_patterns: + self._allowed_patterns[compute_pattern] = {} + self._allowed_patterns[compute_pattern][mode] = dist_specs + + def _set_default(self, compute_pattern: ComputePattern, target_mode): + self._allowed_patterns[compute_pattern]['default'] = self._allowed_patterns[compute_pattern][target_mode] + + def has_compute_pattern(self, compute_pattern: ComputePattern): + return compute_pattern in self._allowed_patterns + + def get_dist_specs(self, compute_pattern: ComputePattern): + assert self.has_compute_pattern(compute_pattern) + return self._allowed_patterns[compute_pattern] + + def has_compute_pattern_with_mode(self, compute_pattern: ComputePattern, mode='default'): + return compute_pattern in self._allowed_patterns and mode in self._allowed_patterns[compute_pattern] + + def get_dist_specs_with_mode(self, compute_pattern: ComputePattern, mode='default'): + assert self.has_compute_pattern_with_mode(compute_pattern, mode) + return self._allowed_patterns[compute_pattern][mode] + + def get_param_names(self): + return self._shard_params + + def register(self, compute_pattern, pg): + raise NotImplementedError diff --git a/colossalai/nn/parallel/layers/embedding.py b/colossalai/nn/parallel/layers/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..ccacc1ead297b349c7252aec37a41f106dff7993 --- /dev/null +++ b/colossalai/nn/parallel/layers/embedding.py @@ -0,0 +1,36 @@ +from .colo_module import ColoModule +from colossalai.tensor import ComputePattern, distspec, ProcessGroup, ShardSpec + + +class ColoEmbedding(ColoModule): + + def __init__(self): + super(ColoEmbedding, self).__init__() + self._register_shard_params(['weight']) + + def register(self, compute_pattern, pg: ProcessGroup): + if not compute_pattern in self._allowed_patterns: + if ComputePattern.TP1D == compute_pattern: + self._set_TP1D(pg) + + def _set_TP1D(self, pg: ProcessGroup): + # TP1D Row Linear + _compute_pattern = ComputePattern.TP1D + self._register_allowed_patterns( + compute_pattern=_compute_pattern, + dist_specs={ + 'weight': ShardSpec([0], [pg.tp_world_size()]), + }, + mode='row', + ) + + # TP1D Col Linear + self._register_allowed_patterns( + compute_pattern=_compute_pattern, + dist_specs={ + 'weight': ShardSpec([-1], [pg.tp_world_size()]), + }, + mode='col', + ) + + self._set_default(compute_pattern=_compute_pattern, target_mode='row') diff --git a/colossalai/nn/parallel/layers/linear.py b/colossalai/nn/parallel/layers/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..84a8c042587dfdb279f4bd1f83e07d317bc57d08 --- /dev/null +++ b/colossalai/nn/parallel/layers/linear.py @@ -0,0 +1,38 @@ +from .colo_module import ColoModule +from colossalai.tensor import ComputePattern, distspec, ProcessGroup, ShardSpec + + +class ColoLinear(ColoModule): + + def __init__(self): + super(ColoLinear, self).__init__() + self._register_shard_params(['weight', 'bias']) + + def register(self, compute_pattern, pg: ProcessGroup): + if not compute_pattern in self._allowed_patterns: + if ComputePattern.TP1D == compute_pattern: + self._set_TP1D(pg) + + def _set_TP1D(self, pg): + # TP1D Row Linear + _compute_pattern = ComputePattern.TP1D + self._register_allowed_patterns( + compute_pattern=_compute_pattern, + dist_specs={ + 'weight': ShardSpec([-1], [pg.tp_world_size()]), + 'bias': None + }, + mode='row', + ) + + # TP1D Col Linear + self._register_allowed_patterns( + compute_pattern=_compute_pattern, + dist_specs={ + 'weight': ShardSpec([0], [pg.tp_world_size()]), + 'bias': ShardSpec([0], [pg.tp_world_size()]) + }, + mode='col', + ) + + self._set_default(compute_pattern=_compute_pattern, target_mode='row') diff --git a/colossalai/nn/parallel/layers/module_utils.py b/colossalai/nn/parallel/layers/module_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..38d128cc705e6bfa3db68cb6c59b66450ce1223e --- /dev/null +++ b/colossalai/nn/parallel/layers/module_utils.py @@ -0,0 +1,113 @@ +from typing import Dict +from colossalai.tensor import ColoParameter, ComputeSpec, ProcessGroup +from colossalai.tensor import distspec +from . import ColoModule +import torch + +_COLOSSAL_MODULES: Dict[type, ColoModule] = {} + + +def register_colo_module(module_type: type, colo_module: ColoModule): + global _COLOSSAL_MODULES + _COLOSSAL_MODULES[module_type] = colo_module + + +def is_colo_module(module: torch.nn.Module): + global _COLOSSAL_MODULES + for module_type in _COLOSSAL_MODULES.keys(): + if isinstance(module, module_type): + return True + return False + + +def get_colo_module(module: torch.nn.Module): + global _COLOSSAL_MODULES + if is_colo_module(module): + for module_type, colo_module in _COLOSSAL_MODULES.items(): + if isinstance(module, module_type): + return colo_module + else: + return None + + +def check_colo_module(module: torch.nn.Module, pg: ProcessGroup, recursive=True): + if is_colo_module(module): + colo_module = get_colo_module(module) + param_names = colo_module.get_param_names() + compute_pattern = None + for param_name in param_names: + param = module.get_parameter(param_name) + if not isinstance(param, ColoParameter): + raise Exception(f'Invalid ColoParameter spec: {param} in {module} is not a ColoParameter.') + if param.has_compute_spec(): + cur_compute_pattern = param.compute_spec.compute_pattern + if compute_pattern is None: + compute_pattern = cur_compute_pattern + else: + if cur_compute_pattern != compute_pattern: + raise Exception( + f'Invalid ColoParameter spec: Params in {module} have different compute_pattern.') + else: + continue + + if compute_pattern is not None: + colo_module.register(compute_pattern, pg) + if not colo_module.has_compute_pattern(compute_pattern): + raise Exception( + f'Invalid ColoParameter spec: ComputePattern {compute_pattern} in {module} is not allowed.') + + match_specs = False + allowed_specs = colo_module.get_dist_specs(compute_pattern) + for _, param_specs in allowed_specs.items(): + cur_match = True + for param_name, dist_spec in param_specs.items(): + param = module.get_parameter(param_name) + if param.has_compute_spec(): + if dist_spec != param.dist_spec: + cur_match = False + break + else: + if dist_spec is not None: + cur_match = False + break + if cur_match == True: + match_specs = True + break + if match_specs == False: + raise Exception(f'Invalid ColoParameter spec: Params in {module} are incorrectly sharded.') + if recursive == True: + for submodule in module.children(): + check_colo_module(submodule, pg=pg, recursive=True) + + +def init_colo_module(module: torch.nn.Module, + compute_spec: ComputeSpec, + pg: ProcessGroup, + recursive=True, + mode='default'): + compute_pattern = compute_spec.compute_pattern + if is_colo_module(module): + # for each param + # set its process_group, dist_spec and compute_spec + colo_module = get_colo_module(module) + colo_module.register(compute_pattern, pg) + if not colo_module.has_compute_pattern_with_mode(compute_pattern, mode=mode): + raise NotImplementedError + # a set for modules which update at least one param in the init process. + # these modules need to be checked whether all params still match one of the valid compute pattern. + modules_update_param = {module} + for param_name, dist_spec in colo_module.get_dist_specs_with_mode(compute_pattern, mode=mode).items(): + if dist_spec is None: + continue + param = module.get_parameter(param_name) + if isinstance(param, ColoParameter): + param.set_process_group(pg) + param.set_dist_spec(dist_spec) + param.compute_spec = compute_spec + for mod in param.shared_param_modules: + modules_update_param.add(mod) + for mod in modules_update_param: + check_colo_module(mod, pg, recursive=False) + if recursive == True: + for submodule in module.children(): + init_colo_module(submodule, compute_spec, pg=pg, recursive=True, mode=mode) diff --git a/colossalai/nn/parallel/reducer.py b/colossalai/nn/parallel/reducer.py new file mode 100644 index 0000000000000000000000000000000000000000..5687055819fe1fd0177e507f1a27d3bab8b5b1b5 --- /dev/null +++ b/colossalai/nn/parallel/reducer.py @@ -0,0 +1,116 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import functools +from typing import Callable, Dict, List, Optional, Tuple + +import torch +import torch.distributed as dist +from torch import Tensor +from torch.distributed import ProcessGroup + + +class Bucket: + + def __init__(self, size: int, dtype: torch.dtype, device: torch.device, group: ProcessGroup): + self.buffer = torch.zeros(size, dtype=dtype, device=device) + self.group = group + self.offset = 0 + self.callbacks: List[Callable] = [] + + def flush(self) -> None: + """Flush content of the bucket.""" + if self.offset == 0: + assert len(self.callbacks) == 0 + return + # reduce-scatter bucket + dist.all_reduce(self.buffer[:self.offset], group=self.group) + + # execute post-reduction callbacks + for callback_fn in self.callbacks: + callback_fn() + # reuse input bucket but allocate a fresh output shard + self.offset = 0 + self.callbacks.clear() + self.buffer = torch.zeros_like(self.buffer) + + def alloc(self) -> None: + + if self.buffer.storage().size() == 0: + self.buffer.storage().resize_(self.buffer.numel()) + + def free(self) -> None: + + assert self.offset == 0 and self.callbacks == [], "Incorrect call of teardown" + self.buffer.storage().resize_(0) + + def append(self, tensor: Tensor, callback_fn: Callable): + tensor_size = tensor.numel() + offset = self.offset + self.buffer[offset:offset + tensor_size].copy_(tensor.flatten()) + self.offset += tensor_size + + # callback will be given the reduced result + if callback_fn is not None: + result_view = self.buffer[offset:offset + tensor_size].view(tensor.shape) + self.callbacks.append(functools.partial(callback_fn, result_view)) + + @property + def avail_size(self) -> int: + return self.buffer.size(0) - self.offset + + +class Reducer: + + def __init__(self, bucket_size_mb: int = 25): + self.bucket_size_mb = bucket_size_mb + self.buckets: Dict[Tuple[torch.dtype, torch.device, ProcessGroup], Bucket] = {} + + @torch.no_grad() + def all_reduce_async( + self, + tensor: Tensor, + group: ProcessGroup, + callback_fn: Optional[Callable] = None, + ) -> None: + bucket_size = self._get_bucket_size(tensor.element_size()) + + if tensor.numel() >= bucket_size: + dist.all_reduce(tensor, group=group) + if callback_fn is not None: + callback_fn(tensor) + return + + bucket = self._get_bucket(tensor, group) + if tensor.numel() > bucket.avail_size: + # not enough space remaining in bucket, flush it now + bucket.flush() + bucket.append(tensor, callback_fn) + + @torch.no_grad() + def flush(self) -> None: + for bucket in self.buckets.values(): + bucket.flush() + + @torch.no_grad() + def free(self) -> None: + for bucket in self.buckets.values(): + bucket.free() + + @functools.lru_cache() + def _get_bucket_size(self, element_size: int) -> int: + if self.bucket_size_mb <= 0: # Values <= 0 disable bucketing. + return 0 + MB = 1024 * 1024 + bucket_size = self.bucket_size_mb * MB / element_size + return int(bucket_size) + + def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket: + key = (tensor.dtype, tensor.device, group) + if key not in self.buckets: + bucket_size = self._get_bucket_size(tensor.element_size()) + self.buckets[key] = Bucket(bucket_size, tensor.dtype, tensor.device, group) + self.buckets[key].alloc() + return self.buckets[key] diff --git a/colossalai/pipeline/__init__.py b/colossalai/pipeline/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0fcde970764688c210d37de51964b5b081834666 --- /dev/null +++ b/colossalai/pipeline/__init__.py @@ -0,0 +1,4 @@ +from .pipelinable import PipelinableContext, PipelinableModel +from .layer_spec import LayerSpec + +__all__ = ['PipelinableModel', 'PipelinableContext', 'LayerSpec'] \ No newline at end of file diff --git a/colossalai/pipeline/layer_spec.py b/colossalai/pipeline/layer_spec.py new file mode 100644 index 0000000000000000000000000000000000000000..7e9169efff78bad7d30f42e9896dc3f61ecaf7fd --- /dev/null +++ b/colossalai/pipeline/layer_spec.py @@ -0,0 +1,55 @@ +import torch +from colossalai.utils.model.utils import call_to_str + +class LayerSpec: + """ + + """ + + def __init__(self, typename, *module_args, **module_kwargs): + self.typename = typename + self.module_args = module_args + self.module_kwargs = module_kwargs + self.children = None + self._param_count = 0 + + if not issubclass(typename, torch.nn.Module): + raise RuntimeError('LayerSpec only supports torch.nn.Module types.') + + def __repr__(self): + return call_to_str(self.typename.__name__, self.module_args, self.module_kwargs) + + @property + def param_count(self): + return self._param_count + + def build(self): + """Build the stored specification.""" + + recovered_args = [] + for obj in self.module_args: + if isinstance(obj, LayerSpec): + obj = obj.build() + recovered_args.append(obj) + recovered_args = tuple(recovered_args) + + recovered_kwargs = {} + for k, v in self.module_kwargs.items(): + if isinstance(v, LayerSpec): + v = v.build() + recovered_kwargs[k] = v + + return self.typename(*recovered_args, **recovered_kwargs) + + def set_children(self, children): + self.children = children + + def count_params(self): + self._param_count = 0 + layer = self.build() + for param in layer.parameters(): + self._param_count += param.numel() + return self._param_count + + def reset_param_count(self): + self._param_count = 0 \ No newline at end of file diff --git a/colossalai/pipeline/middleware/__init__.py b/colossalai/pipeline/middleware/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..79e19f9eaf771635852d8ffd747c06ea1209e110 --- /dev/null +++ b/colossalai/pipeline/middleware/__init__.py @@ -0,0 +1,3 @@ +from .topo import Topo, Partition, PartitionOutputVal, PartitionInputVal + +__all__ = ['Topo', 'Partition', 'PartitionOutputVal', 'PartitionInputVal'] \ No newline at end of file diff --git a/colossalai/pipeline/middleware/adaptor/__init__.py b/colossalai/pipeline/middleware/adaptor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..949700a2c49de505b37c408fd44283ef67f9569b --- /dev/null +++ b/colossalai/pipeline/middleware/adaptor/__init__.py @@ -0,0 +1,3 @@ +from .fx import get_topology as get_fx_topology + +__all__ = ['get_fx_topology'] \ No newline at end of file diff --git a/colossalai/pipeline/middleware/adaptor/fx.py b/colossalai/pipeline/middleware/adaptor/fx.py new file mode 100644 index 0000000000000000000000000000000000000000..8437c519476218dec90c968a47a73440ed71f519 --- /dev/null +++ b/colossalai/pipeline/middleware/adaptor/fx.py @@ -0,0 +1,145 @@ +from torch.fx.graph_module import GraphModule +from colossalai.pipeline.middleware.topo import Partition, PartitionInputVal, PartitionOutputVal, Topo +import torch + +def partition_name_to_id(partition_name, is_input=False, is_output=False): + if is_input: + partition_id = 0 + elif is_output: + partition_id = 1 + else: + prefix = 'submod_' + partition_id = int(partition_name.split(prefix)[-1]) + 2 + return partition_id + +# There are two kinds of def in fx.graph +# 1. non direct_use & non direct_def, which means the output is used by next partition with a temporary mid value. +# e.g. submod1 = call_module(...) +# temporary_val = submod1[0] +# submod2 = call_module(temporary_val, ...) +# 2. direct_use & direct_def, which means the output is used by next partition directly. +# e.g. submod1 = call_module(...) +# submod2 = call_module(submod1, ...) +def find_input_in_partition(node, partitions, input_partitions=None): + p_input_val = None + direct_def = not node.name.startswith('getitem') + # search in input + if direct_def and input_partitions is not None: + partition_id = partition_name_to_id('', is_input=True) + for i, input_node in enumerate(input_partitions): + if input_node == node: + p_input_val = PartitionInputVal(partition_id=partition_id, offset=i) + return p_input_val + # search submod in mid part + if direct_def: + for partition in partitions: + if partition == node: + partition_id = partition_name_to_id(partition.name) + p_input_val = PartitionInputVal(partition_id=partition_id, offset=0) + return p_input_val + # search temporary value in graph + else: + for partition in partitions: + for offset, mid_val in enumerate(partition.users): + if mid_val == node: + partition_id = partition_name_to_id(partition.name) + p_input_val = PartitionInputVal(partition_id=partition_id, offset=offset) + return p_input_val + + return p_input_val + +def find_output_in_partition(node, partitions, output_partitions=None): + p_output_val = PartitionOutputVal() + for user in node.users: + direct_use = not user.name.startswith('getitem') + # user is mid partition + for partition in partitions: + # direct call + if direct_use: + if user == partition: + partition_id = partition_name_to_id(partition.name) + for i, arg in enumerate(partition.args): + if arg == node: + p_output_val.add(partition_id=partition_id, offset=i) + break + # getitem call + else: + if user in partition.args: + partition_id = partition_name_to_id(partition.name) + for i, arg in enumerate(partition.args): + if arg == user: + p_output_val.add(partition_id=partition_id, offset=i) + break + + # user is output + if output_partitions is not None: + output_node = output_partitions[0] + if user.op == output_node.op: + output_keys = {} + partition_id = partition_name_to_id('', is_output=True) + torch.fx.graph.map_arg(output_node.args[0], lambda n: output_keys.setdefault(n)) + for i, arg in enumerate(output_keys): + if arg == node: + p_output_val.add(partition_id=partition_id, offset=i) + break + return p_output_val + +def get_topology(gm: GraphModule): + topo = Topo() + topo_output_partition = Partition() + + input_partitions = [] + partitions = [] + output_partitions = [] + for node in gm.graph.nodes: + if node.op == 'placeholder': + input_partitions.append(node) + elif node.name.startswith('submod_'): + partitions.append(node) + elif node.op == 'output': + output_partitions.append(node) + else: + continue + + # set output for input_partition + topo_input_partition = Partition() + for partition in input_partitions: + cur_node = partition + p_output_val = find_output_in_partition(cur_node, partitions, output_partitions) + topo_input_partition.add_output_val(p_output_val) + topo.set_partitions(partition_id=0, partition=topo_input_partition) + topo.set_input_partition_id(partition_id=0) + + for i, partition in enumerate(partitions): + topo_mid_partition = Partition() + # set input for submodule + for arg in partition.args: + cur_node = arg + p_input_val = find_input_in_partition(cur_node, partitions, input_partitions) + topo_mid_partition.add_input_val(p_input_val) + # set output for submodule + direct_use = True + for user in partition.users: + if user.name.startswith('getitem'): + direct_use = False + break + if direct_use: + cur_node = partition + p_output_val = find_output_in_partition(cur_node, partitions, output_partitions) + topo_mid_partition.add_output_val(p_output_val) + else: + for user in partition.users: + cur_node = user + p_output_val = find_output_in_partition(cur_node, partitions, output_partitions) + topo_mid_partition.add_output_val(p_output_val) + topo.set_partitions(partition_id=i+2, partition=topo_mid_partition) + + # set input for output_partition + for partition in output_partitions: + topo_output_partition = Partition() + torch.fx.graph.map_arg(partition.args[0], lambda n: topo_output_partition.add_input_val( + find_input_in_partition(n, partitions, input_partitions))) + topo.set_partitions(partition_id=1, partition=topo_output_partition) + topo.set_output_partition_id(partition_id=1) + + return topo \ No newline at end of file diff --git a/colossalai/pipeline/middleware/topo.py b/colossalai/pipeline/middleware/topo.py new file mode 100644 index 0000000000000000000000000000000000000000..e798e2ed9cab0cd3036b2dc39169c2f81470bae8 --- /dev/null +++ b/colossalai/pipeline/middleware/topo.py @@ -0,0 +1,206 @@ +from typing import Dict, List +from dataclasses import dataclass + +# This file includes data structure used by Pipeline Middleware. + +@dataclass +class ValPosition: + partition_id: int + offset: int + + def __str__(self) -> str: + res = f'[partition_id:{self.partition_id},offset:{self.offset}]' + return res + + def __repr__(self) -> str: + return self.__str__() + +class PartitionInputVal(object): + def __init__(self, partition_id, offset) -> None: + # every input from which partition_id and which offset + val_pos = ValPosition(partition_id, offset) + self._from_partition_and_offset: ValPosition = val_pos + + def get(self): + return self._from_partition_and_offset + + def __str__(self) -> str: + res = '' + res += f'<-({self._from_partition_and_offset})' + return res + + def __repr__(self) -> str: + return self.__str__() + +class PartitionOutputVal(object): + def __init__(self) -> None: + # every output to which partition_id and which offset + self._to_partition_and_offset: List[ValPosition] = [] + + def add(self, partition_id, offset): + val_pos = ValPosition(partition_id, offset) + self._to_partition_and_offset.append(val_pos) + + def get(self): + return self._to_partition_and_offset + + def __str__(self) -> str: + res = '' + res += '->(' + for val_pos in self._to_partition_and_offset: + res += f'{val_pos},' + res += ')' + return res + + def __repr__(self) -> str: + return self.__str__() + +class Partition(object): + def __init__(self) -> None: + self._input_vals: List[PartitionInputVal] = [] + self._output_vals: List[PartitionOutputVal] = [] + + def add_input_val(self, input_val: PartitionInputVal): + self._input_vals.append(input_val) + + def add_output_val(self, output_val: PartitionOutputVal): + self._output_vals.append(output_val) + + def get_input_vals(self): + return self._input_vals + + def get_output_vals(self): + return self._output_vals + + # get the output offsets sent to dst_partition_id + def get_output_offsets(self, dst_partition_id): + res = [] + for offset, output_val in enumerate(self._output_vals): + outputs = output_val.get() + for val_pos in outputs: + if val_pos.partition_id == dst_partition_id: + res.append(offset) + + return res + + # get all input dst partition_ids + def get_input_partition_ids(self): + res = [] + for input_val in self._input_vals: + val_pos = input_val.get() + if val_pos.partition_id not in res: + res.append(val_pos.partition_id) + return res + + # get all output dst partition_ids + def get_output_partition_ids(self): + res = [] + for output_val in self._output_vals: + outputs = output_val.get() + for val_pos in outputs: + if val_pos.partition_id not in res: + res.append(val_pos.partition_id) + return res + + def __str__(self) -> str: + res = '' + res += f' input:\n' + res += f' length:{len(self._input_vals)}\n' + for i, input_val in enumerate(self._input_vals): + res += f' offset={i}:{input_val}\n' + + res += f' output:\n' + res += f' length:{len(self._output_vals)}\n' + for i, output_val in enumerate(self._output_vals): + res += f' offset={i}:{output_val}\n' + + return res + + def __repr__(self) -> str: + return self.__str__() + +# This class is a middleware between partition splitter +# and Pipeline Scheduler. It records the graph info about +# partition input/output and provides it to scheduler. +# There are three kinds of partition in Pipeline Middleware Design +# which represents the whole process of a model execution: input-fwd-output +# 1. input_partition: records the input of a model. +# 2. mid_partition: record the splitted forwards execution of a model. +# 3. output_partition: records the output of a model. +# attributes: +# _partitions: include all partitions +# _input_partition_id: the key represents input_partition +# _output_partition_id: the key represents output_partition +class Topo(object): + def __init__(self, input_partition_id=None, output_partition_id=None) -> None: + self._partitions: Dict[int, Partition] = {} + self._input_partition_id = input_partition_id + self._output_partition_id = output_partition_id + + def set_input_partition_id(self, partition_id: int): + self._input_partition_id = partition_id + + def set_output_partition_id(self, partition_id: int): + self._output_partition_id = partition_id + + def get_input_partition_id(self): + return self._input_partition_id + + def get_output_partition_id(self): + return self._output_partition_id + + def set_partitions(self, partition_id: int, partition: Partition): + self._partitions[partition_id] = partition + + def get_mid_partitions(self): + res = {} #{partition_id: Partition} + for partition_id, partition in self._partitions.items(): + if self._input_partition_id == partition_id or self._output_partition_id == partition_id: + continue + res[partition_id] = partition + return res + + def get_mid_partition_ids(self): + return list(self.get_mid_partitions().keys()) + + def get_input_partition(self): + if self._input_partition_id is not None: + return self._partitions[self._input_partition_id] + return None + + def get_output_partition(self): + if self._output_partition_id is not None: + return self._partitions[self._output_partition_id] + return None + + def get_partition_by_id(self, partition_id): + return self._partitions[partition_id] + + def __str__(self) -> str: + res = '' + if len(self._partitions) == 0: + return 'Empty Topo Graph.' + + input_part = self.get_input_partition() + if input_part is not None: + res += '{\n' + res += f'InputPartition:\n partition_id={self._input_partition_id}\n{input_part}' + res += '}\n' + + mid_parts = self.get_mid_partitions() + for i, (partition_id, part) in enumerate(mid_parts.items()): + res += '{\n' + res += f'SubPartition_{i}:\n partition_id={partition_id}\n {part}' + res += '}\n' + + output_part = self.get_output_partition() + if output_part is not None: + res += '{\n' + res += f'OutputPartition:\n partition_id={self._output_partition_id}\n{output_part}' + res += '}\n' + + return res + + def __repr__(self) -> str: + return self.__str__() + \ No newline at end of file diff --git a/colossalai/pipeline/pipelinable.py b/colossalai/pipeline/pipelinable.py new file mode 100644 index 0000000000000000000000000000000000000000..9731530a6e15755c9d152697fec0ae0cfd102328 --- /dev/null +++ b/colossalai/pipeline/pipelinable.py @@ -0,0 +1,254 @@ +import torch +import inspect +from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses + +from .utils import partition_uniform, partition_balanced, build_kwargs_for_function, \ + build_kwargs_for_module, exec_func_with_kwargs, exec_funcs_with_kwargs, \ + call_module, customized_partition +from colossalai.nn.layer.utils import CheckpointModule +from colossalai.tensor import ColoParameter +from colossalai.core import global_context as gpc +from colossalai.context import ParallelMode +from .layer_spec import LayerSpec + + +class PipelinableContext(InsertPostInitMethodToModuleSubClasses): + """ + A context manager to split the model into pipeline stages. + """ + + def __init__(self, policy: str = "balanced"): + super().__init__() + self._layer_spec_dict = {} + self._root_children = None + self._model = None + self._layer_spec_list = [] + self._func_dict = {} + self._policy = policy + + @property + def policy(self): + return self._policy + + @policy.setter + def policy(self, policy: str): + self._policy = policy + + @property + def layers_count(self): + return len(self._layer_spec_list) + + @property + def funcs_count(self): + return len(self._func_dict) + + def _pre_context_exec(self): + """ + The Callback function when entering the context + """ + # reserve rng states + self.cpu_rng_state = torch.get_rng_state() + self.cuda_rng_state = torch.cuda.get_rng_state() + + def _post_context_exec(self): + """ + The callback function when exiting context. + """ + + # reset rng states + torch.set_rng_state(self.cpu_rng_state) + torch.cuda.set_rng_state(self.cuda_rng_state) + + def _post_init_method(self, module: torch.nn.Module, *args, **kwargs): + """ + The function to call at the end of the constructor of each module. + NOTE() The module may be passed to this function multiple times. + """ + # iterate over the positional arguments + # to check if an argument is a torch Module + # if found any torch Module, replace it with its layer spec + # for storage purpose + modified_args = [] + for arg in args: + if isinstance(arg, torch.nn.Module): + # if nn.Module is an argument of a non-root module, then we should convert it to layer spec, which make sure the correct init method used in the real build. + # if nn.Module is an argument of the root module, then we should just record the module instance itself, because those instance has been built outside of the context. + if id(arg) in self._layer_spec_dict: + arg = self._layer_spec_dict[id(arg)] + + modified_args.append(arg) + + # to the same for the keyword arguments + modified_kwargs = {} + for k, v in kwargs.items(): + if isinstance(v, torch.nn.Module): + v = self._layer_spec_dict[id(v)] + # (lyl)TODO: analyse ColoTensor as well + modified_kwargs[k] = v + + # keep track of the module children + # as torch.nn.Module.__init__ is called from inner module to outer module, + # the final value of self._model will be the outermost model + # e.g. if the model is torchvision.models.resnet18, then the final value of self._model + # will be the ``ResNet`` object. + self._root_children = list(module.children()) + self._model = module + + # store the children to keep the module hierarchy + layer_spec = LayerSpec(module.__class__, *modified_args, **modified_kwargs) + layer_spec.set_children(module.children()) + + # store the layer spec in this context + module_id = id(module) + self._layer_spec_dict[module_id] = layer_spec + + # convert all torch.nn.Parameter to colossalai.tensor.ColoParameter + name_list = [] + for name, param in module.named_parameters(): + if isinstance(param, ColoParameter): + continue + name_list.append((name, param)) + + for name, param in name_list: + if hasattr(module, name): + delattr(module, name) + setattr(module, name, ColoParameter.from_torch_tensor(tensor=param.data, requires_grad=param.requires_grad)) + + def to_layer_list(self, exec_seq=None): + """ + Create a layer spec list and func list with execution sequence given by user. + If exec_seq is None, we will take the module initizing order as execution order. + """ + + self._exec_seq = exec_seq + if exec_seq is None: + # if user do not provide the model executing sequence, we use the initialization order as the executing order. + children_name = [] + for child in self._root_children: + layer_spec = self._layer_spec_dict[id(child)] + if layer_spec.typename in (torch.nn.modules.container.ModuleList, + torch.nn.modules.container.Sequential): + for child_in_container in layer_spec.children: + self._layer_spec_list.append(self._layer_spec_dict[id(child_in_container)]) + for name, module in self._model.named_modules(): + if id(module) == id(child_in_container): + children_name.append(name) + break + else: + self._layer_spec_list.append(layer_spec) + for name, module in self._model.named_modules(): + if id(module) == id(child): + children_name.append(name) + break + + else: + front_funcs_list = [] + named_modules = dict(self._model.named_modules()) + for index, element in enumerate(exec_seq): + if isinstance(element, str): + if element == 'SPLIT_NODE': + continue + assert element in named_modules, f'Found invalid module name {element}, please check if you spell the module name correctly.' + + # get the layer spec based on the module ID + module = named_modules[element] + layer_spec = self._layer_spec_dict[id(module)] + + # check whether there are functions which should be executed before this module + if len(front_funcs_list) != 0: + func_key = (layer_spec, "front") + if func_key not in self._func_dict: + self._func_dict[func_key] = [] + for f in front_funcs_list: + self._func_dict[func_key].append(f) + front_funcs_list = [] + + func_key = (layer_spec, "behind") + self._layer_spec_list.append(layer_spec) + elif isinstance(element, tuple) and element[1] == "front": + front_funcs_list.append(element[0]) + else: + if func_key not in self._func_dict: + self._func_dict[func_key] = [] + if isinstance(element, tuple): + self._func_dict[func_key].append(element[0]) + else: + self._func_dict[func_key].append(element) + + def partition(self, num_chunks, pipeline_size, rank): + """ + Partitioned model will be built respect to partion policy. + The real module instance will be built in this method. + """ + if isinstance(self._policy, str): + if self._policy == "uniform": + parts = partition_uniform(len(self._layer_spec_list), pipeline_size, num_chunks)[rank] + elif self._policy == "balanced": + param_counts = [] + for layer_spec in self._layer_spec_list: + param_counts.append(layer_spec.count_params()) + parts = partition_balanced(param_counts, pipeline_size, num_chunks)[rank] + elif self._policy == "customized": + assert self._exec_seq is not None, f'An explicit exec_seq must be defined by user in customized policy mode.' + self.customized_parts = customized_partition(self._exec_seq) + assert len(self.customized_parts) == gpc.get_world_size( + ParallelMode.PIPELINE + ), f'World size is {gpc.get_world_size(ParallelMode.PIPELINE)}, but the number of partions is {len(self.customized_parts)}' + parts = self.customized_parts[rank] + else: + raise ValueError("A string partition policy should be one of ['uniform', 'balanced', 'customized'].") + elif isinstance(self._policy, dict): + parts = self._policy[rank] + else: + raise ValueError("A partition policy should be either a string or a dictionary.") + + layers_to_build = [] + for start, end in parts: + layers_to_build += self._layer_spec_list[start:end] + behind_func_dict_in_partition = {} + front_func_dict_in_partition = {} + module_list_in_partition = [] + for layer in layers_to_build: + module = layer.build() + module_list_in_partition.append(module) + if (layer, "front") in self._func_dict: + front_func_dict_in_partition[id(module)] = self._func_dict[(layer, "front")] + elif (layer, "behind") in self._func_dict: + behind_func_dict_in_partition[id(module)] = self._func_dict[(layer, "behind")] + module_list_in_partition = torch.nn.ModuleList(module_list_in_partition) + pipeline_model = PipelinableModel(module_list_in_partition, front_func_dict_in_partition, + behind_func_dict_in_partition) + + return pipeline_model + + +class PipelinableModel(torch.nn.Module): + + def __init__(self, module_list, front_func_dict, behind_func_dict): + super().__init__() + self._module_list = module_list + self._front_func_dict = front_func_dict + self._behind_func_dict = behind_func_dict + + def forward(self, *input_tensor, **kwargs): + for module in self._module_list: + + if id(module) in self._front_func_dict: + input_tensor = exec_funcs_with_kwargs(self._front_func_dict, id(module), input_tensor, kwargs) + + if isinstance(module, CheckpointModule): + forward_func = module._forward + else: + forward_func = module.forward + module_kwargs = build_kwargs_for_module(forward_func, input_tensor, kwargs) + if input_tensor is None: + input_tensor = call_module(module, kwargs=module_kwargs) + elif isinstance(input_tensor, torch.Tensor): + input_tensor = call_module(module, args=(input_tensor,), kwargs=module_kwargs) + else: + input_tensor = call_module(module, args=input_tensor, kwargs=module_kwargs) + + if id(module) in self._behind_func_dict: + input_tensor = exec_funcs_with_kwargs(self._behind_func_dict, id(module), input_tensor, kwargs) + + return input_tensor diff --git a/colossalai/pipeline/pipeline_process_group.py b/colossalai/pipeline/pipeline_process_group.py new file mode 100644 index 0000000000000000000000000000000000000000..c61d97ebabfa354ced39b952708188287d56cab3 --- /dev/null +++ b/colossalai/pipeline/pipeline_process_group.py @@ -0,0 +1,168 @@ +from typing import List, Dict, Tuple +import os +import threading + +from torch.distributed import rpc +import torch.distributed as dist + +from colossalai.tensor import ProcessGroup + + +class PipelineProcessGroup: + # TODO : flexible API for DP size and TP size + # In the future design mode, dp_degree and tp_degree should be removed + def __init__(self) -> None: + self.is_initialize = False + + def set_global_info(self, + rank: int, + world_size: int, + dp_degree: int = 1, + tp_degree: int = 1, + num_worker_threads: int = 1, + device: str = "cuda") -> None: + + device_mesh_size = dp_degree * tp_degree + assert world_size % device_mesh_size == 0, "world_size must be the multiple of dp_degree * tp_degree !!!" + self._num_worker_threads = num_worker_threads + + self._device_mesh_size = device_mesh_size + self._rank = rank + self._world_size = world_size + self._dp_degree = dp_degree + self._tp_degree = tp_degree + self.device = device + self._stage_num = world_size // device_mesh_size + self._pp_rank = rank // device_mesh_size + self._pp_ranks = [(rank % device_mesh_size) + i * device_mesh_size for i in range(self._stage_num)] + self._local_stage_ranks = [(rank // device_mesh_size * device_mesh_size) + i for i in range(device_mesh_size)] + + # pp_ranks + self._initialize_pp_process_group() + + # initialise tp dp process groups + self._initialize_tp_dp_process_group() + + # status + self._is_first_pp_rank = self._pp_rank == 0 + self._is_last_pp_rank = self._pp_rank == self._stage_num - 1 + + self.is_initialize = True + + # lock + self.initialise_lock = threading.Lock() + self.chimera_lock = threading.Lock() + + def _initialize_process_group(self): + stage_num = self.get_stage_num() + if stage_num == 1: + return + device = self.device + world_size = self.get_world_size() + rank = self.get_global_rank() + backend = 'nccl' if device == 'cuda' else 'gloo' + dist.init_process_group(backend, world_size=world_size, rank=rank, group_name='main_group') + + def _initialize_pp_process_group(self) -> None: + rank = self.get_global_rank() + world_size = self.get_world_size() + + # build rpc connection + options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=self._num_worker_threads) + + for pp_rank in self._pp_ranks: + options.set_device_map(f'work{pp_rank}', {rank: pp_rank}) + + rpc.init_rpc(name=f'work{rank}', rank=rank, world_size=world_size, rpc_backend_options=options) + + def _initialize_tp_dp_process_group(self) -> None: + rank = self.get_global_rank() + local_stage_ranks = self.get_local_stage_global_ranks() + dp_degree = self.get_dp_degree() + tp_degree = self.get_tp_degree() + self._tp_dp_process_group = ProcessGroup(rank, local_stage_ranks, tp_degree, dp_degree) + + def get_global_rank(self): + return self._rank + + def get_world_size(self): + return self._world_size + + def get_dp_degree(self) -> int: + return self._dp_degree + + def get_tp_degree(self) -> int: + return self._tp_degree + + def get_local_device_mesh_size(self) -> int: + return self._device_mesh_size + + def get_device_mesh_num(self) -> int: + pass + + def get_stage_num(self) -> int: + return self._stage_num + + def is_first_stage(self) -> bool: + return self._is_first_pp_rank + + def is_last_stage(self) -> bool: + return self._is_last_pp_rank + + def check_pp_rank_valid(self, pp_rank: int) -> bool: + return -1 < pp_rank < self._stage_num + + def get_local_pp_rank(self) -> int: + return self._pp_rank + + def get_prev_pp_rank(self) -> int: + prev_pp_rank = self._pp_rank - 1 + if not self.check_pp_rank_valid(prev_pp_rank): + assert ValueError(f"current rank's pp_rank: {self._pp_rank} doesn't have a previous stage!") + return prev_pp_rank + + def get_next_pp_rank(self) -> int: + next_pp_rank = self._pp_rank + 1 + if not self.check_pp_rank_valid(next_pp_rank): + assert ValueError(f"current rank's pp_rank: {self._pp_rank} doesn't have a next stage!") + return next_pp_rank + + def get_local_stage_global_ranks(self) -> List[int]: + return self._local_stage_ranks + + def local_dp_rank(self) -> int: + return self._tp_dp_process_group.dp_local_rank() + + def local_tp_rank(self) -> int: + return self._tp_dp_process_group.tp_local_rank() + + def get_pp_global_ranks(self) -> int: + return self._pp_ranks + + def get_dp_global_ranks(self): + pass + + def get_tp_global_ranks(self): + pass + + def get_chimera_all_reduce_group(self, pp_rank: int): + with self.chimera_lock: + if not hasattr(self, 'chimera_groups'): + world_size = self.get_world_size() + stage_num = self.get_stage_num() + assert world_size % 2 == 0, 'world_size must be even in chimera!' + self.chimera_groups = {} + for rank in range(world_size // 2): + pair = [rank, world_size - 1 - rank] + group = dist.new_group(pair) + self.chimera_groups[pair[0]] = group + self.chimera_groups[pair[1]] = group + self.chimera_groups[pair[0] + stage_num] = group + self.chimera_groups[pair[1] + stage_num] = group + self.chimera_step_lock = threading.Lock() + self.chimera_step_lock.acquire() + + return self.chimera_groups[pp_rank] + + +ppg = PipelineProcessGroup() diff --git a/colossalai/pipeline/rpc/__init__.py b/colossalai/pipeline/rpc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9d9e9d44f46c5aea6c0d211d24af21ed834a2503 --- /dev/null +++ b/colossalai/pipeline/rpc/__init__.py @@ -0,0 +1,4 @@ +from ._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine, ChimeraPipelineEngine +from .utils import pytree_map + +__all__ = ['FillDrainPipelineEngine', 'OneFOneBPipelineEngine', 'ChimeraPipelineEngine', 'pytree_map'] \ No newline at end of file diff --git a/colossalai/pipeline/rpc/_pipeline_base.py b/colossalai/pipeline/rpc/_pipeline_base.py new file mode 100644 index 0000000000000000000000000000000000000000..2d7e25c82e7b917ce7f0b9e4eaf1d15cc8a3cbdd --- /dev/null +++ b/colossalai/pipeline/rpc/_pipeline_base.py @@ -0,0 +1,1309 @@ +import inspect +import math +import threading +from abc import ABC, abstractmethod +from enum import Enum +from functools import partial +from typing import Any, Callable, Dict, List, Tuple + +import torch +import torch.distributed.rpc as rpc +from torch import autograd, nn, optim +from torch._C._distributed_rpc import PyRRef +from torch.futures import Future + +from colossalai.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo +from colossalai.pipeline.pipeline_process_group import ppg +from colossalai.pipeline.rpc.utils import ( + get_batch_lengths, + pyobj_map, + pytree_filter, + pytree_map, + split_batch, + tensor_shape_list, + type_detail, +) + + +class Phase(Enum): + FORWARD = 0 + BACKWARD = 1 + UPDATE = 2 + INPUT = 3 + + +class UniqueKey: + __slots__ = ('microbatch_id', 'phase') + microbatch_id: int + phase: Phase + + def __init__(self, microbatch_id, phase) -> None: + self.microbatch_id = microbatch_id + self.phase = phase + + def __eq__(self, __o: object) -> bool: + return (self.microbatch_id == __o.microbatch_id) and (self.phase == __o.phase) + + def __hash__(self) -> int: + return tuple.__hash__((self.microbatch_id, self.phase)) + + def __repr__(self) -> str: + return f'Key(microbatch_id={self.microbatch_id}, phase={self.phase})' + + +class WorkItem: + __slots__ = ('stage_id', 'phase', 'args', 'kwargs', 'output', 'refcount', 'microbatch_id', 'batch_id', + 'num_microbatches', 'forward_only') + + stage_id: int + phase: Phase + args: Tuple[Any] + kwargs: Dict[str, Any] + output: Future + microbatch_id: int + refcount: int + batch_id: int + num_microbatches: int + forward_only: bool + + def __init__(self, + stage_id, + phase, + args, + kwargs, + output, + microbatch_id, + batch_id, + num_microbatches, + forward_only, + refcount=0) -> None: + for attr_name in self.__slots__: + setattr(self, attr_name, locals()[attr_name]) + + +class BackwardCache: + __slots__ = ('checkpoint', 'stage_input_args', 'stage_input_kwargs', 'stage_outputs') + checkpoint: bool + stage_input_args: Tuple[Any] + stage_input_kwargs: Dict[Any, Any] + stage_outputs: Tuple[Any] + + def __init__(self, + stage_input_args: Tuple[Any], + stage_input_kwargs: Dict[Any, Any] = None, + stage_outputs: Tuple[Any] = None, + checkpoint: bool = False) -> None: + for arg_name in self.__slots__: + setattr(self, arg_name, locals()[arg_name]) + + +class WorkerBase(ABC): + + def __init__(self, + partition_fn: Callable, + partition_args: tuple, + pp_rank: int, + actual_stage_num: int, + num_microbatches: int, + device: str, + criterion: Callable = None, + metric: Callable = None, + checkpoint: bool = False, + data_process_func: Callable = None) -> None: + super().__init__() + + self.pp_rank = pp_rank + self.actual_stage_num = actual_stage_num + self.num_microbatches = num_microbatches + self.checkpoint = checkpoint + + if data_process_func is not None: + self.data_process_func = partial(data_process_func, pp_rank) + + self.device = device + self._initialize_outstanding_range() + + # variable and const for context managment + self.outstanding = 0 + self.forward_times = 0 + self.backward_times = 0 + self.reset_key = UniqueKey(0, Phase.FORWARD) + + # rref of other workers + self.pp_rank_to_worker_rref: Dict[int, PyRRef] = None + + # lock for the list + self._initialize_lock() + + # topology info + self.producer_stage_ids: List[int] = None + self.consumer_stage_ids: List[int] = None + + # module partitions + self.partition_fn = partition_fn + self.partition_args = partition_args + self.criterion = criterion + self.metric = metric + self.reset = False + + # context to maintain loop + self._initialize_context_container() + + # main loop + self.main_loop_thread = threading.Thread(target=self._work_loop, name=f'rank_{pp_rank}', daemon=True) + self.main_loop_thread.start() + + def _get_future_by_device(self): + return torch.futures.Future(devices=None if self.device in (None, 'cpu') else [self.device]) + + def _initialize_outstanding_range(self): + outstanding_range = None + if self.pp_rank == self.actual_stage_num - 1: + outstanding_range = (0, 1) + else: + outstanding_range = (self.actual_stage_num, self.actual_stage_num) + self.outstanding_range = outstanding_range + + def _initialize_context_container(self): + self.microbatch_id_to_backward_cache: Dict[int, BackwardCache] = dict() + self.microbatch_id_to_labels: Dict[int, Any] = dict() + self.work_list: Dict[UniqueKey, WorkItem] = dict() + self.output_list: Dict[UniqueKey, WorkItem] = dict() + + def _initialize_lock(self): + self.partition_condition_lock = threading.Condition(threading.Lock()) + self.work_list_condition_lock = threading.Condition(threading.Lock()) + self.output_list_condition_lock = threading.Condition(threading.Lock()) + self.label_lock = threading.Condition(threading.Lock()) + self.reset_condition = threading.Condition(threading.Lock()) + + def _initialize_partition(self): + partition_fn = self.partition_fn + partition_args = self.partition_args + device = self.device + with self.partition_condition_lock: + self.module_partition: nn.Module = partition_fn(*partition_args).to(device) + self.partition_condition_lock.notify_all() + + def _get_output_all(self, key: UniqueKey, ref_use=False, rank=None): + with self.output_list_condition_lock: + self.output_list_condition_lock.wait_for(lambda: key in self.output_list) + output_work_item = self.output_list[key] + output = output_work_item.output + if not ref_use and output_work_item.phase != Phase.INPUT: + self.output_list.pop(key) + + if not ref_use and output_work_item.phase != Phase.INPUT: + output_work_item.refcount += 1 + refcount = output_work_item.refcount + # lifecycle management for DAG scheduler + if output_work_item.phase == Phase.FORWARD: + lifecycle = len(self.get_consumer_stage_ids()) + if self.is_model_output(): # an extra reference for scheduler collecting results + lifecycle += 1 + elif output_work_item.phase == Phase.BACKWARD: + lifecycle = len(self.get_producer_stage_ids()) + if self.is_model_input() and self._is_last_step( + output_work_item): # an extra reference for ensure_backward + lifecycle += 1 + else: + lifecycle = 0 + refcount = 0 + + with self.output_list_condition_lock: + if refcount <= lifecycle: + self.output_list[key] = output_work_item + self.output_list_condition_lock.notify_all() + + if isinstance(output, Future): + output = output.wait() + + return output + + def sync_global_worker_rrefs(self, pp_rank_to_worker_rref: Dict[int, PyRRef]) -> None: + assert self.pp_rank_to_worker_rref is None, f"in rank {self.pp_rank}, worker has sync global workers rrefs" + assert pp_rank_to_worker_rref is not None, "stage_to_workers must be a dict instead of None" + self.pp_rank_to_worker_rref = pp_rank_to_worker_rref + + # for some schedule need the other worker's info to initialise partition (like Chimera) + # construction of partition is executed after the registion of pp_rank_to_worker_rref + self._initialize_partition() + + # res_use works for lifecycle counter, + # if ref_use is True, lifecycle won't add. + # offset supports get partial output to reduce comm costs. + def get_output_by_key(self, key: UniqueKey, ref_use=False, rank=None, offsets=None) -> Any: + output = self._get_output_all(key, ref_use, rank) + if offsets is None: # get all for non iterable output + return output + else: # get part for iterable output + output = [output[i] for i in offsets] + return output + + def get_numels(self) -> int: + numel = sum(param.numel() for param in self.module_partition.parameters()) + return numel + + def get_parameters(self) -> List[torch.Tensor]: + return [p for p in self.module_partition.parameters()] + + def get_parameter_gradients(self) -> List[torch.Tensor]: + return [p.grad for p in self.module_partition.parameters()] + + def get_partition(self): + with self.partition_condition_lock: + self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition')) + return self.module_partition + + def get_partition_state_dict(self): + with self.partition_condition_lock: + self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition')) + return self.module_partition.state_dict() + + def _make_args_kwargs(self, microbatch, merge=False): + if isinstance(microbatch, dict): + if merge: + return list(microbatch.values()), {} + return [], microbatch + elif isinstance(microbatch, torch.Tensor): + return [microbatch], {} + elif isinstance(microbatch, (tuple, list)): + args = [] + kwargs = {} + for arg in microbatch: + if isinstance(arg, dict): + kwargs.update(arg) + else: + args.append(arg) + if merge: + arg_lst = args + for arg in kwargs.values(): + arg_lst.append(arg) + return arg_lst, {} + return args, kwargs + else: + raise TypeError(f"Input batch can be only dict, list, tuple or tensor, but receive {type(microbatch)}") + + # just for first pp_rank + def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool): + key = UniqueKey(microbatch_id, Phase.FORWARD) + output = self._get_future_by_device() + + if not self.use_middleware(): + # make args and kwargs + args, kwargs = self._make_args_kwargs(microbatch) + + work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, kwargs, output, microbatch_id, None, + self.num_microbatches, forward_only) + with self.work_list_condition_lock: + self.work_list[key] = work_item + self.work_list_condition_lock.notify_all() + else: + # make args and kwargs + arg_lst, _ = self._make_args_kwargs(microbatch, merge=True) + + # first stage assign correct input into other stages + topo: Topo = self.get_topo() + self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo) + input_partition = topo.get_input_partition() + self_input_offsets = input_partition.get_output_offsets(self_partition_id) + recv_input_key = UniqueKey(microbatch_id, Phase.INPUT) + + # set input for self rank + self_arg_lst = [] + for off in self_input_offsets: + self_arg_lst.append(arg_lst[off]) + + work_item = WorkItem(self.pp_rank, Phase.FORWARD, self_arg_lst, {}, output, microbatch_id, None, + self.num_microbatches, forward_only) + with self.work_list_condition_lock: + self.work_list[key] = work_item + self.work_list_condition_lock.notify_all() + + # put input tensor which other nodes need into output_list as Phase.INPUT + work_item_remote = WorkItem(self.pp_rank, Phase.INPUT, [], {}, arg_lst, microbatch_id, None, + self.num_microbatches, forward_only) + + with self.output_list_condition_lock: + self.output_list[recv_input_key] = work_item_remote + self.output_list_condition_lock.notify_all() + + # just for last pp_rank + def set_labels(self, microbatch_id: int, microlabels: Any): + with self.label_lock: + self.microbatch_id_to_labels[microbatch_id] = microlabels + self.label_lock.notify_all() + + # just for last pp_rank + def _begin_backward(self, microbatch_id: int): + with self.work_list_condition_lock: + assert self.producer_stage_ids is not None + + key = UniqueKey(microbatch_id, Phase.BACKWARD) + output = self._get_future_by_device() + grad_wrt_loss = None + + work_item = WorkItem(self.pp_rank, Phase.BACKWARD, grad_wrt_loss, {}, output, microbatch_id, None, + self.num_microbatches, False) + + self.work_list[key] = work_item + self.work_list_condition_lock.notify_all() + + def _subscribe_producer(self, microbatch_id: int, forward_only: bool): + """ + You should call this function asynchronously + """ + stage_id = self.pp_rank + output = self._get_future_by_device() + if not self.use_middleware(): + producer_num = len(self.producer_stage_ids) + subscribe_forward_futures: List[Future] = [None] * producer_num + for i in range(producer_num): + producer_stage_id = self.producer_stage_ids[i] + producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD) + producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] + subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key) + else: + producer_stage_ids = self.get_producer_stage_ids() + producer_num = len(producer_stage_ids) + if self.need_model_input(): + producer_num += 1 # for input partition + subscribe_forward_futures: List[Future] = [None] * producer_num + + # TODO(jiangziyue) get single value instead of the whole output + if self.need_model_input(): + producer_stage_id = 0 + producer_output_key = UniqueKey(microbatch_id, Phase.INPUT) + producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] + offsets = self._get_input_offsets_by_index(target_index=0) + subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, + rank=self.pp_rank, + offsets=offsets) + + for i in range(0, producer_num - 1): + producer_stage_id = producer_stage_ids[i] + producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD) + producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] + target_index = i + 1 + offsets = self._get_input_offsets_by_index(target_index=target_index) + if offsets is not None and len(offsets) == 0: # no need to do rpc + subscribe_forward_futures[target_index] = [] + else: + subscribe_forward_futures[target_index] = producer_worker_rref.rpc_async().get_output_by_key( + producer_output_key, rank=self.pp_rank, offsets=offsets) + + else: + for i in range(producer_num): + producer_stage_id = producer_stage_ids[i] + producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD) + producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] + target_index = i + offsets = self._get_input_offsets_by_index(target_index=target_index) + if offsets is not None and len(offsets) == 0: # no need to do rpc + subscribe_forward_futures[target_index] = [] + else: + subscribe_forward_futures[target_index] = producer_worker_rref.rpc_async().get_output_by_key( + producer_output_key, rank=self.pp_rank, offsets=offsets) + + work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output, + microbatch_id, None, self.num_microbatches, forward_only) + + return work_item_from_producer + + # TODO(jiangziyue) Profile the side effect of the lock for lifecycle protection and consider a better one. + def subscribe_producer(self, microbatch_id: int, forward_only: bool): + key = UniqueKey(microbatch_id, Phase.FORWARD) + with self.work_list_condition_lock: + if key not in self.work_list: + # On current PP middleware design for DAG, get_output_by_key used by _subscribe_producer + # can only be executed once for every producer-consumer stage pair, which is necessary + # to count the lifecycle of work_item. So, keeping the _subscribe_producer in the same + # lock of work_item queue operation gurantees the consistency of lifecycle counter. + work_item_from_producer = self._subscribe_producer(microbatch_id, forward_only) + self.work_list[key] = work_item_from_producer + self.work_list_condition_lock.notify_all() + + def _subscribe_consumer(self, microbatch_id: int): + """ + You should call this function asynchronously + """ + stage_id = self.pp_rank + output = self._get_future_by_device() + if not self.use_middleware(): + consumer_stage_ids = self.consumer_stage_ids + else: + consumer_stage_ids = self.get_consumer_stage_ids() + consumer_num = len(consumer_stage_ids) + subscribe_backward_futures: List[Future] = [None] * consumer_num + for i in range(consumer_num): + consumer_stage_id = consumer_stage_ids[i] + consumer_output_key = UniqueKey(microbatch_id, Phase.BACKWARD) + consumer_worker_rref = self.pp_rank_to_worker_rref[consumer_stage_id] + target_index = i + offsets = self._get_output_offsets_by_index(target_index=target_index) + if offsets is not None and len(offsets) == 0: # no need to do rpc + subscribe_backward_futures[target_index] = [] + else: + subscribe_backward_futures[target_index] = consumer_worker_rref.rpc_async().get_output_by_key( + consumer_output_key, rank=self.pp_rank, offsets=offsets) + + # flatten args + work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, subscribe_backward_futures, {}, output, + microbatch_id, None, self.num_microbatches, False) + + return work_item_from_consumer + + def subscribe_consumer(self, microbatch_id: int): + key = UniqueKey(microbatch_id, Phase.BACKWARD) + with self.work_list_condition_lock: + if key not in self.work_list: + # On current PP middleware design for DAG, get_output_by_key used by subscribe_consumer + # can only be executed once for every producer-consumer stage pair, which is necessary + # to count the lifecycle of work_item. So, keeping the subscribe_consumer in the same + # lock of work_item queue operation gurantees the consistency of lifecycle counter. + work_item_from_consumer = self._subscribe_consumer(microbatch_id) + self.work_list[key] = work_item_from_consumer + self.work_list_condition_lock.notify_all() + + def get_producer_stage_ids(self): + producer_stage_ids = [] + rank = self.pp_rank + if not self.use_middleware(): + prev_rank = rank - 1 + if prev_rank >= 0: + producer_stage_ids.append(prev_rank) + else: + topo: Topo = self.get_topo() + self_partition_id = self.pp_rank_to_partition_id(rank, topo) + self_partition: Partition = topo.get_partition_by_id(self_partition_id) + input_partition_ids = self_partition.get_input_partition_ids() + model_input_partition_id = topo.get_input_partition_id() + for partition_id in input_partition_ids: + # ignore input partition in current implementation. + # it will be specially tackled. + if partition_id != model_input_partition_id: + producer_stage_ids.append(self.partition_id_to_pp_rank(partition_id, topo)) + return producer_stage_ids + + def get_consumer_stage_ids(self): + consumer_stage_ids = [] + rank = self.pp_rank + if not self.use_middleware(): + next_rank = rank + 1 + if next_rank <= self.actual_stage_num - 1: + consumer_stage_ids.append(next_rank) + else: + topo: Topo = self.get_topo() + self_partition_id = self.pp_rank_to_partition_id(rank, topo) + self_partition: Partition = topo.get_partition_by_id(self_partition_id) + output_partition_ids = self_partition.get_output_partition_ids() + model_output_partition_id = topo.get_output_partition_id() + for partition_id in output_partition_ids: + if model_output_partition_id != partition_id: + consumer_stage_ids.append(self.partition_id_to_pp_rank(partition_id, topo)) + return consumer_stage_ids + + def _get_producer_consumer(self) -> None: + rank = self.pp_rank + assert self.producer_stage_ids is None, f"all the producers of rank {rank} has been subscribed" + assert self.consumer_stage_ids is None, f"all the consumers of rank {rank} has been subscribed" + + # should be aranged in order, the order of the input of current forward + self.producer_stage_ids = self.get_producer_stage_ids() + self.consumer_stage_ids = self.get_consumer_stage_ids() + + def pp_rank_to_partition_id(self, pp_rank: int, topo: Topo): + partition_ids = topo.get_mid_partition_ids() + return partition_ids[pp_rank] + + def partition_id_to_pp_rank(self, partition_id: int, topo: Topo): + partition_ids = topo.get_mid_partition_ids() + for i, id in enumerate(partition_ids): + if id == partition_id: + return i + + def get_topo(self): + with self.partition_condition_lock: + self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition')) + if hasattr(self.module_partition, '_topo'): + return self.module_partition._topo + else: + return None + + def use_middleware(self): + topo = self.get_topo() + return topo is not None + + def _get_input_offsets_by_index(self, target_index): + res = [] + topo: Topo = self.get_topo() + self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo) + self_partition: Partition = topo.get_partition_by_id(self_partition_id) + model_input_partition_id = topo.get_input_partition_id() + input_vals = self_partition.get_input_vals() + producer_stage_ids = self.get_producer_stage_ids() + if self.need_model_input(): + # 0 for data from input batch + # >= 1 for data from prev stages + base = 1 + else: + # data from prev stages + base = 0 + for val in input_vals: + val_pos = val.get() + src_partition_id = val_pos.partition_id + src_offset = val_pos.offset + src_index = base + src_partition = topo.get_partition_by_id(src_partition_id) + output_len = len(src_partition.get_output_vals()) + # data from not-input partition + if src_partition_id != model_input_partition_id: + src_stage_id = self.partition_id_to_pp_rank(src_partition_id, topo) + src_index = base + for i, stage_id in enumerate(producer_stage_ids): + if stage_id == src_stage_id: + src_index += i + break + else: # data from input partition + src_index = 0 + # when output_len = 1, not iterable + if target_index == src_index: + if output_len == 1: + res = None # offset = None to get all outputs + return res + else: + res.append(src_offset) + return res + + def _get_output_offsets_by_index(self, target_index): + res = [] + topo: Topo = self.get_topo() + self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo) + self_partition: Partition = topo.get_partition_by_id(self_partition_id) + output_vals = self_partition.get_output_vals() + consumer_stage_ids = self.get_consumer_stage_ids() + for val_list in output_vals: + # An output may be passed to many down stages. + target = None + for val_pos in val_list.get(): + dst_partition_id = val_pos.partition_id + dst_offset = val_pos.offset + dst_partition = topo.get_partition_by_id(dst_partition_id) + input_len = len(dst_partition.get_input_vals()) + dst_stage_id = self.partition_id_to_pp_rank(dst_partition_id, topo) + for i, stage_id in enumerate(consumer_stage_ids): + if stage_id == dst_stage_id: + dst_index = i + break + if target_index == dst_index: + if input_len == 1: + res = None # offset = None to get all outputs + return res + else: + res.append(dst_offset) + return res + + # TODO(jiangziyue) get single value instead of the whole output + def _get_real_args_kwargs_fwd(self, args_or_kwargs): + if not self.use_middleware(): + args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future) + if args_or_kwargs is not None: + if isinstance(args_or_kwargs, dict): + pass + else: + flatten_args = [] + pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True) + args_or_kwargs = flatten_args + else: + args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future) + if args_or_kwargs is not None: + if isinstance(args_or_kwargs, dict): + pass + else: + flatten_args = [] + if self.is_first_stage(): + pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True) + else: # get by offset + topo: Topo = self.get_topo() + self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo) + self_partition: Partition = topo.get_partition_by_id(self_partition_id) + model_input_partition_id = topo.get_input_partition_id() + input_vals = self_partition.get_input_vals() + producer_stage_ids = self.get_producer_stage_ids() + if self.need_model_input(): + # 0 for data from input batch + # >= 1 for data from prev stages + base = 1 + else: + # data from prev stages + base = 0 + for val in input_vals: + val_pos = val.get() + src_partition_id = val_pos.partition_id + src_offset = val_pos.offset + src_index = base + src_partition = topo.get_partition_by_id(src_partition_id) + output_len = len(src_partition.get_output_vals()) + # data from not-input partition + if src_partition_id != model_input_partition_id: + src_stage_id = self.partition_id_to_pp_rank(src_partition_id, topo) + src_index = base + for i, stage_id in enumerate(producer_stage_ids): + if stage_id == src_stage_id: + src_index += i + break + else: # data from input partition + src_index = 0 + # when output_len = 1, not iterable + if output_len == 1: + target = args_or_kwargs[src_index] + else: + offsets = self._get_input_offsets_by_index(src_index) + real_offset = offsets.index(src_offset) + target = args_or_kwargs[src_index][real_offset] + flatten_args.append(target) + args_or_kwargs = flatten_args + return args_or_kwargs + + # TODO(jiangziyue) get single value instead of the whole output + def _get_real_args_kwargs_bwd(self, args_or_kwargs): + if not self.use_middleware(): + args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future) + if args_or_kwargs is not None: + if isinstance(args_or_kwargs, dict): + pass + else: + flatten_args = [] + pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True) + args_or_kwargs = flatten_args + else: + for i, arg in enumerate(args_or_kwargs): + args_or_kwargs[i] = arg.wait() + if args_or_kwargs is not None: # get by offset + flatten_args = [] + topo: Topo = self.get_topo() + self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo) + self_partition: Partition = topo.get_partition_by_id(self_partition_id) + output_vals = self_partition.get_output_vals() + consumer_stage_ids = self.get_consumer_stage_ids() + for val_list in output_vals: + # An output may be passed to many down stages. + target = None + for val_pos in val_list.get(): + dst_partition_id = val_pos.partition_id + dst_offset = val_pos.offset + dst_partition = topo.get_partition_by_id(dst_partition_id) + input_len = len(dst_partition.get_input_vals()) + dst_stage_id = self.partition_id_to_pp_rank(dst_partition_id, topo) + for i, stage_id in enumerate(consumer_stage_ids): + if stage_id == dst_stage_id: + dst_index = i + break + if input_len == 1: + part_grad = args_or_kwargs[dst_index] + else: + offsets = self._get_output_offsets_by_index(dst_index) + real_offsets = offsets.index(dst_offset) + part_grad = args_or_kwargs[dst_index][real_offsets] + + if target is None: + target = part_grad + elif part_grad is not None: + target += part_grad + else: + continue + flatten_args.append(target) + args_or_kwargs = flatten_args + return args_or_kwargs + + @abstractmethod + def _get_work_item_key(self) -> UniqueKey: + """ + this method control the order of the microbatch to consume + """ + + def is_first_stage(self): + return self.pp_rank == 0 + + def is_last_stage(self): + return self.pp_rank == self.actual_stage_num - 1 + + def need_model_input(self): + need_input = False + topo: Topo = self.get_topo() + self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo) + self_partition = topo.get_partition_by_id(self_partition_id) + partition_inputs = self_partition.get_input_partition_ids() + model_input_partition_id = topo.get_input_partition_id() + if model_input_partition_id in partition_inputs: + need_input = True + return not self.is_first_stage() and need_input + + def is_model_output(self): + return self.is_last_stage() + + def is_model_input(self): + return self.is_first_stage() + + def _default_data_process_func(self, args_kwargs): + if self.is_first_stage(): + args = args_kwargs[0] + kwargs = args_kwargs[1] + else: + args = args_kwargs + kwargs = {} + + return args, kwargs + + def _consume_work_item_by_phase(self, work_item: WorkItem): + phase = work_item.phase + args = work_item.args + kwargs = work_item.kwargs + microbatch_id = work_item.microbatch_id + forward_only = work_item.forward_only + data_process_func = getattr(self, 'data_process_func', self._default_data_process_func) + consume_result = None + + is_first_stage = self.is_first_stage() + is_last_stage = self.is_last_stage() + + if phase == Phase.FORWARD: + # remind its consumer to get data before forward + if not is_last_stage: + for stage_id in self.consumer_stage_ids: + consumer_worker_rref = self.pp_rank_to_worker_rref[stage_id] + consumer_worker_rref.remote().subscribe_producer(microbatch_id, forward_only) + + # sustain pipeline context + self.forward_times += 1 + if not forward_only: + self.outstanding += 1 + + # parse and integrate args and kwargs + if is_first_stage: + args = self._get_real_args_kwargs_fwd(args) + kwargs = self._get_real_args_kwargs_fwd(kwargs) + args_kwargs = (args, kwargs) + else: + args_kwargs = self._get_real_args_kwargs_fwd(args) + + args_kwargs = pyobj_map(args_kwargs, fn=lambda x: x.to(self.device).detach(), + process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU + args_kwargs = pyobj_map(args_kwargs, fn=lambda x: self.device, + process_types=torch.device) # change devices from last stage to current device + + args, kwargs = data_process_func(args_kwargs) + + stage_outputs = None + stage_input_args = args + stage_input_kwargs = kwargs + use_checkpoint = None + + if forward_only: + with torch.no_grad(): + consume_result = self.module_partition(*args, **kwargs) + + if is_last_stage and self.criterion: + with self.label_lock: + self.label_lock.wait_for(lambda: microbatch_id in self.microbatch_id_to_labels) + labels = self.microbatch_id_to_labels.pop(microbatch_id) + loss: torch.Tensor = self.criterion(consume_result, labels) + if self.metric is not None: + metric_result = self.metric(consume_result, labels) + if isinstance(metric_result, torch.Tensor): + metric_result = metric_result.item() + else: + metric_result = None + consume_result = [loss.item(), metric_result] + + # last stage doesn't need to do checkpoint, for it will do backward instantly + stage_input_args = None + stage_input_kwargs = None + stage_outputs = consume_result + + elif self.checkpoint and not is_last_stage: + with torch.no_grad(): + consume_result = self.module_partition(*args, **kwargs) + + stage_outputs = consume_result + use_checkpoint = True + + else: + consume_result = self.module_partition(*args, **kwargs) + + if is_last_stage and self.criterion: + with self.label_lock: + self.label_lock.wait_for(lambda: microbatch_id in self.microbatch_id_to_labels) + labels = self.microbatch_id_to_labels.pop(microbatch_id) + loss: torch.Tensor = self.criterion(consume_result, labels) + if self.metric is not None: + metric_result = self.metric(consume_result, labels) + if isinstance(metric_result, torch.Tensor): + metric_result = metric_result.item() + else: + metric_result = None + + consume_result = [loss.item(), metric_result] + else: + loss = consume_result + + stage_outputs = loss + use_checkpoint = False + + if not forward_only: + self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache(stage_input_args, + stage_input_kwargs, + stage_outputs, + checkpoint=use_checkpoint) + consume_result = pyobj_map(consume_result, fn=lambda x: x.to('cpu'), + process_types=torch.Tensor) # torch rpc doesn't support args or rets in + + # if not forward_only, do the backward + if not forward_only: + if is_last_stage: # if it is the last stage, trigger backward automatic + self._begin_backward(microbatch_id) + + elif phase == Phase.BACKWARD: + # remind its producer to get data before backward + if not is_first_stage: + for stage_id in self.producer_stage_ids: + producer_worker_rref = self.pp_rank_to_worker_rref[stage_id] + producer_worker_rref.remote().subscribe_consumer(microbatch_id) + self.backward_times += 1 + self.outstanding -= 1 + + assert microbatch_id in self.microbatch_id_to_backward_cache, f"microbatch_id {microbatch_id} not in backward cache" + backward_cache = self.microbatch_id_to_backward_cache.pop(microbatch_id) + + stage_outputs = backward_cache.stage_outputs + stage_input_args = backward_cache.stage_input_args + stage_input_kwargs = backward_cache.stage_input_kwargs + use_checkpoint = backward_cache.checkpoint + + if use_checkpoint: + stage_outputs = [self.module_partition(*stage_input_args, **stage_input_kwargs)] + + # overlap recompute and future.wait + if not is_last_stage: + grad_tensors = self._get_real_args_kwargs_bwd(args) + else: + grad_tensors = None + + # take tensor only (for only tensor can do backward) + # TODO(jiangziyue) : All values which should do bp are torch.Tensor? + stage_outputs = pytree_filter(lambda x: True, stage_outputs, process_types=torch.Tensor) + grad_tensors = pytree_filter(lambda x: True, grad_tensors, process_types=torch.Tensor) + + # output all input's grad to producer, even it has no grad(output None) + # to make the offset aligned to the topo's record. + if grad_tensors is not None: + filtered_outputs = [] + filtered_grads = [] + for i, grad in enumerate(grad_tensors): + stage_output = stage_outputs[i] + if stage_output.requires_grad and grad is not None: + filtered_outputs.append(stage_output) + filtered_grads.append(grad) + + stage_outputs = filtered_outputs + grad_tensors = pyobj_map(filtered_grads, fn=lambda x: x.to(self.device), + process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU + autograd.backward(stage_outputs, grad_tensors=grad_tensors) + + # collect grad of input tensor + consume_result = [] + if not is_first_stage: + # In current design, input mush be a flatten args. + for arg in stage_input_args: + if isinstance(arg, torch.Tensor): + consume_result.append(arg.grad) + else: + consume_result.append(None) + consume_result = pyobj_map( + consume_result, fn=lambda x: x.to('cpu'), + process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU + + else: + raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}") + + return consume_result + + def _get_store_len(self): + return f'work_list:{len(self.work_list)} output_list:{len(self.output_list)} backward_cache:{len(self.microbatch_id_to_backward_cache)} label_cache:{len(self.microbatch_id_to_labels)}' + + def _get_parameter_grad_sum(self): + grad_sum = 0 + for p in self.module_partition.parameters(): + if p.grad is not None: + grad_sum += p.grad.sum() + return grad_sum + + def _is_first_step(self, work_item: WorkItem) -> bool: + return work_item.phase == Phase.FORWARD and work_item.microbatch_id == 0 + + def _is_last_step(self, work_item: WorkItem) -> bool: + if work_item.forward_only: + last_phase = Phase.FORWARD + else: + last_phase = Phase.BACKWARD + is_last_phase = work_item.phase == last_phase + is_last_microbatch = work_item.microbatch_id == self.num_microbatches - 1 + return is_last_phase and is_last_microbatch + + def _hook_before_step(self): + pass + + # install the main loop to wait for next batch input + def _wait_for_reset(self): + with self.reset_condition: + self.reset_condition.wait_for(lambda: self.reset) + self.reset = False + + # do the main loop to consume ready_list + def _work_loop(self): + # for init + self._get_producer_consumer() + torch.cuda.set_device(ppg.get_local_pp_rank()) + + # main loop + while True: + work_item_key = self._get_work_item_key() + # move current work item to output_list to activate subscribe in advance + with self.work_list_condition_lock: + self.work_list_condition_lock.wait_for(lambda: work_item_key in self.work_list) + work_item = self.work_list[work_item_key] + + with self.output_list_condition_lock: + # assert work_item_key not in self.output_list + self.output_list[work_item_key] = work_item + self.output_list_condition_lock.notify_all() + + consume_result = self._consume_work_item_by_phase(work_item) + + with self.work_list_condition_lock: + self.work_list.pop(work_item_key) + work_item.output.set_result(consume_result) + + # if is last step in one batch reset context and do step + if self._is_last_step(work_item): + self._wait_for_reset() + + # reset context and resume loop + def reset_context(self): + self.forward_times = 0 + self.backward_times = 0 + self.outstanding = 0 + self._initialize_outstanding_range() + with self.work_list_condition_lock: + self.work_list.clear() + + with self.output_list_condition_lock: + self.output_list.clear() + + with self.reset_condition: + self.reset = True + self.reset_condition.notify_all() + + def initialize_optimizer(self, optimizer_class: type, **kwargs): + self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs) + + def step(self): + self._hook_before_step() + self.optimizer.step() + self.optimizer.zero_grad() + + +class PipelineEngineBase(ABC, nn.Module): + + def __init__(self, + worker_type, + partition_fn: Callable, + stage_num, + num_microbatches, + device: str, + use_1F1B=False, + chunk: int = 1, + criterion: Callable = None, + metric: Callable = None, + checkpoint: bool = False, + data_process_func: Callable = None) -> None: + super().__init__() + self.worker_type = worker_type + self.partition_fn: Callable = partition_fn + self.chunk = chunk + self.criterion = criterion + self.metric = metric + self.num_microbatches = num_microbatches + self.device = device + self.use_1F1B = use_1F1B + self.stage_num = stage_num + self.checkpoint = checkpoint + self.data_process_func = data_process_func + + self.pp_rank_to_worker_rref: Dict[int, PyRRef] = dict() + + self._check_argument() + self._create_pp_rank_to_rpc_worker_id() + self._create_pp_rank_to_module_partition_id() + self._init_worker() + + def _check_argument(self) -> None: + # make virtual stage num + self.virtual_stage_num = self.stage_num * self.chunk + assert self.stage_num <= torch.cuda.device_count(), "stage_num must be smaller than device count!" + + # check data_process_func + data_process_func = self.data_process_func + if data_process_func is not None: + assert callable(data_process_func), "data_process_func must be a function" + assert '' not in data_process_func.__repr__(), "data_process_func must be a global function" + assert '' not in data_process_func.__repr__(), "data_process_func cannot be a lambda expression" + sig = inspect.signature(data_process_func) + assert len( + sig.parameters + ) == 2, f"length of data_process_func' arguments must be 2, receive {len(sig.parameters)} arguments instead" + + def _get_actual_stage_num(self) -> int: + return self.stage_num if self.chunk == 1 else self.virtual_stage_num + + def _create_pp_rank_to_rpc_worker_id(self) -> None: + """create a map from model partition to stage_id, which is useful when use_interleave is True. + e.g. If a model is splited into 4 parts, which means stage_num is 2, chunk is 2, then + pp_rank_to_rpc_worker_id = [0, 1, 0, 1], that means first and third part + of partitions will be moved to device 0 and the others to device 1 + """ + stage_num = self.stage_num + actual_stage_num = self._get_actual_stage_num() + self.pp_rank_to_rpc_worker_id = [0] * actual_stage_num + for pp_rank in range(actual_stage_num): + self.pp_rank_to_rpc_worker_id[pp_rank] = pp_rank % stage_num + + def _create_pp_rank_to_module_partition_id(self) -> None: + """By default(both fill drain and 1F1B), length of model partitions equal to + actual_stage_num, so allocate model partition to corresponding stage + """ + actual_stage_num = self._get_actual_stage_num() + self.pp_rank_to_module_partition_id = [0] * actual_stage_num + for pp_rank in range(actual_stage_num): + self.pp_rank_to_module_partition_id[pp_rank] = pp_rank + + def _init_worker(self) -> None: + actual_stage_num = self._get_actual_stage_num() + + worker_type = self.worker_type + checkpoint = self.checkpoint + num_microbatches = self.num_microbatches + device = self.device + criterion = self.criterion + metric = self.metric + partition_fn = self.partition_fn + chunk = self.chunk + data_process_func = self.data_process_func + + for pp_rank in range(len(self.pp_rank_to_rpc_worker_id)): + partition_id = self.pp_rank_to_module_partition_id[pp_rank] + partition_args = (partition_id, chunk, actual_stage_num) + rpc_worker_id = self.pp_rank_to_rpc_worker_id[pp_rank] + if device[:4] == 'cuda': + device = f'cuda:{rpc_worker_id}' + self.pp_rank_to_worker_rref[pp_rank] = rpc.remote(rpc_worker_id, + worker_type, + args=(partition_fn, partition_args, pp_rank, + actual_stage_num, num_microbatches, device, + criterion, metric, checkpoint, data_process_func)) + + # let each worker know global worker rref (include itself) + sync_futs = [] + for pp_rank in self.pp_rank_to_worker_rref: + fut = self.pp_rank_to_worker_rref[pp_rank].rpc_async(timeout=0).sync_global_worker_rrefs( + self.pp_rank_to_worker_rref) + sync_futs.append(fut) + + for fut in sync_futs: + fut.wait() + + def remote_numels(self) -> Dict[int, int]: + numels = {} + actual_stage_num = self._get_actual_stage_num() + for stage_id in range(actual_stage_num): + worker_rref = self.pp_rank_to_worker_rref[stage_id] + numel = worker_rref.rpc_sync().get_numels() + numels[stage_id] = numel + return numels + + def remote_parameters(self) -> Dict[int, List[torch.Tensor]]: + parameters = {} + actual_stage_num = self._get_actual_stage_num() + for stage_id in range(actual_stage_num): + parameters[stage_id] = [] + worker_rref = self.pp_rank_to_worker_rref[stage_id] + for p in worker_rref.rpc_sync().get_parameters(): + parameters[stage_id].append(p) + return parameters + + def remote_grad(self) -> Dict[int, List[torch.Tensor]]: + grads = {} + actual_stage_num = self._get_actual_stage_num() + for stage_id in range(actual_stage_num): + grads[stage_id] = [] + worker_rref = self.pp_rank_to_worker_rref[stage_id] + for grad in worker_rref.rpc_sync().get_parameter_gradients(): + grads[stage_id].append(grad) + return grads + + def get_input_pp_ranks(self) -> List[int]: + return [0] + + def get_output_pp_ranks(self) -> List[int]: + return [self._get_actual_stage_num() - 1] + + def _consume_constraint(self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int], + output_pp_ranks: List[int], ret_future): + actual_stage_num = self._get_actual_stage_num() + use_1F1B = self.use_1F1B + if microbatch_id >= actual_stage_num: + if forward_only or not use_1F1B: + for pp_rank in output_pp_ranks: + ret_future[pp_rank][microbatch_id - actual_stage_num].wait() + else: + key = UniqueKey(microbatch_id - actual_stage_num, Phase.BACKWARD) + futs = [] + for pp_rank in input_pp_ranks: + worker_rref = self.pp_rank_to_worker_rref[pp_rank] + fut = worker_rref.rpc_async().get_output_by_key(key, ref_use=True, offsets=[]) + futs.append(fut) + + for fut in futs: + fut.wait() + + def _create_ret_future(self, output_pp_ranks: List[int]) -> Dict[int, List[Future]]: + num_microbatches = self.num_microbatches + return {pp_rank: [None] * num_microbatches for pp_rank in output_pp_ranks} + + def _set_input(self, input_pp_ranks: List[int], microbatch_id: int, microbatch, forward_only: bool): + for pp_rank in input_pp_ranks: + worker_rref = self.pp_rank_to_worker_rref[pp_rank] + # TODO : add relationship between input_pp_ranks and parts of microbatch + worker_rref.remote().set_input(microbatch_id, microbatch, forward_only) + + def _set_labels(self, output_pp_ranks: List[int], microbatch_id: int, microlabels): + for pp_rank in output_pp_ranks: + worker_rref = self.pp_rank_to_worker_rref[pp_rank] + # TODO : add relationship between output_pp_ranks and parts of microlabels + worker_rref.remote().set_labels(microbatch_id, microlabels) + + # TODO(jiangziyue) : get model output with single value, instead of merging into last stage. + def _subscribe_forward(self, microbatch_id: int, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]): + key = UniqueKey(microbatch_id, Phase.FORWARD) + for pp_rank in output_pp_ranks: + worker_rref = self.pp_rank_to_worker_rref[pp_rank] + ret_future[pp_rank][microbatch_id] = worker_rref.rpc_async().get_output_by_key(key) + + def _ensure_backward(self, forward_only: bool, input_pp_ranks: List[int]): + if not forward_only: + backward_result = [] + for pp_rank in input_pp_ranks: + worker_rref = self.pp_rank_to_worker_rref[pp_rank] + key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD) + fut = worker_rref.rpc_async().get_output_by_key( + key, offsets=[]) # only ensure the res exists, no need for real data. + backward_result.append(fut) + + for fut in backward_result: + fut.wait() + + def _collect_forward_result(self, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]): + forward_result = [] + for pp_rank in output_pp_ranks: + worker_forward_result = [None] * self.num_microbatches + for microbatch_id in range(self.num_microbatches): + ret = ret_future[pp_rank][microbatch_id].wait() + # TODO : more stable format + ret = [ret] if isinstance(ret, torch.Tensor) else ret + worker_forward_result[microbatch_id] = ret + + worker_forward_result = list(zip(*worker_forward_result)) + forward_result.extend(worker_forward_result) + + return forward_result + + def _reset_worker(self): + actual_stage_num = self._get_actual_stage_num() + reset_futs: List[Future] = [] + for pp_rank in range(actual_stage_num): + worker_rref = self.pp_rank_to_worker_rref[pp_rank] + fut = worker_rref.rpc_async().reset_context() + reset_futs.append(fut) + + for fut in reset_futs: + fut.wait() + + def forward_backward(self, batch: torch.Tensor, labels: torch.Tensor = None, forward_only: bool = False): + batch_lengths = get_batch_lengths(batch) + batch_length = batch_lengths[0] + + if labels is not None and not forward_only: + assert hasattr( + self, 'optimizer_class'), "call `initialize_optimizer` to initialize optimizer before forward_backward" + + num_microbatches = self.num_microbatches + + assert batch_length >= num_microbatches, "num_microbatches is greater than the size of a batch, which is illegal" + microbatch_size = math.ceil(batch_length / num_microbatches) + device = self.device + + # If Chimera mode is used, then rank of down pipeline is excluded from 'input_pp_ranks' or 'output_pp_ranks' + input_pp_ranks = self.get_input_pp_ranks() + output_pp_ranks = self.get_output_pp_ranks() + + # a cache to collect data and control flow + ret_future = self._create_ret_future(output_pp_ranks) + + for microbatch_id in range(num_microbatches): + # control data input speed + # to prevent exceed of wait limitations + # self._consume_constraint(microbatch_id, forward_only, input_pp_ranks, output_pp_ranks, ret_future) + batch_start = microbatch_size * microbatch_id + batch_end = min(batch_start + microbatch_size, batch_length) + + # set input + microbatch = split_batch(batch, batch_start, batch_end, device) + self._set_input(input_pp_ranks, microbatch_id, microbatch, forward_only) + + # set labels + if labels is not None: + # microlabels = labels[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)] + microlabels = split_batch(labels, batch_start, batch_end, device) + self._set_labels(output_pp_ranks, microbatch_id, microlabels) + + # get data asynchronously + self._subscribe_forward(microbatch_id, output_pp_ranks, ret_future) + + # wait for first rank to ensure all backwards are done + self._ensure_backward(forward_only, input_pp_ranks) + + # collect forward result + forward_result = self._collect_forward_result(output_pp_ranks, ret_future) + + if not forward_only and hasattr(self, 'optimizer_class'): + self.step() + + self._reset_worker() # reset worker attributes for next batch + return forward_result + + def initialize_optimizer(self, optimizer_class: type, **kwargs): + self.optimizer_class = optimizer_class + for pp_rank in self.pp_rank_to_worker_rref: + worker_rref = self.pp_rank_to_worker_rref[pp_rank] + worker_rref.remote().initialize_optimizer(optimizer_class, **kwargs) + + def step(self): + actual_stage_num = self._get_actual_stage_num() + step_futs: List[Future] = [] + for pp_rank in range(actual_stage_num): + worker_rref = self.pp_rank_to_worker_rref[pp_rank] + fut = worker_rref.rpc_async().step() + step_futs.append(fut) + + for fut in step_futs: + fut.wait() diff --git a/colossalai/pipeline/rpc/_pipeline_schedule.py b/colossalai/pipeline/rpc/_pipeline_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..0d572231d37830a89cac07eecf066f7e556ea640 --- /dev/null +++ b/colossalai/pipeline/rpc/_pipeline_schedule.py @@ -0,0 +1,346 @@ +import threading +from typing import Callable, Dict, List + +import torch +import torch.distributed as dist +from torch._C._distributed_rpc import PyRRef +from torch.futures import Future + +from colossalai.pipeline.pipeline_process_group import ppg +from colossalai.pipeline.rpc._pipeline_base import Phase, PipelineEngineBase, UniqueKey, WorkerBase, WorkItem + +# Implementation of different Pipeline schedule +# Worker defines the worker for each stage +# PipelineEngine is the class for use + + +class FillDrainWorker(WorkerBase): + + def _get_work_item_key(self) -> UniqueKey: + # execute backward first (if backward phase in work_list) + num_microbatches = self.num_microbatches + + if self.forward_times < num_microbatches: + target_phase = Phase.FORWARD + target_microbatch_id = self.forward_times + else: + target_phase = Phase.BACKWARD + target_microbatch_id = self.backward_times + + target_key = UniqueKey(target_microbatch_id, target_phase) + + return target_key + + +class FillDrainPipelineEngine(PipelineEngineBase): + + def __init__(self, + partition_fn: Callable, + stage_num: int, + num_microbatches: int, + device: str, + chunk: int = 1, + criterion: Callable = None, + metric: Callable = None, + checkpoint: bool = False, + data_process_func: Callable = None) -> None: + + if chunk > 1: + assert num_microbatches % stage_num == 0, \ + "if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!" + use_1F1B = False + + super().__init__(FillDrainWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, + metric, checkpoint, data_process_func) + + +class OneFOneBWorker(WorkerBase): + + def _get_work_item_key(self) -> UniqueKey: + # execute backward first (if backward phase in work_list) + pp_rank = self.pp_rank + actual_stage_num = self.actual_stage_num + num_microbatches = self.num_microbatches + is_last_stage = pp_rank == actual_stage_num - 1 + + if self.outstanding <= self.outstanding_range[0]: + target_phase = Phase.FORWARD + target_microbatch_id = self.forward_times + elif self.outstanding >= self.outstanding_range[1]: + target_phase = Phase.BACKWARD + target_microbatch_id = self.backward_times + else: + raise ValueError("outstanding_range[1] - outstanding_range[0] must be in [0, 1]") + + target_key = UniqueKey(target_microbatch_id, target_phase) + + # change outstanding_range at: + # 1. forward times reach actual_stage_num, this is the end of continuous forward + # 2. forward times reach num_microbatches, this is the end of 1F1B mode + if not is_last_stage and \ + target_key.phase == Phase.FORWARD: + if target_key.microbatch_id == actual_stage_num - 1 and num_microbatches > 2: + # Why need num_microbatches > 2 ? Because there is no steady stage when num_microbatches <= 2 + outstanding_min = actual_stage_num - pp_rank - 1 + outstanding_max = actual_stage_num - pp_rank + self.outstanding_range = (outstanding_min, outstanding_max) + if target_key.microbatch_id == num_microbatches - 1: + self.outstanding_range = (0, 0) + + return target_key + + +class OneFOneBPipelineEngine(PipelineEngineBase): + + def __init__(self, + partition_fn: Callable, + stage_num: int, + num_microbatches: int, + device: str, + chunk: int = 1, + criterion: Callable = None, + metric: Callable = None, + checkpoint: bool = False, + data_process_func: Callable = None) -> None: + + if chunk > 1: + assert num_microbatches % stage_num == 0, \ + "if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!" + # assert num_microbatches > stage_num * chunk, "num_microbatches must be greater than stage_num * chunk" + use_1F1B = True + + super().__init__(OneFOneBWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, + metric, checkpoint, data_process_func) + + +class ChimeraWorker(WorkerBase): + + def _get_producer_consumer(self) -> None: + rank = self.pp_rank + min_pp_rank = (rank // self.actual_stage_num) * self.actual_stage_num + max_pp_rank = min_pp_rank + self.actual_stage_num - 1 + + assert self.producer_stage_ids is None, f"all the producers of rank {rank} has been subscribed" + assert self.consumer_stage_ids is None, f"all the consumers of rank {rank} has been subscribed" + + # should be aranged in order, the order of the input of current forward + self.producer_stage_ids = [] + self.consumer_stage_ids = [] + + # Just for demo + prev_rank = rank - 1 + next_rank = rank + 1 + if prev_rank >= min_pp_rank: + self.producer_stage_ids.append(prev_rank) + if next_rank <= max_pp_rank: + self.consumer_stage_ids.append(next_rank) + + def _get_work_item_key(self) -> UniqueKey: + pp_rank = self.pp_rank + stage_num = self.actual_stage_num + real_microbatch_num = self.num_microbatches // 2 + + forward_block_size = 1 if self.num_microbatches < stage_num else self.num_microbatches // stage_num + forward_block_num = self.forward_times // forward_block_size + + if self.forward_times >= real_microbatch_num or \ + ((pp_rank + 1) % stage_num == 0 and forward_block_num > self.backward_times): + target_phase = Phase.BACKWARD + target_microbatch_id = self.backward_times + else: # others + target_phase = Phase.FORWARD + target_microbatch_id = self.forward_times + + # In up pipeline, microbatch_id to consume is 0, 2, 4 (2n) + # In down pipeline, microbatch_id to consume is 1, 3, 5 (2n + 1) + real_target_microbatch_id = target_microbatch_id * 2 + if pp_rank >= stage_num: + real_target_microbatch_id += 1 + target_key = UniqueKey(real_target_microbatch_id, target_phase) + + with self.work_list_condition_lock: + self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list) + return target_key + + def _initialize_partition(self): + # In order to ensure the down pipeline share the same parameter + # with the up pipeline, partition of down partition will be copied + # from corresponding up stage + pp_rank = self.pp_rank + stage_num = self.actual_stage_num + device = self.device + if pp_rank < stage_num: + super()._initialize_partition() + else: + # if it is down pipeline, create partition by origin method + co_up_pp_worker_rref = self.pp_rank_to_worker_rref[pp_rank - stage_num] + # get the coresponding model state dict and wait for its init + state_dict = co_up_pp_worker_rref.rpc_sync().get_partition_state_dict() + super()._initialize_partition() + self.module_partition.load_state_dict(state_dict) + + # init group for chimera in ppg + ppg.get_chimera_all_reduce_group(pp_rank) + + # lock for step sync + self.step_sync_lock = threading.Lock() + self.step_sync_lock.acquire() + + self.have_grad_lock = threading.Lock() + self.have_grad_lock.acquire() + + def _get_lock_gradient(self): + self.have_grad_lock.acquire() + grads = self.get_parameter_gradients() + self.step_sync_lock.release() + return grads + + def is_first_stage(self): + return (self.pp_rank % self.actual_stage_num) == 0 + + def is_last_stage(self): + return (self.pp_rank % self.actual_stage_num) == self.actual_stage_num - 1 + + def _is_last_step(self, work_item: WorkItem) -> bool: + if work_item.forward_only: + last_phase = Phase.FORWARD + else: + last_phase = Phase.BACKWARD + is_last_phase = work_item.phase == last_phase + last_microbatch_id = self.num_microbatches - 1 + if self.pp_rank < self.actual_stage_num: + last_microbatch_id -= 1 + is_last_microbatch = work_item.microbatch_id == last_microbatch_id + return is_last_phase and is_last_microbatch + + def _get_step_order(self) -> List[int]: + # TODO : If you want to extend it to multi head chimera, overwrite here + stage_num = self.actual_stage_num + pp_rank = self.pp_rank + # pp_rank in the same device + local_device_pp_ranks = [pp_rank, stage_num * 2 - pp_rank - 1] + local_device_pp_ranks.sort(reverse=min(local_device_pp_ranks) < stage_num // 2) + return local_device_pp_ranks + + def _hook_before_step(self): + self.have_grad_lock.release() + pp_rank = self.pp_rank + stage_num = self.actual_stage_num + co_pp_rank = (pp_rank + stage_num) % (2 * stage_num) + + # if currrent pp_rank is not the first to do step + # wait its previous pp_rank finish step + grads = self.get_parameter_gradients() + + # send + co_worker = self.pp_rank_to_worker_rref[co_pp_rank] + co_grads = co_worker.rpc_sync()._get_lock_gradient() + # sync + self.step_sync_lock.acquire() + for i in range(len(grads)): + grads[i] += co_grads[i] + + +class ChimeraPipelineEngine(PipelineEngineBase): + + def __init__(self, + partition_fn: Callable, + stage_num: int, + num_microbatches: int, + device: str, + criterion: Callable = None, + metric: Callable = None, + checkpoint: bool = False, + data_process_func: Callable = None) -> None: + + assert num_microbatches % stage_num == 0, \ + "In Chimera, num_microbatches must be the multiply of stage_num!" + use_1F1B = False + chunk = 1 + + super().__init__(ChimeraWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, + metric, checkpoint, data_process_func) + + def _consume_constraint(self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int], + output_pp_ranks: List[int], ret_future): + pass + + def _create_pp_rank_to_rpc_worker_id(self) -> None: + stage_num = self.stage_num + self.pp_rank_to_rpc_worker_id = [0] * (stage_num * 2) + for pp_rank in range(stage_num): + self.pp_rank_to_rpc_worker_id[pp_rank] = pp_rank + self.pp_rank_to_rpc_worker_id[pp_rank + stage_num] = stage_num - pp_rank - 1 + + def _create_pp_rank_to_module_partition_id(self) -> None: + stage_num = self.stage_num + self.pp_rank_to_module_partition_id = [0] * (stage_num * 2) + for pp_rank in range(stage_num): + self.pp_rank_to_module_partition_id[pp_rank] = pp_rank + self.pp_rank_to_module_partition_id[pp_rank + stage_num] = pp_rank + + def _create_ret_future(self, output_pp_ranks: List[int]) -> Dict[int, List[Future]]: + num_microbatches = self.num_microbatches + stage_num = self.stage_num + up_ret_future = {pp_rank: [None] * num_microbatches for pp_rank in output_pp_ranks} + down_ret_future = {pp_rank + stage_num: [None] * num_microbatches for pp_rank in output_pp_ranks} + # merge up and down + return {**up_ret_future, **down_ret_future} + + def _set_input(self, input_pp_ranks: List[int], microbatch_id: int, microbatch, forward_only: bool): + # offset is 0 for all the ranks in up pipeline + # offset is stage_num for all the ranks in down pipeline + offset = (microbatch_id % 2) * self.stage_num + for pp_rank in input_pp_ranks: + worker_rref = self.pp_rank_to_worker_rref[pp_rank + offset] + worker_rref.remote().set_input(microbatch_id, microbatch, forward_only) + + def _set_labels(self, output_pp_ranks: List[int], microbatch_id: int, microlabels): + # offset is 0 for all the ranks in up pipeline + # offset is stage_num for all the ranks in down pipeline + offset = (microbatch_id % 2) * self.stage_num + for pp_rank in output_pp_ranks: + worker_rref = self.pp_rank_to_worker_rref[pp_rank + offset] + worker_rref.remote().set_labels(microbatch_id, microlabels) + + def _subscribe_forward(self, microbatch_id: int, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]): + key = UniqueKey(microbatch_id, Phase.FORWARD) + offset = (microbatch_id % 2) * self.stage_num + for pp_rank in output_pp_ranks: + worker_rref = self.pp_rank_to_worker_rref[pp_rank + offset] + ret_future[pp_rank + offset][microbatch_id] = worker_rref.rpc_async().get_output_by_key(key) + + def _ensure_backward(self, forward_only: bool, input_pp_ranks: List[int]): + stage_num = self.stage_num + num_microbatches = self.num_microbatches + if not forward_only: + for pp_rank in input_pp_ranks: + up_last_microbatch_id = num_microbatches - 2 + down_last_microbatch_id = num_microbatches - 1 + + up_worker_rref = self.pp_rank_to_worker_rref[pp_rank] + down_worker_rref = self.pp_rank_to_worker_rref[pp_rank + stage_num] + + up_key = UniqueKey(up_last_microbatch_id, Phase.BACKWARD) + down_key = UniqueKey(down_last_microbatch_id, Phase.BACKWARD) + up_worker_rref.rpc_sync().get_output_by_key(up_key) + down_worker_rref.rpc_sync().get_output_by_key(down_key) + + def _collect_forward_result(self, output_pp_ranks: List[int], ret_future: Dict[PyRRef, List[Future]]): + """Logic of collection of forward in Chimera. + Currently, only one input one output model is supported + """ + stage_num = self.stage_num + forward_result = [] + for pp_rank in output_pp_ranks: + worker_forward_result = [None] * self.num_microbatches + for microbatch_id in range(self.num_microbatches): + offset = (microbatch_id % 2) * stage_num + ret = ret_future[pp_rank + offset][microbatch_id].wait() + ret = [ret] if isinstance(ret, torch.Tensor) else ret + worker_forward_result[microbatch_id] = ret + + worker_forward_result = list(zip(*worker_forward_result)) + forward_result.extend(worker_forward_result) + + return forward_result diff --git a/colossalai/pipeline/rpc/utils.py b/colossalai/pipeline/rpc/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..06e6d976d7715cd9d6d57f4e196bd0bd1b3116cd --- /dev/null +++ b/colossalai/pipeline/rpc/utils.py @@ -0,0 +1,155 @@ +import argparse +import os +import warnings +from typing import Any, Callable, Dict, List, Tuple, Type, Union + +import torch +import torch.distributed.rpc as rpc +import torch.multiprocessing as mp +from torch._C._distributed_rpc import _is_current_rpc_agent_set +from torch.futures import Future + +from colossalai.initialize import launch +from colossalai.pipeline.pipeline_process_group import ppg + + +def pyobj_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = ()) -> Any: + if isinstance(obj, process_types): + return fn(obj) + elif type(obj) is dict: + return {k: pyobj_map(obj[k], fn, process_types) for k in obj} + elif type(obj) is tuple: + return tuple(pyobj_map(o, fn, process_types) for o in obj) + elif type(obj) is list: + return list(pyobj_map(o, fn, process_types) for o in obj) + else: + return obj + + +def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any: + """process object recursively, like pytree + + Args: + obj (:class:`Any`): object to process + fn (:class:`Callable`): a function to process subobject in obj + process_types (:class: `type | tuple[type]`): types to determine the type to process + map_all (:class: `bool`): if map_all is True, then any type of element will use fn + + Returns: + :class:`Any`: returns have the same structure of `obj` and type in process_types after map of `fn` + """ + if isinstance(obj, dict): + return {k: pytree_map(obj[k], fn, process_types, map_all) for k in obj} + elif isinstance(obj, tuple): + return tuple(pytree_map(o, fn, process_types, map_all) for o in obj) + elif isinstance(obj, list): + return list(pytree_map(o, fn, process_types, map_all) for o in obj) + elif isinstance(obj, process_types): + return fn(obj) + else: + return fn(obj) if map_all else obj + + +def tensor_shape_list(obj): + return pytree_map(obj, fn=lambda x: x.shape, process_types=torch.Tensor) + + +def get_batch_lengths(batch): + lengths = [] + pytree_map(batch, fn=lambda x: lengths.append(len(x)), process_types=torch.Tensor) + return lengths + + +def split_batch(batch: Any, start, stop, device: str): + if device == 'cuda': + fn = lambda x: x[start:stop].cuda() + else: + fn = lambda x: x[start:stop] + return pytree_map(batch, fn=fn, process_types=torch.Tensor) + + +def type_detail(obj): + return pytree_map(obj, lambda x: type(x), map_all=True) + + +def pytree_filter(fn, obj, process_types): + if obj is None: + return None + + filters = [] + + def condition_append(obj): + if fn(obj): + filters.append(obj) + + pytree_map(obj, fn=condition_append, process_types=process_types) + return filters + + +def get_real_args_kwargs(args_or_kwargs): + args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future) + # TODO : combine producer and consumer + # by default, merge all args in the output args or kwargs + if args_or_kwargs is not None: + if isinstance(args_or_kwargs, dict): + pass + else: + flatten_args = [] + pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True) + args_or_kwargs = flatten_args + + return args_or_kwargs + + +def run_worker(rank, args, master_func): + os.environ['MASTER_ADDR'] = args.master_addr + os.environ['MASTER_PORT'] = args.master_port + + device = args.device + world_size = args.world_size + dp_degree = args.dp_degree + tp_degree = args.tp_degree + num_worker_threads = args.num_worker_threads + host = args.master_addr + port = args.master_port + backend = 'nccl' if device == 'cuda' else 'gloo' + + launch(dict(), rank, world_size, host, int(port), backend, verbose=False) + ppg.set_global_info(rank=rank, + world_size=world_size, + dp_degree=dp_degree, + tp_degree=tp_degree, + num_worker_threads=num_worker_threads, + device=device) + ppg.args = args + # in rpc mode, only rank 0 is needed to be coded + if rank == 0: + master_func(args) + # barrier here + if _is_current_rpc_agent_set(): + rpc.shutdown() + else: + warnings.warn("RPC has not been initialized") + + +def rpc_run(args, master_func): + world_size = args.world_size + mp.spawn(run_worker, args=(args, master_func), nprocs=world_size) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--epoch', type=int, default=1) + parser.add_argument('--world_size', type=int, default=2) + parser.add_argument('--batch_size', type=int, default=16) + parser.add_argument('--dp_degree', type=int, default=1) + parser.add_argument('--tp_degree', type=int, default=1) + parser.add_argument('--num_microbatches', type=int, default=2) + parser.add_argument('--chunk', type=int, default=1) + parser.add_argument('--use_checkpoint', action='store_true') + parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'RMSprop'], default='SGD') + parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda') + parser.add_argument('--master_addr', type=str, default='localhost') + parser.add_argument('--master_port', type=str, default='29020') + parser.add_argument('--num_worker_threads', type=int, default=128) + return parser.parse_args() diff --git a/colossalai/pipeline/utils.py b/colossalai/pipeline/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..df7226644a7a9931ecd82c09a39653fc5e0cfdeb --- /dev/null +++ b/colossalai/pipeline/utils.py @@ -0,0 +1,275 @@ +import heapq +import inspect +import torch + +from colossalai.logging import get_dist_logger +from colossalai.nn.layer.utils import CheckpointModule +from typing import List + +from collections import OrderedDict + +def _binary_partition(weights: List, start: int, end: int): + """Returns the binary partition position of `weights`, given the start + position `st` and the end position `ed`. + + Args: + weights (list): A python list to be binary partitioned + start (int): the start position of the binary partition + end (int): the end position of the binary partition + + Returns: + int: the binary partition position of `weights` + """ + w_sum = weights[end - 1] + prefix = 0 + if start > 0: + w_sum -= weights[start - 1] + prefix = weights[start - 1] + minimum = float("inf") + for idx in range(start + 1, end): + front = weights[idx - 1] - prefix + diff = abs(w_sum - 2 * front) + if diff < minimum: + pos = idx + minimum = diff + + return start, pos, end + + +def _heap_addition(weights: List, intervals: int, add_cnt: int): + """ + """ + + def _heap_push(heap, st, ed): + value = weights[ed - 1] + if st > 0: + value -= weights[st - 1] + heapq.heappush(heap, (-value, st, ed)) + + ret_intervals = [] + heap = [] + + for st, ed in intervals: + _heap_push(heap, st, ed) + + while add_cnt > 0: + _, st, ed = heapq.heappop(heap) + if ed - st == 1: + ret_intervals.append((st, ed)) + else: + l, m, r = _binary_partition(weights, st, ed) + _heap_push(heap, l, m) + _heap_push(heap, m, r) + add_cnt -= 1 + + while heap: + _, st, ed = heapq.heappop(heap) + ret_intervals.append((st, ed)) + + ret_intervals.sort() + return ret_intervals + + +def _calc_partitions(weights, value): + prev = 0 + prefix = 0 + num_block = 0 + intervals = [] + + for idx, w in enumerate(weights): + if weights[idx] - prefix > value: + intervals.append((prev, idx)) + prev = idx + prefix = weights[idx - 1] + num_block += 1 + + intervals.append((prev, len(weights))) + return num_block + 1, intervals + + +def _binary_search(weights, num): + length = len(weights) + prefix = [1 if w == 0 else w for w in weights] + for i in range(1, length): + prefix[i] += prefix[i - 1] + + lower_bound = max(weights) + upper_bound = prefix[length - 1] + + while upper_bound > lower_bound: + mid = (upper_bound + lower_bound) // 2 + number, _ = _calc_partitions(prefix, mid) + if number <= num: + upper_bound = mid + else: + lower_bound = mid + 1 + + num_block, intervals = _calc_partitions(prefix, upper_bound) + if num_block < num: + intervals = _heap_addition(prefix, intervals, num - num_block) + + return intervals + + +def partition_uniform(num_items, pipeline_parallel_size, num_chunks): + assert num_items % num_chunks == 0, \ + "Layer length should be divided by the number of chunks, otherwise parameter method is recomended" + + logger = get_dist_logger() + parts = [[] for _ in range(pipeline_parallel_size)] + partition_items = num_items // num_chunks + for idx in range(num_chunks): + base_idx = idx * partition_items + chunk_size = partition_items // pipeline_parallel_size + left = pipeline_parallel_size - partition_items % pipeline_parallel_size + if chunk_size == 0: + logger.warning("Some nodes in Pipeline have no requests") + + for p in range(pipeline_parallel_size): + st = base_idx + base_idx += chunk_size + (p >= left) + parts[p].append((st, base_idx)) + + return parts + + +def partition_balanced(weights, pipeline_parallel_size, num_chunks): + num_total = pipeline_parallel_size * num_chunks + num_items = len(weights) + if num_items <= num_total: + return partition_uniform(num_items, pipeline_parallel_size, num_chunks) + + intervals = _binary_search(weights, num_total) + + current = 0 + parts = [[] for _ in range(pipeline_parallel_size)] + for inter in intervals: + parts[current].append(inter) + current = (current + 1) % pipeline_parallel_size + + return parts + + +def build_kwargs_for_module(function, input_tensor, kw_dict): + """ + Generally, the first argument of module.forward is an input tensor come from the previous layer. + Therefore, we just filter the kwargs from second element of the dictionary. + """ + sig = inspect.signature(function) + if input_tensor is None: + kwargs_offset = 0 + elif isinstance(input_tensor, torch.Tensor): + kwargs_offset = 1 + elif isinstance(input_tensor, (tuple, OrderedDict)): + #assert isinstance(input_tensor, tuple), f'input_tensor should be a torch.Tensor or a tuple object.' + # Huggingface will take their own structures based on OrderedDict as the output + # between layers so we've to close this check. + kwargs_offset = len(input_tensor) + args_name_list = list(sig.parameters.keys()) + kw_dict = {k: v for k, v in kw_dict.items() if k in args_name_list[kwargs_offset:]} + if len(kw_dict) == 0: + return None + return kw_dict + + +def build_kwargs_for_function(function, kw_dict): + sig = inspect.signature(function) + kw_dict = {k: v for k, v in kw_dict.items() if k in sig.parameters} + if len(kw_dict) == 0: + return None + return kw_dict + + +def exec_func_with_kwargs(func, kw_dict, input_tensor, kwargs): + """ + We suppose the callable object passed to to_layer_list method in two purpose: + a. use the callable object to modify input tensor, such as \ + lambda x: torch.flatten(x, 1) + b. use the callable object to modify kwargs value, such as \ + def foo(attention_mask=None): + if attention_mask is not None: + batch_size = input_ids.shape[0] + attention_mask = attention_mask.view(batch_size, -1) + return attention_mask + """ + + if kw_dict is not None: + rst = func(**kw_dict) + if isinstance(rst, tuple): + for i, k in enumerate(kw_dict.keys()): + kwargs[k] = rst[i] + else: + for k in kw_dict.keys(): + kwargs[k] = rst + return input_tensor + if isinstance(input_tensor, tuple): + assert len(input_tensor) > 0, f'input_tensor should not be empty, when kw_dict is None.' + sig = inspect.signature(func) + func_args_num = len(sig.parameters) + assert func_args_num <= len( + input_tensor), f'func requires {func_args_num} arguments, but input_tensors only have {len(input_tensor)}.' + if func_args_num < len(input_tensor): + return func(*input_tensor[:func_args_num]) + else: + return func(*input_tensor) + assert isinstance(input_tensor, torch.Tensor), 'input_tensor should be a type of torch.Tensor or tuple.' + return func(input_tensor) + + +def exec_funcs_with_kwargs(func_dict, func_key, input_tensor, kwargs): + + assert func_key in func_dict, f"{func_key} is not in the function_dict." + funcs_to_exec = func_dict[func_key] + if isinstance(funcs_to_exec, list): + for f in funcs_to_exec: + f_kwargs = build_kwargs_for_function(f, kwargs) + input_tensor = exec_func_with_kwargs(f, f_kwargs, input_tensor, kwargs) + else: + f_kwargs = build_kwargs_for_function(funcs_to_exec, kwargs) + input_tensor = exec_func_with_kwargs(funcs_to_exec, f_kwargs, input_tensor, kwargs) + + return input_tensor + + +def call_module(module, args=None, kwargs=None): + if args is None: + args = () + if kwargs is None: + kwargs = {} + if isinstance(module, CheckpointModule): + forward_func = module._forward + else: + forward_func = module.forward + sig = inspect.signature(forward_func) + param_nums = len(sig.parameters) + feed_nums = len(args) + len(kwargs) + args_needed_nums = param_nums - len(kwargs) + args_needed = args[:args_needed_nums] + if isinstance(module, CheckpointModule): + convert_kwargs_to_args = [] + for v in kwargs.values(): + convert_kwargs_to_args.append(v) + return module(*args_needed, *convert_kwargs_to_args) + else: + return module(*args_needed, **kwargs) + + +def customized_partition(exec_seq): + ''' + This function will analyze the exec_seq. In the exec_seq, users will use 'SPLIT_NODE' as an + annotation to note the partition point. + ''' + customized_parts = {} + start = 0 + stop = 0 + rank = 0 + for element in exec_seq: + if isinstance(element, str): + if element == 'SPLIT_NODE': + customized_parts[rank] = [(start, stop)] + start = stop + rank += 1 + else: + stop += 1 + customized_parts[rank] = [(start, stop)] + return customized_parts diff --git a/colossalai/registry/__init__.py b/colossalai/registry/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b48d9e7d21d2eed6f9e9dc54e26191aa3632a52c --- /dev/null +++ b/colossalai/registry/__init__.py @@ -0,0 +1,19 @@ +import torch.distributed.optim as dist_optim +import torch.nn as nn +import torch.optim as optim + +from .registry import Registry + +LAYERS = Registry("layers", third_party_library=[nn]) +MODELS = Registry("models") +OPTIMIZERS = Registry("optimizers", third_party_library=[optim, dist_optim]) +DATASETS = Registry("datasets") +DIST_GROUP_INITIALIZER = Registry("dist_group_initializer") +GRADIENT_HANDLER = Registry("gradient_handler") +LOSSES = Registry("losses", third_party_library=[nn]) +HOOKS = Registry("hooks") +TRANSFORMS = Registry("transforms") +DATA_SAMPLERS = Registry("data_samplers") +LR_SCHEDULERS = Registry("lr_schedulers") +SCHEDULE = Registry("schedules") +OPHOOKS = Registry("ophooks") diff --git a/colossalai/registry/registry.py b/colossalai/registry/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..8a4173f7ab992079d180322245d25cbb9010b07c --- /dev/null +++ b/colossalai/registry/registry.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from types import ModuleType +from typing import List + + +class Registry: + """This is a registry class used to register classes and modules so that a universal + object builder can be enabled. + + Args: + name (str): The name of the registry . + third_party_library (list, optional): + List of third party libraries which are used in the initialization of the register module. + """ + + def __init__(self, name: str, third_party_library: List[ModuleType] = None): + self._name = name + self._registry = dict() + self._third_party_lib = third_party_library + + @property + def name(self): + return self._name + + def register_module(self, module_class): + """Registers a module represented in `module_class`. + + Args: + module_class (class): The module to be registered. + Returns: + class: The module to be registered, so as to use it normally if via importing. + Raises: + AssertionError: Raises an AssertionError if the module has already been registered before. + """ + module_name = module_class.__name__ + assert module_name not in self._registry, f"{module_name} not found in {self.name}" + self._registry[module_name] = module_class + + # return so as to use it normally if via importing + return module_class + + def get_module(self, module_name: str): + """Retrieves a module with name `module_name` and returns the module if it has + already been registered before. + + Args: + module_name (str): The name of the module to be retrieved. + Returns: + :class:`object`: The retrieved module or None. + Raises: + NameError: Raises a NameError if the module to be retrieved has neither been + registered directly nor as third party modules before. + """ + if module_name in self._registry: + return self._registry[module_name] + elif self._third_party_lib is not None: + for lib in self._third_party_lib: + if hasattr(lib, module_name): + return getattr(lib, module_name) + raise NameError(f'Module {module_name} not found in the registry {self.name}') + + def has(self, module_name: str): + """Searches for a module with name `module_name` and returns a boolean value indicating + whether the module has been registered directly or as third party modules before. + + Args: + module_name (str): The name of the module to be searched for. + Returns: + bool: A boolean value indicating whether the module has been registered directly or + as third party modules before. + """ + found_flag = module_name in self._registry + + if self._third_party_lib: + for lib in self._third_party_lib: + if hasattr(lib, module_name): + found_flag = True + break + + return found_flag diff --git a/colossalai/tensor/__init__.py b/colossalai/tensor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b2da64e6c33a0f410dd5b7a9e05ff5775cc0a6eb --- /dev/null +++ b/colossalai/tensor/__init__.py @@ -0,0 +1,18 @@ +from . import distspec +from .colo_parameter import ColoParameter +from .colo_tensor import ColoTensor +from .comm_spec import CollectiveCommPattern, CommSpec +from .compute_spec import ComputePattern, ComputeSpec +from .dist_spec_mgr import DistSpecManager +from .distspec import ReplicaSpec, ShardSpec +from .param_op_hook import ColoParamOpHook, ColoParamOpHookManager +from .process_group import ProcessGroup +from .tensor_spec import ColoTensorSpec +from .utils import convert_dim_partition_dict, convert_parameter, merge_same_dim_mesh_list, named_params_with_colotensor + +__all__ = [ + 'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter', + 'distspec', 'DistSpecManager', 'ColoParamOpHook', 'ColoParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec', + 'ShardSpec', 'ReplicaSpec', 'CommSpec', 'CollectiveCommPattern', 'convert_dim_partition_dict', + 'merge_same_dim_mesh_list' +] diff --git a/colossalai/tensor/colo_parameter.py b/colossalai/tensor/colo_parameter.py new file mode 100644 index 0000000000000000000000000000000000000000..b384579feb35ef643fb3206a165ad4e3e0c02a0a --- /dev/null +++ b/colossalai/tensor/colo_parameter.py @@ -0,0 +1,119 @@ +from typing import Optional + +import torch + +from colossalai.tensor.colo_tensor import ColoTensor +from colossalai.tensor.const import TensorType +from colossalai.tensor.param_op_hook import ColoParamOpHookManager +from colossalai.tensor.tensor_spec import ColoTensorSpec + + +def filter_colo_parameters(*args, **kwargs): + param_list = [] + + def get_colo_parameters(element) -> None: + if isinstance(element, list) or isinstance(element, tuple): + for e in element: + get_colo_parameters(e) + elif isinstance(element, dict): + raise RuntimeError("Found Dict: ColoParameter can't deal with complicated arguments.") + elif isinstance(element, ColoParameter): + param_list.append(element) + return + + for a in args: + get_colo_parameters(a) + for v in kwargs.values(): + get_colo_parameters(v) + + return param_list + + +def replace_args(args, kwargs, new_args): + args = new_args[:len(args)] + for k, v in zip(kwargs.keys(), new_args[len(args):]): + kwargs[k] = v + return tuple(args), kwargs + + +class ColoParameter(ColoTensor, torch.nn.Parameter): + r"""A kind of ColoTensor to be considered as a module parameter. + + """ + + def __new__(cls, + data: Optional[torch.Tensor] = None, + requires_grad: bool = True, + spec: ColoTensorSpec = None) -> 'ColoParameter': + if data is None: + data = torch.empty(0) + return torch.Tensor._make_subclass(cls, data, requires_grad) + + def __init__(self, + data: Optional[torch.Tensor] = None, + requires_grad: bool = True, + spec: ColoTensorSpec = None) -> None: + ColoTensor.__init__(self, data, spec) + self._type = TensorType.MODEL + # a list contains modules sharing this ColoParameter with others. + self._shared_param_modules = [] + + @property + def shared_param_modules(self): + return self._shared_param_modules + + @staticmethod + def from_torch_tensor(tensor: torch.Tensor, + requires_grad: bool = True, + spec: ColoTensorSpec = None) -> 'ColoParameter': + tensor = tensor.as_subclass(ColoParameter) + tensor.__init__(tensor, requires_grad=requires_grad, spec=spec) + return tensor + + def __repr__(self): + return super(ColoParameter, self).__repr__() + + @classmethod + def __torch_function__(cls, func, types, args=..., kwargs=None): + if ColoParamOpHookManager.has_hook(): + if not func.__name__.startswith('__'): + if kwargs is None: + kwargs = {} + params = filter_colo_parameters(*args, **kwargs) + if len(params) > 0: + with torch._C.DisableTorchFunction(): + new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values()) + args, kwargs = replace_args(args, kwargs, new_args) + ret = super().__torch_function__(func, types, args, kwargs) + with torch._C.DisableTorchFunction(): + ret = ColoParamOpHookManager.post_op(params, ret) + return ret + return super().__torch_function__(func, types, args, kwargs) + + def __deepcopy__(self, memo): + if id(self) in memo: + return memo[id(self)] + else: + with torch._C.DisableTorchFunction(): + data = self.data.clone() + tensor = ColoParameter(data, + self.requires_grad, + spec=ColoTensorSpec(self.get_process_group(), self.dist_spec, self.compute_spec)) + memo[id(self)] = tensor + return tensor + + def __reduce_ex__(self, proto): + # Adapted from torch._utils._rebuild_parameter + # def _rebuild_colo_parameter(data, requires_grad, backward_hooks): + # colo_param = ColoParameter(data, requires_grad) + # colo_param._backward_hooks = backward_hooks + # return colo_param + + # return ( + # _rebuild_colo_parameter, + # (self.data, self.requires_grad, OrderedDict()) + # ) + + # TODO(jzy) we don't support object reflection now. + # distspec cannot be pickled or rebuilt because it's tightly connected to runtime attribute `process_group`. + raise NotImplementedError diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..4d762076461d085cd51aefd1b42d9c77491fa032 --- /dev/null +++ b/colossalai/tensor/colo_tensor.py @@ -0,0 +1,343 @@ +import operator +from copy import copy +from functools import lru_cache, reduce +from typing import Callable, Optional, Set + +import torch + +from colossalai.tensor.dist_spec_mgr import DistSpecManager +from colossalai.tensor.distspec import DistPlacementPattern, ReplicaSpec, _DistSpec +from colossalai.tensor.process_group import ProcessGroup +from colossalai.tensor.tensor_spec import ColoTensorSpec + +from .const import TensorType +from .op_wrapper import _COLOSSAL_OPS + + +@lru_cache(None) +def _get_my_nowrap_functions() -> Set[Callable]: + Tensor = torch.Tensor + return { + Tensor._base.__get__, + Tensor.grad.__get__, + Tensor._grad.__get__, + Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor + } + + +def _convert_output(output, colo_spec: ColoTensorSpec): + if type(output) == torch.Tensor: + return ColoTensor.from_torch_tensor(output, colo_spec) + elif isinstance(output, (list, tuple)): + return type(output)(_convert_output(o, colo_spec) for o in output) + else: + return output + + +def _get_spec_from_args(args, kwargs) -> ColoTensorSpec: + for elem in args: + if isinstance(elem, ColoTensor): + pg = elem.get_process_group() + dp = elem.dist_spec + return ColoTensorSpec(pg, dp) + elif isinstance(elem, (list, tuple)): + spec = _get_spec_from_args(elem, {}) + if spec is not None: + return spec + for k, v in kwargs.items(): + if isinstance(v, ColoTensor): + pg = v.get_process_group() + dp = v.dist_spec + return ColoTensorSpec(pg, dp) + return None + + +class ColoTensor(torch.Tensor): + """ Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor. + + The Colotensor can be initialized with a PyTorch tensor in the following ways. + + >>> pg = ProcessGroup() + >>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec())) + >>> # The tensor passed in is a tensor after sharding but not a global tensor. + >>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size), + >>> dims=[0], + >>> num_partitions=[world_size]) + >>> tensor_spec = ColoTensorSpec(pg, shard_spec) + >>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec) + + Args: + data (torch.Tensor): a torch tensor used as the payload the colotensor. + spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()). + """ + torch_major = int(torch.__version__.split('.')[0]) + torch_minor = int(torch.__version__.split('.')[1]) + + def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor': + """ + The signature of the __new__ has to be consistent with the torch.Tensor. + + Args: + data (torch.Tensor): a torch tensor used as the payload the colotensor. + spec (TensorSpec, optional): the tensor spec of initialization. + + Returns: + ColoTensor: a ColoTensor wrappers the data. + """ + if data is None: + data = torch.empty(0) + return torch.Tensor._make_subclass(cls, data, data.requires_grad) + + def __init__(self, data: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> None: + # If not set spec, use a DP process group and replicate dist spec + if spec is None: + self.has_initialized = False + self.dist_spec = ReplicaSpec() + self.compute_spec = None + self.process_group = ProcessGroup() + else: + self.has_initialized = True + self.dist_spec = spec.dist_attr + self.compute_spec = spec.compute_attr + if spec.pg is None: + self.process_group = ProcessGroup() + else: + self.process_group = spec.pg + + self._type = TensorType.NONMODEL + + def has_compute_spec(self) -> bool: + return self.compute_spec is not None + + def is_model_data(self) -> bool: + return self._type == TensorType.MODEL + + def get_process_group(self) -> 'ProcessGroup': + return self.process_group + + def set_process_group(self, pg: ProcessGroup): + """set_process_group + change the pg of the ColoTensor. Note that the valid use cases is limited. + It works for the target pg is DP and TP only and current dist spec of the Tensor is Replica. + + Args: + pg (ProcessGroup): target pg + + """ + assert isinstance(pg, ProcessGroup), f"pg as type {type(pg)} is invalid" + # if the new pg is the same as the old pg, just returns + if self.process_group == pg: + return + assert self.process_group.tp_world_size() == 1 or self.process_group.dp_world_size() == 1, \ + "Can not set_process_group on a ColoTensor whose process_group is both tp > 1 and world group > 1" + assert self.dist_spec.placement.value == 'r', \ + "Can not set_process_group on a ColoTensor whose dist spec is not Replica" + + self.process_group = pg + + def get_tp_world_size(self) -> int: + return self.process_group.tp_world_size() + + def get_dp_world_size(self) -> int: + """get_dp_world_size + get the dp world size of the tensor. + + Returns: + int: dp world size + """ + return self.process_group.dp_world_size() + + def set_dist_spec(self, dist_spec: _DistSpec): + """set_dist_spec + set dist spec and change the payloads. + + Args: + dist_spec (_DistSpec): target dist spec. + """ + assert isinstance(dist_spec, _DistSpec) + assert self.process_group is not None + self._redistribute(dist_spec) + + def set_tensor_spec(self, dist_spec, compute_spec): + if dist_spec is not None: + assert isinstance(dist_spec, _DistSpec), f"{type(dist_spec)}" + self.set_dist_spec(dist_spec) + if compute_spec is not None: + self.compute_spec = compute_spec + + def has_compute_pattern(self, compute_pattern): + return self.compute_spec.compute_pattern == compute_pattern + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + if not all(issubclass(cls, t) for t in types): + return NotImplemented + global _COLOSSAL_OPS + if func in _COLOSSAL_OPS: + func = _COLOSSAL_OPS[func] + + if cls.torch_major > 1 or (cls.torch_major == 1 and cls.torch_minor >= 12): + # in order to trigger pre-op hook in the forward of checkpoint module + # we have to capture the `backward` function + # and make sure that it does not in `torch._C.DisableTorchFunction()` context + if func is torch.Tensor.backward: + assert len(args) == 1 # only has 1 parameter + backward_tensor = torch.Tensor(args[0]) + tensor_kwargs = {k: torch.Tensor(v) if torch.is_tensor(v) else v for k, v in kwargs.items()} + return backward_tensor.backward(**tensor_kwargs) + + with torch._C.DisableTorchFunction(): + ret = func(*args, **kwargs) + if func in _get_my_nowrap_functions(): + return ret + else: + colo_spec = _get_spec_from_args(args, kwargs) + return _convert_output(ret, colo_spec) + + def __repr__(self): + output_list = [super(ColoTensor, self).__repr__()] + output_list.append(str(self.process_group)) + output_list.append(str(self.dist_spec)) + if self.compute_spec is not None: + output_list.append(str(self.compute_spec)) + return "\n".join(output_list) + + def _redistribute(self, dist_spec: _DistSpec) -> None: + """_redistribute + Note the function will not handle the logic of backward propagation! + It is used during model tensor initializations as an internal function. + + Args: + dist_spec (_DistSpec): the target dist. spec. + """ + assert self.grad_fn is None, "Current tensor has grad_fn and it can't get converted" + with DistSpecManager.no_grad(): + self.data = DistSpecManager.handle_trans_spec(self.data, self.dist_spec, dist_spec, self.process_group) + self.dist_spec = dist_spec + + def redistribute(self, dist_spec: _DistSpec, pg: Optional[ProcessGroup] = None) -> 'ColoTensor': + """redistribute + Redistribute the tensor among processes. The rule is like this: + + 1. If the pg is None, then redistribute the tensor payload among the TP process group. Keep the + DP process group not changed. + + 2. If the pg is not not None and not equal to the current process group. + First, convert the tensor as replicated among the TP process group. + Second, reset the process group to the new pg. + Third, convert the tensor (new replicated both among the tp process group) to the new dist_spec. + + Args: + dist_spec (_DistSpec): the new dist spec. + pg (Optional[ProcessGroup], optional): the new process group . Defaults to None. + + Returns: + ColoTensor: a redistributed colotensor + """ + if pg is not None and pg != self.get_process_group(): + # if the pg is not equal, convert the current tensor to replicated + handled = self.redistribute(ReplicaSpec()) + else: + handled = self + pg = self.process_group + + ret = DistSpecManager.handle_trans_spec(handled, handled.dist_spec, dist_spec, pg) + return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(pg=pg, dist_attr=dist_spec)) + + def to_replicate_(self): + """to_replicate_ + + an inline member function, converting dist spec of the tensor to REPLICATE + """ + self._redistribute(dist_spec=ReplicaSpec()) + + def to_replicate(self) -> 'ColoTensor': + """to_replicate + + converting dist spec of the tensor to ReplicaSpec() + """ + return self.redistribute(ReplicaSpec()) + + @staticmethod + def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor': + """from_torch_tensor + + A static method builds a `ColoTensor` from a PyTorch Tensor. + + Args: + tensor (torch.Tensor): the pytorch tensor, which is a local tensor for this rank not a global tensor. + spec (Optional[ColoTensorSpec], optional): tensor spec. Defaults to None. + + Returns: + ColoTensor: a ColoTensor + """ + tensor = tensor.as_subclass(ColoTensor) + tensor.__init__(tensor, spec=spec) + return tensor + + def __deepcopy__(self, memo): + if id(self) in memo: + return memo[id(self)] + else: + with torch._C.DisableTorchFunction(): + data = self.data.clone() + tensor = ColoTensor(data, spec=copy(ColoTensorSpec(self.process_group, self.dist_spec, self.compute_spec))) + memo[id(self)] = tensor + return tensor + + # override builtin functions which must use tensor in replicate placement # + + def size_local(self, *args) -> torch.Size: + with torch._C.DisableTorchFunction(): + return super().size(*args) + + def size_global(self, *args) -> torch.Size: + """size_global + + override the torch building size() + the shape passed in must be in a replicate placement. + + Returns: + torch.Size: the global tensor shape + """ + if self.is_replicate(): + return self.size_local(*args) + spec = self.dist_spec + dims = spec.dims + num_partitions = spec.num_partitions + # import inspect + # print(*['{:40}| {}:{}\n'.format(x.function, x.filename, x.lineno) for x in inspect.stack()]) + size_list = list(self.size_local()) + for dim, num_partition in zip(dims, num_partitions): + size_list[dim] *= num_partition + if args == (): + return torch.Size(size_list) + else: + return size_list[args[0]] + + def numel_global(self): + """Returns the number of elements in the tensor when it's replicated. + """ + return reduce(operator.mul, self.size_global(), 1) + + # Some API for dist spec check + + def is_replicate(self): + return self.dist_spec.placement == DistPlacementPattern.REPLICATE \ + or (len(self.dist_spec.num_partitions) == 1 + and self.dist_spec.num_partitions[0] == 1) \ + or (self.process_group.tp_world_size() == 1) + + def is_shard_1dcol(self): + return self.dist_spec.placement == DistPlacementPattern.SHARD \ + and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1 + + def is_shard_1drow(self): + return self.dist_spec.placement == DistPlacementPattern.SHARD \ + and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0 + + def is_sharded(self): + return self.dist_spec.placement == DistPlacementPattern.SHARD diff --git a/colossalai/tensor/comm_spec.py b/colossalai/tensor/comm_spec.py new file mode 100644 index 0000000000000000000000000000000000000000..af38d2a502c25ef76a6fb99650176781e09bda83 --- /dev/null +++ b/colossalai/tensor/comm_spec.py @@ -0,0 +1,527 @@ +import operator +from enum import Enum +from functools import reduce + +import torch +import torch.distributed as dist +from torch.distributed import ReduceOp + +__all__ = [ + 'CollectiveCommPattern', + 'CommSpec', +] + + +def _all_gather(tensor, comm_spec): + ''' + Implement all gather operation on device mesh based on information provided by comm_spec. + ''' + process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] + for rank_list, process_group in process_groups_list: + if dist.get_rank() in rank_list: + tensor_list = [ + torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) + for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis]) + ] + # without this contiguous operation, the all gather may get some unexpected results. + tensor = tensor.contiguous() + dist.all_gather(tensor_list, tensor, group=process_group) + output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous() + return output + + +def _split(tensor, comm_spec): + ''' + Implement shard operation on device mesh based on information provided by comm_spec. + ''' + process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] + for rank_list, _ in process_groups_list: + if dist.get_rank() in rank_list: + dim = comm_spec.shard_dim + length = tensor.shape[comm_spec.shard_dim] // len(rank_list) + start = length * rank_list.index(dist.get_rank()) + output = torch.narrow(tensor, dim, start, length).contiguous() + return output + + +def _all_to_all(tensor, comm_spec): + ''' + Implement all to all operation on device mesh based on information provided by comm_spec. + ''' + process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] + for rank_list, process_group in process_groups_list: + if dist.get_rank() in rank_list: + new_shape = list(tensor.shape) + new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list) + new_shape = torch.Size(new_shape) + output_tensor_list = [ + torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list)) + ] + dim = comm_spec.shard_dim + length = tensor.shape[comm_spec.shard_dim] // len(rank_list) + input_tensor_list = [ + torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list)) + ] + group = process_group + dist.all_to_all(output_tensor_list, input_tensor_list, group) + output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous() + return output + + +def _all_reduce(tensor, comm_spec, async_op=False): + ''' + Implement all reduce operation on device mesh based on information provided by comm_spec. + ''' + process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] + for rank_list, process_group in process_groups_list: + if dist.get_rank() in rank_list: + if not tensor.is_contiguous(): + tensor = tensor.contiguous() + dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op) + return tensor + + +def _mix_gather(tensor, comm_spec): + ''' + Implement mix gather operation on device mesh based on information provided by comm_spec. + Mix gather is the all-gather operation on all devices in the device_mesh(FlattenDeviceMesh) of the comm_spec. It is + different from _all_gather because _mix_gather does all-gather in two dimensions of device mesh, while _all_gather + only does all-gather in one dimension. + Assume index of f and b target pairs are 'f' and 'b' + ShardingSpec => gather_dim, logical_process_axes + S0S1 => [b, f], (1, 0) + S1S0 => [b, f], (0, 1) + S01R => [f], (1, 1) + RS01 => [b], (1, 1) + Example: + mesh_shape = (2,4) + # [[0, 1, 2, 3], + # [4, 5, 6, 7]] + # return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]} + S0S1: + leading_group_dim = 1 + process_group = "[0, 1, 2, 3, 4, 5, 6, 7]" + tensor_list = [(0,0),(0,1),(0,2),(0,3),(1,0),(1,1),(1,2),(1,3)] # [(slice_id_f, slice_id_b),...] + mesh_shape = (2,4) + cat_slice = [4,2] + tmp_tensor_list = [(...,shape[f],shape[b]*4,...),(...,shape[f],shape[b]*4,...)] + tmp_tensor_list[0] = torch.cat(((0,0),(0,1),(0,2),(0,3)), dim=b) + tmp_tensor_list[1] = torch.cat(((1,0),(1,1),(1,2),(1,3)), dim=b) + output = torch.cat((tmp_tensor_list[0],tmp_tensor_list[1]), dim=a) + S1S0: + leading_group_dim = 0 + process_group = "[0, 4, 1, 5, 2, 6, 3, 7]" + tensor_list = [(0,0),(0,1),(1,0),(1,1),(2,0),(2,1),(3,0),(3,1)] + mesh_shape = (2,4) + cat_slice = [2,4] + tmp_tensor_list = [(...,shape[f],shape[b]*2,...),(...,shape[f],shape[b]*2,...),(...,shape[f],shape[b]*2,...),(...,shape[f],shape[b]*2,...)] + tmp_tensor_list[0] = torch.cat(((0,0),(0,1)), dim=b) + tmp_tensor_list[1] = torch.cat(((1,0),(1,1)), dim=b) + tmp_tensor_list[2] = torch.cat(((2,0),(2,1)), dim=b) + tmp_tensor_list[3] = torch.cat(((3,0),(3,1)), dim=b) + S10R: + leading_group_dim = 0 + process_group = "[0, 4, 1, 5, 2, 6, 3, 7]" + tensor_list = [(0,0),(1,0),(2,0),(3,0),(4,0),(5,0),(6,0),(7,0)] + S01R: + leading_group_dim = 1 + process_group = "[0, 1, 2, 3, 4, 5, 6, 7]" + tensor_list = [(0,0),(1,0),(2,0),(3,0),(4,0),(5,0),(6,0),(7,0)] + ''' + total_slices = comm_spec.device_mesh.mesh_shape[0] + tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(total_slices)] + leading_group_dim = comm_spec.logical_process_axes[0] + assert len(comm_spec.device_mesh.process_groups_dict) == 1 + _, process_group = comm_spec.device_mesh.process_groups_dict[0][0] + process_number_list = comm_spec.device_meshes.process_number_dict[leading_group_dim] + + # Global all_gather + dist.all_gather(tensor_list, tensor, group=process_group) + + # This is very ugly. I'm figuring out more elegant methods + tensor_list_sorted = [ + torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(total_slices) + ] + for i in range(total_slices): + tensor_list_sorted[i] = tensor_list[process_number_list[i]] + tensor_list = tensor_list_sorted + + if comm_spec.logical_process_axes[0] == comm_spec.logical_process_axes[1]: + output = torch.cat(tuple(tensor_list), comm_spec.gather_dim[0]).contiguous() + else: + mesh_shape = comm_spec.device_meshes.mesh_shape + cat_slice = [mesh_shape[comm_spec.logical_process_axes[0]], mesh_shape[comm_spec.logical_process_axes[1]]] + tmp_tensor_shape = list(tensor.shape) + tmp_tensor_shape[comm_spec.gather_dim[0]] *= cat_slice[0] + tmp_tensor_shape = torch.Size(tmp_tensor_shape) + tmp_tensor_list = [ + torch.zeros(tmp_tensor_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(cat_slice[1]) + ] + for i in range(cat_slice[1]): + tmp_tensor_list[i] = torch.cat(tuple(tensor_list[i * cat_slice[0]:(i + 1) * cat_slice[0]]), + comm_spec.gather_dim[0]).contiguous() + output = torch.cat(tuple(tmp_tensor_list), comm_spec.gather_dim[1]).contiguous() + + return output + + +def _mix_split(tensor, comm_spec): + ''' + Implement mix split operation. Mix split is only called for the backward of mix gather (Use ctx to keep consistent) + Mix split shards the tensor on device mesh based on information provided by comm_spec. It is different from split + because _mix_split shards the tensor in two dimensions of device mesh, while _split only shards in one dimension. + Assume index of f and b target pairs are 'f' and 'b' + S0S1 => [b, f], (1, 0) + S1S0 => [b, f], (0, 1) + S01R => [f], (0, 0) + RS01 => [b], (0, 0) + Example: + mesh_shape = (2,4) + # [[0, 1, 2, 3], + # [4, 5, 6, 7]] + # return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]} + ''' + mesh_shape = comm_spec.device_meshes.mesh_shape + dim = comm_spec.gather_dim + total_slices = comm_spec.device_mesh.mesh_shape[0] + + # Get global rank + rank = dist.get_rank() + + leading_group_dim = comm_spec.logical_process_axes[0] + process_number_list = comm_spec.device_meshes.process_number_dict[leading_group_dim] + rank = process_number_list.index(rank) + + if comm_spec.logical_process_axes[0] == comm_spec.logical_process_axes[1]: + length = tensor.shape[dim[0]] // total_slices + start = length * rank + output = torch.narrow(tensor, dim[0], start, length).contiguous() + else: + tensor_shape = [tensor.shape[dim[0]], tensor.shape[dim[1]]] + rank_slice = [mesh_shape[comm_spec.logical_process_axes[0]], mesh_shape[comm_spec.logical_process_axes[1]]] + length = [tensor_shape[0] // rank_slice[0], tensor_shape[1] // rank_slice[1]] + start = [(rank % rank_slice[0]) * length[0], (rank // rank_slice[0]) * length[1]] + tmp_output = torch.narrow(tensor, dim[0], start[0], length[0]).contiguous() + output = torch.narrow(tmp_output, dim[1], start[1], length[1]).contiguous() + + return output + + +class _ReduceGrad(torch.autograd.Function): + """ + A customized communication operation which forward is an identity operation, + backward is all_reduce operation. + + Args: + input_: input matrix. + comm_spec: comm_spec will give information like process group, rank list, etc. + """ + + @staticmethod + def symbolic(graph, input_): + return input_ + + @staticmethod + def forward(ctx, input_, comm_spec): + ctx.comm_spec = comm_spec + return input_ + + @staticmethod + def backward(ctx, grad_output): + return _all_reduce(grad_output, ctx.comm_spec), None + + +class _ReduceInput(torch.autograd.Function): + """ + A customized communication operation which forward is all_reduce operation, + backward is an identity operation. + + Args: + input_: input matrix. + comm_spec: comm_spec will give information like process group, rank list, etc. + """ + + @staticmethod + def symbolic(graph, input_): + return _all_reduce(input_) + + @staticmethod + def forward(ctx, input_, comm_spec): + return _all_reduce(input_, comm_spec) + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None + + +class _SplitForwardGatherBackward(torch.autograd.Function): + """ + A customized communication operation which forward is split operation, + backward is an all gather operation. + + Args: + input_: input matrix. + comm_spec: comm_spec will give information like process group, rank list, etc. + """ + + @staticmethod + def symbolic(graph, input_): + return _split(input_) + + @staticmethod + def forward(ctx, input_, comm_spec): + ctx.comm_spec = comm_spec + return _split(input_, comm_spec) + + @staticmethod + def backward(ctx, grad_output): + return _all_gather(grad_output, ctx.comm_spec), None + + +class _GatherForwardSplitBackward(torch.autograd.Function): + """ + A customized communication operation which forward is an all gather operation, + backward is split operation. + + Args: + input_: input matrix. + comm_spec: comm_spec will give information like process group, rank list, etc. + """ + + @staticmethod + def symbolic(graph, input_): + return _all_gather(input_) + + @staticmethod + def forward(ctx, input_, comm_spec): + ctx.comm_spec = comm_spec + return _all_gather(input_, comm_spec) + + @staticmethod + def backward(ctx, grad_output): + return _split(grad_output, ctx.comm_spec), None + + +class _AllToAll(torch.autograd.Function): + """ + A customized communication operation which forward is an all to all operation, + backward is an all to all operation. + + Args: + input_: input matrix. + comm_spec: comm_spec will give information like process group, rank list, etc. + """ + + @staticmethod + def symbolic(graph, input_): + return _all_to_all(input_) + + @staticmethod + def forward(ctx, input_, comm_spec): + output = _all_to_all(input_, comm_spec) + comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern, + sharding_spec=comm_spec.sharding_spec, + gather_dim=comm_spec.shard_dim, + shard_dim=comm_spec.gather_dim, + logical_process_axis=comm_spec.logical_process_axis) + ctx.comm_spec = comm_spec_for_backward + return output + + @staticmethod + def backward(ctx, grad_outputs): + return _all_to_all(grad_outputs, ctx.comm_spec), None + + +class _MixGatherForwardMixSplitBackward(torch.autograd.Function): + + @staticmethod + def symbolic(graph, input_): + return _mix_gather(input_) + + @staticmethod + def forward(ctx, input_, comm_spec): + ctx.comm_spec = comm_spec + return _mix_gather(input_, comm_spec) + + @staticmethod + def backward(ctx, grad_output): + return _mix_split(grad_output, ctx.comm_spec), None + + +def reduce_grad(input_, comm_spec): + return _ReduceGrad.apply(input_, comm_spec) + + +def reduce_input(input_, comm_spec): + return _ReduceInput.apply(input_, comm_spec) + + +def split_forward_gather_backward(input_, comm_spec): + return _SplitForwardGatherBackward.apply(input_, comm_spec) + + +def gather_forward_split_backward(input_, comm_spec): + return _GatherForwardSplitBackward.apply(input_, comm_spec) + + +def all_to_all(input_, comm_spec): + return _AllToAll.apply(input_, comm_spec) + + +def mixgather_forward_split_backward(input_, comm_spec): + return _MixGatherForwardMixSplitBackward.apply(input_, comm_spec) + + +class CollectiveCommPattern(Enum): + GATHER_FWD_SPLIT_BWD = 'gather_fwd_split_bwd' + ALL2ALL_FWD_ALL2ALL_BWD = 'all2all_fwd_all2all_bwd' + SPLIT_FWD_GATHER_BWD = 'split_fwd_gather_bwd' + ALLREDUCE_FWD_IDENTITY_BWD = 'all_reduce_fwd_identity_bwd' + IDENTITY_FWD_ALLREDUCE_BWD = 'identity_fwd_all_reduce_bwd' + MIXGATHER_FWD_SPLIT_BWD = "mixgather_fwd_split_bwd" + + +class CommSpec: + ''' + Communication spec is used to record the communication action. It has two main functions: + 1. Compute the communication cost which will be used in auto parallel solver. + 2. Convert the communication spec to real action which will be used in runtime. + It contains comm_pattern to determine the + communication method, sharding_spec to determine the communication size, gather_dim and shard_dim + to determine the buffer shape, and logical_process_axis + + Argument: + comm_pattern(CollectiveCommPattern): describe the communication method used in this spec. + sharding_spec(ShardingSpec): This is sharding spec of the tensor which will join the communication action. + gather_dim(int, Optional): The gather_dim of the tensor will be gathered. + shard_dim(int, Optional): The shard_dim of the tensor will be sharded. + logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action. + ''' + + def __init__(self, + comm_pattern, + sharding_spec, + gather_dim=None, + shard_dim=None, + logical_process_axis=None, + forward_only=False, + mix_gather=False): + self.comm_pattern = comm_pattern + self.sharding_spec = sharding_spec + self.gather_dim = gather_dim + self.shard_dim = shard_dim + self.logical_process_axis = logical_process_axis + self.forward_only = forward_only + if isinstance(self.logical_process_axis, list): + if not mix_gather: + self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh + self.logical_process_axis = 0 + else: + self.device_meshes = self.sharding_spec.device_mesh.flatten_device_meshes + self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh + # Create a new member `logical_process_axes` to distinguish from original flatten + self.logical_process_axes = logical_process_axis + else: + self.device_mesh = self.sharding_spec.device_mesh + + def __repr__(self): + res_list = ["CommSpec:("] + if self.comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: + res_list.append(f"comm_pattern:GATHER_FWD_SPLIT_BWD, ") + res_list.append(f"gather_dim:{self.gather_dim}, ") + res_list.append(f"shard_dim:{self.shard_dim}, ") + res_list.append(f"logical_process_axis:{self.logical_process_axis})") + elif self.comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: + res_list.append(f"comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, ") + res_list.append(f"gather_dim:{self.gather_dim}, ") + res_list.append(f"shard_dim:{self.shard_dim}, ") + res_list.append(f"logical_process_axis: {self.logical_process_axis})") + elif self.comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: + res_list.append(f"comm_pattern:SPLIT_FWD_GATHER_BWD, ") + res_list.append(f"gather_dim:{self.gather_dim}, ") + res_list.append(f"shard_dim:{self.shard_dim}, ") + res_list.append(f"logical_process_axis:{self.logical_process_axis})") + elif self.comm_pattern == CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: + res_list.append(f"comm_pattern:ALLREDUCE_FWD_IDENTITY_BWD, ") + res_list.append(f"logical_process_axis:{self.logical_process_axis})") + elif self.comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: + res_list.append(f"comm_pattern:IDENTITY_FWD_ALLREDUCE_BWD, ") + res_list.append(f"logical_process_axis:{self.logical_process_axis})") + elif self.comm_pattern == CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD: + res_list.append(f"comm_pattern:MIXGATHER_FWD_SPLIT_BWD, ") + res_list.append(f"gather_dim:{self.gather_dim}, ") + res_list.append(f"logical_process_asex:{self.logical_process_axes})") + + return ''.join(res_list) + + def get_comm_cost(self): + ''' + For all_gather, all2all, and all_reduce operation, the formula provided in DeviceMesh with alpha-beta model is used to + compute the communication cost. + For shard operation, it is an on-chip operation, so the communication cost is zero. + ''' + comm_size = reduce(operator.mul, self.sharding_spec.get_sharded_shape_per_device(), 1) + cost_dict = {} + if self.comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: + forward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis) + # give a tiny cost to shard + backward_communication_cost = 100 + + if self.comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: + forward_communication_cost = self.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis) + # grad should have same shape as input tensor + # all to all operation has same logical process axis as forward. + backward_communication_cost = self.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis) + + if self.comm_pattern == CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: + forward_communication_cost = self.device_mesh.all_reduce_cost(comm_size, self.logical_process_axis) + backward_communication_cost = 0 + + if self.comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: + forward_communication_cost = 0 + backward_communication_cost = self.device_mesh.all_reduce_cost(comm_size, self.logical_process_axis) + + if self.comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: + # give a tiny cost to shard + forward_communication_cost = 100 + backward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis) + + if self.comm_pattern == CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD: + # no need for axis because all devices are used in mix_gather + forward_communication_cost = self.device_mesh.mix_gather_cost(comm_size) + backward_communication_cost = 100 + + if self.forward_only: + cost_dict["forward"] = forward_communication_cost + cost_dict["backward"] = 0 + cost_dict["total"] = cost_dict["forward"] + cost_dict["backward"] + else: + cost_dict["forward"] = forward_communication_cost + cost_dict["backward"] = backward_communication_cost + cost_dict["total"] = cost_dict["forward"] + cost_dict["backward"] + + return cost_dict + + def covert_spec_to_action(self, tensor): + ''' + Convert CommSpec into runtime action, implement real collection communication to target tensor. + The collection communication action is directed by the CommSpec. + + Argument: + tensor(torch.Tensor): Tensor stored in each device, which could be different in different ranks. + ''' + if self.comm_pattern in pattern_to_func_dict: + tensor = pattern_to_func_dict[self.comm_pattern](tensor, self) + else: + tensor = tensor + return tensor + + +pattern_to_func_dict = { + CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: gather_forward_split_backward, + CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: all_to_all, + CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: split_forward_gather_backward, + CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: reduce_input, + CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: reduce_grad, + CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD: mixgather_forward_split_backward, +} diff --git a/colossalai/tensor/compute_spec.py b/colossalai/tensor/compute_spec.py new file mode 100644 index 0000000000000000000000000000000000000000..12f8f36bc61318910edfa3a0e5ece9cd81b6aafe --- /dev/null +++ b/colossalai/tensor/compute_spec.py @@ -0,0 +1,29 @@ +from enum import Enum + + +class ComputePattern(Enum): + TP1D = 0 + TP2D = 1 + TP2P5D = 2 + TP3D = 3 + + +class ComputeSpec(object): + """ComputeSpec + The Specification for computation pattern + + Args: + compute_pattern (ComputePattern): an Enum instance for compute pattern. + """ + + def __init__(self, compute_pattern: ComputePattern) -> None: + assert isinstance(compute_pattern, ComputePattern) + self.compute_pattern = compute_pattern + # Make sure output tensors are replicate + self.output_replicate = True + + def __repr__(self): + return f'ComputeSpec(pattern={self.compute_pattern}, replicate_output={self.output_replicate})' + + def set_output_replicate(self, flag: bool = True): + self.output_replicate = flag diff --git a/colossalai/tensor/const.py b/colossalai/tensor/const.py new file mode 100644 index 0000000000000000000000000000000000000000..356e8ecc885a3fb24766683b106a91ca2fac44eb --- /dev/null +++ b/colossalai/tensor/const.py @@ -0,0 +1,6 @@ +from enum import Enum + + +class TensorType(Enum): + MODEL = 0 + NONMODEL = 1 # mainly activations diff --git a/colossalai/tensor/d_tensor/__init__.py b/colossalai/tensor/d_tensor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/colossalai/tensor/d_tensor/comm_spec.py b/colossalai/tensor/d_tensor/comm_spec.py new file mode 100644 index 0000000000000000000000000000000000000000..765d8ec1b01a7857aa74e4ea4cb599869a308f1c --- /dev/null +++ b/colossalai/tensor/d_tensor/comm_spec.py @@ -0,0 +1,310 @@ +from enum import Enum +from typing import Dict + +import torch +import torch.distributed as dist +from torch.distributed import ReduceOp + +__all__ = [ + 'CollectiveCommPattern', + 'CommSpec', +] + + +class CollectiveCommPattern(Enum): + GATHER_FWD_SPLIT_BWD = 'gather_fwd_split_bwd' + ALL2ALL_FWD_ALL2ALL_BWD = 'all2all_fwd_all2all_bwd' + SPLIT_FWD_GATHER_BWD = 'split_fwd_gather_bwd' + ALLREDUCE_FWD_IDENTITY_BWD = 'all_reduce_fwd_identity_bwd' + IDENTITY_FWD_ALLREDUCE_BWD = 'identity_fwd_all_reduce_bwd' + MIXGATHER_FWD_SPLIT_BWD = "mixgather_fwd_split_bwd" + + +class CommSpec: + ''' + Communication spec is used to record the communication action. It converts the communication spec + to real action which will be used in runtime. It contains comm_pattern to determine the + communication method, process_groups_dict to determine the process groups, gather_dim and shard_dim + to determine the buffer shape, and logical_process_axis + + Argument: + comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec. + process_groups_dict(Dict): A dict which contains the process groups used to apply this CommSpec. + gather_dim(int, Optional): The gather_dim of the tensor will be gathered. + shard_dim(int, Optional): The shard_dim of the tensor will be sharded. + logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action. + ''' + + def __init__(self, + comm_pattern: CollectiveCommPattern, + process_groups_dict: Dict, + gather_dim: int = None, + shard_dim: int = None, + logical_process_axis: int = None): + self.comm_pattern = comm_pattern + self.gather_dim = gather_dim + self.shard_dim = shard_dim + self.logical_process_axis = logical_process_axis + self.process_groups_dict = process_groups_dict + + def __repr__(self): + res_list = ["CommSpec:("] + if self.comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: + res_list.append(f"comm_pattern:GATHER_FWD_SPLIT_BWD, ") + res_list.append(f"gather_dim:{self.gather_dim}, ") + res_list.append(f"shard_dim:{self.gather_dim}, ") + res_list.append(f"logical_process_axis:{self.logical_process_axis})") + elif self.comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: + res_list.append(f"comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, ") + res_list.append(f"gather_dim:{self.gather_dim}, ") + res_list.append(f"shard_dim:{self.shard_dim}, ") + res_list.append(f"logical_process_axis: {self.logical_process_axis})") + elif self.comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: + res_list.append(f"comm_pattern:SPLIT_FWD_GATHER_BWD, ") + res_list.append(f"gather_dim:{self.gather_dim}, ") + res_list.append(f"shard_dim:{self.shard_dim}, ") + res_list.append(f"logical_process_axis:{self.logical_process_axis})") + elif self.comm_pattern == CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: + res_list.append(f"comm_pattern:ALLREDUCE_FWD_IDENTITY_BWD, ") + res_list.append(f"logical_process_axis:{self.logical_process_axis})") + elif self.comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: + res_list.append(f"comm_pattern:IDENTITY_FWD_ALLREDUCE_BWD, ") + res_list.append(f"logical_process_axis:{self.logical_process_axis})") + + return ''.join(res_list) + + def covert_spec_to_action(self, tensor): + ''' + Convert CommSpec into runtime action, implement real collection communication to target tensor. + The collection communication action is directed by the CommSpec. + + Argument: + tensor(torch.Tensor): Tensor stored in each device, which could be different in different ranks. + ''' + if self.comm_pattern in pattern_to_func_dict: + tensor = pattern_to_func_dict[self.comm_pattern](tensor, self) + else: + tensor = tensor + return tensor + + +def _all_gather(tensor: torch.Tensor, comm_spec: CommSpec): + ''' + Implement all gather operation on device mesh based on information provided by comm_spec. + ''' + process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] + for rank_list, process_group in process_groups_list: + if dist.get_rank() in rank_list: + tensor_list = [ + torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list)) + ] + # without this contiguous operation, the all gather may get some unexpected results. + tensor = tensor.contiguous() + dist.all_gather(tensor_list, tensor, group=process_group) + output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous() + return output + + +def _split(tensor: torch.Tensor, comm_spec: CommSpec): + ''' + Implement shard operation on device mesh based on information provided by comm_spec. + ''' + process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] + for rank_list, _ in process_groups_list: + if dist.get_rank() in rank_list: + dim = comm_spec.shard_dim + length = tensor.shape[comm_spec.shard_dim] // len(rank_list) + start = length * rank_list.index(dist.get_rank()) + output = torch.narrow(tensor, dim, start, length).contiguous() + return output + + +def _all_to_all(tensor: torch.Tensor, comm_spec: CommSpec): + ''' + Implement all to all operation on device mesh based on information provided by comm_spec. + ''' + process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] + for rank_list, process_group in process_groups_list: + if dist.get_rank() in rank_list: + new_shape = list(tensor.shape) + new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list) + new_shape = torch.Size(new_shape) + output_tensor_list = [ + torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list)) + ] + dim = comm_spec.shard_dim + length = tensor.shape[comm_spec.shard_dim] // len(rank_list) + input_tensor_list = [ + torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list)) + ] + group = process_group + dist.all_to_all(output_tensor_list, input_tensor_list, group) + output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous() + return output + + +def _all_reduce(tensor: torch.Tensor, comm_spec: CommSpec, async_op: bool = False): + ''' + Implement all reduce operation on device mesh based on information provided by comm_spec. + ''' + process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] + for rank_list, process_group in process_groups_list: + if dist.get_rank() in rank_list: + if not tensor.is_contiguous(): + tensor = tensor.contiguous() + dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op) + return tensor + + +class _ReduceGrad(torch.autograd.Function): + """ + A customized communication operation which forward is an identity operation, + backward is all_reduce operation. + + Args: + input_: input matrix. + comm_spec: comm_spec will give information like process group, rank list, etc. + """ + + @staticmethod + def symbolic(graph, input_): + return input_ + + @staticmethod + def forward(ctx, input_, comm_spec): + ctx.comm_spec = comm_spec + return input_ + + @staticmethod + def backward(ctx, grad_output): + return _all_reduce(grad_output, ctx.comm_spec), None + + +class _ReduceInput(torch.autograd.Function): + """ + A customized communication operation which forward is all_reduce operation, + backward is an identity operation. + + Args: + input_: input matrix. + comm_spec: comm_spec will give information like process group, rank list, etc. + """ + + @staticmethod + def symbolic(graph, input_): + return _all_reduce(input_) + + @staticmethod + def forward(ctx, input_, comm_spec): + return _all_reduce(input_, comm_spec) + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None + + +class _SplitForwardGatherBackward(torch.autograd.Function): + """ + A customized communication operation which forward is split operation, + backward is an all gather operation. + + Args: + input_: input matrix. + comm_spec: comm_spec will give information like process group, rank list, etc. + """ + + @staticmethod + def symbolic(graph, input_): + return _split(input_) + + @staticmethod + def forward(ctx, input_, comm_spec): + ctx.comm_spec = comm_spec + return _split(input_, comm_spec) + + @staticmethod + def backward(ctx, grad_output): + return _all_gather(grad_output, ctx.comm_spec), None + + +class _GatherForwardSplitBackward(torch.autograd.Function): + """ + A customized communication operation which forward is an all gather operation, + backward is split operation. + + Args: + input_: input matrix. + comm_spec: comm_spec will give information like process group, rank list, etc. + """ + + @staticmethod + def symbolic(graph, input_): + return _all_gather(input_) + + @staticmethod + def forward(ctx, input_, comm_spec): + ctx.comm_spec = comm_spec + return _all_gather(input_, comm_spec) + + @staticmethod + def backward(ctx, grad_output): + return _split(grad_output, ctx.comm_spec), None + + +class _AllToAll(torch.autograd.Function): + """ + A customized communication operation which forward is an all to all operation, + backward is an all to all operation. + + Args: + input_: input matrix. + comm_spec: comm_spec will give information like process group, rank list, etc. + """ + + @staticmethod + def symbolic(graph, input_): + return _all_to_all(input_) + + @staticmethod + def forward(ctx, input_, comm_spec): + output = _all_to_all(input_, comm_spec) + comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern, + process_groups_dict=comm_spec.process_groups_dict, + gather_dim=comm_spec.shard_dim, + shard_dim=comm_spec.gather_dim, + logical_process_axis=comm_spec.logical_process_axis) + ctx.comm_spec = comm_spec_for_backward + return output + + @staticmethod + def backward(ctx, grad_outputs): + return _all_to_all(grad_outputs, ctx.comm_spec), None + + +def reduce_grad(input_, comm_spec): + return _ReduceGrad.apply(input_, comm_spec) + + +def reduce_input(input_, comm_spec): + return _ReduceInput.apply(input_, comm_spec) + + +def split_forward_gather_backward(input_, comm_spec): + return _SplitForwardGatherBackward.apply(input_, comm_spec) + + +def gather_forward_split_backward(input_, comm_spec): + return _GatherForwardSplitBackward.apply(input_, comm_spec) + + +def all_to_all(input_, comm_spec): + return _AllToAll.apply(input_, comm_spec) + + +pattern_to_func_dict = { + CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: gather_forward_split_backward, + CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: all_to_all, + CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: split_forward_gather_backward, + CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: reduce_input, + CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: reduce_grad, +} diff --git a/colossalai/tensor/d_tensor/d_tensor.py b/colossalai/tensor/d_tensor/d_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..c1fe9d50a048397c3fee6297cd80fda61c745bc3 --- /dev/null +++ b/colossalai/tensor/d_tensor/d_tensor.py @@ -0,0 +1,142 @@ +from typing import Optional + +import torch +from torch.utils._pytree import tree_map + +from .layout import Layout +from .layout_converter import LayoutConverter, to_global +from .sharding_spec import ShardingSpec + +layout_converter = LayoutConverter() + + +class DTensor(torch.Tensor): + + def __init__(self, local_tensor: torch.Tensor, dist_layout: Layout): + self.local_tensor = local_tensor + self.data_type = local_tensor.dtype + self.entire_shape = local_tensor.shape + self.dist_layout = dist_layout + self._apply_layout() + + @staticmethod + def __new__(cls, local_tensor, layout): + return torch.Tensor._make_subclass(cls, local_tensor, local_tensor.requires_grad) + + def __repr__(self): + return f"DTensor({self.to_global()}, {self.dist_layout})" + + def __str__(self): + return self.__repr__() + + def layout_convert(self, target_layout): + ''' + Convert the layout of the tensor from source_spec to target_spec. + ''' + self.local_tensor = layout_converter.apply(self.local_tensor, self.dist_layout, target_layout) + self.dist_layout = target_layout + + def _apply_layout(self): + ''' + Apply the layout to the local tensor during initializing process. + ''' + source_spec = construct_default_sharding_spec(self.local_tensor) + source_layout = Layout(device_mesh=self.dist_layout.device_mesh, + device_type=self.dist_layout.device_type, + sharding_spec=source_spec, + entire_shape=self.entire_shape) + self.local_tensor = layout_converter.apply(self.local_tensor, source_layout, self.dist_layout) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + def filter_arg(arg): + if isinstance(arg, DTensor): + return arg.local_tensor + else: + return arg + + args = tree_map(filter_arg, args) + kwargs = tree_map(filter_arg, kwargs) + # if we want to convert the result into DTensor, we need to infer the layout of result from the layout of input tensors + # and op type. + + return func(*args, **kwargs) + + @property + def device_mesh(self): + ''' + Return the device mesh of the tensor. + ''' + return self.dist_layout.device_mesh + + @property + def sharding_spec(self): + ''' + Return the sharding specification of the tensor. + ''' + return self.dist_layout.sharding_spec + + def to(self, *args, **kwargs): + ''' + Move the tensor to a new device or convert the tensor to a new dtype. + ''' + self.local_tensor = self.local_tensor.to(*args, **kwargs) + self.data_type = self.local_tensor.dtype + self.dist_layout.device_type = self.local_tensor.device + # TODO: update the device mesh process groups or we should just cache + # both the cpu process groups and the cuda process groups? + return self + + def to_local(self): + ''' + Return the local tensor in this rank. + ''' + return self.local_tensor + + def to_global(self): + ''' + Recover the global tensor from the distributed tensor. + + Note: This function will all_gather the local tensor to the global tensor and it + will not change the layout of the DTensor. This function is mainly used for debugging or + check the correctness of the distributed tensor. + ''' + return to_global(self.local_tensor, self.dist_layout) + + +def distribute_tensor(local_tensor: torch.Tensor, dist_layout: Layout) -> DTensor: + ''' + Distribute the local tensor to the distributed tensor according to the dist_layout specified. + + Args: + local_tensor: tensor to be distributed. + dist_layout: the layout specification of the distributed tensor. + + Returns: + A 'DTensor' object. + ''' + return DTensor(local_tensor, dist_layout) + + +def distribute_module(module: torch.nn.Module, partition_fn: Optional[callable] = None) -> torch.nn.Module: + ''' + This function converts all the parameters in the module to DTensor(DParam). + + Note: This function is subject to future change as the DParam has not been implemented yet. + ''' + for name, param in module.named_parameters(): + if param is not None and not isinstance(param, DTensor): + # TODO: we could convert the parameter to DParam here, + # the type of the parameter could be an optional argument. + setattr(module, name, torch.nn.Parameter(partition_fn(name, param.data))) + return module + + +def construct_default_sharding_spec(tensor: torch.Tensor,) -> ShardingSpec: + ''' + Construct the default sharding specification for the tensor. + ''' + return ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={}) diff --git a/colossalai/tensor/d_tensor/layout.py b/colossalai/tensor/d_tensor/layout.py new file mode 100644 index 0000000000000000000000000000000000000000..ee7ef74a99aed377a3956ffe53fd96417a4b7aef --- /dev/null +++ b/colossalai/tensor/d_tensor/layout.py @@ -0,0 +1,68 @@ +import operator +from dataclasses import dataclass +from functools import reduce + +import torch + +from colossalai.device.device_mesh import DeviceMesh + +from .misc import DuplicatedShardingDimensionError, LayoutException, ShardingNotDivisibleError +from .sharding_spec import ShardingSpec + + +class Layout: + """Layout of a tensor. + + Attributes: + device_mesh: the device mesh to store the tensor distributed. + device_type: the type of the device mesh, e.g. 'cpu' or 'cuda'. + sharding_spec: the sharding specification to describe how the tensor is sharded. + entire_shape: the entire shape of the global tensor. + """ + + def __init__(self, device_mesh: DeviceMesh, device_type: torch.device, sharding_spec: ShardingSpec, + entire_shape: torch.Size): + self.device_mesh = device_mesh + self.device_type = device_type + self.sharding_spec = sharding_spec + self.entire_shape = entire_shape + self._sanity_check() + + def __hash__(self) -> int: + return hash(f'{self.sharding_spec}') + + def get_sharded_shape_per_device(self): + sharded_shape = list(self.entire_shape) + for dim, shard_list in self.sharding_spec.dim_partition_dict.items(): + mesh_list = [self.device_mesh.mesh_shape[mesh_dim] for mesh_dim in shard_list] + shard_partitions = reduce(operator.mul, mesh_list, 1) + assert sharded_shape[ + dim] % shard_partitions == 0, f'Cannot shard dimension {dim} into {shard_partitions} partitions.' + sharded_shape[dim] //= shard_partitions + return torch.Size(sharded_shape) + + def _sanity_check(self): + sharding_spec = self.sharding_spec + + # make sure all axes in logical device mesh only be used once + dim_check_list = list(range(self.device_mesh.logical_mesh_id.dim())) + for dim, shard_list in sharding_spec.dim_partition_dict.items(): + for element in shard_list: + if element in dim_check_list: + dim_check_list.remove(element) + else: + raise DuplicatedShardingDimensionError( + f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.") + + # make sure that the sharding for a dimension is divisible by the number of devices + for dim, shard_list in sharding_spec.dim_partition_dict.items(): + tensor_dim_size = self.entire_shape[dim] + num_devices = 1 + + for element in shard_list: + num_devices *= self.device_mesh.mesh_shape[element] + + if tensor_dim_size % num_devices != 0: + raise ShardingNotDivisibleError( + f'The size of dimension at index {dim} is {tensor_dim_size}, it cannot be sharded over {num_devices} devices.' + ) diff --git a/colossalai/tensor/d_tensor/layout_converter.py b/colossalai/tensor/d_tensor/layout_converter.py new file mode 100644 index 0000000000000000000000000000000000000000..cf02aac309f40d3bd7b5037000094601f5b2f2e3 --- /dev/null +++ b/colossalai/tensor/d_tensor/layout_converter.py @@ -0,0 +1,556 @@ +import math +from copy import deepcopy +from dataclasses import dataclass +from typing import Dict, List, Tuple + +import numpy as np +import torch + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem +from colossalai.context.singleton_meta import SingletonMeta +from colossalai.tensor.d_tensor.comm_spec import * +from colossalai.tensor.d_tensor.layout import Layout +from colossalai.tensor.d_tensor.misc import LayoutException +from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator + +from .sharding_spec import ShardingSpec +from .utils import get_comm_cost + +__all__ = ['LayoutConverter', 'LayoutConverterOptions', 'set_layout_converting_options'] + + +@dataclass +class LayoutConverterOptions: + """ + LayoutConverterOptions is a dataclass which specifies the preferences for layout converting. + """ + # TODO: layout converter option is not implemented yet + pass + + +def to_global(distributed_tensor: torch.Tensor, layout: Layout) -> torch.Tensor: + layout_converter = LayoutConverter() + global_sharding_spec = ShardingSpec(distributed_tensor.dim(), {}) + global_layout = Layout(device_mesh=layout.device_mesh, + device_type=layout.device_type, + sharding_spec=global_sharding_spec, + entire_shape=layout.entire_shape) + with torch.no_grad(): + global_tensor = layout_converter.apply(distributed_tensor, layout, global_layout) + return global_tensor + + +def set_layout_converting_options(options: LayoutConverterOptions): + """ + Configure the shape consistency manager via function call. + """ + manager = LayoutConverter() + manager.options = options + + +class LayoutConverter(metaclass=SingletonMeta): + + def __init__(self): + self._options = None + self._forward_only = False + self.cached_solution = {} + + @property + def options(self): + return self._options + + @options.setter + def options(self, options_: LayoutConverterOptions): + assert isinstance(options_, LayoutConverterOptions) + self._options = options_ + + @property + def forward_only(self): + return self._forward_only + + @forward_only.setter + def forward_only(self, value): + assert isinstance(value, bool) + self._forward_only = value + + def all_gather_transform_layouts(self, source_layout: Layout) -> Dict[Layout, CommSpec]: + ''' + Get all valid layouts from source_layout with single all-gather operation. + For the all-gather operation, we just care about the S dimension. + + Argument: + source_layout: the layout to be transformed. + + Return: + valid_spec_dict(Dict[Layout, CommSpec]): all valid layouts from source_layout with single all-gather operation. + + Example: + layout_converter = LayoutConverter() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1, + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + entire_shape = (4, 4, 4) + dim_partition_dict = {0: [0], 1: [1]} + + # [S0,S1,R] + sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) + layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=sharding_spec, + entire_shape=entire_shape) + + rst_dict = layout_converter.all_gather_transform_layouts(layout) + for layout, comm_spec in rst_dict.items(): + print(f'{layout.sharding_spec.sharding_sequence}: {comm_spec}') + + Output: + [R, S1, R]: CommSpec:(comm_pattern:GATHER_FWD_SPLIT_BWD, gather_dim:0, shard_dim:0, logical_process_axis:0) + [S0, R, R]: CommSpec:(comm_pattern:GATHER_FWD_SPLIT_BWD, gather_dim:1, shard_dim:1, logical_process_axis:1) + ''' + valid_spec_dict = {} + comm_pattern = CollectiveCommPattern.GATHER_FWD_SPLIT_BWD + source_spec = source_layout.sharding_spec + process_groups_dict = source_layout.device_mesh.process_groups_dict + for target_pair in source_spec.dim_partition_dict.items(): + shard_list = all_gather_simulator(target_pair) + index = target_pair[0] + new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict) + + # We won't add empty list into dim_partition_dict + # The key will be popped if the related shard_list is empty + if shard_list: + new_dim_partition_dict[index] = shard_list + else: + new_dim_partition_dict.pop(index) + + # generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec + gather_dim = index + logical_process_axis = target_pair[1][-1] + comm_spec = CommSpec( + comm_pattern, + process_groups_dict=process_groups_dict, + gather_dim=gather_dim, + # shard_dim will be used during backward + shard_dim=gather_dim, + logical_process_axis=logical_process_axis) + + # generate new sharding spec + try: + new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict) + new_layout = Layout(device_mesh=source_layout.device_mesh, + sharding_spec=new_sharding_spec, + device_type=source_layout.device_type, + entire_shape=source_layout.entire_shape) + + valid_spec_dict[new_layout] = comm_spec + except LayoutException: + pass + return valid_spec_dict + + def all_to_all_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec]: + ''' + Get all valid layouts from source_layout with single all-to-all operation. + For the all-to-all operation, we just care about the pairs containing S dimension. + + Argument: + source_layout(Layout): the layout to be transformed. + + Return: + valid_spec_dict(Dict[Layout, CommSpec]): all valid layouts from source_layout with single all-to-all operation. + + Example: + layout_converter = LayoutConverter() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1, + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + entire_shape = (4, 4, 4) + dim_partition_dict = {0: [0], 1: [1]} + + # [S0,S1,R] + sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) + layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=sharding_spec, + entire_shape=entire_shape) + rst_dict = layout_converter.all_to_all_transform_layout(layout) + + for layout, comm_spec in rst_dict.items(): + print(f'{layout.sharding_spec.sharding_sequence}: {comm_spec}') + + Output: + [S01, R, R]: CommSpec:(comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, gather_dim:1, shard_dim:0, logical_process_axis: 1) + [R, S1, S0]: CommSpec:(comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, gather_dim:0, shard_dim:2, logical_process_axis: 0) + [S0, R, S1]: CommSpec:(comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, gather_dim:1, shard_dim:2, logical_process_axis: 1) + ''' + valid_spec_dict = {} + comm_pattern = CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD + process_groups_dict = source_layout.device_mesh.process_groups_dict + source_spec = source_layout.sharding_spec + tensor_dims = source_spec.dims + for f_index in range(tensor_dims - 1): + for b_index in range(f_index + 1, tensor_dims): + # skip (R, R) cases + if f_index not in source_spec.dim_partition_dict and b_index not in source_spec.dim_partition_dict: + continue + else: + if f_index in source_spec.dim_partition_dict: + # skip (S01, R) -> (R, S01) is NOT allowed + if len(source_spec.dim_partition_dict[f_index]) >= 2: + continue + f_target_pair = (f_index, deepcopy(source_spec.dim_partition_dict[f_index])) + else: + f_target_pair = (f_index, []) + if b_index in source_spec.dim_partition_dict: + # skip (R, S01) -> (S01, R) is NOT allowed + if len(source_spec.dim_partition_dict[b_index]) >= 2: + continue + b_target_pair = (b_index, deepcopy(source_spec.dim_partition_dict[b_index])) + else: + b_target_pair = (b_index, []) + + # skip (S1, S0) -> S10 + if f_target_pair[1] and b_target_pair[1] and f_target_pair[1][0] >= b_target_pair[1][0]: + continue + f_shard_list, b_shard_list = all_to_all_simulator(f_target_pair, b_target_pair) + f_index = f_target_pair[0] + b_index = b_target_pair[0] + + # generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec + if len(f_shard_list) < len(f_target_pair[1]): + gather_dim = f_index + shard_dim = b_index + logical_process_axis = f_target_pair[1][-1] + else: + gather_dim = b_index + shard_dim = f_index + logical_process_axis = b_target_pair[1][-1] + comm_spec = CommSpec(comm_pattern, + process_groups_dict, + gather_dim=gather_dim, + shard_dim=shard_dim, + logical_process_axis=logical_process_axis) + + new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict) + + # We won't add empty list into dim_partition_dict + # The key will be popped if the related shard_list is empty + if f_shard_list: + new_dim_partition_dict[f_index] = f_shard_list + else: + new_dim_partition_dict.pop(f_index) + if b_shard_list: + new_dim_partition_dict[b_index] = b_shard_list + else: + new_dim_partition_dict.pop(b_index) + + # generate new sharding spec + try: + new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict) + new_layout = Layout(device_mesh=source_layout.device_mesh, + sharding_spec=new_sharding_spec, + device_type=source_layout.device_type, + entire_shape=source_layout.entire_shape) + valid_spec_dict[new_layout] = comm_spec + except LayoutException: + pass + + return valid_spec_dict + + def shard_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec]: + ''' + Get all valid layouts from source_layout with single shard operation. + For the sharding operation, we just care about legal sharding dimensions. + + Argument: + source_layout(Layout): the layout to be transformed. + + Return: + valid_spec_dict(Dict[Layout, CommSpec]): all valid layouts from source_layout with single shard operation. + + Example: + layout_converter = LayoutConverter() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1, + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + entire_shape = (4, 4, 4) + + dim_partition_dict = {0: [0]} + + # [S0,R,R] + sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) + layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=sharding_spec, + entire_shape=entire_shape) + rst_dict = layout_converter.shard_transform_layout(layout) + + for layout, comm_spec in rst_dict.items(): + print(f'{layout.sharding_spec.sharding_sequence}: {comm_spec}') + + Output: + [S01, R, R]: CommSpec:(comm_pattern:SPLIT_FWD_GATHER_BWD, gather_dim:0, shard_dim:0, logical_process_axis:1) + [S0, S1, R]: CommSpec:(comm_pattern:SPLIT_FWD_GATHER_BWD, gather_dim:1, shard_dim:1, logical_process_axis:1) + [S0, R, S1]: CommSpec:(comm_pattern:SPLIT_FWD_GATHER_BWD, gather_dim:2, shard_dim:2, logical_process_axis:1) + ''' + valid_spec_dict = {} + comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD + source_spec = source_layout.sharding_spec + process_groups_dict = source_layout.device_mesh.process_groups_dict + + # legal sharding dims means the mesh_id is still available to use. + legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.mesh_shape))] + for dim, shard_list in source_spec.dim_partition_dict.items(): + for element in shard_list: + legal_sharding_dims.remove(element) + + if len(legal_sharding_dims) == 0: + return valid_spec_dict + + tensor_dims = source_spec.dims + + for index in range(tensor_dims): + if index not in source_spec.dim_partition_dict: + shard_list_list = shard_simulator((index, []), legal_sharding_dims) + else: + shard_list_list = shard_simulator((index, source_spec.dim_partition_dict[index]), legal_sharding_dims) + if not shard_list_list: + continue + for shard_list in shard_list_list: + new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict) + new_dim_partition_dict[index] = shard_list + + # generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec + shard_dim = index + logical_process_axis = shard_list[-1] + comm_spec = CommSpec(comm_pattern, + process_groups_dict, + gather_dim=shard_dim, + shard_dim=shard_dim, + logical_process_axis=logical_process_axis) + + # generate new sharding spec + try: + new_sharding_spec = ShardingSpec(dim_size=source_spec.dims, + dim_partition_dict=new_dim_partition_dict) + new_layout = Layout(device_mesh=source_layout.device_mesh, + sharding_spec=new_sharding_spec, + device_type=source_layout.device_type, + entire_shape=source_layout.entire_shape) + valid_spec_dict[new_layout] = comm_spec + except LayoutException: + pass + return valid_spec_dict + + def get_all_one_step_transform_spec(self, source_layout: Layout) -> Dict[Layout, CommSpec]: + ''' + Get all valid layouts from source_layout with one step transform. + + Note: + all-gather will eliminate a sharding dimension, all-to-all will keep sharding dimension same as before, + and shard will add a sharding dimension. Therefore, the result of above operations are mutual exclusive, + we could safely put them together. + + Argument: + source_layout(Layout): the layout to be transformer. + + Return: + valid_spec_dict(Dict[Layout, CommSpec]): all valid layouts from source_layout with one step transform. + ''' + valid_spec_dict = {} + valid_spec_dict.update(self.all_gather_transform_layouts(source_layout)) + valid_spec_dict.update(self.all_to_all_transform_layout(source_layout)) + valid_spec_dict.update(self.shard_transform_layout(source_layout)) + return valid_spec_dict + + def layout_converting(self, source_layout: Layout, + target_layout: Layout) -> Tuple[List[Layout], List[CommSpec], float]: + ''' + This method will find a path to transform source_layout to target_layout with + a greedy algorithm. + The basic idea is: + Step1: + Generate all one-step transform sequences from source_layout. + Step2: + Pick the 'best' layout following the heuristic function. + Step3: + Repeat above steps until the source layout transform to target layout. + + Additionally, to avoid repeating the path search in runtime, we cached all solved path + in auto parallel strategy building time, which could handle most of cases in runtime. + + Args: + source_layout(Layout): the layout to be transformed. + target_layout(Layout): the layout to be achieved after a serious of transforms. + + Return: + transform_path(List[Layout]): The transform path from source_layout to target_layout, + it contains the source_layout and target_layout. + comm_action_sequence(List[CommSpec]): Keep the communication operations to complete the layout converting in order. + + Example: + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1, + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + entire_shape = (4, 4, 4) + + dim_partition_source = {1: [0, 1]} + dim_partition_target = {0: [0, 1]} + + # [R,S01,R] + sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) + source_layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=sharding_spec_source, + entire_shape=entire_shape) + + # [S01,R,R] + sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) + target_layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=sharding_spec_target, + entire_shape=entire_shape) + + transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout) + transform_path_str = '->'.join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path]) + print(transform_path_str) + + output: + [R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R] + ''' + source_spec = source_layout.sharding_spec + target_spec = target_layout.sharding_spec + MAX_TRANSFORM_STEPS = 20 + total_steps = 0 + transform_path = [] + comm_action_sequence = [] + spec_pairs = (str(source_spec.sharding_sequence), str(target_spec.sharding_sequence)) + + if spec_pairs in self.cached_solution: + return self.cached_solution[spec_pairs] + + # We do nothing if the sharding spec is all the same. + if source_spec.spec_diff(target_spec) == 0: + self.cached_solution[spec_pairs] = (transform_path, comm_action_sequence) + return ( + transform_path, + comm_action_sequence, + ) + + temp_sharding_layout = source_layout + + transform_path.append(temp_sharding_layout) + # To avoid dead loop, the loop will break after MAX_TRANSFORM_STEPS transforms + while total_steps <= MAX_TRANSFORM_STEPS: + valid_transform_spec_dict = self.get_all_one_step_transform_spec(temp_sharding_layout) + best_difference_score = math.inf + + for layout, comm_spec in valid_transform_spec_dict.items(): + sharding_spec = layout.sharding_spec + spec_difference = sharding_spec.spec_diff(target_spec) + + if spec_difference == 0: + transform_path.append(layout) + comm_action_sequence.append(comm_spec) + self.cached_solution[spec_pairs] = (transform_path, comm_action_sequence) + return (transform_path, comm_action_sequence) + + if spec_difference < best_difference_score: + temp_sharding_layout = layout + temp_comm_spec = comm_spec + best_difference_score = spec_difference + + transform_path.append(temp_sharding_layout) + comm_action_sequence.append(temp_comm_spec) + + total_steps += 1 + + raise RuntimeError(f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps.") + + def get_total_comm_cost(self, source_layout: Layout, target_layout: Layout) -> Dict[str, float]: + ''' + Get the total communication cost of the layout converting process. + ''' + transform_path, comm_action_sequence = self.layout_converting(source_layout, target_layout) + total_cost = {'forward': 0.0, 'backward': 0.0, 'total': 0.0} + for layout, comm_spec in zip(transform_path, comm_action_sequence): + cost_dict = get_comm_cost(layout, comm_spec, self.forward_only) + for key in total_cost: + total_cost[key] += cost_dict[key] + return total_cost + + def apply(self, tensor: torch.Tensor, source_layout: Layout, target_layout: Layout) -> torch.Tensor: + ''' + Apply target_layout to tensor with source layout, the transform path is generated by the + layout_converting method. + + Argument: + tensor (torch.Tensor): The tensor to be redistributed. + source_layout(Layout): The source layout of the tensor. + target_layout (Layout): The tensor will be redistributed to the target_layout. + + Example: + layout_converter = LayoutConverter() + dim_partition_source = {0: [0]} + dim_partition_target = {1: [0]} + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1, + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + entire_shape = (4, 4, 4) + + # [S0,R,R] + sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) + source_layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=sharding_spec_source, + entire_shape=entire_shape) + + # [R,S0,R] + sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) + target_layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=sharding_spec_target, + entire_shape=entire_shape) + + if rank in (0, 1): + sharded_tensor_0 = torch.zeros(2, 1) + sharded_tensor_1 = torch.ones(2, 1) + # tensor([[0., 1.], + # [0., 1.]]) + tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda() + if rank in (2, 3): + sharded_tensor_0 = torch.ones(2, 1) * 2 + sharded_tensor_1 = torch.ones(2, 1) * 3 + # tensor([[2., 3.], + # [2., 3.]]) + tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda() + + # converted_tensor: [R, S0, R] + converted_tensor = layout_converter.apply(tensor_to_comm, source_layout, target_layout) + print(converted_tensor) + + Output in rank0 and rank1: + tensor([[0.], + [0.], + [2.], + [2.]]) + + Output in rank2 and rank3: + tensor([[1.], + [1.], + [3.], + [3.]]) + ''' + _, comm_action_sequence = self.layout_converting(source_layout, target_layout) + for comm_spec in comm_action_sequence: + tensor = comm_spec.covert_spec_to_action(tensor) + return tensor diff --git a/colossalai/tensor/d_tensor/misc.py b/colossalai/tensor/d_tensor/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..3bb3f6f1961eb9b7261d14abf998e693ce5ec2d3 --- /dev/null +++ b/colossalai/tensor/d_tensor/misc.py @@ -0,0 +1,14 @@ +class LayoutException(Exception): + pass + + +class DuplicatedShardingDimensionError(LayoutException): + pass + + +class ShardingNotDivisibleError(LayoutException): + pass + + +class ShardingOutOfIndexError(LayoutException): + pass diff --git a/colossalai/tensor/d_tensor/sharding_spec.py b/colossalai/tensor/d_tensor/sharding_spec.py new file mode 100644 index 0000000000000000000000000000000000000000..2ea0c4db89fd3ca73fc920f84185c27c446b9573 --- /dev/null +++ b/colossalai/tensor/d_tensor/sharding_spec.py @@ -0,0 +1,237 @@ +from copy import deepcopy +from typing import Dict, List + +from ..utils import merge_same_dim_mesh_list +from .misc import ShardingOutOfIndexError + +__all__ = ['DimSpec', 'ShardingException', 'ShardingSpec'] + +ALLGATHER_COST = 20 +SHARD_COST = 5 +STEP_PENALTY = 6 +NAN = 'nan' + + +class DimSpec: + ''' + Sharding spec for single dimension of the sharded tensor describe the sharding dimension of + logical device mesh and give a method to compute the difference between them. + This class is used internally in ShardingSpec. + + Argument: + shard_list(List[int]): if shard_list is None, the dim spec will be 'R' type. + Otherwise, the element in shard_list means the data will be sharded in that dimension. + ''' + + def __init__(self, shard_list): + self.is_replica = len(shard_list) == 0 + self.shard_list = shard_list + self.build_difference_2d_dict() + + def __eq__(self, other): + return str(self) == str(other) + + def __repr__(self): + if self.is_replica: + return 'R' + target = 'S' + for dim in self.shard_list: + target += str(dim) + return target + + def _convert_str_to_shard_list(self, str_spec): + ''' + Conver str_spec into shard_list. + + Argument: + str_spec(str): dim spec in str type. + ''' + + if str_spec == 'R': + return [] + if str_spec == 'S0': + return [0] + if str_spec == 'S1': + return [1] + if str_spec == 'S01': + return [0, 1] + + def build_difference_2d_dict(self): + ''' + Build a difference maping for 2D device mesh case. It will be used to + compute the difference between DimSpec pairs. + ''' + + source_spec_list = ['R', 'S0', 'S1', 'S01'] + target_spec_list = ['R', 'S0', 'S1', 'S01'] + difference_dict = {} + for source_spec in source_spec_list: + for target_spec in target_spec_list: + legal_sharding_dims = [] + spec_pair = (deepcopy(source_spec), deepcopy(target_spec)) + source_shard_list = self._convert_str_to_shard_list(source_spec) + target_shard_list = self._convert_str_to_shard_list(target_spec) + + # source same as target + if source_shard_list == target_shard_list: + difference = 0 + + # all_gather(source) -> target + elif len(source_shard_list + ) == len(target_shard_list) + 1 and source_shard_list[:-1] == target_shard_list: + difference = ALLGATHER_COST + + # shard(source) -> target + elif len(source_shard_list) == len( + target_shard_list) - 1 and source_shard_list == target_shard_list[:-1] and target_shard_list[ + -1] not in source_shard_list: + difference = SHARD_COST + + # S1 -> S0 or S0 -> S1 + elif len(source_shard_list) == len(target_shard_list): + # source -> R -> target + difference = ALLGATHER_COST + STEP_PENALTY + SHARD_COST + + # R -> S01 + elif len(source_shard_list) == len(target_shard_list) - 2: + difference = SHARD_COST + STEP_PENALTY + SHARD_COST + + # S01 -> R + elif len(source_shard_list) == len(target_shard_list) + 2: + difference = ALLGATHER_COST + STEP_PENALTY + ALLGATHER_COST + + # S1 -> S01 + elif len(source_shard_list) == len(target_shard_list) - 1: + difference = ALLGATHER_COST + STEP_PENALTY + SHARD_COST + STEP_PENALTY + SHARD_COST + + # S01 -> S1 + elif len(source_shard_list) == len(target_shard_list) + 1: + difference = ALLGATHER_COST + STEP_PENALTY + ALLGATHER_COST + STEP_PENALTY + SHARD_COST + + else: + difference = NAN + difference_dict[spec_pair] = difference + + self.difference_dict = difference_dict + + def dim_diff(self, other): + ''' + The difference between two _DimSpec. + + Argument: + other(_DimSpec): the dim spec to compare with. + + Return: + difference(int): the difference between two _DimSpec. + + Example: + dim_spec = _DimSpec([0]) + other_dim_spec = _DimSpec([0, 1]) + print(dim_spec.difference(other_dim_spec)) + + Output: + 5 + ''' + difference = self.difference_dict[(str(self), str(other))] + return difference + + +class ShardingSpec: + ''' + Sharding spec describes how to shard a tensor with dim_size dimensions. The sharding sequence looks like + [R, R, S0, S1], which means + + Argument: + dim_partition_dict(Dict[int, List[int]], optional): The key is the dimension of tensor to be sharded, + and the value of the key describe which logical axis will be sharded in that dimension. + sharding_sequence(List[DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1]. + ''' + + def __init__(self, + dim_size: int, + dim_partition_dict: Dict[int, List[int]] = None, + sharding_sequence: List[DimSpec] = None): + self.dims = dim_size + self.dim_partition_dict = dim_partition_dict + self.sharding_sequence = sharding_sequence + if self.sharding_sequence is None: + assert self.dim_partition_dict is not None, f'dim_partition_dict should not be None, if sharding_sequence is NoneType object.' + self.dim_partition_dict = merge_same_dim_mesh_list(dim_size=self.dims, + dim_partition_dict=self.dim_partition_dict) + self.sharding_sequence = self.convert_dict_to_shard_sequence() + + elif self.dim_partition_dict is None: + assert self.sharding_sequence is not None, f'sharding_sequence should not be None, if dim_partition_dict is NoneType object.' + self.dim_partition_dict = self.convert_shard_sequence_to_dict() + + self._sanity_check() + + def _sanity_check(self): + if len(self.sharding_sequence) > self.dims: + raise ShardingOutOfIndexError( + f'sharding_sequence should have {self.dims} elements, but got index {len(self.sharding_sequence)}.') + + if list(self.dim_partition_dict.keys()) and max(list(self.dim_partition_dict.keys())) >= self.dims: + raise ShardingOutOfIndexError( + f'the key of dim_partition_dict should be less than {self.dims}, but got {max(list(self.dim_partition_dict.keys()))}.' + ) + + def __repr__(self): + res_list = ["ShardingSpec:"] + res_list.append(f"\n\tshard_sequence: " + ",".join(str(dimspec) for dimspec in self.sharding_sequence)) + return ' '.join(res_list) + + def convert_dict_to_shard_sequence(self): + ''' + Convert dim_partition_dict into list of DimSpec, and assign it to sharding_sequence. + ''' + sharding_sequence = [DimSpec([])] * self.dims + for dim, shard_list in self.dim_partition_dict.items(): + sharding_sequence[dim] = DimSpec(shard_list) + return sharding_sequence + + def convert_shard_sequence_to_dict(self): + ''' + Convert sharding_sequence into dim_partition_dict. + ''' + new_dim_partition_dict = {} + for index, dim_spec in enumerate(self.sharding_sequence): + if not dim_spec.is_replica: + if index not in new_dim_partition_dict: + new_dim_partition_dict[index] = [] + new_dim_partition_dict[index].extend(dim_spec.shard_list) + return new_dim_partition_dict + + def spec_diff(self, other): + ''' + This function is a naive version of difference computation. It just simply accumulates difference every dimension between the + pair of sharding sequence. + + Example: + dim_partition_dict = {0: [0, 1]} + # DistSpec: + # shard_sequence: S01,R,R + # device_mesh_shape: (4, 4) + sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict) + dim_partition_dict_to_compare = {0: [0], 1: [1]} + # DistSpec: + # shard_sequence: S0,S1,R + # device_mesh_shape: (4, 4) + sharding_spec_to_compare = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_to_compare) + print(sharding_spec.sharding_sequence_difference(sharding_spec_to_compare)) + + Output: + 25 + + Argument: + other(ShardingSpec): The ShardingSpec to compared with. + + Return: + difference(int): Difference between two ShardingSpec. + ''' + assert len(self.sharding_sequence) == len( + other.sharding_sequence), f'Cannot compare difference for two sharding specs with different length.' + difference = 0 + for orig_dim_spec, other_dim_spec in zip(self.sharding_sequence, other.sharding_sequence): + difference += orig_dim_spec.dim_diff(other_dim_spec) + return difference diff --git a/colossalai/tensor/d_tensor/utils.py b/colossalai/tensor/d_tensor/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..644bb6306b42c3ac28717396aeea169370748868 --- /dev/null +++ b/colossalai/tensor/d_tensor/utils.py @@ -0,0 +1,66 @@ +import operator +from functools import reduce +from typing import Dict + +from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern, CommSpec +from colossalai.tensor.d_tensor.layout import Layout + + +def get_comm_cost(layout: Layout, comm_spec: CommSpec, forward_only: bool = False) -> Dict[str, float]: + ''' + This method is used to compute the communication cost for a given layout and comm_spec. + + For all_gather, all2all, and all_reduce operation, the formula provided in DeviceMesh with alpha-beta model is used to + compute the communication cost. For shard operation, it is an on-chip operation, so the communication cost is a tiny cost. + + Args: + layout: the layout of the tensor. + comm_spec: the comm_spec to instruct the communication operation. + forward_only: if it is True, we will just count the forward communication cost. + If it is False, we will count both forward and backward communication cost. + ''' + comm_size = reduce(operator.mul, layout.get_sharded_shape_per_device(), 1) + device_mesh = layout.device_mesh + comm_pattern = comm_spec.comm_pattern + logical_process_axis = comm_spec.logical_process_axis + cost_dict = {} + + if comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: + # the comm size for all gather is the size of the gathered tensor + gather_dim = comm_spec.gather_dim + all_gather_axis = layout.sharding_spec.dim_partition_dict[gather_dim][-1] + all_gather_size = device_mesh.mesh_shape[all_gather_axis] + comm_size_for_all_gather = comm_size * all_gather_size + forward_communication_cost = device_mesh.all_gather_cost(comm_size_for_all_gather, logical_process_axis) + # give a tiny cost to shard + backward_communication_cost = 100 + + if comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: + forward_communication_cost = device_mesh.all_to_all_cost(comm_size, logical_process_axis) + # grad should have same shape as input tensor + # all to all operation has same logical process axis as forward. + backward_communication_cost = device_mesh.all_to_all_cost(comm_size, logical_process_axis) + + if comm_pattern == CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: + forward_communication_cost = device_mesh.all_reduce_cost(comm_size, logical_process_axis) + backward_communication_cost = 0 + + if comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: + forward_communication_cost = 0 + backward_communication_cost = device_mesh.all_reduce_cost(comm_size, logical_process_axis) + + if comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: + # give a tiny cost to shard + forward_communication_cost = 100 + backward_communication_cost = device_mesh.all_gather_cost(comm_size, logical_process_axis) + + if forward_only: + cost_dict["forward"] = forward_communication_cost + cost_dict["backward"] = 0 + cost_dict["total"] = cost_dict["forward"] + cost_dict["backward"] + else: + cost_dict["forward"] = forward_communication_cost + cost_dict["backward"] = backward_communication_cost + cost_dict["total"] = cost_dict["forward"] + cost_dict["backward"] + + return cost_dict diff --git a/colossalai/tensor/dist_spec_mgr.py b/colossalai/tensor/dist_spec_mgr.py new file mode 100644 index 0000000000000000000000000000000000000000..8657989235db49623f31191b2b4df84896893db6 --- /dev/null +++ b/colossalai/tensor/dist_spec_mgr.py @@ -0,0 +1,189 @@ +from contextlib import contextmanager + +import torch +import torch.distributed as dist +# from colossalai.nn.layer.utils import divide +from numpy import prod +from packaging import version + +from colossalai.logging import get_dist_logger +from colossalai.tensor.distspec import _DistSpec +from colossalai.tensor.process_group import ProcessGroup + + +# TODO(jiaruifang) circle import, move the divide to colossalai.commons. +# colossalai.tensor shall not import any submodule from colossal.nn +def divide(numerator, denominator): + """Only allow exact division. + + Args: + numerator (int): Numerator of the division. + denominator (int): Denominator of the division. + + Returns: + int: the result of exact division. + """ + assert denominator != 0, 'denominator can not be zero' + assert numerator % denominator == 0, \ + '{} is not divisible by {}'.format(numerator, denominator) + return numerator // denominator + + +class TransformDistSpec(torch.autograd.Function): + + @staticmethod + def forward(ctx, tensor, old_dist_spec, dist_spec, pg, forward_trans_func, backward_trans_func): + ctx.old_dist_spec = old_dist_spec + ctx.dist_spec = dist_spec + ctx.backward_trans_func = backward_trans_func + ctx.pg = pg + return forward_trans_func(tensor, old_dist_spec, dist_spec, pg) + + @staticmethod + def backward(ctx, grad_outputs): + return ctx.backward_trans_func(grad_outputs, ctx.dist_spec, ctx.old_dist_spec, + ctx.pg), None, None, None, None, None + + +class DistSpecManager: + + _use_autograd_function: bool = True + + @staticmethod + def _sanity_check(old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> None: + pass + + @staticmethod + def _shard_as(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, + pg: ProcessGroup) -> torch.Tensor: + """_shard_as: shard the tensor w.r.t a distributed specification. + Assuming the tensor passed in is a global (replicated) tensor. + Args: + tensor (torch.Tensor): a global (replicated) tensor before shard + dist_spec (_DistSpec): the distributed spec. to be sharded as. + pg (ProcessGroup): the process group of the corresponding colotensor + Returns: + torch.Tensor: a torch tensor after sharded. + """ + assert old_dist_spec.placement.value == 'r', f"The old_dist_spec of DistSpecManager._shard_as must be REPLICATE!" + DistSpecManager._sanity_check(old_dist_spec, dist_spec) + + chunk = tensor + idx = pg.tp_local_rank() + num_parts = prod(dist_spec.num_partitions) + for i, dim in enumerate(dist_spec.dims): + num_parts //= dist_spec.num_partitions[i] + + chunk_size = divide(tensor.size(dim), dist_spec.num_partitions[i]) + chunk = chunk.narrow(dim, idx // num_parts * chunk_size, chunk_size) + idx %= num_parts + return chunk.clone().detach().contiguous() + + @staticmethod + def _gather(tensor: torch.Tensor, old_dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor: + """_gather gather sharded tensors to a replicated one. + Args: + tensor (torch.Tensor): a shared torch tensor + old_dist_spec (_DistSpec): the distributed spec. of the tensor. + + Returns: + torch.Tensor: a replicated tensor. + """ + assert old_dist_spec.placement.value == 's', f"The old_dist_spec of DistSpecManager._gather must be SHARD!" + is_cpu_tensor = False + if tensor.device.type == 'cpu': + # pytorch lower than 1.11 dose not support gather a cpu tensor. + # Therefore, we transfer tensor to GPU before gather. + saved_dev = tensor.device + tensor.data = tensor.data.cuda() + is_cpu_tensor = True + + buffer = [torch.empty_like(tensor) for _ in range(pg.tp_world_size())] + assert tensor.device.type == 'cuda' + dist.all_gather(buffer, tensor, group=pg.tp_process_group()) + for i in range(len(old_dist_spec.dims) - 1, -1, -1): + new_buffer = [] + dim = old_dist_spec.dims[i] + num_parts = old_dist_spec.num_partitions[i] + for start in range(0, len(buffer), num_parts): + new_buffer.append(torch.cat(buffer[start:start + num_parts], dim)) + buffer = new_buffer + assert len(buffer) == 1 + + if is_cpu_tensor: + buffer[0].data = buffer[0].data.to(saved_dev) + return buffer[0] + + @staticmethod + def _all_to_all(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, + pg: ProcessGroup) -> torch.Tensor: + world_size = pg.tp_world_size() + if world_size == 1: + return tensor + + assert tensor.device.type == "cuda", \ + "Currently, only CUDA Tensor with NCCL backend is supported for the requested AlltoAll " \ + f"collective function, however, we got {tensor.device.type} device" + + gather_dim = old_dist_spec.dims[0] + scatter_dim = dist_spec.dims[0] + shapes = list(tensor.shape) + scattered_dim_size = shapes[scatter_dim] // world_size + gathered_dim_size = shapes[gather_dim] * world_size + shapes[scatter_dim] = scattered_dim_size + + scatter_list = [t.contiguous() for t in torch.tensor_split(tensor, world_size, scatter_dim)] + gather_list = [torch.empty(*shapes, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)] + dist.all_to_all(gather_list, scatter_list, group=pg.tp_process_group()) + + output_ = torch.cat(gather_list, dim=gather_dim).contiguous() + assert output_.shape[scatter_dim] == scattered_dim_size and output_.shape[gather_dim] == gathered_dim_size + return output_ + + @staticmethod + def _r2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor: + DistSpecManager._sanity_check(old_dist_spec, dist_spec) + return tensor + + @staticmethod + def _r2s(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor: + DistSpecManager._sanity_check(old_dist_spec, dist_spec) + return DistSpecManager._shard_as(tensor, old_dist_spec, dist_spec, pg) + + @staticmethod + def _s2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor: + DistSpecManager._sanity_check(old_dist_spec, dist_spec) + return DistSpecManager._gather(tensor, old_dist_spec, pg) + + @staticmethod + def _s2s(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor: + DistSpecManager._sanity_check(old_dist_spec, dist_spec) + if old_dist_spec == dist_spec: + return tensor + if len(old_dist_spec.dims) == 1 and len(dist_spec.dims) == 1: + # use all-to-all to save memory + return DistSpecManager._all_to_all(tensor, old_dist_spec, dist_spec, pg) + tensor = DistSpecManager._gather(tensor, old_dist_spec, pg) + return DistSpecManager._shard_as(tensor, old_dist_spec, dist_spec, pg) + + @staticmethod + def handle_trans_spec(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, + pg: ProcessGroup) -> torch.Tensor: + assert isinstance(old_dist_spec, _DistSpec), f"{type(old_dist_spec)} should be _DistSpec" + assert isinstance(dist_spec, _DistSpec), f"{type(dist_spec)} should be _DistSpec" + forward_trans_handle = getattr(DistSpecManager, f'_{old_dist_spec.placement.value}2{dist_spec.placement.value}') + if not DistSpecManager._use_autograd_function: + return forward_trans_handle(tensor, old_dist_spec, dist_spec, pg) + backward_trans_handle = getattr(DistSpecManager, + f'_{dist_spec.placement.value}2{old_dist_spec.placement.value}') + return TransformDistSpec.apply(tensor, old_dist_spec, dist_spec, pg, forward_trans_handle, + backward_trans_handle) + + @staticmethod + @contextmanager + def no_grad(): + try: + DistSpecManager._use_autograd_function = False + yield + finally: + DistSpecManager._use_autograd_function = True diff --git a/colossalai/tensor/distspec.py b/colossalai/tensor/distspec.py new file mode 100644 index 0000000000000000000000000000000000000000..3a09f1426e3140f1b2857e0c02409b16e0fba041 --- /dev/null +++ b/colossalai/tensor/distspec.py @@ -0,0 +1,78 @@ +from enum import Enum +from typing import List + +__all__ = ['ReplicaSpec', 'ShardSpec'] + + +class DistPlacementPattern(Enum): + REPLICATE = 'r' + SHARD = 's' + + +class _DistSpec: + """_DistSpec + + A class indicates Distributed Specification. + The DistSpec is only works for the tensor parallel process groups. + Because the dist spec of data parallel process group can be automatically deduced. + This is an internal data structure. + The API for users should be `ShardSpec` and `ReplicaSpec`. + + Args: + dist_placement_pattern (DistPlacementPattern): the pattern describing how tensors are distributed among processes. + The dist_placement_pattern is picked from a limited set, now including two patterns: replicate and shard. + process_group (Optional[ProcessGroup], optional): the process group contains processes. Defaults to None. + """ + + def __init__(self, dist_placement_pattern: DistPlacementPattern, **meta_info): + + self.placement = dist_placement_pattern + for k, v in meta_info.items(): + setattr(self, k, v) + + def __eq__(self, other: "_DistSpec") -> bool: + if dir(self) != dir(other): + return False + for attr in dir(self): + if not attr.startswith('__') and getattr(self, attr) != getattr(other, attr): + return False + return True + + def __repr__(self) -> str: + attr_list = [] + for attr in dir(self): + if not attr.startswith('__'): + attr_list.append(f'{attr}={str(getattr(self, attr))}') + attr_str = ", ".join(attr_list) + return "DistSpec(" + attr_str + ")" + + +def ReplicaSpec() -> _DistSpec: + """ReplicaSpec + + A distributed specification represents the tensor is replicated among the tensor parallel process group. + + Returns: + _DistSpec: an replicated dist spec instance. + """ + return _DistSpec(DistPlacementPattern.REPLICATE) + + +def ShardSpec(dims: List[int], num_partitions: List[int]) -> _DistSpec: + """ShardSpec + + A distributed specification represents the tensor is sharded among the tensor parallel process group. + + Note: + Currently, only shard on one dimension is valid. In another word, dims should be of size 1. + + Args: + dims (List[int]): a list of dimensions + num_partitions (List[int]): a list of partition number of each dimensions. + + Returns: + _DistSpec: an shard dist spec instance. + """ + assert isinstance(dims, list) and isinstance(num_partitions, list) + assert len(dims) == len(num_partitions) + return _DistSpec(DistPlacementPattern.SHARD, dims=tuple(dims), num_partitions=tuple(num_partitions)) diff --git a/colossalai/tensor/op_wrapper.py b/colossalai/tensor/op_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..1c00066f74655ef997a4e53e4fc1ce97d33f6434 --- /dev/null +++ b/colossalai/tensor/op_wrapper.py @@ -0,0 +1,53 @@ +from typing import ( + Callable, + Dict, +) +import functools + +# Custom sharded ops +_COLOSSAL_OPS: Dict[str, Callable] = {} + + +def _register_colo_op(op, func): + global _COLOSSAL_OPS + _COLOSSAL_OPS[op] = func + + +def colo_op_impl(func): + """ + Provides a way for users to write their own custom operator. This + can be used to override existing ColoTensor operators or write a new + one not supported by ColoTensor. If the operator in question is covered + by ``__torch_function__`` dispatch and has a ColoTensor as any of its + parameters, the function provided will be invoked for that operator. + + Example: + >>> @colo_op_impl(torch.nn.functional.linear) + >>> def my_custom_linear(types, args, kwargs, process_group): + >>> .... + >>> + >>> input = torch.rand(10, 32) + >>> weight = ColoTensor(torch.rand(32, 16)) + >>> bias = ColoTensor(torch.rand(16)) + >>> # This will call `my_custom_linear` instead of the default. + >>> torch.nn.functional.linear(input, weight, bias) + + The types, args and kwargs parameters are the same parameters that are + passed to ``__torch_function__`` dispatch API + (https://pytorch.org/docs/stable/notes/extending.html#extending-torch). + + Args: + func(Callable): Torch function for which we want to provide a sharded + implementation (ex: torch.nn.functional.linear) + """ + + def decorator_sharded_func(wrapped_func): + _register_colo_op(func, wrapped_func) + + @functools.wraps(wrapped_func) + def wrapper(*args, **kwargs): + return wrapped_func(*args, **kwargs) + + return wrapper + + return decorator_sharded_func diff --git a/colossalai/tensor/param_op_hook.py b/colossalai/tensor/param_op_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..9c2e0d4adbf1bdc10b25a2602098eef3d36ebbaf --- /dev/null +++ b/colossalai/tensor/param_op_hook.py @@ -0,0 +1,197 @@ +from abc import ABC, abstractmethod +from contextlib import contextmanager +from typing import Any, List, Tuple + +import torch + +from colossalai.tensor.colo_tensor import ColoTensor +from colossalai.tensor.tensor_spec import ColoTensorSpec + + +class ColoParamOpHook(ABC): + """ + Hook which is triggered by each operation when operands contain ColoParameter. + To customize it, you must inherit this abstract class, and implement ``pre_forward``, + ``post_forward``, ``pre_backward`` and ``post_backward``. + These four methods apply a list of ColoParameter as input args. + """ + + @abstractmethod + def pre_forward(self, params: List[torch.Tensor]) -> None: + pass + + @abstractmethod + def post_forward(self, params: List[torch.Tensor]) -> None: + pass + + @abstractmethod + def pre_backward(self, params: List[torch.Tensor]) -> None: + pass + + @abstractmethod + def post_backward(self, params: List[torch.Tensor]) -> None: + pass + + +class ColoParamOpHookManager: + """ + Manage your param op hooks. It only has static methods. + The only static method you should call is ``use_hooks(*hooks)``. + """ + hooks: Tuple[ColoParamOpHook, ...] = tuple() + + @staticmethod + @contextmanager + def use_hooks(*hooks: ColoParamOpHook): + """Change the param op hooks you use. Nested calling is allowed. + + Example: + >>> with ColoParamOpHookManager.use_hooks(*hooks): + >>> do_something() + >>> with ColoParamOpHookManager.use_hooks(): + >>> // clear hooks + >>> do_something() + """ + try: + old_param_op_hooks = ColoParamOpHookManager.hooks + ColoParamOpHookManager.hooks = hooks + yield + finally: + ColoParamOpHookManager.hooks = old_param_op_hooks + + @staticmethod + def _trigger_pre_forward(params: List[torch.Tensor]) -> None: + for hook in ColoParamOpHookManager.hooks: + hook.pre_forward(params) + + @staticmethod + def _trigger_post_forward(params: List[torch.Tensor]) -> None: + for hook in ColoParamOpHookManager.hooks: + hook.post_forward(params) + + @staticmethod + def _trigger_pre_backward(params: List[torch.Tensor]) -> None: + for hook in ColoParamOpHookManager.hooks: + hook.pre_backward(params) + + @staticmethod + def _trigger_post_backward(params: List[torch.Tensor]) -> None: + for hook in ColoParamOpHookManager.hooks: + hook.post_backward(params) + + @staticmethod + def pre_op(params: List[torch.Tensor], *args: Any) -> list: + ColoParamOpHookManager._trigger_pre_forward(params) + grad_args, rear_args = _get_grad_args(*args) + colo_info = _get_colo_tensors_info(*grad_args) + rets = PreFwdPostBwd.apply(params, *grad_args) + update_args = _update_colo_tensors(colo_info, *rets) + if rear_args is None: + return update_args + else: + arg_zero = (tuple(update_args),) + return arg_zero + rear_args + + @staticmethod + def post_op(params: List[torch.Tensor], arg: Any) -> Any: + ColoParamOpHookManager._trigger_post_forward(params) + colo_info = _get_colo_tensors_info(arg) + ret = PostFwdPreBwd.apply(params, arg) + res = _update_colo_tensors(colo_info, ret) + if len(res) == 1: + return res[0] + else: + return res + + @staticmethod + def has_hook() -> bool: + return len(ColoParamOpHookManager.hooks) > 0 + + +class PreFwdPostBwd(torch.autograd.Function): + + @staticmethod + def forward(ctx, params, *args): + ctx.params = params + return args + + @staticmethod + def backward(ctx, *grads): + ColoParamOpHookManager._trigger_post_backward(ctx.params) + return (None,) + grads + + +class PostFwdPreBwd(torch.autograd.Function): + + @staticmethod + def forward(ctx, params, args): + ctx.params = params + return args + + @staticmethod + def backward(ctx, *grads): + ColoParamOpHookManager._trigger_pre_backward(ctx.params) + return (None,) + grads + + +def _is_grad_tensor(obj) -> bool: + if torch.is_tensor(obj): + if obj.grad_fn is not None or obj.requires_grad: + return True + return False + + +def _has_grad_tensor(obj) -> bool: + if isinstance(obj, tuple) or isinstance(obj, list): + for x in obj: + if _has_grad_tensor(x): + return True + return False + elif isinstance(obj, dict): + for x in obj.values(): + if _has_grad_tensor(x): + return True + return False + else: + return _is_grad_tensor(obj) + + +def _get_grad_args(*args): + # if there is no grad tensors, do nothing + if not _has_grad_tensor(args): + return args, None + # returns the identical args if there is a grad tensor + for obj in args: + if _is_grad_tensor(obj): + return args, None + # otherwise, the first arguement should be a tuple of grad tensors + # if there is no grad tensor, the backward of PreFwdPostBwd can't be triggered + arg_zero = args[0] + if not isinstance(arg_zero, tuple): + raise NotImplementedError("Some torch function is incompatible because of its complicated inputs.") + check_grad_flag = False + for obj in arg_zero: + check_grad_flag |= _is_grad_tensor(obj) + if not check_grad_flag: + raise NotImplementedError("Some torch function is incompatible because of its complicated inputs.") + return arg_zero, args[1:] + + +def _get_colo_tensors_info(*args) -> list: + info = [] + for arg in args: + if isinstance(arg, ColoTensor): + info.append((arg.__class__, ColoTensorSpec(arg.get_process_group(), arg.dist_spec, arg.compute_spec))) + else: + info.append(None) + return info + + +def _update_colo_tensors(info, *args) -> list: + ret = [] + for t_info, arg in zip(info, args): + if t_info is not None: + t_cls, spec = t_info + arg = t_cls.from_torch_tensor(arg, spec=spec) + ret.append(arg) + return ret diff --git a/colossalai/tensor/process_group.py b/colossalai/tensor/process_group.py new file mode 100644 index 0000000000000000000000000000000000000000..f108bdc247f5d84e0aa78240b52c5b9b1f17ed06 --- /dev/null +++ b/colossalai/tensor/process_group.py @@ -0,0 +1,319 @@ +from typing import List, Optional + +import torch + +from colossalai.context.singleton_meta import SingletonMeta +from colossalai.logging import get_dist_logger + + +class PyTorchProcessGroupDict(metaclass=SingletonMeta): + + def __init__(self): + # distributed settings + # use this dict to record all Pytorch ProcessGroups + self.dict = {} + # set a distributed logger + self.logger = get_dist_logger('ProcessGroup') + + def log_pg_init(self, rank_list: List[int], backend: str): + str_list = ["Pytorch ProcessGroup Init:"] + str_list.append(f"backend: {backend}") + str_list.append(f"ranks: {rank_list}") + self.logger.info("\n\t".join(str_list), ranks=[0]) + + def get(self, rank_list: List[int], backend: str = 'nccl'): + """Reuse Pytorch ProcessGroup when such a group is initialized + """ + # we need to convert the passed list to a tuple + # since List is unhashable + processgroup_key = (backend, tuple(rank_list)) + if processgroup_key not in self.dict: + self.log_pg_init(rank_list=rank_list, backend=backend) + self.dict[processgroup_key] = torch.distributed.new_group(ranks=rank_list, backend=backend) + return self.dict[processgroup_key] + + +PYTORCHPGDICT_ = PyTorchProcessGroupDict() + + +class ProcessGroup: + """ProcessGroup + Process Group indicates how processes are organized in groups for parallel execution using Tensor Parallelism and Data Parallelism. + + NOTE, the ProcessGroup must be used after `torch.distributed.initialize()` + + + Args: + rank: the global rank of the current process. + ranks: List[int], a list of rank id belongings to this process group. + backend: str, the backend of the process group. + tp_degree: Optional[int], tensor parallelism degree. How many processes are inside a tp process group. default None means 1. + dp_degree: Optional[int], data parallelism degree. How many processes are inside a dp process group. . default None means len(ranks). + """ + + def __init__(self, + rank: Optional[int] = None, + ranks: Optional[List[int]] = None, + tp_degree: Optional[int] = None, + dp_degree: Optional[int] = None) -> None: + if not torch.distributed.is_initialized(): + self.is_init = False + return + + assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized" + + self._rank = torch.distributed.get_rank() + if rank is not None: + assert self._rank == rank # make sure that the global rank is correct + + if ranks is None: + self._rank_list = list(range(torch.distributed.get_world_size())) + else: + self._rank_list = ranks + self._rank_list.sort() # ensure that the list is in order + + self._world_size = len(self._rank_list) + + if dp_degree is None and tp_degree is None: + self._dp_degree = self._world_size + self._tp_degree = 1 + elif dp_degree and not tp_degree: + self._dp_degree = dp_degree + assert self._world_size % self._dp_degree == 0, f"DP degree {dp_degree} should be divisible by {self._world_size} hen DP degree is None" + self._tp_degree = self._world_size // dp_degree + elif not dp_degree and tp_degree: + self._tp_degree = tp_degree + assert self._world_size % self._tp_degree == 0, f"TP degree {tp_degree} should be divisible by {self._world_size} when DP degree is None" + self._dp_degree = self._world_size // tp_degree + else: + self._dp_degree = dp_degree + self._tp_degree = tp_degree + assert self._dp_degree * self._tp_degree == self._world_size, \ + f"the world size {self._world_size} should equals to the product of DP degree {self._dp_degree}" \ + f"and TP degree {self._tp_degree}" + + self._tp_rank_list = None + self._dp_rank_list = None + + for i in range(self._dp_degree): + i_tp_list = [self._rank_list[i * self._tp_degree + j] for j in range(self._tp_degree)] + PYTORCHPGDICT_.get(i_tp_list, 'nccl') + if self._rank in i_tp_list: + self._tp_rank_list = i_tp_list + + for j in range(self._tp_degree): + j_dp_list = [self._rank_list[i * self._tp_degree + j] for i in range(self._dp_degree)] + PYTORCHPGDICT_.get(j_dp_list, 'nccl') + if self._rank in j_dp_list: + self._dp_rank_list = j_dp_list + + self._has_cpu_groups = False + self.is_init = True + + def set_cpu_groups(self): + """set_cpu_groups + Initialize Pytorch process groups for cpu communications. + """ + if self.has_cpu_groups: + return + + for i in range(self._dp_degree): + i_tp_list = [self._rank_list[i * self._tp_degree + j] for j in range(self._tp_degree)] + PYTORCHPGDICT_.get(i_tp_list, 'gloo') + + for j in range(self._tp_degree): + j_dp_list = [self._rank_list[i * self._tp_degree + j] for i in range(self._dp_degree)] + PYTORCHPGDICT_.get(j_dp_list, 'gloo') + + self._has_cpu_groups = True + + @property + def has_cpu_groups(self) -> bool: + """has_cpu_groups + If cpu groups have been initailized. + + Returns: + bool: cpu process groups have been initialized or not. + """ + return self._has_cpu_groups + + def __repr__(self): + if self.is_init: + ranks_str = f"ProcessGroup(ranks={self._rank_list},\n" + personal_str = f" rank={self._rank}, dp={self._dp_degree}, tp={self._tp_degree})" + return ranks_str + personal_str + else: + return "ProcessGroup not initialized" + + def __eq__(self, obj: 'ProcessGroup') -> bool: + if not isinstance(obj, ProcessGroup): + return False + if self._rank != obj._rank: + return False + if self._rank_list != obj._rank_list: + return False + if self._tp_rank_list != obj._tp_rank_list: + return False + if self._dp_rank_list != obj._dp_rank_list: + return False + if self._tp_degree != obj._tp_degree: + return False + if self._dp_degree != obj._dp_degree: + return False + return True + + def rank(self) -> int: + """rank + + The current rank in the global process group. + + Returns: + int: the rank number + """ + return self._rank + + def ranks_in_group(self) -> List[int]: + """ranks_in_group + + a list of rank number in in the global process group. + + Returns: + List[int]: a list of rank number. + """ + return self._rank_list + + def world_size(self) -> int: + """world_size + + The world size of the global process group. + + Returns: + int: world size + """ + return self._world_size + + def tp_rank_list(self) -> List[int]: + """tp_rank_list + + the rank list in the TP process group containing the current rank. + + Returns: + List[int]: the list of rank number. + """ + return self._tp_rank_list + + def dp_rank_list(self) -> List[int]: + """dp_rank_list + + the rank list in the DP process group containing the current rank. + + Returns: + List[int]: the list of rank number. + """ + return self._dp_rank_list + + def tp_local_rank(self) -> int: + """tp_local_rank + + The local rank number in the current TP process group. + + Returns: + int: tp rank number. + """ + return self._rank % self._tp_degree + + def dp_local_rank(self) -> int: + """dp_local_rank + + The local rank number in the current DP process group. + + Returns: + int: dp rank number. + """ + return self._rank // self._tp_degree + + def dp_world_size(self) -> int: + """dp_world_size + + The world size of the current DP process group. + + Returns: + int: dp world size + """ + return len(self._dp_rank_list) + + def tp_world_size(self) -> int: + """tp_world_size + + The world size of the current TP process group. + + Returns: + int: tp world size + """ + return len(self._tp_rank_list) + + def dp_process_group(self): + """dp_process_group + + the pytorch DP process group containing the current rank. + + Returns: + `torch._C._distributed_c10d.ProcessGroup`: the pytorch DP process group. + """ + return PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl') + + def tp_process_group(self): + """tp_process_group + + the pytorch TP process group containing the current rank. + + Returns: + `torch._C._distributed_c10d.ProcessGroup`: the pytorch TP process group. + """ + return PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl') + + def cpu_dp_process_group(self): + """cpu_dp_process_group + + the pytorch CPU DP process group containing the current rank. + + assert failed if cpu process group is not initialized. + + Returns: + `torch._C._distributed_c10d.ProcessGroup`: the pytorch DP process group. + """ + assert self._has_cpu_groups + return PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo') + + def cpu_tp_process_group(self): + """cpu_tp_process_group + + the pytorch CPU TP process group containing the current rank. + + assert failed if cpu process group is not initialized. + + Returns: + `torch._C._distributed_c10d.ProcessGroup`: the pytorch TP process group. + """ + assert self._has_cpu_groups + return PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo') + + def get_ranks_in_dp(self) -> List[int]: + """get_ranks_in_dp + + ranks in current dp process group. + + Returns: + List[int]: a list of rank number. + """ + return self._dp_rank_list + + def get_ranks_in_tp(self): + """get_ranks_in_tp + + ranks in current tp process group. + + Returns: + List[int]: a list of rank number. + """ + return self._tp_rank_list diff --git a/colossalai/tensor/shape_consistency.py b/colossalai/tensor/shape_consistency.py new file mode 100644 index 0000000000000000000000000000000000000000..0a840006f086353c1ae8cc6db6a695e230e8e59c --- /dev/null +++ b/colossalai/tensor/shape_consistency.py @@ -0,0 +1,744 @@ +import math +from copy import deepcopy +from dataclasses import dataclass +from typing import Dict, List, Tuple + +import numpy as np +import torch + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem +from colossalai.context.singleton_meta import SingletonMeta +from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException +from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, mix_gather_simulator, shard_simulator + +from .comm_spec import * + +__all__ = ['ShapeConsistencyManager', 'ShapeConsistencyOptions', 'set_shape_consistency_options'] + + +@dataclass +class ShapeConsistencyOptions: + """ + ShapeConsistencyOptions is a dataclass which specifies the preferences for shape consistency. + """ + # TODO: shape consistency option is not implemented yet + pass + + +def to_global(distributed_tensor: torch.Tensor, sharding_spec: ShardingSpec) -> torch.Tensor: + shape_consistency_manager = ShapeConsistencyManager() + global_sharding_spec = ShardingSpec(sharding_spec.device_mesh, sharding_spec.entire_shape, {}) + with torch.no_grad(): + global_tensor = shape_consistency_manager.apply_for_autoparallel_runtime(distributed_tensor, sharding_spec, + global_sharding_spec) + return global_tensor + + +def set_shape_consistency_options(options: ShapeConsistencyOptions): + """ + Configure the shape consistency manager via function call. + """ + manager = ShapeConsistencyManager() + manager.options = options + + +class ShapeConsistencyManager(metaclass=SingletonMeta): + + def __init__(self): + self._options = None + self._forward_only = False + self.total_communication_cost = 0 + self.total_transform_steps = 0 + self.cached_spec_pairs_transform_path = {} + + @property + def options(self): + return self._options + + @options.setter + def options(self, options_: ShapeConsistencyOptions): + assert isinstance(options_, ShapeConsistencyOptions) + self._options = options_ + + @property + def forward_only(self): + return self._forward_only + + @forward_only.setter + def forward_only(self, value): + assert isinstance(value, bool) + self._forward_only = value + + def get_all_all_gather_spec(self, source_spec: ShardingSpec, + orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]: + ''' + Get all valid sharding specs from source_spec with single all-gather operation, and + accumulate communication cost on origin cost which will finally be used in auto sharding solver. + For the all-gather operation, we just care about the S dimension. + + Argument: + source_spec(ShardingSpec): the ShardingSpec of the source_spec. + orig_cost(Dict[str, float]): the original communication cost before this operation. + + Return: + valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-gather operation. + + Example: + dim_partition_dict = {0: [0], 1: [1]} + # DistSpec: + # shard_sequence: S0,S1,R + # device_mesh_shape: (4, 4) + sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict) + shape_consistency_manager = ShapeConsistencyManager() + rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0}) + print(rst_dict) + + Output: + {DistSpec: + shard_sequence: R,S1,R + device_mesh_shape: (4, 4): 0, DistSpec: + shard_sequence: S0,R,R + device_mesh_shape: (4, 4): 0} + ''' + valid_spec_dict = {} + comm_pattern = CollectiveCommPattern.GATHER_FWD_SPLIT_BWD + for target_pair in source_spec.dim_partition_dict.items(): + shard_list = all_gather_simulator(target_pair) + index = target_pair[0] + new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict) + + # We won't add empty list into dim_partition_dict + # The key will be popped if the related shard_list is empty + if shard_list: + new_dim_partition_dict[index] = shard_list + else: + new_dim_partition_dict.pop(index) + + # generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec + gather_dim = index + logical_process_axis = target_pair[1][-1] + comm_spec = CommSpec( + comm_pattern, + sharding_spec=source_spec, + gather_dim=gather_dim, + # shard_dim will be used during backward + shard_dim=gather_dim, + logical_process_axis=logical_process_axis, + forward_only=self.forward_only) + + # compute the communication cost with CommSpec + cost_dict = comm_spec.get_comm_cost() + + # generate new sharding spec + try: + new_sharding_spec = ShardingSpec(source_spec.device_mesh, + source_spec.entire_shape, + dim_partition_dict=new_dim_partition_dict) + for phase, cost in cost_dict.items(): + cost_dict[phase] = cost + orig_cost_dict[phase] + valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict) + except ShardingSpecException: + pass + return valid_spec_dict + + def get_all_all_to_all_spec(self, source_spec: ShardingSpec, + orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]: + ''' + Get all valid sharding specs from source_spec with single all-to-all operation, and + accumulate communication cost on origin cost which will finally be used in auto sharding solver. + For the all-to-all operation, we just care about the pairs containing S dimension. + + Argument: + source_spec(ShardingSpec): the ShardingSpec of the source_spec. + orig_cost(Dict[str, float]): the original communication cost before this operation. + + Return: + valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation. + + Example: + dim_partition_dict = {0: [0], 1: [1]} + # DistSpec: + # shard_sequence: S0,S1,R + # device_mesh_shape: (4, 4) + sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict) + shape_consistency_manager = ShapeConsistencyManager() + rst_dict = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0}) + print(rst_dict) + + Output: + {DistSpec: + shard_sequence: S01,R,R + device_mesh_shape: (4, 4): 0, DistSpec: + shard_sequence: R,S1,S0 + device_mesh_shape: (4, 4): 0, DistSpec: + shard_sequence: S0,R,S1 + device_mesh_shape: (4, 4): 0} + ''' + valid_spec_dict = {} + comm_pattern = CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD + tensor_dims = len(source_spec.entire_shape) + for f_index in range(tensor_dims - 1): + for b_index in range(f_index + 1, tensor_dims): + # skip (R, R) cases + if f_index not in source_spec.dim_partition_dict and b_index not in source_spec.dim_partition_dict: + continue + else: + if f_index in source_spec.dim_partition_dict: + # skip (S01, R) -> (R, S01) is NOT allowed + if len(source_spec.dim_partition_dict[f_index]) >= 2: + continue + f_target_pair = (f_index, deepcopy(source_spec.dim_partition_dict[f_index])) + else: + f_target_pair = (f_index, []) + if b_index in source_spec.dim_partition_dict: + # skip (R, S01) -> (S01, R) is NOT allowed + if len(source_spec.dim_partition_dict[b_index]) >= 2: + continue + b_target_pair = (b_index, deepcopy(source_spec.dim_partition_dict[b_index])) + else: + b_target_pair = (b_index, []) + + # skip (S1, S0) -> S10 + if f_target_pair[1] and b_target_pair[1] and f_target_pair[1][0] >= b_target_pair[1][0]: + continue + f_shard_list, b_shard_list = all_to_all_simulator(f_target_pair, b_target_pair) + f_index = f_target_pair[0] + b_index = b_target_pair[0] + + # generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec + if len(f_shard_list) < len(f_target_pair[1]): + gather_dim = f_index + shard_dim = b_index + logical_process_axis = f_target_pair[1][-1] + else: + gather_dim = b_index + shard_dim = f_index + logical_process_axis = b_target_pair[1][-1] + comm_spec = CommSpec(comm_pattern, + sharding_spec=source_spec, + gather_dim=gather_dim, + shard_dim=shard_dim, + logical_process_axis=logical_process_axis, + forward_only=self.forward_only) + + # compute the communication cost with CommSpec + cost_dict = comm_spec.get_comm_cost() + new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict) + + # We won't add empty list into dim_partition_dict + # The key will be popped if the related shard_list is empty + if f_shard_list: + new_dim_partition_dict[f_index] = f_shard_list + else: + new_dim_partition_dict.pop(f_index) + if b_shard_list: + new_dim_partition_dict[b_index] = b_shard_list + else: + new_dim_partition_dict.pop(b_index) + + # generate new sharding spec + try: + new_sharding_spec = ShardingSpec(source_spec.device_mesh, + source_spec.entire_shape, + dim_partition_dict=new_dim_partition_dict) + for phase, cost in cost_dict.items(): + cost_dict[phase] = cost + orig_cost_dict[phase] + valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict) + except ShardingSpecException: + pass + + return valid_spec_dict + + def get_all_shard_spec(self, source_spec: ShardingSpec, orig_cost_dict): + ''' + Get all valid sharding specs from source_spec with single shard operation, and + accumulate commucation cost on origin cost which will finally be used in auto sharding solver. + For the sharding operation, we just care about legal sharding dimensions. + + Argument: + source_spec(ShardingSpec): the ShardingSpec of the source_spec. + orig_cost(float): the original communication cost before this operation. + + Return: + valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation. + + Example: + dim_partition_dict = {0: [0]} + # DistSpec: + # shard_sequence: S0,R,R + # device_mesh_shape: (4, 4) + sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict) + shape_consistency_manager = ShapeConsistencyManager() + rst_dict = shape_consistency_manager.get_all_shard_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0}) + print(rst_dict) + + Output: + {DistSpec: + shard_sequence: S01,R,R + device_mesh_shape: (4, 4): 0, DistSpec: + shard_sequence: S0,S1,R + device_mesh_shape: (4, 4): 0, DistSpec: + shard_sequence: S0,R,S1 + device_mesh_shape: (4, 4): 0} + ''' + valid_spec_dict = {} + comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD + + # legal sharding dims means the mesh_id is still available to use. + legal_sharding_dims = [i for i in range(len(source_spec.device_mesh.mesh_shape))] + for dim, shard_list in source_spec.dim_partition_dict.items(): + for element in shard_list: + legal_sharding_dims.remove(element) + if len(legal_sharding_dims) == 0: + return valid_spec_dict + + tensor_dims = len(source_spec.entire_shape) + + for index in range(tensor_dims): + if index not in source_spec.dim_partition_dict: + shard_list_list = shard_simulator((index, []), legal_sharding_dims) + else: + shard_list_list = shard_simulator((index, source_spec.dim_partition_dict[index]), legal_sharding_dims) + if not shard_list_list: + continue + for shard_list in shard_list_list: + new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict) + new_dim_partition_dict[index] = shard_list + + # generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec + shard_dim = index + logical_process_axis = shard_list[-1] + comm_spec = CommSpec(comm_pattern, + sharding_spec=source_spec, + gather_dim=shard_dim, + shard_dim=shard_dim, + logical_process_axis=logical_process_axis, + forward_only=self.forward_only) + + # compute the communication cost with CommSpec + cost_dict = comm_spec.get_comm_cost() + + # generate new sharding spec + try: + new_sharding_spec = ShardingSpec(source_spec.device_mesh, + source_spec.entire_shape, + dim_partition_dict=new_dim_partition_dict) + for phase, cost in cost_dict.items(): + cost_dict[phase] = cost + orig_cost_dict[phase] + valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict) + except ShardingSpecException: + pass + return valid_spec_dict + + def get_all_mix_gather_spec(self, source_spec: ShardingSpec, + orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]: + ''' + S0S1 -> RR + S1S0 -> RR + S01R -> RR + RS01 -> RR + ''' + valid_spec_dict = {} + comm_pathern = CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD + tensor_dims = len(source_spec.entire_shape) + for f_index in range(tensor_dims - 1): + for b_index in range(f_index + 1, tensor_dims): + if (f_index not in source_spec.dim_partition_dict) and (b_index not in source_spec.dim_partition_dict): + continue + else: + if f_index in source_spec.dim_partition_dict: + # skip (S10, R) -> (R, R) + if len(f_target_pair[1]) == 2 and f_target_pair[1][0] >= f_target_pair[1][1]: + continue + f_target_pair = (f_index, deepcopy(source_spec.dim_partition_dict[f_index])) + else: + f_target_pair = (f_index, []) + if b_index in source_spec.dim_partition_dict: + # skip (R, S10) -> (R, R) + if len(b_target_pair[1]) == 2 and b_target_pair[1][0] >= b_target_pair[1][1]: + continue + b_target_pair = (b_index, deepcopy(source_spec.dim_partition_dict[b_index])) + else: + b_target_pair = (b_index, []) + + gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair) + comm_spec = CommSpec(comm_pathern, + sharding_spec=source_spec, + gather_dim=gather_dim, + logical_process_axis=logical_process_axes, + forward_only=self.forward_only, + mix_gather=True) + cost_dict = comm_spec.get_comm_cost() + new_dim_partition_dict = {} + # generate new sharding spec + try: + new_sharding_spec = ShardingSpec(source_spec.device_mesh, + source_spec.entire_shape, + dim_partition_dict=new_dim_partition_dict) + for phase, cost in cost_dict.items(): + cost_dict[phase] = cost + orig_cost_dict[phase] + valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict) + except ShardingSpecException: + pass + + return valid_spec_dict + + def get_all_one_step_transform_spec(self, source_spec: ShardingSpec, orig_cost_dict) -> Dict[ShardingSpec, float]: + ''' + Get all valid sharding specs from source_spec with one step transform, and + accumulate commucation cost on origin cost which will finally be used in auto sharding solver. + Note: + all-gather will eliminate a sharding dimension, all-to-all will keep sharding dimension same as before, + and shard will add a sharding dimension. Therefore, the result of above operations are mutual exclusive, + we could safely put them together. + + Argument: + source_spec(ShardingSpec): the ShardingSpec of the source_spec. + orig_cost(float): the original communication cost before this operation. + + Return: + valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation. + ''' + valid_spec_dict = {} + valid_spec_dict.update(self.get_all_all_gather_spec(source_spec, orig_cost_dict)) + valid_spec_dict.update(self.get_all_all_to_all_spec(source_spec, orig_cost_dict)) + valid_spec_dict.update(self.get_all_shard_spec(source_spec, orig_cost_dict)) + return valid_spec_dict + + def mem_cost(self, comm_action_sequence: List[CommSpec]) -> TrainCycleItem: + """memory cost of the communication action sequence + + Args: + comm_action_sequence (List[CommSpec]): list of communication actions + + Returns: + TrainCycleItem: memory (numel) cost of such comm_action_sequence + """ + + def compute_shape(sharding_spec: ShardingSpec): + shape = sharding_spec.entire_shape + new_shape = [] + for dim, shard in sharding_spec.dim_partition_dict.items(): + new_shape.append(shape[dim] // len(shard)) + return new_shape + + def gather_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int): + """analyze all_gather memory footprint + all_gather will allocate memory for the output tensor, and there will be temp memory for + all_gather operation, which is twice the size of output tensor + + Args: + comm_spec (CommSpec): input CommSpec + discard_input (bool): whether to discard the input tensor + alloc_numel (int): current allocated numel + peak_numel (int): current peak numel + """ + input_shape = compute_shape(comm_spec.sharding_spec) + input_numel = np.prod(input_shape) + output_numel = input_numel * comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis] + peak_numel = max(peak_numel, alloc_numel + output_numel * 2) + alloc_numel += output_numel + if discard_input: + alloc_numel -= input_numel + + return alloc_numel, peak_numel + + def split_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int): + """analyze split memory footprint + split will allocate memory for the output tensor if we don't apply shard on the first dimension of + the input tensor. If we apply shard on the first dimension, the `torch.tensor.contiguous()` will not + generate new tensor in this case, so no memory will be allocated. + + Args: + comm_spec (CommSpec): input CommSpec + discard_input (bool): whether to discard the input tensor + alloc_numel (int): current allocated numel + peak_numel (int): current peak numel + """ + shard_dim = comm_spec.shard_dim + if shard_dim != 0: + # if we don't shard the tensor on the first dimension, the split action will + # generate a new tensor + input_shape = compute_shape(comm_spec.sharding_spec) + input_numel = np.prod(input_shape) + output_numel = input_numel // comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis] + alloc_numel += output_numel + peak_numel = max(peak_numel, alloc_numel) + if discard_input: + alloc_numel -= input_numel + else: + # if we shard the tensor on the first dimension, the split action will not generate + # a new tensor, and as it will preserve a reference to the input tensor, we could + # override the discard_input option here + # NOTE: this special case might fail in some weird cases, e.g. if we have three split + # actions in the comm actions sequence, the first split action operate on the second dimension, + # the second split action operate on the first dimension, and the third split action operate, again, + # on the second dimension. Therefore, after the first two actions in the sequence, we will allocate + # memory the same size as the output of first split action. However, the third split action will discard + # the input tensor, and it actually should discard the tensor generated by the first split action, so in + # the current memory estimation framework, we will overestimate the memory usage. But the above case is + # kind of weird, and I think we could ignore it for now. + pass + + return alloc_numel, peak_numel + + def reduce_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int): + """ + a dummy function for reduce memory footprint analysis, as the reduce action doesn't allocate extra memory + """ + return alloc_numel, peak_numel + + def all2all_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int): + """analyze all_to_all memory footprint + all_to_all will allocate memory for the output tensor, and temp memory of all_to_all action + is twice the size of output tensor if we shard input tensor on the first dimension, otherwise + the temp memory is three times the size of output tensor + + Args: + comm_spec (CommSpec): input CommSpec + discard_input (bool): whether to discard the input tensor + alloc_numel (int): current allocated numel + peak_numel (int): current peak numel + """ + input_shape = compute_shape(comm_spec.sharding_spec) + input_numel = np.prod(input_shape) + output_numel = input_numel + shard_dim = comm_spec.shard_dim + if shard_dim != 0: + peak_numel = max(peak_numel, alloc_numel + output_numel * 3) + else: + peak_numel = max(peak_numel, alloc_numel + output_numel * 2) + alloc_numel += output_numel + if discard_input: + alloc_numel -= input_numel + + return alloc_numel, peak_numel + + def identity_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int): + """ + a dummy function for identity memory footprint analysis, as the identity action doesn't allocate extra memory + """ + return alloc_numel, peak_numel + + pattern_to_func_dict = { + CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: [gather_analysis, split_analysis], + CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: [all2all_analysis, all2all_analysis], + CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: [split_analysis, gather_analysis], + CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: [reduce_analysis, identity_analysis], + CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: [identity_analysis, reduce_analysis], + CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD: [], + } + + fwd_actions = [] + bwd_actions = [] + + # construct forward and backward comm actions sequence + for comm_spec in comm_action_sequence: + comm_spec: CommSpec + fwd_action, bwd_action = pattern_to_func_dict[comm_spec.comm_pattern] + fwd_actions.append(fwd_action) + bwd_actions.append(bwd_action) + + # analyze memory footprint of forward comm actions sequence + fwd_alloc_numel = 0 + fwd_peak_numel = 0 + for idx, action_spec_pair in enumerate(zip(fwd_actions, comm_action_sequence)): + # the first forward comm action will not discard input + fwd_action, comm_spec = action_spec_pair + fwd_alloc_numel, fwd_peak_numel = fwd_action(comm_spec, False, fwd_alloc_numel, + fwd_peak_numel) if idx == 0 else fwd_action( + comm_spec, True, fwd_alloc_numel, fwd_peak_numel) + + # analyze memory footprint for backward comm actions sequence + bwd_alloc_numel = 0 + bwd_peak_numel = 0 + for idx, action_spec_pair in enumerate(zip(reversed(bwd_actions), reversed(comm_action_sequence))): + bwd_action, comm_spec = action_spec_pair + bwd_alloc_numel, bwd_peak_numel = bwd_action(comm_spec, False, bwd_alloc_numel, + bwd_peak_numel) if idx == 0 else bwd_action( + comm_spec, True, bwd_alloc_numel, bwd_peak_numel) + + fwd_mem = MemoryCost(activation=fwd_alloc_numel, temp=fwd_peak_numel - fwd_alloc_numel) + bwd_mem = MemoryCost(activation=bwd_alloc_numel, temp=bwd_peak_numel - bwd_alloc_numel) + total_mem = MemoryCost(activation=fwd_alloc_numel + bwd_alloc_numel) + + return TrainCycleItem(fwd_mem, bwd_mem, total_mem) + + def shape_consistency(self, source_spec: ShardingSpec, + target_spec: ShardingSpec) -> Tuple[List[ShardingSpec], List[CommSpec], float]: + ''' + This method will find a path to transform source_spec to target_spec with + a greedy algorithm. + The basic idea is: + Step1: + Generate all one-step transform sequences from source_spec. + Step2: + Pick the 'best' sharding spec following the heuristic function. + Step3: + Repeat above steps until the source spec transform to target spec. + + During finding the transform path, commucation cost will be accumulated, and it + will be finally used in auto parallel solver. + + Additionally, to avoid repeating the path search in runtime, we cached all solved path + in auto parallel strategy building time, which could handle most of cases in runtime. + + Argument: + source_spec(ShardingSpec): ShardingSpec of the source activation. + target_spec(ShardingSpec): ShardingSpec of the target activation. + + Return: + transform_path(List[ShardingSpec]): The transform path from source_spec to target_spec, + it contains the source_spec and target_spec. + comm_action_sequence(List[CommSpec]): Keep the communication operations to complete the shape consistency in order. + total_cost(float): total cost to complete shape consistency transform. + + Example: + dim_partition_source = {1: [0, 1]} + dim_partition_target = {0: [0, 1]} + # DistSpec: + # shard_sequence: R,S01,R + # device_mesh_shape: (4, 4) + sharding_spec_source = ShardingSpec(device_mesh, entire_shape, dim_partition_source) + # DistSpec: + # shard_sequence: S01,R,R + # device_mesh_shape: (4, 4) + sharding_spec_target = ShardingSpec(device_mesh, entire_shape, dim_partition_target) + transform_path, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(sharding_spec_source, sharding_spec_target) + print(f'transform_path: {transform_path}') + print(f'comm_action_sequence: {comm_action_sequence}') + print(f'total_cost: {total_cost}') + + output: + transform_path: [DistSpec: + shard_sequence: R,S01,R + device_mesh_shape: (4, 4), DistSpec: + shard_sequence: R,S0,R + device_mesh_shape: (4, 4), DistSpec: + shard_sequence: S0,R,R + device_mesh_shape: (4, 4), DistSpec: + shard_sequence: S01,R,R + device_mesh_shape: (4, 4)] + comm_action_sequence: [CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1), + CommSpec:(comm_pattern:all2all, gather_dim:1, shard_dim:0, logical_process_axis: 0), + CommSpec:(comm_pattern:shard, shard_dim:0, logical_process_axis:1)] + total_cost: 12294.402000000002 + ''' + MAX_TRANSFORM_STEPS = 20 + total_cost_dict = {'forward': 0, 'backward': 0, 'total': 0} + total_steps = 0 + transform_path = [] + comm_action_sequence = [] + spec_pairs = (str(source_spec.sharding_sequence), str(target_spec.sharding_sequence)) + self.cached_spec_pairs_transform_path[spec_pairs] = (None, None) + + # We do nothing if the sharding spec is all the same. + if source_spec.sharding_sequence_difference(target_spec) == 0: + self.cached_spec_pairs_transform_path[spec_pairs] = (transform_path, comm_action_sequence) + return (transform_path, comm_action_sequence, total_cost_dict) + + temp_sharding_spec = source_spec + + transform_path.append(temp_sharding_spec) + # To avoid dead loop, the loop will break after MAX_TRANSFORM_STEPS transforms + while total_steps <= MAX_TRANSFORM_STEPS: + valid_transform_spec_dict = self.get_all_one_step_transform_spec(temp_sharding_spec, total_cost_dict) + best_difference_score = math.inf + + for sharding_spec, info_pairs in valid_transform_spec_dict.items(): + comm_spec, cost_dict = info_pairs + spec_difference = sharding_spec.sharding_sequence_difference(target_spec) + + if spec_difference == 0: + for phase, cost in total_cost_dict.items(): + total_cost_dict[phase] = cost + cost_dict[phase] + transform_path.append(sharding_spec) + comm_action_sequence.append(comm_spec) + self.cached_spec_pairs_transform_path[spec_pairs] = (transform_path, comm_action_sequence) + return (transform_path, comm_action_sequence, total_cost_dict) + + if spec_difference < best_difference_score: + temp_sharding_spec = sharding_spec + temp_cost_dict = cost_dict + temp_comm_spec = comm_spec + best_difference_score = spec_difference + + transform_path.append(temp_sharding_spec) + comm_action_sequence.append(temp_comm_spec) + for phase, cost in total_cost_dict.items(): + total_cost_dict[phase] = cost + temp_cost_dict[phase] + total_steps += 1 + + raise RuntimeError(f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps.") + + def apply(self, tensor_with_sharding_spec: torch.Tensor, target_spec: ShardingSpec) -> torch.Tensor: + ''' + Apply target_spec to tensor with source sharding spec, the transform path is generated by the + shape_consistency method. + + Argument: + tensor_with_sharding_spec (torch.Tensor): a tensor with source sharding spec to be transformed to the target spec. + target_spec (ShardingSpec): The tensor transform processes will be directed by the target_spec. + + Example: + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1, + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + entire_shape = torch.Size((4, 2)) + shape_consistency_manager = ShapeConsistencyManager() + dim_partition_source = {0: [0]} + dim_partition_target = {1: [0]} + + # DistSpec: + # shard_sequence: S0,R + # device_mesh_shape: (2, 2) + sharding_spec_source = ShardingSpec(device_mesh, entire_shape, dim_partition_source) + + # DistSpec: + # shard_sequence: R,S0 + # device_mesh_shape: (2, 2) + sharding_spec_target = ShardingSpec(device_mesh, entire_shape, dim_partition_target) + + if rank in (0, 1): + sharded_tensor_0 = torch.zeros(2, 1) + sharded_tensor_1 = torch.ones(2, 1) + # tensor([[0., 1.], + # [0., 1.]]) + tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda() + if rank in (2, 3): + sharded_tensor_0 = torch.ones(2, 1) * 2 + sharded_tensor_1 = torch.ones(2, 1) * 3 + # tensor([[2., 3.], + # [2., 3.]]) + tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda() + + tensor_to_comm.sharding_spec = sharding_spec_source + shape_consistency_manager.apply(tensor_to_comm, sharding_spec_target) + print(tensor_to_comm) + + Output in rank0 and rank2: + tensor([[0.], + [0.], + [2.], + [2.]]) + + Output in rank1 and rank3: + tensor([[1.], + [1.], + [3.], + [3.]]) + ''' + _, comm_action_sequence, _ = self.shape_consistency(tensor_with_sharding_spec.sharding_spec, target_spec) + for comm_spec in comm_action_sequence: + tensor_with_sharding_spec = comm_spec.covert_spec_to_action(tensor_with_sharding_spec) + tensor_with_sharding_spec.sharding_spec = target_spec + return tensor_with_sharding_spec + + def apply_for_autoparallel_runtime(self, tensor, source_spec, target_spec): + _, comm_action_sequence, _ = self.shape_consistency(source_spec, target_spec) + for comm_spec in comm_action_sequence: + tensor = comm_spec.covert_spec_to_action(tensor) + tensor.sharding_spec = target_spec + return tensor diff --git a/colossalai/tensor/sharding_spec.py b/colossalai/tensor/sharding_spec.py new file mode 100644 index 0000000000000000000000000000000000000000..bed320130ccdc09cbd2e82f2c1d1ba6d01f4e295 --- /dev/null +++ b/colossalai/tensor/sharding_spec.py @@ -0,0 +1,296 @@ +import operator +from copy import deepcopy +from functools import reduce + +import torch + +from colossalai.device.device_mesh import DeviceMesh + +from .utils import merge_same_dim_mesh_list + +__all__ = ['_DimSpec', 'ShardingException', 'ShardingSpec'] + +ALLGATHER_COST = 20 +SHARD_COST = 5 +STEP_PENALTY = 6 +NAN = 'nan' + + +class _DimSpec: + ''' + Sharding spec for single dimension of the sharded tensor describe the sharding dimension of + logical device mesh and give a method to compute the difference between them. + This class is used internally in ShardingSpec. + + Argument: + shard_list(List[int]): if shard_list is None, the dim spec will be 'R' type. + Otherwise, the element in shard_list means the data will be sharded in that dimension. + ''' + + def __init__(self, shard_list): + self.is_replica = len(shard_list) == 0 + self.shard_list = shard_list + self.build_difference_2d_dict() + + def __eq__(self, other): + return str(self) == str(other) + + def __repr__(self): + if self.is_replica: + return 'R' + target = 'S' + for dim in self.shard_list: + target += str(dim) + return target + + def _convert_str_to_shard_list(self, str_spec): + ''' + Conver str_spec into shard_list. + + Argument: + str_spec(str): dim spec in str type. + ''' + + if str_spec == 'R': + return [] + if str_spec == 'S0': + return [0] + if str_spec == 'S1': + return [1] + if str_spec == 'S01': + return [0, 1] + + def build_difference_2d_dict(self): + ''' + Build a difference maping for 2D device mesh case. It will be used to + compute the difference between DimSpec pairs. + ''' + + source_spec_list = ['R', 'S0', 'S1', 'S01'] + target_spec_list = ['R', 'S0', 'S1', 'S01'] + difference_dict = {} + for source_spec in source_spec_list: + for target_spec in target_spec_list: + legal_sharding_dims = [] + spec_pair = (deepcopy(source_spec), deepcopy(target_spec)) + source_shard_list = self._convert_str_to_shard_list(source_spec) + target_shard_list = self._convert_str_to_shard_list(target_spec) + + # source same as target + if source_shard_list == target_shard_list: + difference = 0 + + # all_gather(source) -> target + elif len(source_shard_list + ) == len(target_shard_list) + 1 and source_shard_list[:-1] == target_shard_list: + difference = ALLGATHER_COST + + # shard(source) -> target + elif len(source_shard_list) == len( + target_shard_list) - 1 and source_shard_list == target_shard_list[:-1] and target_shard_list[ + -1] not in source_shard_list: + difference = SHARD_COST + + # S1 -> S0 or S0 -> S1 + elif len(source_shard_list) == len(target_shard_list): + # source -> R -> target + difference = ALLGATHER_COST + STEP_PENALTY + SHARD_COST + + # R -> S01 + elif len(source_shard_list) == len(target_shard_list) - 2: + difference = SHARD_COST + STEP_PENALTY + SHARD_COST + + # S01 -> R + elif len(source_shard_list) == len(target_shard_list) + 2: + difference = ALLGATHER_COST + STEP_PENALTY + ALLGATHER_COST + + # S1 -> S01 + elif len(source_shard_list) == len(target_shard_list) - 1: + difference = ALLGATHER_COST + STEP_PENALTY + SHARD_COST + STEP_PENALTY + SHARD_COST + + # S01 -> S1 + elif len(source_shard_list) == len(target_shard_list) + 1: + difference = ALLGATHER_COST + STEP_PENALTY + ALLGATHER_COST + STEP_PENALTY + SHARD_COST + + else: + difference = NAN + difference_dict[spec_pair] = difference + + self.difference_dict = difference_dict + + def difference(self, other): + ''' + The difference between two _DimSpec. + + Argument: + other(_DimSpec): the dim spec to compare with. + + Return: + difference(int): the difference between two _DimSpec. + + Example: + dim_spec = _DimSpec([0]) + other_dim_spec = _DimSpec([0, 1]) + print(dim_spec.difference(other_dim_spec)) + + Output: + 5 + ''' + difference = self.difference_dict[(str(self), str(other))] + return difference + + +class ShardingSpecException(Exception): + pass + + +class ShardingOutOfIndexError(ShardingSpecException): + pass + + +class DuplicatedShardingDimensionError(ShardingSpecException): + pass + + +class ShardingNotDivisibleError(ShardingSpecException): + pass + + +class ShardingSpec: + ''' + Sharding spec for a tensor, it contains info of the logical device mesh this tensor belong + to, the entire shape of the tensor before sharded, and the sharding sequence looks like + [R, R, S0, S1]. + + Argument: + device_mesh(DeviceMesh): A logical view of a physical mesh. + entire_shape(torch.Size): The entire shape of tensor before sharded. + dim_partition_dict(Dict[int, List[int]], optional): The key is the dimension of tensor to be sharded, + and the value of the key decribe which logical axis will be sharded in that dimension. + sharding_sequence(List[_DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1]. + ''' + + def __init__(self, + device_mesh: DeviceMesh, + entire_shape: torch.Size, + dim_partition_dict=None, + sharding_sequence=None): + self.device_mesh = device_mesh + + if isinstance(entire_shape, (list, tuple)): + entire_shape = torch.Size(entire_shape) + self.entire_shape = entire_shape + self.dim_partition_dict = dim_partition_dict + self.sharding_sequence = sharding_sequence + if self.sharding_sequence is None: + assert self.dim_partition_dict is not None, f'dim_partition_dict should not be None, if sharding_sequence is NoneType object.' + self.dim_partition_dict = merge_same_dim_mesh_list(dim_size=len(entire_shape), + dim_partition_dict=self.dim_partition_dict) + self.convert_dict_to_shard_sequence() + elif self.dim_partition_dict is None: + assert self.sharding_sequence is not None, f'sharding_sequence should not be None, if dim_partition_dict is NoneType object.' + self.convert_shard_sequence_to_dict() + self._sanity_check() + + def __repr__(self): + res_list = ["DistSpec:"] + res_list.append(f"\n\tshard_sequence: " + ",".join(str(dimspec) for dimspec in self.sharding_sequence)) + res_list.append(f"\n\tdevice_mesh_shape: {self.device_mesh.mesh_shape}") + return ' '.join(res_list) + + def _sanity_check(self): + # make sure all axes in logical device mesh only be used once + dim_check_list = list(range(self.device_mesh.logical_mesh_id.dim())) + for dim, shard_list in self.dim_partition_dict.items(): + for element in shard_list: + if element in dim_check_list: + dim_check_list.remove(element) + else: + raise DuplicatedShardingDimensionError( + f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.") + + # make sure that the dimension is not out of index + for dim in self.dim_partition_dict.keys(): + if dim >= len(self.entire_shape): + raise ShardingOutOfIndexError( + f"The dim_partition_dict specifies to shard dimension {dim} but the entire_shape only has {len(self.entire_shape)} dimensions" + ) + + # make sure that the sharding for a dimension is divisible by the number of devices + for dim, shard_list in self.dim_partition_dict.items(): + tensor_dim_size = self.entire_shape[dim] + num_devices = 1 + + for element in shard_list: + num_devices *= self.device_mesh.mesh_shape[element] + + if tensor_dim_size % num_devices != 0: + raise ShardingNotDivisibleError( + f'The size of dimension at index {dim} is {tensor_dim_size}, it cannot be sharded over {num_devices} devices.' + ) + + def convert_dict_to_shard_sequence(self): + ''' + Convert dim_partition_dict into list of _DimSpec, and assign it to sharding_sequence. + ''' + sharding_sequence = [_DimSpec([])] * len(self.entire_shape) + for dim, shard_list in self.dim_partition_dict.items(): + sharding_sequence[dim] = _DimSpec(shard_list) + self.sharding_sequence = sharding_sequence + + def convert_shard_sequence_to_dict(self): + ''' + Convert sharding_sequence into dim_partition_dict. + ''' + new_dim_partition_dict = {} + for index, dim_spec in enumerate(self.sharding_sequence): + if not dim_spec.is_replica: + if index not in new_dim_partition_dict: + new_dim_partition_dict[index] = [] + new_dim_partition_dict[index].extend(dim_spec.shard_list) + self.dim_partition_dict = new_dim_partition_dict + + def sharding_sequence_difference(self, other): + ''' + This function is a naive version of difference computation. It just simply accumulates difference every dimension between the + pair of sharding sequence. + + Example: + dim_partition_dict = {0: [0, 1]} + # DistSpec: + # shard_sequence: S01,R,R + # device_mesh_shape: (4, 4) + sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict) + dim_partition_dict_to_compare = {0: [0], 1: [1]} + # DistSpec: + # shard_sequence: S0,S1,R + # device_mesh_shape: (4, 4) + sharding_spec_to_compare = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_to_compare) + print(sharding_spec.sharding_sequence_difference(sharding_spec_to_compare)) + + Output: + 25 + + Argument: + other(ShardingSpec): The ShardingSpec to compared with. + + Return: + difference(int): Difference between two ShardingSpec. + ''' + assert len(self.sharding_sequence) == len( + other.sharding_sequence), f'Cannot compare difference for two sharding specs with different length.' + difference = 0 + for orig_dim_spec, other_dim_spec in zip(self.sharding_sequence, other.sharding_sequence): + difference += orig_dim_spec.difference(other_dim_spec) + return difference + + def get_sharded_shape_per_device(self): + + sharded_shape = list(self.entire_shape) + for dim, shard_list in self.dim_partition_dict.items(): + mesh_list = [self.device_mesh.mesh_shape[mesh_dim] for mesh_dim in shard_list] + shard_partitions = reduce(operator.mul, mesh_list, 1) + assert sharded_shape[ + dim] % shard_partitions == 0, f'Cannot shard dimension {dim} into {shard_partitions} partitions.' + sharded_shape[dim] //= shard_partitions + return torch.Size(sharded_shape) diff --git a/colossalai/tensor/tensor_spec.py b/colossalai/tensor/tensor_spec.py new file mode 100644 index 0000000000000000000000000000000000000000..580df9f8f31023f0623cadc08fe67d49b309819f --- /dev/null +++ b/colossalai/tensor/tensor_spec.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass +from typing import Optional + +from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec +from colossalai.tensor.process_group import ProcessGroup + +from .compute_spec import ComputeSpec + + +@dataclass +class ColoTensorSpec: + """ ColoTensorSpec + + A data class for specifications of the `ColoTensor`. + It contains attributes of `ProcessGroup`, `_DistSpec`, `ComputeSpec`. + The latter two attributes are optional. If not set, they are default value is `Replicate()` and `None`. + """ + pg: ProcessGroup + dist_attr: Optional[_DistSpec] = _DistSpec(DistPlacementPattern.REPLICATE) + compute_attr: Optional[ComputeSpec] = None diff --git a/colossalai/tensor/utils.py b/colossalai/tensor/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6e30f97fef0388ac4d65d7e83aa458257026f3ee --- /dev/null +++ b/colossalai/tensor/utils.py @@ -0,0 +1,226 @@ +from typing import Dict, Iterator, List, Tuple, Union + +import torch +import torch.nn as nn + +from colossalai.tensor.colo_tensor import ColoTensor + + +def all_gather_simulator(target_pair): + ''' + Simulating all-gather operation, analyze the communication cost + and simulate the influence of the DimSpec. + + We don't allow uncontiguous layout, such as all-gather(S012)->S02 is NOT allowed. + Therefore, all gather operation just remove the last element in shard list, + e.g.: + all-gather(S01) -> S0 + + Argument: + target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded, + and the second element describes which logical axis will be sharded in that dimension. + ''' + _, shard_list = target_pair + new_shard_list = shard_list[:-1] + + return new_shard_list + + +def all_to_all_simulator(f_target_pair, b_target_pair): + ''' + Simulating all-to-all operation, analyze the communication cost + and simulate the influence of the DimSpec. + + We BANNED all representations which shard_list in decreasing order, + such as S10, so all-to-all(S0, S1) -> RS01 is NOT allowed. + Therefore, if the behind shard_list is not None, we just extend it to the front shard_list. + Argument: + target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded, + and the second element describes which logical axis will be sharded in that dimension. + e.g.: + all-to-all(S0, S1) -> [S01, R] + all-to-all(S0, R) -> [R, S0] + Otherwise, we extend the front shard_list to behind. + e.g.: + all-to-all(R, S1) -> [S1, R] + + Argument: + target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded, + and the second element describes which logical axis will be sharded in that dimension. + ''' + _, f_shard_list = f_target_pair + _, b_shard_list = b_target_pair + if not len(b_shard_list): + b_shard_list.extend(f_shard_list) + f_shard_list = [] + else: + f_shard_list.extend(b_shard_list) + b_shard_list = [] + + return f_shard_list, b_shard_list + + +def shard_simulator(target_pair, legal_sharding_dims): + ''' + Simulating shard operation, analyze the communication cost(always ZERO) + and simulate the influence of the DimSpec. + + We don't allow uncontiguous layout, such as shard(S0)->S02 is NOT allowed. + In addition, We BANNED all representations which shard_list in decreasing order, + such as S10, so shard(S0) -> S10 is NOT allowed. + Therefore, for the R dimension, we could just append any legal sharding dim on it. + e.g.: + shard(R) -> S0 + For the S dimension, we need to make sure the shard_list after sharding still keep rising order. + e.g: + shard(S0) -> S01 + + Argument: + target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded, + and the second element decribes which logical axis will be sharded in that dimension. + ''' + _, shard_list = target_pair + shard_list_list = [] + for dim in legal_sharding_dims: + if len(shard_list) != 0 and dim <= shard_list[-1]: + continue + new_shard_list = shard_list + [dim] + shard_list_list.append(new_shard_list) + + return shard_list_list + + +def mix_gather_simulator(f_target_pair, b_target_pair): + ''' + Assume index of f and b target pairs are 'f' and 'b' + S0S1 => Input: (f, [0]), (b, [1]) Output: [b, f], (1, 0) + S1S0 => Input: (f, [1]), (b, [0]) Output: [b, f], (0, 1) + S01R => Input: (f, [0, 1]), (b, []) Output: [f], (1, 1) + RS01 => Input: (f, []), (b, [0, 1]) Output: [b], (1, 1) + S10R => Input: (f, [0, 1]), (b, []) Output: [f], (0, 0) + RS10 => Input: (f, []), (b, [0, 1]) Output: [b], (0, 0) + ''' + if f_target_pair[1] and b_target_pair[1]: + leading_dim = b_target_pair[1] > f_target_pair[1] + return [b_target_pair[0], f_target_pair[0]], [int(leading_dim), int(leading_dim ^ 1)] + if f_target_pair[1]: + leading_dim = f_target_pair[1][0] < f_target_pair[1][1] + return [ + f_target_pair[0], + ], [int(leading_dim), int(leading_dim)] + if b_target_pair[1]: + leading_dim = b_target_pair[1][0] < b_target_pair[1][1] + return [ + b_target_pair[0], + ], [int(leading_dim), int(leading_dim)] + + +# The function is credited to PyTorch Team +def named_params_with_colotensor( + module: nn.Module, + prefix: str = '', + recurse: bool = True, +) -> Iterator[Tuple[str, Union[nn.Parameter, ColoTensor]]]: + r"""Returns an iterator over module parameters (together with the + ColoTensor parameters), yielding both the name of the parameter + as well as the parameter itself. This is typically passed to a + :class:torchshard._shard.sharded_optim.ShardedOptimizer + + Args: + prefix (str): prefix to prepend to all parameter names. + recurse (bool): if True, then yields parameters of this module + and all submodules. Otherwise, yields only parameters that + are direct members of this module. + + Yields: + (string, Union[Tensor, ColoTensor]): Tuple containing + the name and parameter (or ColoTensor parameter) + + Example: + + >>> model = torch.nn.Linear(*linear_size) + >>> delattr(model.weight) + >>> setattr(model.weight, ColoTensor(...)) + >>> for name, param in named_params_with_colotensor(model): + >>> if name in ['weight']: + >>> print(param.size()) + + """ + modules = module.named_modules(prefix=prefix) if recurse else [(prefix, module)] + + memo = set() + for mod_prefix, mod in modules: + # find all sharded tensor params + for name, val in vars(mod).items(): + if isinstance(val, ColoTensor) and val not in memo: + memo.add(val) + name = mod_prefix + ('.' if mod_prefix else '') + name + yield name, val + + # find all nn.Parameters + for name, val in module.named_parameters(): + yield name, val + + +def _convert_tensor(tensor: torch.Tensor) -> ColoTensor: + return ColoTensor(tensor) + + +def convert_parameter(module: torch.nn.Module, param_name: str): + # Perform some validation first. + if not hasattr(module, param_name): + raise ValueError(f'module: {module} does not have parameter with name: {param_name}') + + tensor = getattr(module, param_name) + if not isinstance(tensor, torch.Tensor): + raise ValueError( + f'Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}') + + if not tensor.is_contiguous(): + raise ValueError(f'param: {param_name} is not a contiguous Tensor') + + st = _convert_tensor(tensor) + + # Replace param with ColoTensor. + + # Need to delete the attribute first since param_name might be + # torch.nn.Parameter and can't be replaced with ColoTensor which is + # not torch.nn.Parameter. + delattr(module, param_name) + + # Now we can set the attribute appropriately. + setattr(module, param_name, st) + + +def convert_dim_partition_dict(dim_size: int, dim_partition_dict: Dict[int, List[int]]) -> Dict[int, List[int]]: + ''' + This method is used to convert the negative dim value to positive. + ''' + dims_to_convert = [] + for dim, mesh_list in dim_partition_dict.items(): + if dim < 0: + dims_to_convert.append(dim) + for dim in dims_to_convert: + dim_partition_dict.pop(dim) + dim_partition_dict[dim_size + dim] = mesh_list + return dim_partition_dict + + +def merge_same_dim_mesh_list(dim_size: int, dim_partition_dict: Dict[int, List[int]]) -> Dict[int, List[int]]: + ''' + This method is used to merge the different key value which points to same physical position. + + For example: + dim_partition_dict: {1 :[0], -1: [1]} or {1: [0], 1: [1]} for a 2d tensor, the dim 1 and -1 point same physical position. + In this method, above dim_partition_dict will be converted to {1: [0, 1]} + ''' + converted_dim_partition_dict = {} + for dim, mesh_list in dim_partition_dict.items(): + if dim < 0: + dim = dim_size + dim + if dim not in converted_dim_partition_dict: + converted_dim_partition_dict[dim] = mesh_list + else: + converted_dim_partition_dict[dim].extend(mesh_list) + + return converted_dim_partition_dict diff --git a/colossalai/testing/__init__.py b/colossalai/testing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c53e0f44c7e0a4d9c925844a84c5e4e775fd6883 --- /dev/null +++ b/colossalai/testing/__init__.py @@ -0,0 +1,17 @@ +from .comparison import assert_close, assert_close_loose, assert_equal, assert_equal_in_group, assert_not_equal +from .pytest_wrapper import run_on_environment_flag +from .utils import ( + clear_cache_before_run, + free_port, + parameterize, + rerun_if_address_is_in_use, + rerun_on_exception, + skip_if_not_enough_gpus, + spawn, +) + +__all__ = [ + 'assert_equal', 'assert_not_equal', 'assert_close', 'assert_close_loose', 'assert_equal_in_group', 'parameterize', + 'rerun_on_exception', 'rerun_if_address_is_in_use', 'skip_if_not_enough_gpus', 'free_port', 'spawn', + 'clear_cache_before_run', 'run_on_environment_flag' +] diff --git a/colossalai/testing/comparison.py b/colossalai/testing/comparison.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d0da168c72c286f7515056bd8acbccb22695d --- /dev/null +++ b/colossalai/testing/comparison.py @@ -0,0 +1,30 @@ +import torch +import torch.distributed as dist +from torch import Tensor +from torch.distributed import ProcessGroup +from torch.testing import assert_close + + +def assert_equal(a: Tensor, b: Tensor): + assert torch.all(a == b), f'expected a and b to be equal but they are not, {a} vs {b}' + + +def assert_not_equal(a: Tensor, b: Tensor): + assert not torch.all(a == b), f'expected a and b to be not equal but they are, {a} vs {b}' + + +def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-3, atol: float = 1e-3): + assert_close(a, b, rtol=rtol, atol=atol) + + +def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None): + # all gather tensors from different ranks + world_size = dist.get_world_size(process_group) + tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] + dist.all_gather(tensor_list, tensor, group=process_group) + + # check if they are equal one by one + for i in range(world_size - 1): + a = tensor_list[i] + b = tensor_list[i + 1] + assert torch.all(a == b), f'expected tensors on rank {i} and {i + 1} to be equal but they are not, {a} vs {b}' diff --git a/colossalai/testing/pytest_wrapper.py b/colossalai/testing/pytest_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..a472eb3723ec1e8ce8705aa978355a460179a328 --- /dev/null +++ b/colossalai/testing/pytest_wrapper.py @@ -0,0 +1,40 @@ +""" +This file will not be automatically imported by `colossalai.testing` +as this file has a dependency on `pytest`. Therefore, you need to +explicitly import this file `from colossalai.testing.pytest_wrapper import `.from +""" + +import pytest +import os + + +def run_on_environment_flag(name: str): + """ + Conditionally run a test based on the environment variable. If this environment variable is set + to 1, this test will be executed. Otherwise, this test is skipped. The environment variable is default to 0. + + Args: + name (str): the name of the environment variable flag. + + Usage: + # in your pytest file + @run_on_environment_flag(name='SOME_FLAG') + def test_for_something(): + do_something() + + # in your terminal + # this will execute your test + SOME_FLAG=1 pytest test_for_something.py + + # this will skip your test + pytest test_for_something.py + + """ + assert isinstance(name, str) + flag = os.environ.get(name.upper(), '0') + + reason = f'Environment varialbe {name} is {flag}' + if flag == '1': + return pytest.mark.skipif(False, reason=reason) + else: + return pytest.mark.skipif(True, reason=reason) diff --git a/colossalai/testing/random.py b/colossalai/testing/random.py new file mode 100644 index 0000000000000000000000000000000000000000..ad6d24a4b94b152f20f404f5b1bb1cae2d74a0b5 --- /dev/null +++ b/colossalai/testing/random.py @@ -0,0 +1,19 @@ +import random + +import numpy as np +import torch + + +def seed_all(seed, cuda_deterministic=False): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + if cuda_deterministic: # slower, more reproducible + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + else: + torch.backends.cudnn.deterministic = False + torch.backends.cudnn.benchmark = True diff --git a/colossalai/testing/utils.py b/colossalai/testing/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6583eeb12bf43e8036fef3d583cc2a5e77540441 --- /dev/null +++ b/colossalai/testing/utils.py @@ -0,0 +1,278 @@ +import gc +import random +import re +import socket +from functools import partial +from inspect import signature +from typing import Any, Callable, List + +import torch +import torch.multiprocessing as mp +from packaging import version + + +def parameterize(argument: str, values: List[Any]) -> Callable: + """ + This function is to simulate the same behavior as pytest.mark.parameterize. As + we want to avoid the number of distributed network initialization, we need to have + this extra decorator on the function launched by torch.multiprocessing. + + If a function is wrapped with this wrapper, non-parametrized arguments must be keyword arguments, + positional arguments are not allowed. + + Usage:: + + # Example 1: + @parameterize('person', ['xavier', 'davis']) + def say_something(person, msg): + print(f'{person}: {msg}') + + say_something(msg='hello') + + # This will generate output: + # > xavier: hello + # > davis: hello + + # Example 2: + @parameterize('person', ['xavier', 'davis']) + @parameterize('msg', ['hello', 'bye', 'stop']) + def say_something(person, msg): + print(f'{person}: {msg}') + + say_something() + + # This will generate output: + # > xavier: hello + # > xavier: bye + # > xavier: stop + # > davis: hello + # > davis: bye + # > davis: stop + + Args: + argument (str): the name of the argument to parameterize + values (List[Any]): a list of values to iterate for this argument + """ + + def _wrapper(func): + + def _execute_function_by_param(**kwargs): + for val in values: + arg_map = {argument: val} + partial_func = partial(func, **arg_map) + partial_func(**kwargs) + + return _execute_function_by_param + + return _wrapper + + +def rerun_on_exception(exception_type: Exception = Exception, pattern: str = None, max_try: int = 5) -> Callable: + """ + A decorator on a function to re-run when an exception occurs. + + Usage:: + + # rerun for all kinds of exception + @rerun_on_exception() + def test_method(): + print('hey') + raise RuntimeError('Address already in use') + + # rerun for RuntimeError only + @rerun_on_exception(exception_type=RuntimeError) + def test_method(): + print('hey') + raise RuntimeError('Address already in use') + + # rerun for maximum 10 times if Runtime error occurs + @rerun_on_exception(exception_type=RuntimeError, max_try=10) + def test_method(): + print('hey') + raise RuntimeError('Address already in use') + + # rerun for infinite times if Runtime error occurs + @rerun_on_exception(exception_type=RuntimeError, max_try=None) + def test_method(): + print('hey') + raise RuntimeError('Address already in use') + + # rerun only the exception message is matched with pattern + # for infinite times if Runtime error occurs + @rerun_on_exception(exception_type=RuntimeError, pattern="^Address.*$") + def test_method(): + print('hey') + raise RuntimeError('Address already in use') + + Args: + exception_type (Exception, Optional): The type of exception to detect for rerun + pattern (str, Optional): The pattern to match the exception message. + If the pattern is not None and matches the exception message, + the exception will be detected for rerun + max_try (int, Optional): Maximum reruns for this function. The default value is 5. + If max_try is None, it will rerun forever if exception keeps occurring + """ + + def _match_lines(lines, pattern): + for line in lines: + if re.match(pattern, line): + return True + return False + + def _wrapper(func): + + def _run_until_success(*args, **kwargs): + try_count = 0 + assert max_try is None or isinstance(max_try, int), \ + f'Expected max_try to be None or int, but got {type(max_try)}' + + while max_try is None or try_count < max_try: + try: + try_count += 1 + ret = func(*args, **kwargs) + return ret + except exception_type as e: + error_lines = str(e).split('\n') + if try_count < max_try and (pattern is None or _match_lines(error_lines, pattern)): + print('Exception is caught, retrying...') + # when pattern is not specified, we always skip the exception + # when pattern is specified, we only skip when pattern is matched + continue + else: + print('Maximum number of attempts is reached or pattern is not matched, no more retrying...') + raise e + + # Override signature + # otherwise pytest.mark.parameterize will raise the following error: + # function does not use argument xxx + sig = signature(func) + _run_until_success.__signature__ = sig + + return _run_until_success + + return _wrapper + + +def rerun_if_address_is_in_use(): + """ + This function reruns a wrapped function if "address already in use" occurs + in testing spawned with torch.multiprocessing + + Usage:: + + @rerun_if_address_is_in_use() + def test_something(): + ... + + """ + # check version + torch_version = version.parse(torch.__version__) + assert torch_version.major == 1 + + # only torch >= 1.8 has ProcessRaisedException + if torch_version.minor >= 8: + exception = torch.multiprocessing.ProcessRaisedException + else: + exception = Exception + + func_wrapper = rerun_on_exception(exception_type=exception, pattern=".*Address already in use.*") + return func_wrapper + + +def skip_if_not_enough_gpus(min_gpus: int): + """ + This function is used to check the number of available GPUs on the system and + automatically skip the test cases which require more GPUs. + + Note: + The wrapped function must have `world_size` in its keyword argument. + + Usage: + @skip_if_not_enough_gpus(min_gpus=8) + def test_something(): + # will be skipped if there are fewer than 8 GPUs available + do_something() + + Arg: + min_gpus (int): the minimum number of GPUs required to run this test. + """ + + def _wrap_func(f): + + def _execute_by_gpu_num(*args, **kwargs): + num_avail_gpu = torch.cuda.device_count() + if num_avail_gpu >= min_gpus: + f(*args, **kwargs) + + return _execute_by_gpu_num + + return _wrap_func + + +def free_port() -> int: + """Get a free port on localhost. + + Returns: + int: A free port on localhost. + """ + while True: + port = random.randint(20000, 65000) + try: + with socket.socket() as sock: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(("localhost", port)) + return port + except OSError: + continue + + +def spawn(func, nprocs=1, **kwargs): + """ + This function is used to spawn processes for testing. + + Usage: + # must contains arguments rank, world_size, port + def do_something(rank, world_size, port): + ... + + spawn(do_something, nprocs=8) + + # can also pass other arguments + def do_something(rank, world_size, port, arg1, arg2): + ... + + spawn(do_something, nprocs=8, arg1=1, arg2=2) + + Args: + func (Callable): The function to be spawned. + nprocs (int, optional): The number of processes to spawn. Defaults to 1. + """ + port = free_port() + wrapped_func = partial(func, world_size=nprocs, port=port, **kwargs) + mp.spawn(wrapped_func, nprocs=nprocs) + + +def clear_cache_before_run(): + """ + This function is a wrapper to clear CUDA and python cache before executing the function. + + Usage: + @clear_cache_before_run() + def test_something(): + ... + """ + + def _wrap_func(f): + + def _clear_cache(*args, **kwargs): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_max_memory_cached() + torch.cuda.synchronize() + gc.collect() + f(*args, **kwargs) + + return _clear_cache + + return _wrap_func diff --git a/colossalai/trainer/__init__.py b/colossalai/trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..84e53dc4e87ac5b10a93aacc0fce975cc49c66eb --- /dev/null +++ b/colossalai/trainer/__init__.py @@ -0,0 +1,3 @@ +from ._trainer import Trainer + +__all__ = ['Trainer'] diff --git a/colossalai/trainer/_trainer.py b/colossalai/trainer/_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..60bbc4eeee32adbd46472106fff4a07c653f7cb6 --- /dev/null +++ b/colossalai/trainer/_trainer.py @@ -0,0 +1,409 @@ +from typing import Union, List, Any + +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm + +from colossalai.engine import Engine +from colossalai.logging import DistributedLogger +from colossalai.utils import MultiTimer +from colossalai.utils import is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage +from colossalai.trainer.hooks import BaseHook + + +class Trainer: + r"""This is a class tending for easy deployments of users' training and evaluation instead of + writing their own scripts. It is similar with ``ignite.engine`` and ``keras.engine``, but is + called `Trainer`. + + Args: + engine (:class:`Engine`): Engine responsible for the process function. + timer (:class:`MultiTimer`, optional): Timer used to monitor the whole training. + logger (:class:`colossalai.logging.DistributedLogger`, optional): Logger used to record the whole training log. + + + Examples: + >>> # define model, criterion, optimizer, lr_scheduler, train_dataloader for your training + >>> model = ... + >>> criterion = ... + >>> optimizer = ... + >>> train_dataloader = ... + >>> # Initialize your engine, train_dataloader, test_dataloader, lr_scheduler + >>> engine, train_dataloader, _, _ = colossalai.initialize(model, optimizer, criterion) + >>> # Beginning training progress + >>> timier = ... + >>> logger = ... + >>> trainer = Trainer(engine=engine, logger=logger, timer=timier) + >>> # add hooks you would like to use here. + >>> hook_list = [] + >>> trainer.fit( + >>> train_dataloader=train_dataloader, + >>> epochs=gpc.config.NUM_EPOCHS, + >>> test_interval=1, + >>> hooks=hook_list, + >>> display_progress=True, + >>> return_output_label=False + >>> ) + + More examples and details could be found in + `Training with engine and trainer `_ + and `ColossalAI-Examples `_. + """ + + def __init__( + self, + engine: Engine, + timer: MultiTimer = None, + logger: DistributedLogger = None, + ): + # training-ralated params + self._engine = engine + self._max_epochs = 0 + self._cur_epoch = 0 + self._max_steps = 0 + self._cur_step = 0 + self._steps_per_epoch = 0 + + # misc params + self._logger = logger + self._verbose = logger is not None + + # hooks can store states in this dict, and could be consumed by other hooks + self.states = dict() + + # build hooks + self.hooks = list() + + # multi-timer for time benchmarking + self._timer = timer + + @property + def cur_epoch(self): + """Returns the index of the current epoch.""" + return self._cur_epoch + + @cur_epoch.setter + def cur_epoch(self, epoch: int): + """Set how many epochs have been processed.""" + # allow setter for training resumption + self._cur_epoch = epoch + + @property + def cur_step(self): + """Returns how many iteration steps have been processed.""" + return self._cur_step + + @property + def max_epochs(self): + return self._max_epochs + + @property + def max_steps(self): + return self._max_steps + + @property + def steps_per_epoch(self): + return self._steps_per_epoch + + @property + def engine(self): + return self._engine + + def _set_current_step(self, epoch: int): + """Sets current step number. + + Args: + epoch (int): Step number to be set. + """ + self._cur_step = epoch * self._steps_per_epoch + + def _call_timer(self, action: str, item: str, *args, **kwargs) -> None: + """Call timer funciton with a given timer name. + + Args: + action (str): Function to be called on timer. + item (str): Name of the timer. + args (list): args used for action function. + kwargs (dict): kwargs used for action function. + """ + + if self._timer is not None: + getattr(self._timer, action)(item, *args, **kwargs) + + def _reset_states(self) -> None: + """Clear trainer states""" + self.states = dict() + + def _call_hooks(self, func, output=None): + """Calls specific hooks in the current time point. + + Args: + func (str): A string represents the time point. + output (Any, optional): Output of the model after running an iteration or None in any other time points. + """ + # Only after iter hook will receive output + for hook in self.hooks: + if output is None: + getattr(hook, func)(self) + else: + getattr(hook, func)(self, *output) + + @staticmethod + def _should_display_progress(display_progress: bool): + """Only display progress on DP rank 0, TP rank 0 and PP last rank""" + return (display_progress and is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage()) + + def _train_epoch( + self, + train_dataloader: DataLoader, + epoch: int = None, + display_progress: bool = False, + return_output_label: bool = True, + ): + # set training state + self._engine.train() + data_iter = iter(train_dataloader) + progress = range(self._steps_per_epoch) + if display_progress: + if epoch is None: + progress = tqdm(progress, desc="[Train]") + else: + progress = tqdm(progress, desc=f"[Epoch {epoch} / Train]") + + self._call_hooks("before_train_epoch") + self._call_timer(action="start", item="Train-epoch") + for i in progress: + self._call_hooks("before_train_iter") + self._call_timer(action="start", item="Train-step") + + # run 1 training step + self.engine.zero_grad() + logits, label, loss = self.engine.execute_schedule( + data_iter, + forward_only=False, + return_loss=True, + return_output_label=return_output_label, + ) + self.engine.step() + self._call_timer(action="stop", item="Train-step", keep_in_history=True) + self._call_hooks("after_train_iter", output=(logits, label, loss)) + + self._cur_step += 1 + + if display_progress: + if "step_metrics" in self.states: + progress.set_postfix(**self.states["step_metrics"]) + + # stop when max iter is reached + if self._exceed_max_step(): + break + + self._call_timer(action="stop", item="Train-epoch", keep_in_history=True) + self._call_hooks("after_train_epoch") + self._call_timer(action="reset", item="Train-epoch") + + def _eval( + self, + test_dataloader: DataLoader, + epoch: int = None, + display_progress: bool = False, + return_output_label: bool = True, + ): + # switch engine status + self._engine.eval() + + data_iter = iter(test_dataloader) + num_steps = len(test_dataloader) + + self._call_hooks("before_test") + # prepare progress bar + progress = range(num_steps) + if display_progress: + desc = "Evaluation" + if epoch is not None: + desc = "[Epoch %d / Test]" % epoch + progress = tqdm(progress, desc=desc) + + self._call_hooks("before_test_epoch") + self._call_timer(action="start", item="Test-epoch") + with torch.no_grad(): + for _ in progress: + self._call_hooks("before_test_iter") + self._call_timer(action="start", item="Test-step") + logits, label, loss = self.engine.execute_schedule( + data_iter, + forward_only=True, + return_loss=True, + return_output_label=return_output_label, + ) + self._call_timer(action="stop", item="Test-step", keep_in_history=True) + self._call_hooks("after_test_iter", output=(logits, label, loss)) + + if display_progress: + if "step_metrics" in self.states: + progress.set_postfix(**self.states["step_metrics"]) + + self._call_timer(action="stop", item="Test-epoch", keep_in_history=True) + self._call_hooks("after_test_epoch") + self._call_hooks("after_test") + self._call_timer(action="reset", item="Test-step") + self._call_timer(action="reset", item="Test-epoch") + + def _exceed_max_step(self): + return self._max_steps is not None and self._cur_step >= self._max_steps + + def fit( + self, + train_dataloader: DataLoader, + epochs: int, + max_steps: int = None, + test_dataloader: DataLoader = None, + test_interval: int = 1, + hooks: List[BaseHook] = None, + display_progress: bool = False, + return_output_label: bool = True, + ): + r"""Trains the model to fit training data. + + Args: + train_dataloader (:class:`torch.utils.data.DataLoader`): DataLoader for training. + epochs (int): Maximum number of epochs. + max_steps (int, optional): Maximum number of running iterations. + test_dataloader (:class:`torch.utils.data.DataLoader`, optional): DataLoader for validation. + test_interval (int, optional): Interval of validation + hooks (list[BaseHook], optional): A list of hooks used in training. + display_progress (bool, optional): If True, a progress bar will be displayed. + """ + + # set epochs and steps, consider gradient accumulation + self._steps_per_epoch = len(train_dataloader) + self._max_steps = max_steps + self._max_epochs = epochs + + # check if testing is required + should_test = False + if test_dataloader is not None: + should_test = True + + display_progress = self._should_display_progress(display_progress) + + # reset hooks + self._reset_states() + if hooks is not None: + assert isinstance(hooks, list), f"expected argument hooks be to list, but got {type(hooks)}" + + for hook in hooks: + assert isinstance(hook, BaseHook), \ + f'expected the hook to be of type BaseHook, but got {type(hook)}' + else: + hooks = [] + self.hooks = hooks + self.hooks.sort(key=lambda hook: hook.priority) + if self._verbose: + for hook in self.hooks: + self._logger.info( + f"Using {hook.__class__.__name__} for training, priority = {hook.priority}", + ranks=[0], + ) + self._logger.info("Lower value means higher priority for calling hook function", ranks=[0]) + self._call_hooks("after_hook_is_attached") + + self._engine.train() + self._call_hooks("before_train") + + # recover step value if resuming training + last_epoch = self._cur_epoch + if self.cur_epoch != 0: + self._set_current_step(last_epoch) + + for epoch in range(last_epoch, epochs): + # train for one epoch + self._train_epoch( + train_dataloader=train_dataloader, + epoch=epoch, + display_progress=display_progress, + return_output_label=return_output_label, + ) + + # start eval + if should_test and epoch % test_interval == 0: + self._eval( + test_dataloader=test_dataloader, + display_progress=display_progress, + epoch=epoch, + return_output_label=return_output_label, + ) + + self._cur_epoch += 1 + + # check for termination + if self._exceed_max_step(): + self._logger.info( + f"Max number of steps {max_steps} has been reached, training is stopped automatically", + ranks=[0], + ) + break + self._call_hooks("after_train") + self._call_timer("reset", "Train-epoch") + + def evaluate( + self, + test_dataloader: DataLoader, + hooks: List[BaseHook] = None, + display_progress: bool = False, + return_output_label: bool = True, + ): + """Evaluates the model with testing data. + + Args: + test_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for testing. + hooks (list, optional): A list of hooks used in evaluation. Defaults to None. + display_progress (bool, optional): If True, the evaluation progress will be printed. Defaults to False. + return_output_label (bool, optional): If True, the output of model and the label + will be returned. Defaults to True. + """ + # set display + display_progress = self._should_display_progress(display_progress) + + # reset hooks + self._reset_states() + if hooks is not None: + assert isinstance(hooks, list), f"expected argument hooks be to list, but got {type(hooks)}" + else: + hooks = [] + self.hooks = hooks + self.hooks.sort(key=lambda hook: hook.priority) + if self._verbose: + for hook in self.hooks: + self._logger.info( + f"Using {hook.__class__.__name__} for training, priority = {hook.priority}", + ranks=[0], + ) + self._logger.info("Lower value means higher priority for calling hook function", ranks=[0]) + self._call_hooks("after_hook_is_attached") + + # eval + self._eval( + test_dataloader=test_dataloader, + display_progress=display_progress, + return_output_label=return_output_label, + ) + + def predict(self, data: Union[Any, List[Any]]): + """Uses trained model to make a prediction for a tensor or a tensor list. + + Args: + data (Union[:class:`torch.tensor`, List[:class:`torch.tensor`]]): Data as the input. + + Returns: + :class:`torch.tensor`: The output of model as the prediction + """ + # predict without labels + self._engine.eval() + + # prepare a list of (data, label) to make it iterable + # for compatibility with schedule + simple_dataloader = [(data, None)] + data_iter = iter(simple_dataloader) + output, _, _ = self.engine.execute_schedule(data_iter, forward_only=True, return_loss=False) + return output diff --git a/colossalai/trainer/hooks/__init__.py b/colossalai/trainer/hooks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4d36093833d99429b15fb35962c930646b5cbf64 --- /dev/null +++ b/colossalai/trainer/hooks/__init__.py @@ -0,0 +1,12 @@ +from ._base_hook import BaseHook +from ._checkpoint_hook import SaveCheckpointHook +from ._log_hook import (LogMemoryByEpochHook, LogMetricByEpochHook, LogMetricByStepHook, LogTimingByEpochHook, + TensorboardHook) +from ._lr_scheduler_hook import LRSchedulerHook +from ._metric_hook import AccuracyHook, LossHook, MetricHook, ThroughputHook + +__all__ = [ + 'BaseHook', 'MetricHook', 'LossHook', 'AccuracyHook', 'LogMetricByEpochHook', 'TensorboardHook', + 'LogTimingByEpochHook', 'LogMemoryByEpochHook', 'LRSchedulerHook', 'ThroughputHook', 'LogMetricByStepHook', + 'SaveCheckpointHook' +] diff --git a/colossalai/trainer/hooks/_base_hook.py b/colossalai/trainer/hooks/_base_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..cca8e081ec883b8b5f3d88633ec9f57cc9fd6dfc --- /dev/null +++ b/colossalai/trainer/hooks/_base_hook.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from abc import ABC + +from torch import Tensor + + +class BaseHook(ABC): + """This class allows users to add desired actions in specific time points + during training or evaluation. + + :param priority: Priority in the printing, hooks with small priority will be printed in front + :type priority: int + """ + + def __init__(self, priority: int) -> None: + self.priority = priority + + def after_hook_is_attached(self, trainer): + """Actions after hooks are attached to trainer. + """ + pass + + def before_train(self, trainer): + """Actions before training. + """ + pass + + def after_train(self, trainer): + """Actions after training. + """ + pass + + def before_train_iter(self, trainer): + """Actions before running a training iteration. + """ + pass + + def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor): + """Actions after running a training iteration. + + Args: + trainer (:class:`Trainer`): Trainer which is using this hook. + output (:class:`torch.Tensor`): Output of the model. + label (:class:`torch.Tensor`): Labels of the input data. + loss (:class:`torch.Tensor`): Loss between the output and input data. + """ + pass + + def before_train_epoch(self, trainer): + """Actions before starting a training epoch. + """ + pass + + def after_train_epoch(self, trainer): + """Actions after finishing a training epoch. + """ + pass + + def before_test(self, trainer): + """Actions before evaluation. + """ + pass + + def after_test(self, trainer): + """Actions after evaluation. + """ + pass + + def before_test_epoch(self, trainer): + """Actions before starting a testing epoch. + """ + pass + + def after_test_epoch(self, trainer): + """Actions after finishing a testing epoch. + """ + pass + + def before_test_iter(self, trainer): + """Actions before running a testing iteration. + """ + pass + + def after_test_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor): + """Actions after running a testing iteration. + + Args: + trainer (:class:`Trainer`): Trainer which is using this hook + output (:class:`torch.Tensor`): Output of the model + label (:class:`torch.Tensor`): Labels of the input data + loss (:class:`torch.Tensor`): Loss between the output and input data + """ + pass + + def init_runner_states(self, trainer, key, val): + """Initializes trainer's state. + + Args: + trainer (:class:`Trainer`): Trainer which is using this hook + key: Key of state to be reset + val: Value of state to be reset + """ + if key not in trainer.states: + trainer.states[key] = val diff --git a/colossalai/trainer/hooks/_checkpoint_hook.py b/colossalai/trainer/hooks/_checkpoint_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..3bcb32cd2dcbc46a9e57dfec1f72abe9bd4aabda --- /dev/null +++ b/colossalai/trainer/hooks/_checkpoint_hook.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- +import torch +from colossalai.logging import get_dist_logger + +from colossalai.registry import HOOKS +from colossalai.trainer.hooks import BaseHook +from colossalai.utils.checkpointing import save_checkpoint +from ._lr_scheduler_hook import LRSchedulerHook + + +@HOOKS.register_module +class SaveCheckpointHook(BaseHook): + """Saves the model by interval in training process. + + Args: + interval (int, optional): Number of epochs between saving the checkpoint, defaults to 1. + if save_by_iter is True, this arg refers to the number of iters between saving. + checkpoint_dir (str, optional): File name to save the checkpoint, defaults to None. + model (torch.nn.Module, Optional): The model to save, defaults to None. When not passing, + 'trainer.engine.model' will be used. We encourage you to pass the model in it to avoid some + unexpected bugs, especially when using **DDP**. + save_by_iter (bool, optional): Whether saving the checkpoint by iter, default to False. + priority (int, optional): Priority in the printing, hooks with small priority will be printed in front + defaults to 10. If different hooks share same priority, the order of printing would + depend on the hooks order in the hook list. + """ + + def __init__(self, + interval: int = 1, + checkpoint_dir: str = None, + model: torch.nn.Module = None, + save_by_iter: bool = False, + priority: int = 10): + super().__init__(priority=priority) + self.interval = interval + self.checkpoint_dir = checkpoint_dir + self.model = model + self.save_by_iter = save_by_iter + self.logger = get_dist_logger() + + # get lr scheduler from the LRSchedulerHook before train + self._lr_scheduler = None + + def after_hook_is_attached(self, trainer): + # get lr scheduler if exists + for hook in trainer.hooks: + if isinstance(hook, LRSchedulerHook): + self._lr_scheduler = hook.lr_scheduler + break + self.model = self.model if self.model is not None else trainer.engine.model + + def after_train_iter(self, trainer, output, label, loss): + """Saves the model after a training iter. + """ + # save by interval + if self.save_by_iter and trainer.cur_step % self.interval == 0: + save_checkpoint(self.checkpoint_dir, trainer.cur_epoch, self.model, trainer.engine.optimizer, + self._lr_scheduler) + self.logger.info(f'checkpoint for iteration {trainer.cur_step} is saved to {self.checkpoint_dir}', + ranks=[0]) + else: + pass + + def after_train_epoch(self, trainer): + """Saves the model after a training epoch. + """ + # save by interval + if trainer.cur_epoch % self.interval == 0: + save_checkpoint(self.checkpoint_dir, trainer.cur_epoch, self.model, trainer.engine.optimizer, + self._lr_scheduler) + self.logger.info(f'checkpoint for epoch {trainer.cur_epoch} is saved to {self.checkpoint_dir}', ranks=[0]) diff --git a/colossalai/trainer/hooks/_commons_.py b/colossalai/trainer/hooks/_commons_.py new file mode 100644 index 0000000000000000000000000000000000000000..4923b8cba6c04e482bd4f5163b33767d34e83ebb --- /dev/null +++ b/colossalai/trainer/hooks/_commons_.py @@ -0,0 +1,9 @@ +import torch + + +def _format_number(val, prec=5): + if isinstance(val, float): + return f'{val:.{prec}g}' + elif torch.is_tensor(val) and torch.is_floating_point(val): + return f'{val.item():.{prec}g}' + return val diff --git a/colossalai/trainer/hooks/_log_hook.py b/colossalai/trainer/hooks/_log_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..5b1f33983422b11389927190acbef90165a15e2a --- /dev/null +++ b/colossalai/trainer/hooks/_log_hook.py @@ -0,0 +1,301 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import os +import os.path as osp + +from typing import List +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.registry import HOOKS +from colossalai.logging import DistributedLogger +from colossalai.utils import report_memory_usage, is_dp_rank_0, \ + is_tp_rank_0, is_no_pp_or_last_stage, MultiTimer +from ._base_hook import BaseHook +from ._commons_ import _format_number +from colossalai.trainer.hooks._metric_hook import ThroughputMetric + + +class LogByEpochHook(BaseHook): + """Hook to log by epoch. + + Args: + logger (:class:`colossalai.logging.DistributedLogger`): Logger for recording the log information. + interval (int, optional): Interval of printing log information, defaults to 1. + priority (int, optional): Priority in the printing, hooks with small priority will be printed in front, + defaults to 1. If different hooks share same priority, the order of printing would + depend on the hooks order in the hook list. + """ + + def __init__(self, logger, interval: int = 1, priority: int = 1): + super().__init__(priority) + self.logger = logger + self._interval = interval + + def _is_epoch_to_log(self, trainer): + return trainer.cur_epoch % self._interval == 0 + + +@HOOKS.register_module +class LogMetricByStepHook(BaseHook): + """Hook to log metric by step. + + Args: + priority (int, optional): Priority in the printing, hooks with small priority will be printed in front, + defaults to 10. If different hooks share same priority, the order of printing would + depend on the hooks order in the hook list. + """ + + def __init__(self, priority: int = 10): + super().__init__(priority) + + def after_train_iter(self, trainer, *args): + trainer.states['step_metrics'] = dict() + for metric_name, metric_calculator in trainer.states['metrics']['train'].items(): + if isinstance(metric_calculator, ThroughputMetric): + trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_info() + else: + trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_value() + + def after_test_iter(self, trainer, *args): + trainer.states['step_metrics'] = dict() + for metric_name, metric_calculator in trainer.states['metrics']['test'].items(): + if isinstance(metric_calculator, ThroughputMetric): + trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_info() + else: + trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_value() + + +@HOOKS.register_module +class LogMetricByEpochHook(LogByEpochHook): + """Specialized hook to record the metric to log. + + Args: + logger (:class:`colossalai.logging.DistributedLogger`): Logger for recording the log information. + interval (int, optional): Interval of printing log information, defaults to 1. + priority (int, optional): Priority in the printing, hooks with small priority will be printed in front, + defaults to 10. If different hooks share same priority, the order of printing would + depend on the hooks order in the hook list. + """ + + def __init__(self, logger, interval: int = 1, priority: int = 10) -> None: + super().__init__(logger, interval, priority) + self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage() + + def _get_str(self, trainer, mode): + msg = [] + for metric_name, metric_calculator in trainer.states['metrics'][mode].items(): + msg.append(f'{metric_name} = {_format_number(metric_calculator.get_accumulated_value())}') + msg = ' | '.join(msg) + return msg + + def after_train_epoch(self, trainer): + if self._is_epoch_to_log(trainer): + msg = self._get_str(trainer=trainer, mode='train') + + if self._is_rank_to_log: + self.logger.info(f'[Epoch {trainer.cur_epoch} / Train]: {msg}') + # f'Training - Epoch {trainer.cur_epoch} - {self.__class__.__name__}: {msg}') + + def after_test_epoch(self, trainer): + if self._is_epoch_to_log(trainer): + msg = self._get_str(trainer=trainer, mode='test') + if self._is_rank_to_log: + self.logger.info(f'[Epoch {trainer.cur_epoch} / Test]: {msg}') + # f'Testing - Epoch {trainer.cur_epoch} - {self.__class__.__name__}: {msg}') + + +@HOOKS.register_module +class TensorboardHook(BaseHook): + """Specialized hook to record the metric to Tensorboard. + + Args: + log_dir (str): Directory of log. + ranks (list): Ranks of processors. + parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`, optional): Parallel mode used in trainer, + defaults to colossalai.context.parallel_mode.ParallelMode.GLOBAL. + priority (int, optional): Priority in the printing, hooks with small priority will be printed in front, + defaults to 10. If different hooks share same priority, the order of printing would + depend on the hooks order in the hook list. + """ + + def __init__( + self, + log_dir: str, + ranks: List = None, + parallel_mode: ParallelMode = ParallelMode.GLOBAL, + priority: int = 10, + ) -> None: + super().__init__(priority=priority) + from torch.utils.tensorboard import SummaryWriter + + # create log dir + if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0: + os.makedirs(log_dir, exist_ok=True) + + # determine the ranks to generate tensorboard logs + self._is_valid_rank_to_log = False + if not gpc.is_initialized(parallel_mode): + self._is_valid_rank_to_log = True + else: + local_rank = gpc.get_local_rank(parallel_mode) + + if ranks is None or local_rank in ranks: + self._is_valid_rank_to_log = True + + # check for + if gpc.is_initialized(ParallelMode.PIPELINE) and \ + not gpc.is_last_rank(ParallelMode.PIPELINE) and self._is_valid_rank_to_log: + raise ValueError("Tensorboard hook can only log on the last rank of pipeline process group") + + if self._is_valid_rank_to_log: + # create workspace on only one rank + if gpc.is_initialized(parallel_mode): + rank = gpc.get_local_rank(parallel_mode) + else: + rank = 0 + + # create workspace + log_dir = osp.join(log_dir, f'{parallel_mode}_rank_{rank}') + os.makedirs(log_dir, exist_ok=True) + + self.writer = SummaryWriter(log_dir=log_dir, filename_suffix=f'_rank_{rank}') + + def _log_by_iter(self, trainer, mode: str): + for metric_name, metric_calculator in trainer.states['metrics'][mode].items(): + if metric_calculator.epoch_only: + continue + val = metric_calculator.get_last_step_value() + + if self._is_valid_rank_to_log: + self.writer.add_scalar(f'{metric_name}/{mode}', val, trainer.cur_step) + + def _log_by_epoch(self, trainer, mode: str): + for metric_name, metric_calculator in trainer.states['metrics'][mode].items(): + if metric_calculator.epoch_only: + val = metric_calculator.get_accumulated_value() + if self._is_valid_rank_to_log: + self.writer.add_scalar(f'{metric_name}/{mode}', val, trainer.cur_step) + + def after_test_iter(self, trainer, *args): + self._log_by_iter(trainer, mode='test') + + def after_test_epoch(self, trainer): + self._log_by_epoch(trainer, mode='test') + + def after_train_iter(self, trainer, *args): + self._log_by_iter(trainer, mode='train') + + def after_train_epoch(self, trainer): + self._log_by_epoch(trainer, mode='train') + + +@HOOKS.register_module +class LogTimingByEpochHook(LogByEpochHook): + """Specialized hook to write timing record to log. + + Args: + timer (:class:`colossalai.utils.MultiTimer`): Timer for the hook. + logger (:class:`colossalai.logging.DistributedLogger`): Logger for recording the log information. + interval (int, optional): Interval of printing log information, defaults to 1. + priority (int, optional): Priority in the printing, hooks with small priority will be printed in front + defaults to 10. If different hooks share same priority, the order of printing would + depend on the hooks order in the hook list. + log_eval (bool, optional): Whether writes in evaluation, defaults to True. + ignore_num_train_steps (int, optional): Number of training steps to ignore, defaults to 0. + """ + + def __init__(self, + timer: MultiTimer, + logger: DistributedLogger, + interval: int = 1, + priority: int = 10, + log_eval: bool = True, + ignore_num_train_steps: int = 0) -> None: + super().__init__(logger=logger, interval=interval, priority=priority) + self._timer = timer + self._log_eval = log_eval + self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage() + + # extra handling to avoid the unstable readings of the first + # few training steps to affect the history mean time + self._ignore_num_train_steps = ignore_num_train_steps + self._is_train_step_history_trimmed = False + + def _get_message(self, mode): + msg = [] + for timer_name, timer in self._timer: + if timer_name.startswith(mode): + last_elapsed_time = timer.get_elapsed_time() + if timer.has_history: + if timer_name == 'Train-step' and not self._is_train_step_history_trimmed: + timer._history = timer._history[self._ignore_num_train_steps:] + self._is_train_step_history_trimmed = True + history_mean = timer.get_history_mean() + history_sum = timer.get_history_sum() + msg.append( + f'{timer_name}: last = {_format_number(last_elapsed_time)} s, mean = {_format_number(history_mean)} s' + ) + else: + msg.append(f'{timer_name}: last = {_format_number(last_elapsed_time)} s') + + msg = ' | '.join(msg) + return msg + + def after_train_epoch(self, trainer): + """Writes log after finishing a training epoch. + """ + if self._is_epoch_to_log(trainer) and self._is_rank_to_log: + msg = self._get_message('Train') + self.logger.info(f'[Epoch {trainer.cur_epoch} / Train]: {msg} | #steps/epoch = {trainer.steps_per_epoch}') + + def after_test_epoch(self, trainer): + """Writes log after finishing a testing epoch. + """ + if self._is_epoch_to_log(trainer) and self._is_rank_to_log and self._log_eval: + msg = self._get_message('Test') + self.logger.info(f'[Epoch {trainer.cur_epoch} / Test]: {msg}') + + +@HOOKS.register_module +class LogMemoryByEpochHook(LogByEpochHook): + """Specialized Hook to write memory usage record to log. + + Args: + logger (:class:`colossalai.logging.DistributedLogger`): Logger for recording the log information. + interval (int, optional): Interval of printing log information, defaults to 1. + priority (int, optional): Priority in the printing, hooks with small priority will be printed in front + defaults to 1. If different hooks share same priority, the order of printing would + depend on the hooks order in the hook list. + log_eval (bool, optional): Whether writes in evaluation, defaults to True. + """ + + def __init__( + self, + logger: DistributedLogger, + interval: int = 1, + priority: int = 10, + log_eval: bool = True, + report_cpu: bool = False, # no reference + ) -> None: + super().__init__(logger=logger, interval=interval, priority=priority) + self._log_eval = log_eval + self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() + + def before_train(self, trainer): + """Resets before training. + """ + if self._is_epoch_to_log(trainer) and self._is_rank_to_log: + report_memory_usage('Before-train', self.logger) + + def after_train_epoch(self, trainer): + """Writes log after finishing a training epoch. + """ + if self._is_epoch_to_log(trainer) and self._is_rank_to_log: + report_memory_usage(f'[Epoch {trainer.cur_epoch} / Train]', self.logger) + + def after_test(self, trainer): + """Reports after testing. + """ + if self._is_epoch_to_log(trainer) and self._is_rank_to_log and self._log_eval: + report_memory_usage(f'[Epoch {trainer.cur_epoch} / Test]', self.logger) diff --git a/colossalai/trainer/hooks/_lr_scheduler_hook.py b/colossalai/trainer/hooks/_lr_scheduler_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..c6da33442dc39b78474fb36c50b3a9bbfc790666 --- /dev/null +++ b/colossalai/trainer/hooks/_lr_scheduler_hook.py @@ -0,0 +1,47 @@ +from colossalai.registry import HOOKS +from torch import Tensor + +from ._metric_hook import LearningRateMetric, MetricHook + + +@HOOKS.register_module +class LRSchedulerHook(MetricHook): + r"""Build LR scheduler for trainer. + + Args: + lr_scheduler (:class:`colossalai.nn.lr_scheduler`): The specific LR scheduler + in range of ``colossalai.nn.lr_scheduler``, more details about ``lr_scheduler`` could be found in + `lr_scheduler `_. + by_epoch (bool): If `True`, the LR will be scheduled every epoch. Else, the LR will be scheduled every batch. + store_lr_in_state (bool, optional): If `True`, store the learning rate in each state, defaults to `True`. + priority (int, optional): Priority in the printing, hooks with small priority will be printed in front + defaults to 1. If different hooks share same priority, the order of printing would + depend on the hooks order in the hook list. + """ + + def __init__( + self, + lr_scheduler, + by_epoch: bool, + store_lr_in_state: bool = True, + priority: int = 1, + ): + super().__init__(priority=priority) + self.by_epoch = by_epoch + self.lr_scheduler = lr_scheduler + self.store_lr_in_state = store_lr_in_state + + def after_hook_is_attached(self, trainer): + self._check_metric_states_initialization(trainer) + trainer.states['metrics']['train']['LR'] = LearningRateMetric(epoch_only=self.by_epoch, + initial_lr=self.lr_scheduler.get_last_lr()[0]) + + def after_train_epoch(self, trainer): + if self.by_epoch: + self.lr_scheduler.step() + trainer.states['metrics']['train']['LR'].update(self.lr_scheduler.get_last_lr()[0]) + + def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor): + if not self.by_epoch: + self.lr_scheduler.step() + trainer.states['metrics']['train']['LR'].update(self.lr_scheduler.get_last_lr()[0]) diff --git a/colossalai/trainer/hooks/_metric_hook.py b/colossalai/trainer/hooks/_metric_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..526d6c746ec6511c97b283ac1074340daabb2516 --- /dev/null +++ b/colossalai/trainer/hooks/_metric_hook.py @@ -0,0 +1,437 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from abc import ABC, abstractmethod +from typing import Callable + +import torch +import torch.distributed as dist +from colossalai.communication import all_reduce +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.registry import HOOKS +from colossalai.utils import get_current_device, is_no_pp_or_last_stage + +from ._base_hook import BaseHook +from ._commons_ import _format_number + + +class Metric(ABC): + """A basic class of metric collectors. It collects a specific + metric during training or evaluation and would always be used with + :class:`MetricHook` to help it update its states and show the + metric. So please use corresponding hook class to make the metric + collector works. + + Args: + epoch_only (bool): Whether the metric only read for the full epoch. + """ + + def __init__(self, epoch_only: bool): + # is the metric only read for the full epoch + self._epoch_only = epoch_only + + @property + def epoch_only(self): + """Returns :attr:`epoch_only`. + """ + return self._epoch_only + + @abstractmethod + def reset(self) -> None: + """Resets the metric to it's initial state. + By default, this is called at the start of each epoch. + """ + pass + + @abstractmethod + def update(self, *args, **kwargs) -> None: + """Updates the metric's state using the passed batch output. + By default, this is called once for each batch. + """ + pass + + @abstractmethod + def get_last_step_value(self) -> float: + """Returns the metric value in the last iteration. + """ + pass + + @abstractmethod + def get_accumulated_value(self): + """Computes the metric based on it's accumulated state. + By default, this is called at the end of each epoch. + + :return: the actual quantity of interest + :rtype: Any + """ + pass + + @staticmethod + @abstractmethod + def is_better(a, b) -> bool: + """Compares a and b, and returns whether a is better than b + + :return: The result of comparison + :rtype: bool + """ + pass + + +class LossMetric(Metric): + """A metric collector for loss. + + Args: + epoch_only (bool): Whether the metric only read for the full epoch. + """ + + def __init__(self, epoch_only): + super().__init__(epoch_only=epoch_only) + self.last_step_loss = torch.zeros(1, device=get_current_device()) + self.accum_loss = torch.zeros(1, device=get_current_device()) + self.count = 0 + + def reset(self) -> None: + """Sets :attr:`last_step_loss` and :attr:`accum_loss` to zero. + """ + self.last_step_loss.zero_() + self.accum_loss.zero_() + self.count = 0 + + def update(self, loss) -> None: + """Updates :attr:`last_step_loss` and :attr:`accum_loss` with current loss. + It expects the output has loss. + + Args: + loss (:class:`torch.tensor`): Current loss of the output. + """ + # expect output to be logits, label and loss + loss_ = loss.detach() + self.last_step_loss.copy_(loss_) + self.accum_loss.add_(loss_) + self.count += 1 + + def get_accumulated_value(self): + """Returns accumulated loss. + """ + if gpc.is_initialized(ParallelMode.DATA): + dist.all_reduce(self.accum_loss, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.DATA)) + self.accum_loss.div_(gpc.get_world_size(ParallelMode.DATA)) + + self.accum_loss.div_(self.count) + return self.accum_loss.item() + + def get_last_step_value(self) -> float: + """Returns :attr:`last_step_loss`. + """ + return self.last_step_loss.cpu().item() + + @staticmethod + def is_better(a, b): + return a < b + + +class LearningRateMetric(Metric): + """A metric collector for learning rate. + + Args: + epoch_only (bool): Whether the metric only read for the full epoch. + initial_lr (float, optional): Initial learning rate, defaults to 0.0. + """ + + def __init__(self, epoch_only: bool, initial_lr: float = 0.): + super().__init__(epoch_only=epoch_only) + self.lr = initial_lr + + def reset(self) -> None: + pass + + def update(self, lr) -> None: + self.lr = lr + + def get_last_step_value(self) -> float: + return self.lr + + def get_accumulated_value(self): + return self.lr + + @staticmethod + def is_better(a, b) -> bool: + pass + + +class AccuracyMetric(Metric): + """A metric collector for accuracy. It only works for classification + tasks. + + Args: + epoch_only (bool): Whether the metric only read for the full epoch. + accuracy_func (:class:`typing.Callable`): Accuracy function for the classification task. + """ + + def __init__(self, epoch_only: bool, accuracy_func: Callable): + super().__init__(epoch_only=epoch_only) + self.acc = accuracy_func + self.last_step_sum = torch.zeros(1, device=get_current_device()) + self.last_step_correct = torch.zeros(1, device=get_current_device()) + self.accumulated_sum = torch.zeros(1, device=get_current_device()) + self.accumulated_correct = torch.zeros(1, device=get_current_device()) + + def reset(self) -> None: + self.last_step_sum.zero_() + self.last_step_correct.zero_() + self.accumulated_sum.zero_() + self.accumulated_correct.zero_() + + def update(self, logits, targets, batch_size) -> None: + """Updates last step accuracy and accumulated accuracy with current logits + and labels. It expects the output has logits and labels. + + Args: + logits (:class:`torch.tensor`): The logits output of the model. + targets (:class:`torch.tensor`): Real labels of the dataset. + batch_size (int): Batch size of the task. + """ + if isinstance(logits, (list, tuple)): + logits = logits[0] + if isinstance(targets, (list, tuple)): + targets = targets[0] + # update + correct = self.acc(logits, targets) + + self.last_step_sum.fill_(batch_size) + self.last_step_correct.fill_(correct) + self.accumulated_sum += self.last_step_sum + self.accumulated_correct += self.last_step_correct + + def get_last_step_value(self) -> float: + self.last_step_sum = all_reduce(self.last_step_sum, ParallelMode.DATA) + self.last_step_correct = all_reduce(self.last_step_correct, ParallelMode.DATA) + return _format_number((self.last_step_correct / self.last_step_sum).cpu().item()) + + def get_accumulated_value(self): + self.accumulated_sum = all_reduce(self.accumulated_sum, ParallelMode.DATA) + self.accumulated_correct = all_reduce(self.accumulated_correct, ParallelMode.DATA) + return (self.accumulated_correct / self.accumulated_sum).item() + + @staticmethod + def is_better(a, b) -> bool: + return a > b + + +class MetricHook(BaseHook): + """Specialized hook classes for :class:`Metric`. + Some help metric collectors initialize, reset and + update their states. Others are used to display and + record the metric. + + Args: + priority (int): Priority in the printing, hooks with small priority will be printed in front + defaults to 1. If different hooks share same priority, the order of printing would + depend on the hooks order in the hook list. + """ + + def __init__( + self, + priority: int, + ): + super().__init__(priority) + self._is_stage_to_compute = is_no_pp_or_last_stage() + + def _check_metric_states_initialization(self, trainer): + if 'metrics' not in trainer.states: + self.init_runner_states(trainer, 'metrics', dict(train={}, test={})) + + +@HOOKS.register_module +class LossHook(MetricHook): + """Specialized hook class for :class:`Loss`. + + Args: + priority (int, optional): Priority in the printing, hooks with small priority will be printed in front + defaults to 0. If different hooks share same priority, the order of printing would + depend on the hooks order in the hook list. + """ + + def __init__(self, priority: int = 0): + super().__init__(priority) + + def after_hook_is_attached(self, trainer): + self._check_metric_states_initialization(trainer) + + if self._is_stage_to_compute: + self.train_loss = LossMetric(epoch_only=False) + self.test_loss = LossMetric(epoch_only=True) + + # register the metric calculator + trainer.states['metrics']['train']['Loss'] = self.train_loss + trainer.states['metrics']['test']['Loss'] = self.test_loss + + def before_train_epoch(self, trainer): + if self._is_stage_to_compute: + self.train_loss.reset() + + def after_train_iter(self, trainer, logits, label, loss): + if self._is_stage_to_compute: + self.train_loss.update(loss) + + def before_test_epoch(self, trainer): + if self._is_stage_to_compute: + self.test_loss.reset() + + def after_test_iter(self, trainer, logits, label, loss): + if self._is_stage_to_compute: + self.test_loss.update(loss) + + +@HOOKS.register_module +class AccuracyHook(MetricHook): + """Specialized hook class for :class:`Accuracy`. + + Args: + accuracy_func (:class:`typing.Callable`): Accuracy function for the classification task. + priority (int, optional): Priority in the printing, hooks with small priority will be printed in front + defaults to 0. If different hooks share same priority, the order of printing would + depend on the hooks order in the hook list. + """ + + def __init__(self, accuracy_func: Callable, priority: int = 0): + super().__init__(priority) + self.accuracy_func = accuracy_func + + def after_hook_is_attached(self, trainer): + self._check_metric_states_initialization(trainer) + if self._is_stage_to_compute: + self.metric = AccuracyMetric(epoch_only=True, accuracy_func=self.accuracy_func) + + # register the metric + trainer.states['metrics']['test']['Accuracy'] = self.metric + + def before_test(self, trainer): + if self._is_stage_to_compute: + self.metric.reset() + + def after_test_iter(self, trainer, logits, targets, *args): + if self._is_stage_to_compute: + batch_size = trainer.engine.schedule.batch_size + self.metric.update(logits, targets, batch_size) + + +class ThroughputMetric(Metric): + """Metric for :class:`Throughput`. + + Args: + epoch_only (bool): Whether the metric only read for the full epoch. + """ + + def __init__(self, epoch_only: bool, ignored_steps: int = 0, tflop_per_step: int = 0, use_local: bool = False): + super().__init__(epoch_only=epoch_only) + self.ignored_steps = ignored_steps + self.cur_steps = 0 + self.accumulated_num_samples = torch.zeros(1, device=get_current_device()) + self.accumulated_used_time = torch.zeros(1, device=get_current_device()) + self.last_step_num_samples = torch.zeros(1, device=get_current_device()) + self.last_step_used_time = torch.zeros(1, device=get_current_device()) + self._tflop_per_step = tflop_per_step + self._use_local = use_local + + def reset(self) -> None: + # self.cur_steps = 0 + self.accumulated_num_samples.zero_() + self.accumulated_used_time.zero_() + self.last_step_num_samples.zero_() + self.last_step_used_time.zero_() + + def update(self, num_samples, time) -> None: + self.cur_steps += 1 + self.last_step_num_samples.fill_(num_samples) + self.last_step_used_time.fill_(time) + if self.cur_steps >= self.ignored_steps: + self.accumulated_num_samples += self.last_step_num_samples + self.accumulated_used_time += self.last_step_used_time + + def get_last_step_value(self) -> float: + if self._use_local: + self.last_step_num_samples *= gpc.get_world_size(ParallelMode.DATA) + else: + self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \ + gpc.get_world_size(ParallelMode.DATA) + self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA) + + sample_per_sec = _format_number(self.last_step_num_samples / (self.last_step_used_time + 1e-12).item()) + return sample_per_sec + + def get_last_step_info(self) -> str: + if self._use_local: + self.last_step_num_samples *= gpc.get_world_size(ParallelMode.DATA) + else: + self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \ + gpc.get_world_size(ParallelMode.DATA) + self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA) + + sample_per_sec = _format_number(self.last_step_num_samples / (self.last_step_used_time + 1e-12).item()) + if self._tflop_per_step > 0: + tflops = _format_number(self._tflop_per_step / (self.last_step_used_time.item() + 1e-12)) + return f"{sample_per_sec} sample_per_sec, {tflops} Tflops" + else: + return f"{sample_per_sec} sample_per_sec" + + def get_accumulated_value(self) -> float: + self.accumulated_used_time = all_reduce(self.accumulated_used_time, ParallelMode.DATA) / \ + gpc.get_world_size(ParallelMode.DATA) + self.accumulated_num_samples = all_reduce(self.accumulated_num_samples, ParallelMode.DATA) + return (self.accumulated_num_samples / (self.accumulated_used_time + 1e-12)).item() + + @staticmethod + def is_better(a, b) -> bool: + pass + + +@HOOKS.register_module +class ThroughputHook(MetricHook): + """Specialized hook class for :class:`Throughput`. Hook to measure execution throughput (samples/sec). + + Args: + ignored_steps (int, optional): the number of initial training steps to ignore. + priority (int, optional): Priority in the printing, hooks with small priority will be printed in front + defaults to 10. If different hooks share same priority, the order of printing would + depend on the hooks order in the hook list. + tflop_per_step(int, optional): tera floating point operations per step. + use_local (bool, optional): Whether to use local time for throughput calculation. + """ + + def __init__(self, ignored_steps: int = 0, priority: int = 10, tflop_per_step: int = 0, use_local=False): + super().__init__(priority) + self.ignored_steps = ignored_steps + self._tflop_per_step = tflop_per_step + self._use_local = use_local + + def after_hook_is_attached(self, trainer): + self._check_metric_states_initialization(trainer) + if self._is_stage_to_compute: + self.metric = ThroughputMetric(epoch_only=True, + ignored_steps=self.ignored_steps, + tflop_per_step=self._tflop_per_step, + use_local=self._use_local) + + # register the metric + trainer.states['metrics']['train']['Throughput'] = self.metric + trainer.states['metrics']['test']['Throughput'] = self.metric + + def before_train_epoch(self, trainer): + if self._is_stage_to_compute: + self.metric.reset() + + def after_train_iter(self, trainer, *args): + if self._is_stage_to_compute: + self.metric.update(trainer.engine.schedule.batch_size, + trainer._timer.get_timer('Train-step').get_elapsed_time()) + + def before_test(self, trainer): + if self._is_stage_to_compute: + self.metric.reset() + + def after_test_iter(self, trainer, *args): + if self._is_stage_to_compute: + self.metric.update(trainer.engine.schedule.batch_size, + trainer._timer.get_timer('Test-step').get_elapsed_time()) diff --git a/colossalai/utils/__init__.py b/colossalai/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7b2e8480c66ce09bb5ad99070a2c22c7cb380697 --- /dev/null +++ b/colossalai/utils/__init__.py @@ -0,0 +1,75 @@ +from .activation_checkpoint import checkpoint +from .checkpointing import load_checkpoint, save_checkpoint +from .common import ( + clip_grad_norm_fp32, + conditional_context, + copy_tensor_parallel_attributes, + count_zeros_fp32, + disposable, + ensure_path_exists, + is_ddp_ignored, + is_dp_rank_0, + is_model_parallel_parameter, + is_no_pp_or_last_stage, + is_tp_rank_0, + is_using_ddp, + is_using_pp, + is_using_sequence, + multi_tensor_applier, + param_is_not_tensor_parallel_duplicate, + print_rank_0, + switch_virtual_pipeline_parallel_rank, + sync_model_param, +) +from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize +from .data_sampler import DataParallelSampler, get_dataloader +from .memory import ( + colo_device_memory_capacity, + colo_device_memory_used, + colo_get_cpu_memory_capacity, + colo_set_cpu_memory_capacity, + colo_set_process_memory_fraction, + report_memory_usage, +) +from .tensor_detector import TensorDetector +from .timer import MultiTimer, Timer + +__all__ = [ + 'checkpoint', + 'print_rank_0', + 'sync_model_param', + 'is_ddp_ignored', + 'is_dp_rank_0', + 'is_tp_rank_0', + 'is_no_pp_or_last_stage', + 'is_using_ddp', + 'is_using_pp', + 'is_using_sequence', + 'conditional_context', + 'is_model_parallel_parameter', + 'clip_grad_norm_fp32', + 'count_zeros_fp32', + 'copy_tensor_parallel_attributes', + 'param_is_not_tensor_parallel_duplicate', + 'get_current_device', + 'synchronize', + 'empty_cache', + 'set_to_cuda', + 'report_memory_usage', + 'colo_device_memory_capacity', + 'colo_device_memory_used', + 'colo_set_process_memory_fraction', + 'Timer', + 'MultiTimer', + 'multi_tensor_applier', + 'DataParallelSampler', + 'get_dataloader', + 'switch_virtual_pipeline_parallel_rank', + 'TensorDetector', + 'load_checkpoint', + 'save_checkpoint', + 'ensure_path_exists', + 'disposable', + 'colo_set_cpu_memory_capacity', + 'colo_get_cpu_memory_capacity', +] diff --git a/colossalai/utils/activation_checkpoint.py b/colossalai/utils/activation_checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..fa9ed827a8a7fa649d5689d15a16050c9bc877b6 --- /dev/null +++ b/colossalai/utils/activation_checkpoint.py @@ -0,0 +1,259 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch +from torch.utils.checkpoint import check_backward_validity, detach_variable + +from colossalai.context.random import get_states, get_current_mode, set_seed_states, set_mode, sync_states +from .cuda import get_current_device + +import weakref + + +def copy_to_device(obj, device): + if torch.is_tensor(obj): + # Notice: + # When in no_grad context, requires_gard is False after movement + ret = obj.to(device).detach() + ret.requires_grad = obj.requires_grad + return ret + elif isinstance(obj, list): + return [copy_to_device(i, device) for i in obj] + elif isinstance(obj, tuple): + return tuple([copy_to_device(v, device) for v in obj]) + elif isinstance(obj, dict): + return {k: copy_to_device(v, device) for k, v in obj.items()} + else: + return obj + + +class CheckpointFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, run_function, activation_offload=False, *args): + check_backward_validity(args) + ctx.run_function = run_function + ctx.activation_offload = activation_offload + ctx.device = get_current_device() + + # preserve rng states + ctx.fwd_cpu_rng_state = torch.get_rng_state() + sync_states() + ctx.fwd_seed_states = get_states(copy=True) + ctx.fwd_current_mode = get_current_mode() + + if hasattr(torch, 'is_autocast_enabled'): + ctx.had_autocast_in_fwd = torch.is_autocast_enabled() + else: + ctx.had_autocast_in_fwd = False + + if activation_offload: + inputs_cuda = copy_to_device(args, ctx.device) + else: + inputs_cuda = args + + with torch.no_grad(): + outputs = run_function(*inputs_cuda) + # Save non-tensor inputs in ctx, keep a placeholder None for tensors + # to be filled out during the backward. + ctx.inputs = [] + ctx.tensor_indices = [] + tensor_inputs = [] + for i, arg in enumerate(args): + if torch.is_tensor(arg): + if activation_offload: + tensor_inputs.append(copy_to_device(arg, 'cpu')) + else: + tensor_inputs.append(arg) + ctx.tensor_indices.append(i) + ctx.inputs.append(None) + else: + ctx.inputs.append(arg) + + if activation_offload: + ctx.tensor_inputs = tensor_inputs + else: + ctx.save_for_backward(*tensor_inputs) + return outputs + + @staticmethod + def backward(ctx, *args): + if not torch.autograd._is_checkpoint_valid(): + raise RuntimeError("Checkpointing is not compatible with .grad() or when an `inputs` parameter is " + "passed to .backward(). Please use .backward() and do not pass its `inputs` argument.") + # Copy the list to avoid modifying original list. + inputs = list(ctx.inputs) + tensor_indices = ctx.tensor_indices + + if ctx.activation_offload: + tensors = ctx.tensor_inputs + else: + tensors = ctx.saved_tensors + + # store the current states + bwd_cpu_rng_state = torch.get_rng_state() + sync_states() + bwd_seed_states = get_states(copy=True) + bwd_current_mode = get_current_mode() + + # set the states to what it used to be + torch.set_rng_state(ctx.fwd_cpu_rng_state) + for parallel_mode, state in ctx.fwd_seed_states.items(): + set_seed_states(parallel_mode, state) + set_mode(ctx.fwd_current_mode) + if ctx.activation_offload: + tensors = copy_to_device(tensors, ctx.device) + + # Fill in inputs with appropriate saved tensors. + for i, idx in enumerate(tensor_indices): + inputs[idx] = tensors[i] + detached_inputs = detach_variable(tuple(inputs)) + if ctx.had_autocast_in_fwd: + with torch.enable_grad(), torch.cuda.amp.autocast(): + outputs = ctx.run_function(*detached_inputs) + else: + with torch.enable_grad(): + outputs = ctx.run_function(*detached_inputs) + + if isinstance(outputs, torch.Tensor): + outputs = (outputs,) + # recover the rng states + torch.set_rng_state(bwd_cpu_rng_state) + for parallel_mode, state in bwd_seed_states.items(): + set_seed_states(parallel_mode, state) + set_mode(bwd_current_mode) + + # run backward() with only tensor that requires grad + outputs_with_grad = [] + args_with_grad = [] + for i in range(len(outputs)): + if torch.is_tensor(outputs[i]) and outputs[i].requires_grad: + outputs_with_grad.append(outputs[i]) + args_with_grad.append(args[i]) + if len(outputs_with_grad) == 0: + raise RuntimeError("none of output has requires_grad=True," + " this checkpoint() is not necessary") + torch.autograd.backward(outputs_with_grad, args_with_grad) + grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs) + return (None, None) + grads + + +def checkpoint(function, activation_offload, *args, use_reentrant: bool = True): + """Checkpoint the computation while preserve the rng states, modified from Pytorch torch.utils.checkpoint. + + Args: + function: Describe the forward pass function. It should know how to handle the input tuples. + activation_offload: The variable to check whether we should offload activation to cpu + args (list): Tuple containing the parameters of the function + use_reentrant: Bool type to check if we need to use_reentrant, if use_reentrant=False, there + might be more flexibility for user to define there checkpoint function + + Returns: + Output of running function with provided args. + """ + if use_reentrant: + return CheckpointFunction.apply(function, activation_offload, *args) + else: + return _checkpoint_without_reentrant( + function, + activation_offload, + *args, + ) + + +def _checkpoint_without_reentrant(function, activation_offload=False, *args): + # store rng_state + fwd_cpu_state = torch.get_rng_state() + sync_states() + fwd_seed_states = get_states(copy=True) + fwd_current_mode = get_current_mode() + + # check if use autocast + if hasattr(torch, 'is_autocast_enabled'): + has_autocast_in_fwd = torch.is_autocast_enabled() + else: + has_autocast_in_fwd = False + + # using WeakKeyDictionary to store all the activation the first time we call unpack + storage: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() + weak_holder_list = [] + + # class for weakref.ref + class Holder(): + pass + + # return a Holder object for later unpack process + def pack(x): + res = Holder() + weak_holder_list.append(weakref.ref(res)) + return res + + # unpack hook + def unpack(x): + unpack_counter = 0 + + # re-compute all the activation inside the function when we first call unpack + if len(storage) == 0: + + def inner_pack(inner): + nonlocal unpack_counter + unpack_counter += 1 + + # If the holder went out of scope, the SavedVariable is dead and so + # the value will never be read from the storage. Skip filling it. + if weak_holder_list[unpack_counter - 1]() is None: + return + + # Use detach here to ensure we don't keep the temporary autograd + # graph created during the second forward + storage[weak_holder_list[unpack_counter - 1]()] = inner.detach() + return + + def inner_unpack(packed): + raise RuntimeError("You are calling backwards on a tensor that is never exposed. Please open an issue.") + + # restore rng state + torch.set_rng_state(fwd_cpu_state) + for parallel_mode, state in fwd_seed_states.items(): + set_seed_states(parallel_mode, state) + set_mode(fwd_current_mode) + + # reload arg into device if needed + if activation_offload: + for arg in args: + if torch.is_tensor(arg): + arg = arg.to(device=device) + + # rerun forward, the inner_pack will store all the activations in storage + if has_autocast_in_fwd: + with torch.enable_grad(), \ + torch.cuda.amp.autocast(), \ + torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack): + _unused = function(*args) + else: + with torch.enable_grad(), \ + torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack): + _unused = function(*args) + + if x not in storage: + raise RuntimeError("Attempt to retrieve a tensor saved by autograd multiple times without checkpoint" + " recomputation being triggered in between, this is not currently supported. Please" + " open an issue with details on your use case so that we can prioritize adding this.") + + return storage[x] + + # get device if we need to offload the activation + if activation_offload: + device = get_current_device() + + # run function with pack and unpack as saved_tensors_hooks + with torch.autograd.graph.saved_tensors_hooks(pack, unpack): + output = function(*args) + + # offload activation if needed + if activation_offload: + for arg in args: + if torch.is_tensor(arg): + arg = arg.to(device="cpu") + + return output diff --git a/colossalai/utils/checkpoint/__init__.py b/colossalai/utils/checkpoint/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1795b4ce36f41d2a09da0c324db4cb1ef21c5e2c --- /dev/null +++ b/colossalai/utils/checkpoint/__init__.py @@ -0,0 +1,3 @@ +from .module_checkpoint import save_checkpoint, load_checkpoint + +__all__ = ['save_checkpoint', 'load_checkpoint'] diff --git a/colossalai/utils/checkpoint/module_checkpoint.py b/colossalai/utils/checkpoint/module_checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..d390da864cd387445991260b31331a6978248970 --- /dev/null +++ b/colossalai/utils/checkpoint/module_checkpoint.py @@ -0,0 +1,137 @@ +import torch +import torch.distributed as dist +from colossalai.tensor import ColoTensor +from colossalai.nn.optimizer import ColossalaiOptimizer +from colossalai.utils.checkpoint.utils import gather_tensor, scatter_tensor +from typing import Optional, Dict + + +def save_checkpoint(path: str, + epoch: int, + model: torch.nn.Module, + optimizer: Optional[ColossalaiOptimizer] = None, + lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, + *args, + **kwargs): + """save_checkpoint + save a model, whose parameters are `ColoTensor`s. + Args: + path (str): directory to save the checkpoint files. + epoch (int): the number of epoch + model (torch.nn.Module): a torch module initialized by ColoInitContext + optimizer (ColossalaiOptimizer, optional): optimizers. Defaults to None. + lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None. + """ + rank = dist.get_rank() + model_state = model.state_dict() + # save the dist context about the tensors in a new dict, while still maintain the original dict. + for k, v in model_state.items(): + if isinstance(v, ColoTensor): + gather_tensor(v) # gather shared tensors to rank0 + # don't recover tensors in rank0, since the dict is only a copy of model + + if rank == 0: + # sanity check + for k, v in model_state.items(): + if isinstance(v, ColoTensor): + assert v.save_ready + assert v.is_replicate() + delattr(v, 'save_ready') + # model saving + save_state = {'epoch': epoch, 'model': model_state} + torch.save(save_state, path + '/epoch_{}_model.pth'.format(epoch), *args, **kwargs) + + # delete old dicts + del model_state + # synchronize all the processes + dist.barrier() + + if optimizer is not None: + mapping = dict() + optim_state = optimizer.state_dict() + for k, v in optim_state['state'].items(): + for n, t in v.items(): + if isinstance(t, ColoTensor): + mapping[(k, n)] = t.dist_spec + gather_tensor(t) + + if rank == 0: + save_state = {'epoch': epoch, 'optim': optim_state} + torch.save(save_state, path + '/epoch_{}_optim.pth'.format(epoch), *args, **kwargs) + # recover colo tensors in rank0 + for k, v in optimizer.state_dict()['state'].items(): + for n, t in v.items(): + if isinstance(t, ColoTensor): + assert hasattr(t, 'save_ready') + t.set_dist_spec(mapping[(k, n)]) + delattr(t, 'save_ready') + + del optim_state + del mapping + dist.barrier() + + +def load_checkpoint(path: str, + epoch: int, + model: torch.nn.Module, + optimizer: Optional[ColossalaiOptimizer] = None, + lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, + torch_load_kwargs: Optional[Dict] = None, + load_state_dict_kwargs: Optional[Dict] = None): + """load_checkpoint + load a model, whose parameters are `ColoTensor`s. + Args: + path (str): directory to save the checkpoint files. + epoch (int): the number of epoch + model (torch.nn.Module): a torch module initialized by ColoInitContext + optimizer (ColossalaiOptimizer, optional): optimizers. Defaults to None. + lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None. + torch_load_kwargs: (dict, optional): The kwargs of torch.load inside the function + load_state_dict_kwargs (dict, optional): The kwargs of load_state_dict inside the function + """ + # initialize the default parameters + if not torch_load_kwargs: + torch_load_kwargs = dict() + if not load_state_dict_kwargs: + load_state_dict_kwargs = dict() + + rank = dist.get_rank() + mapping = dict() + for n, p in model.named_parameters(): + if isinstance(p, ColoTensor): + mapping[n] = p.dist_spec + gather_tensor(p) + + if rank == 0: + load_state = torch.load(path + '/epoch_{}_model.pth'.format(epoch), **torch_load_kwargs) + model.load_state_dict(load_state['model'], **load_state_dict_kwargs) + dist.barrier() + + # scatter loaded parameters + for n, p in model.named_parameters(): + if isinstance(p, ColoTensor): + scatter_tensor(p, mapping[n]) + if rank == 0: + assert hasattr(p, 'save_ready') + delattr(p, 'save_ready') + del mapping + + if optimizer is not None: + mapping = dict() + for k, v in optimizer.state_dict()['state'].items(): + for n, t in v.items(): + if isinstance(t, ColoTensor): + mapping[(k, n)] = t.dist_spec + gather_tensor(t) + + if rank == 0: + colo_checkpoint = torch.load(path + '/epoch_{}_optim.pth'.format(epoch), **torch_load_kwargs) + optimizer.load_state_dict(colo_checkpoint['optim'], **load_state_dict_kwargs) + dist.barrier() + + for k, v in optimizer.state_dict()['state'].items(): + for n, t in v.items(): + if isinstance(t, ColoTensor): + scatter_tensor(t, mapping[(k, n)]) + + del mapping diff --git a/colossalai/utils/checkpoint/utils.py b/colossalai/utils/checkpoint/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..682cd0903d5b3b8028ea0741132f256b384d4cce --- /dev/null +++ b/colossalai/utils/checkpoint/utils.py @@ -0,0 +1,63 @@ +import torch +import torch.distributed as dist +from colossalai.tensor import ColoTensor, ColoTensorSpec +from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern + + +def robust_broadcast(tensor): + with torch.no_grad(): + is_cpu_ten = tensor.device.type == 'cpu' + if is_cpu_ten: + b_data = tensor.cuda() + else: + b_data = tensor + + dist.broadcast(b_data, 0) + + if is_cpu_ten: + tensor.copy_(b_data) + + +def gather_tensor(colo_tensor: ColoTensor) -> None: + """Make colo_tensor replicated when the rank is 0 + """ + if not colo_tensor.is_replicate(): + pg = colo_tensor.get_process_group() + # for the group which contains rank 0 + if pg.dp_local_rank() == 0: + old_dist_spec = colo_tensor.dist_spec + colo_tensor.to_replicate_() + if dist.get_rank() != 0: + colo_tensor.set_dist_spec(old_dist_spec) + + # synchronize all processes for unexpected problems + dist.barrier() + + if dist.get_rank() == 0: + setattr(colo_tensor, 'save_ready', True) # set saving signature + + +def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None: + """Reversal operation of `gather_tensor`. + """ + if dist_spec.placement == DistPlacementPattern.REPLICATE: + robust_broadcast(colo_tensor.data) + else: + global_size = colo_tensor.size_global() + + if dist.get_rank() == 0: + entire_data = colo_tensor.data + else: + entire_data = torch.empty(global_size, device=colo_tensor.device) + robust_broadcast(entire_data) + + if dist.get_rank() == 0: + colo_tensor.set_dist_spec(dist_spec) + else: + rep_tensor = ColoTensor( + entire_data, ColoTensorSpec(pg=colo_tensor.get_process_group(), compute_attr=colo_tensor.compute_spec)) + rep_tensor.set_dist_spec(dist_spec) + with torch.no_grad(): + colo_tensor.data.copy_(rep_tensor.data) + # synchronize all processes for unexpected problems + dist.barrier() diff --git a/colossalai/utils/checkpoint_io/__init__.py b/colossalai/utils/checkpoint_io/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fe030866894f7666195f1563f6699ae88aa78a61 --- /dev/null +++ b/colossalai/utils/checkpoint_io/__init__.py @@ -0,0 +1,2 @@ +from .io import load, merge, redist, save +from .meta import (ParamDistMeta, ParamRedistMeta, PipelineRedistMeta, RankRedistMeta, RedistMeta) diff --git a/colossalai/utils/checkpoint_io/backend.py b/colossalai/utils/checkpoint_io/backend.py new file mode 100644 index 0000000000000000000000000000000000000000..140192c05f12cf4843df36d43c66723469ef6cad --- /dev/null +++ b/colossalai/utils/checkpoint_io/backend.py @@ -0,0 +1,74 @@ +import shutil +import tempfile +from abc import ABC, abstractmethod +from typing import Dict, List, Type + +from .reader import CheckpointReader, DiskCheckpointReader +from .writer import CheckpointWriter, DiskCheckpointWriter + +_backends: Dict[str, Type['CheckpointIOBackend']] = {} + + +def register(name: str): + assert name not in _backends, f'"{name}" is registered' + + def wrapper(cls): + _backends[name] = cls + return cls + + return wrapper + + +def get_backend(name: str) -> 'CheckpointIOBackend': + assert name in _backends, f'Unsupported backend "{name}"' + return _backends[name]() + + +class CheckpointIOBackend(ABC): + + def __init__(self) -> None: + super().__init__() + self.temps: List[str] = [] + + @abstractmethod + def get_writer(self, + base_name: str, + overwrite: bool = False, + rank: int = 0, + world_size: int = 1) -> CheckpointWriter: + pass + + @abstractmethod + def get_reader(self, base_name: str) -> CheckpointReader: + pass + + @abstractmethod + def get_temp(self, base_name: str) -> str: + pass + + @abstractmethod + def clean_temp(self) -> None: + pass + + +@register('disk') +class CheckpointDiskIO(CheckpointIOBackend): + + def get_writer(self, + base_name: str, + overwrite: bool = False, + rank: int = 0, + world_size: int = 1) -> CheckpointWriter: + return DiskCheckpointWriter(base_name, overwrite, rank=rank, world_size=world_size) + + def get_reader(self, base_name: str) -> CheckpointReader: + return DiskCheckpointReader(base_name) + + def get_temp(self, base_name: str) -> str: + temp_dir_name = tempfile.mkdtemp(dir=base_name) + self.temps.append(temp_dir_name) + return temp_dir_name + + def clean_temp(self) -> None: + for temp_dir_name in self.temps: + shutil.rmtree(temp_dir_name) diff --git a/colossalai/utils/checkpoint_io/constant.py b/colossalai/utils/checkpoint_io/constant.py new file mode 100644 index 0000000000000000000000000000000000000000..2199484741bf5bb934d5c0583dd55f51f0cdbffe --- /dev/null +++ b/colossalai/utils/checkpoint_io/constant.py @@ -0,0 +1,9 @@ +import re + +GLOBAL_META_FILE_NAME = 'global_meta.bin' +MODEL_CKPT_FILE_NAME = 'model.bin' +OPTIM_CKPT_FILE_NAME = 'optim.bin' +META_CKPT_FILE_NAME = 'meta.bin' +OTHER_CKPT_FILE_NAME = 'other.bin' + +CKPT_PAT = re.compile(r'global_meta|model|optim|meta|other') diff --git a/colossalai/utils/checkpoint_io/convertor.py b/colossalai/utils/checkpoint_io/convertor.py new file mode 100644 index 0000000000000000000000000000000000000000..529ceb86829b511d05911fe6103d5ef53ecf325d --- /dev/null +++ b/colossalai/utils/checkpoint_io/convertor.py @@ -0,0 +1,227 @@ +from abc import ABC, abstractmethod +from collections import defaultdict +from typing import Any, Callable, Dict, List, Optional + +from torch import Tensor + +from .distributed import merge_param, unmerge_param +from .meta import ParamDistMeta, RedistMeta +from .utils import (ModelCheckpointSharder, OptimizerCheckpointSharder, run_if_not_none) + + +class CheckpointConvertor(ABC): + + @abstractmethod + def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None: + pass + + @abstractmethod + def complete(self) -> None: + pass + + +class ModelCheckpointConvertor(CheckpointConvertor): + + def __init__(self, param_count: Dict[str, int]) -> None: + super().__init__() + self.param_count = param_count + self.buffer: Dict[str, Dict[int, Tensor]] = defaultdict(dict) + + @abstractmethod + def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None: + pass + + def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None: + for rank, state_dict in shard_dict.items(): + for k, tensor in state_dict.items(): + self.buffer[k][rank] = tensor + converted_keys = set() + for k, rank_dict in self.buffer.items(): + if len(rank_dict) == self.param_count[k]: + tensors = [] + dist_metas = [] + for rank, tensor in rank_dict.items(): + tensors.append(tensor) + if dist_meta_list[rank] is not None: + dist_metas.append(dist_meta_list[rank][k]) + self.convert_tensors(k, tensors, dist_metas) + converted_keys.add(k) + for k in converted_keys: + del self.buffer[k] + + def complete(self) -> None: + assert len(self.buffer) == 0 + + +class ModelCheckpointMerger(ModelCheckpointConvertor): + + def __init__(self, max_shard_size: int, save_fn: Callable[[dict], Any], param_count: Dict[str, int]) -> None: + super().__init__(param_count) + self.sharder = ModelCheckpointSharder(max_shard_size) + self.save_fn = save_fn + + def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None: + assert len(dist_metas) == len(tensors) + tensor = merge_param(tensors, dist_metas) + shard = self.sharder.append(key, tensor) + run_if_not_none(self.save_fn, shard) + + def complete(self) -> None: + super().complete() + run_if_not_none(self.save_fn, self.sharder.complete()) + + +class ModelCheckpointRedistor(ModelCheckpointConvertor): + + def __init__(self, max_shard_size: int, save_fns: List[Callable[[dict], Any]], param_count: Dict[str, int], + redist_meta: RedistMeta) -> None: + super().__init__(param_count) + self.save_fns = save_fns + self.redist_meta = redist_meta + nprocs = len(save_fns) + self.sharders = [ModelCheckpointSharder(max_shard_size) for _ in range(nprocs)] + self.rank_map = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) + for k, rank_meta in redist_meta.rank_meta.items(): + for rank, rank_info in rank_meta.items(): + self.rank_map[k][rank_info.tp_rank][rank_info.dp_rank].append(rank) + + def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None: + if len(dist_metas) == 0: + # already global + tensor = tensors[0] + else: + assert len(dist_metas) == len(tensors) + tensor = merge_param(tensors, dist_metas) + for tp_rank, tensor_list in enumerate(unmerge_param(tensor, self.redist_meta.param_meta[key])): + for dp_rank, t in enumerate(tensor_list): + for rank in self.rank_map[key][tp_rank][dp_rank]: + shard = self.sharders[rank].append(key, t) + run_if_not_none(self.save_fns[rank], shard) + + def complete(self) -> None: + super().complete() + for rank, save_fn in enumerate(self.save_fns): + run_if_not_none(save_fn, self.sharders[rank].complete()) + + +class OptimizerCheckpointConvertor(CheckpointConvertor): + + def __init__(self, param_count: Dict[str, int], param_to_os: Optional[Dict[str, int]], + paired_os: Optional[Dict[int, dict]]) -> None: + super().__init__() + self.param_count = param_count + self.param_to_os = param_to_os + self.paired_os = paired_os + self.buffer: Dict[int, Dict[int, dict]] = defaultdict(dict) + self.os_to_param = {v: k for k, v in param_to_os.items()} + + @abstractmethod + def setup(self, param_groups: dict) -> None: + pass + + @abstractmethod + def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None: + pass + + def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None: + for rank, state_dict in shard_dict.items(): + self.setup(state_dict['param_groups']) + for idx, state in state_dict['state'].items(): + self.buffer[idx][rank] = state + converted_indices = set() + for idx, rank_dict in self.buffer.items(): + if len(rank_dict) == self.param_count[self.os_to_param[idx]]: + states = [] + dist_metas = [] + for rank, state in rank_dict.items(): + states.append(state) + if dist_meta_list[rank] is not None: + dist_metas.append(dist_meta_list[rank][self.os_to_param[idx]]) + self.convert_states(idx, states, dist_metas) + converted_indices.add(idx) + for idx in converted_indices: + del self.buffer[idx] + + def complete(self) -> None: + assert len(self.buffer) == 0 + + +class OptimizerCheckpointMerger(OptimizerCheckpointConvertor): + + def __init__(self, max_shard_size: int, save_fn: Callable[[dict], Any], param_count: Dict[str, int], + param_to_os: Optional[Dict[str, int]], paired_os: Optional[Dict[int, dict]]) -> None: + super().__init__(param_count, param_to_os, paired_os) + self.max_shard_size = max_shard_size + self.save_fn = save_fn + self.sharder = None + + def setup(self, param_groups: dict) -> None: + if self.sharder is None: + self.sharder = OptimizerCheckpointSharder(self.max_shard_size, param_groups) + + def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None: + assert len(dist_metas) == len(states) + new_state = {} + for state_key, state_tensor in states[0].items(): + if self.paired_os[idx][state_key]: + new_state[state_key] = merge_param([state[state_key] for state in states], dist_metas) + else: + new_state[state_key] = state_tensor + shard = self.sharder.append(idx, new_state) + run_if_not_none(self.save_fn, shard) + + def complete(self) -> None: + super().complete() + run_if_not_none(self.save_fn, self.sharder.complete()) + + +class OptimizerCheckpointRedistor(OptimizerCheckpointConvertor): + + def __init__(self, max_shard_size: int, save_fns: List[Callable[[dict], Any]], param_count: Dict[str, int], + param_to_os: Optional[Dict[str, int]], paired_os: Optional[Dict[int, dict]], + redist_meta: RedistMeta) -> None: + super().__init__(param_count, param_to_os, paired_os) + self.max_shard_size = max_shard_size + self.save_fns = save_fns + self.redist_meta = redist_meta + self.sharders: List[OptimizerCheckpointSharder] = [] + self.rank_map = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) + for k, rank_meta in redist_meta.rank_meta.items(): + for rank, rank_info in rank_meta.items(): + self.rank_map[k][rank_info.tp_rank][rank_info.dp_rank].append(rank) + + def setup(self, param_groups: dict) -> None: + if len(self.sharders) == 0: + nprocs = len(self.save_fns) + for _ in range(nprocs): + self.sharders.append(OptimizerCheckpointSharder(self.max_shard_size, param_groups)) + + def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None: + need_merge: bool = True + if len(dist_metas) == 0: + need_merge = False + else: + assert len(dist_metas) == len(states) + new_states = [{} for _ in range(len(self.save_fns))] + for state_key, state_tensor in states[0].items(): + if self.paired_os[idx][state_key]: + if need_merge: + tensor = merge_param([state[state_key] for state in states], dist_metas) + else: + tensor = state_tensor + for tp_rank, tensor_list in enumerate( + unmerge_param(tensor, self.redist_meta.param_meta[self.os_to_param[idx]])): + for dp_rank, t in enumerate(tensor_list): + for rank in self.rank_map[self.os_to_param[idx]][tp_rank][dp_rank]: + new_states[rank][state_key] = t + else: + for new_state in new_states: + new_state[state_key] = state_tensor + for rank, new_state in enumerate(new_states): + shard = self.sharders[rank].append(idx, new_state) + run_if_not_none(self.save_fns[rank], shard) + + def complete(self) -> None: + super().complete() + for rank, save_fn in enumerate(self.save_fns): + run_if_not_none(save_fn, self.sharders[rank].complete()) diff --git a/colossalai/utils/checkpoint_io/distributed.py b/colossalai/utils/checkpoint_io/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..bf720437c41a9114dc805647e415179a9b1243a9 --- /dev/null +++ b/colossalai/utils/checkpoint_io/distributed.py @@ -0,0 +1,127 @@ +import torch +from numpy import prod +from torch import Tensor +from typing import List, Optional, Tuple +from collections import defaultdict +from .meta import ParamDistMeta, ParamRedistMeta + + +def unflatten_zero_param(tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> Tensor: + assert len(tensors) > 0 and len(dist_metas) > 0 and len(tensors) == len(dist_metas) + for dist_meta in dist_metas[1:]: + assert dist_meta.zero_meta == dist_metas[0].zero_meta, 'Expect all params have the same zero meta.' + if not dist_metas[0].used_zero: + # tensors are replicate + return tensors[0] + numel = dist_metas[0].zero_numel + orig_shape = dist_metas[0].zero_orig_shape + tensors = [t[1] for t in sorted(zip(dist_metas, tensors), key=lambda tp: tp[0].dp_rank)] + assert numel == sum(t.numel() for t in tensors), 'Expect numel of all params is equal to zero_numel.' + return torch.cat(tensors).reshape(orig_shape) + + +def gather_tp_param(tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> Tensor: + assert len(tensors) > 0 and len(dist_metas) > 0 and len(tensors) == len(dist_metas) + for dist_meta in dist_metas[1:]: + assert dist_meta.tp_meta == dist_metas[0].tp_meta, 'Expect all params have the same tp meta.' + for t in tensors[1:]: + assert t.shape == tensors[0].shape, 'Expect all params have the same shape.' + if not dist_metas[0].used_tp: + # tensors are replicate + return tensors[0] + total_parts = prod(dist_meta.tp_num_parts) + assert dist_meta.tp_world_size == total_parts, \ + f'Expect prod(tp_num_parts) == tp_world_size, got {total_parts} and {dist_meta.tp_world_size}.' + shard_info = sorted(zip(dist_meta.tp_shard_dims, dist_meta.tp_num_parts), key=lambda t: t[0], reverse=True) + for dim, num_parts in shard_info: + buffer = [] + for start in range(0, len(tensors), num_parts): + buffer.append(torch.cat(tensors[start:start + num_parts], dim)) + tensors = buffer + assert len(tensors) == 1 + return tensors[0] + + +def validate_parallel_info(dist_metas: List[ParamDistMeta]) -> None: + assert len(dist_metas) > 0 + # check world size + for dist_meta in dist_metas[1:]: + assert dist_meta.dp_world_size == dist_metas[ + 0].dp_world_size, 'Expect all dist meta have the same dp_world_size' + assert dist_meta.tp_world_size == dist_metas[ + 0].tp_world_size, 'Expect all dist meta have the same tp_world_size' + + +def deduplicate_params(tensors: List[Tensor], + dist_metas: List[ParamDistMeta]) -> Tuple[List[Tensor], List[ParamDistMeta]]: + unique_dist_meta = [] + unique_idx = [] + for i, dist_meta in enumerate(dist_metas): + if dist_meta not in unique_dist_meta: + unique_dist_meta.append(dist_meta) + unique_idx.append(i) + return [tensors[i] for i in unique_idx], [dist_metas[i] for i in unique_idx] + + +def merge_param(tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> Tensor: + assert len(tensors) > 0 and len(dist_metas) > 0 and len(tensors) == len(dist_metas) + # validate parallel info + validate_parallel_info(dist_metas) + tensors, dist_metas = deduplicate_params(tensors, dist_metas) + unflattened_tensors = [] + # group zero params by tp rank + tensor_dict = defaultdict(list) + dist_meta_dict = defaultdict(list) + for t, dist_meta in zip(tensors, dist_metas): + tensor_dict[dist_meta.tp_rank].append(t) + dist_meta_dict[dist_meta.tp_rank].append(dist_meta) + assert len(tensor_dict + ) == dist_metas[0].tp_world_size, f'Expect {dist_metas[0].tp_world_size} ranks, got {len(tensor_dict)}' + for tp_rank in tensor_dict.keys(): + unflattened_tensors.append(unflatten_zero_param(tensor_dict[tp_rank], dist_meta_dict[tp_rank])) + return gather_tp_param(unflattened_tensors, [dist_meta_list[0] for dist_meta_list in dist_meta_dict.values()]) + + +def split_tp_param(tensor: Tensor, redist_meta: ParamRedistMeta) -> List[Tensor]: + if not redist_meta.used_tp: + assert redist_meta.tp_world_size == 1, 'Expect tp_world_size == 1, when no tp meta provided.' + return [tensor] + total_parts = prod(redist_meta.tp_num_parts) + assert redist_meta.tp_world_size == total_parts, f'Expect prod(tp_num_parts) == tp_world_size, got {total_parts} and {redist_meta.tp_world_size}.' + shard_info = sorted(zip(redist_meta.tp_shard_dims, redist_meta.tp_num_parts), key=lambda t: t[0]) + tensors = [tensor] + for dim, num_parts in shard_info: + buffer = [] + for t in tensors: + assert t.size(dim) % num_parts == 0, \ + f'Expect dim{dim} of tensor({tensor.shape}) is divisible by {num_parts}.' + chunks = [chunk.contiguous() for chunk in t.chunk(num_parts, dim)] + buffer.extend(chunks) + tensors = buffer + assert len(tensors) == redist_meta.tp_world_size + return tensors + + +def flatten_zero_param(tensor: Tensor, redist_meta: ParamRedistMeta) -> List[Tensor]: + if not redist_meta.used_zero: + return [tensor] * redist_meta.dp_world_size + tensors: List[Optional[Tensor]] = [ + torch.empty(0, dtype=tensor.dtype, device=tensor.device) for _ in range(redist_meta.zero_start_dp_rank) + ] + offsets = redist_meta.zero_offsets + [tensor.numel()] + for i, offset in enumerate(offsets[:-1]): + end = offsets[i + 1] + tensors.append(tensor.view(-1)[offset:end]) + if len(tensors) < redist_meta.dp_world_size: + tensors.extend([ + torch.empty(0, dtype=tensor.dtype, device=tensor.device) + for _ in range(redist_meta.dp_world_size - len(tensors)) + ]) + assert len(tensors) == redist_meta.dp_world_size + return tensors + + +def unmerge_param(tensor: Tensor, redist_meta: ParamRedistMeta) -> List[List[Tensor]]: + tensors = split_tp_param(tensor, redist_meta) + tensors = [flatten_zero_param(t, redist_meta) for t in tensors] + return tensors diff --git a/colossalai/utils/checkpoint_io/io.py b/colossalai/utils/checkpoint_io/io.py new file mode 100644 index 0000000000000000000000000000000000000000..f00212cdf85986ff7b3526a2fd4d56210459dda2 --- /dev/null +++ b/colossalai/utils/checkpoint_io/io.py @@ -0,0 +1,170 @@ +import warnings +from typing import Any, Callable, Dict, Generator, List, Optional, Tuple + +import torch.distributed as dist +from torch.nn import Module +from torch.optim import Optimizer + +from .backend import get_backend +from .convertor import (CheckpointConvertor, ModelCheckpointMerger, ModelCheckpointRedistor, OptimizerCheckpointMerger, + OptimizerCheckpointRedistor) +from .meta import ParamDistMeta, RedistMeta +from .utils import build_checkpoints, optimizer_load_state_dict + + +def save(path: str, + model: Module, + optimizer: Optional[Optimizer] = None, + param_to_os: Optional[Dict[str, int]] = None, + dist_meta: Optional[Dict[str, ParamDistMeta]] = None, + max_shard_size_gb: float = 0.0, + overwrite: bool = False, + backend: str = 'disk', + **kwargs: Any) -> None: + io_backend = get_backend(backend) + if dist.is_initialized(): + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + if world_size == 1: + # global doesn't need dist_meta + dist_meta = None + else: + assert dist_meta is not None + max_shard_size = int(max_shard_size_gb * 1024**3) + model_checkpoints, optimizer_checkpoints, meta_checkpoint = build_checkpoints(max_shard_size, model, optimizer, + param_to_os, dist_meta) + writer = io_backend.get_writer(path, overwrite, rank, world_size) + writer.save_others(kwargs) + for model_checkpoint in model_checkpoints: + writer.save_model(model_checkpoint) + for optimizer_checkpoint in optimizer_checkpoints: + writer.save_optimizer(optimizer_checkpoint) + writer.save_meta(meta_checkpoint) + + +def merge(path: str, + output_path: str, + max_shard_size_gb: float = 0.0, + overwrite: bool = False, + backend: str = 'disk') -> bool: + io_backend = get_backend(backend) + if dist.is_initialized() and dist.get_rank() != 0: + return False + reader = io_backend.get_reader(path) + if len(reader.meta_list) == 1: + # already global + warnings.warn(f'Checkpoint at "{path}" is already global, nothing to do.') + return False + dist_meta_list, param_count, param_to_os, paired_os = reader.load_meta() + writer = io_backend.get_writer(output_path, overwrite=overwrite) + writer.save_others(reader.load_others()) + max_shard_size = int(max_shard_size_gb * 1024**3) + _convert_shards(ModelCheckpointMerger(max_shard_size, writer.save_model, param_count), reader.load_models(), + dist_meta_list) + _convert_shards( + OptimizerCheckpointMerger(max_shard_size, writer.save_optimizer, param_count, param_to_os, paired_os), + reader.load_optimizers(), dist_meta_list) + meta_checkpoint = {'dist_meta': None, 'params': list(param_count.keys())} + if param_to_os is not None: + meta_checkpoint['param_to_os'] = param_to_os + meta_checkpoint['paired_os'] = paired_os + writer.save_meta(meta_checkpoint) + return True + + +def redist(path: str, + output_path: str, + redist_meta: RedistMeta, + dist_metas: List[Dict[str, ParamDistMeta]], + max_shard_size_gb: float = 0.0, + overwrite: bool = False, + backend: str = 'disk') -> bool: + io_backend = get_backend(backend) + if dist.is_initialized() and dist.get_rank() != 0: + return False + nprocs = len(dist_metas) + reader = io_backend.get_reader(path) + dist_meta_list, param_count, param_to_os, paired_os = reader.load_meta() + do_redist: bool = False + if len(dist_meta_list) == nprocs: + for a, b in zip(dist_metas, dist_meta_list): + if a != b: + do_redist = True + break + else: + do_redist = True + if not do_redist: + warnings.warn(f'Checkpoint at "{path}" is not required to redist, nothing to do.') + return False + + writers = [io_backend.get_writer(output_path, overwrite, rank, nprocs) for rank in range(nprocs)] + writers[0].save_others(reader.load_others()) + max_shard_size = int(max_shard_size_gb * 1024**3) + _convert_shards( + ModelCheckpointRedistor(max_shard_size, [writer.save_model for writer in writers], param_count, redist_meta), + reader.load_models(), dist_meta_list) + _convert_shards( + OptimizerCheckpointRedistor(max_shard_size, [writer.save_optimizer for writer in writers], param_count, + param_to_os, paired_os, redist_meta), reader.load_optimizers(), dist_meta_list) + for writer, dist_meta in zip(writers, dist_metas): + meta_checkpoint = {'dist_meta': dist_meta, 'params': list(param_count.keys())} + if param_to_os is not None: + meta_checkpoint['param_to_os'] = param_to_os + meta_checkpoint['paired_os'] = paired_os + writer.save_meta(meta_checkpoint) + return True + + +def _convert_shards(convertor: CheckpointConvertor, shard_generator: Generator[dict, None, None], + dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None: + for shard_dict in shard_generator: + convertor.append(shard_dict, dist_meta_list) + convertor.complete() + + +def load(path: str, + model: Module, + optimizer: Optional[Optimizer] = None, + redist_meta: Optional[RedistMeta] = None, + dist_metas: Optional[List[Dict[str, ParamDistMeta]]] = None, + max_shard_size_gb: float = 0.0, + backend: str = 'disk') -> dict: + is_global: bool = not dist.is_initialized() or dist.get_world_size() == 1 + rank: int = dist.get_rank() if dist.is_initialized() else 0 + is_main_process: bool = rank == 0 + # validate args + if redist_meta is None or dist_metas is None: + assert is_global + io_backend = get_backend(backend) + read_path: str = path + if is_main_process: + # pre-process checkpoints + temp_path = io_backend.get_temp(path) + if is_global: + wrote = merge(path, temp_path, max_shard_size_gb, backend=backend) + else: + wrote = redist(path, temp_path, redist_meta, dist_metas, max_shard_size_gb, backend=backend) + if wrote: + read_path = temp_path + if not is_global: + bcast_list = [read_path] if is_main_process else [None] + dist.broadcast_object_list(bcast_list) + read_path = bcast_list[0] + reader = io_backend.get_reader(read_path) + # load model + for shard in reader.load_model(rank): + model.load_state_dict(shard, strict=False) + if optimizer is not None: + for shard in reader.load_optimizer(rank): + # optimizer.load_state_dict(shard) + optimizer_load_state_dict(optimizer, shard) + others_dict = reader.load_others() + if not is_global: + dist.barrier() + # clean up temp + if is_main_process: + io_backend.clean_temp() + return others_dict diff --git a/colossalai/utils/checkpoint_io/meta.py b/colossalai/utils/checkpoint_io/meta.py new file mode 100644 index 0000000000000000000000000000000000000000..994f08b4b5e44be753aa838a0c03ea693a78cc27 --- /dev/null +++ b/colossalai/utils/checkpoint_io/meta.py @@ -0,0 +1,81 @@ +from dataclasses import dataclass +from typing import List, Optional, Set, Dict + + +@dataclass +class ParamDistMeta: + # parallel info + dp_rank: int + dp_world_size: int + tp_rank: int + tp_world_size: int + # tp info + tp_shard_dims: Optional[List[int]] = None + tp_num_parts: Optional[List[int]] = None + # zero info + zero_numel: Optional[int] = None + zero_orig_shape: Optional[List[int]] = None + + @property + def used_tp(self) -> bool: + return self.tp_shard_dims is not None and self.tp_num_parts is not None + + @property + def used_zero(self) -> bool: + return self.zero_numel is not None and self.zero_orig_shape is not None + + @property + def parallel_meta(self) -> tuple: + return self.dp_rank, self.dp_world_size, self.tp_rank, self.tp_world_size + + @property + def tp_meta(self) -> tuple: + return self.tp_shard_dims, self.tp_num_parts + + @property + def zero_meta(self) -> tuple: + return self.zero_numel, self.zero_orig_shape + + @staticmethod + def from_dict(d: dict) -> 'ParamDistMeta': + return ParamDistMeta(**d) + + +@dataclass +class ParamRedistMeta: + # parallel info + dp_world_size: int + tp_world_size: int + # tp info + tp_shard_dims: Optional[List[int]] = None + tp_num_parts: Optional[List[int]] = None + # zero info + zero_start_dp_rank: Optional[int] = None + zero_offsets: Optional[List[int]] = None + + @property + def used_tp(self) -> bool: + return self.tp_shard_dims is not None and self.tp_num_parts is not None + + @property + def used_zero(self) -> bool: + return self.zero_start_dp_rank is not None and self.zero_offsets is not None + + +@dataclass +class RankRedistMeta: + dp_rank: int + tp_rank: int + pp_rank: int + + +@dataclass +class PipelineRedistMeta: + params: Set[str] + + +@dataclass +class RedistMeta: + rank_meta: Dict[str, Dict[int, RankRedistMeta]] + pipeline_meta: List[PipelineRedistMeta] + param_meta: Dict[str, ParamRedistMeta] diff --git a/colossalai/utils/checkpoint_io/reader.py b/colossalai/utils/checkpoint_io/reader.py new file mode 100644 index 0000000000000000000000000000000000000000..3158c6481263a7a4be4b15425728d2505b9661e2 --- /dev/null +++ b/colossalai/utils/checkpoint_io/reader.py @@ -0,0 +1,131 @@ +import os +from abc import ABC, abstractmethod +from collections import Counter +from typing import Dict, Generator, List, Optional, Tuple + +import torch + +from .constant import GLOBAL_META_FILE_NAME, OTHER_CKPT_FILE_NAME +from .meta import ParamDistMeta +from .utils import is_duplicated_list + + +class CheckpointReader(ABC): + + def __init__(self, base_name: str) -> None: + super().__init__() + self.base_name = base_name + self.meta_list = [] + + @abstractmethod + def read(self, name: str) -> dict: + pass + + @abstractmethod + def load_meta( + self) -> Tuple[List[Optional[Dict[str, ParamDistMeta]]], Dict[str, int], Optional[dict], Optional[dict]]: + pass + + @abstractmethod + def load_model(self, rank: int) -> Generator[dict, None, None]: + pass + + @abstractmethod + def load_models(self) -> Generator[Dict[int, dict], None, None]: + pass + + @abstractmethod + def load_optimizer(self, rank: int) -> Generator[dict, None, None]: + pass + + @abstractmethod + def load_optimizers(self) -> Generator[Dict[int, dict], None, None]: + pass + + @abstractmethod + def load_others(self) -> dict: + pass + + +class DiskCheckpointReader(CheckpointReader): + + def __init__(self, base_name: str) -> None: + super().__init__(base_name) + assert os.path.isdir(base_name), f'"{base_name}" is not a directory' + global_meta = self.read(GLOBAL_META_FILE_NAME) + for meta_file_name in global_meta['meta']: + meta = self.read(meta_file_name) + if meta.get('dist_meta', None) is None: + # only global checkpoint can have empty dist_meta + assert len(global_meta['meta']) == 1 + self.meta_list.append(meta) + + def read(self, name: str) -> dict: + return torch.load(os.path.join(self.base_name, name)) + + def load_meta( + self) -> Tuple[List[Optional[Dict[str, ParamDistMeta]]], Dict[str, int], Optional[dict], Optional[dict]]: + meta_infos = [(meta.get('dist_meta', None), meta['params'], meta.get('param_to_os', + None), meta.get('paired_os', None)) + for meta in self.meta_list] + dist_meta_list, params_list, param_to_os_list, paired_os_list = zip(*meta_infos) + # reduce param_count + param_count = Counter(p for params in params_list for p in params) + # validate param_to_os + assert is_duplicated_list(param_to_os_list) + assert is_duplicated_list(paired_os_list) + return list(dist_meta_list), param_count, param_to_os_list[0], paired_os_list[0] + + def _load_shard(self, shard_type: str, rank: int) -> Generator[dict, None, None]: + meta = self.meta_list[rank] + checkpoint_names = meta.get(shard_type, []) + for name in checkpoint_names: + yield self.read(name) + + def load_model(self, rank: int) -> Generator[dict, None, None]: + return self._load_shard('model', rank) + + def load_models(self) -> Generator[Dict[int, dict], None, None]: + indices = [0] * len(self.meta_list) + while True: + shards = {} + for i, meta in enumerate(self.meta_list): + model_checkpoint_names = meta.get('model', []) + if indices[i] < len(model_checkpoint_names): + shards[i] = self.read(model_checkpoint_names[indices[i]]) + indices[i] += 1 + if len(shards) > 0: + yield shards + else: + break + + def load_optimizer(self, rank: int) -> Generator[dict, None, None]: + param_groups = None + for shard in self._load_shard('optimizer', rank): + if param_groups is None: + param_groups = shard['param_groups'] + else: + shard['param_groups'] = param_groups + yield shard + + def load_optimizers(self) -> Generator[Dict[int, dict], None, None]: + indices = [0] * len(self.meta_list) + param_groups = [] + while True: + shards = {} + for i, meta in enumerate(self.meta_list): + optimizer_checkpoint_names = meta.get('optimizer', []) + if indices[i] < len(optimizer_checkpoint_names): + shards[i] = self.read(optimizer_checkpoint_names[indices[i]]) + if indices[i] == 0: + param_groups.append(shards[i]['param_groups']) + else: + shards[i]['param_groups'] = param_groups[i] + indices[i] += 1 + if len(shards) > 0: + yield shards + else: + break + + def load_others(self) -> dict: + return self.read(OTHER_CKPT_FILE_NAME) diff --git a/colossalai/utils/checkpoint_io/utils.py b/colossalai/utils/checkpoint_io/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..135385f5737947400b8f6ae860f51001ca9742fd --- /dev/null +++ b/colossalai/utils/checkpoint_io/utils.py @@ -0,0 +1,223 @@ +import warnings +from copy import deepcopy +from itertools import chain +from typing import Any, Callable, Dict, List, Optional, Tuple + +from torch import Tensor +from torch.nn import Module +from torch.nn.parameter import Parameter +from torch.optim import Optimizer + +from .meta import ParamDistMeta + + +def run_if_not_none(fn: Callable[[Any], Any], arg: Any) -> Any: + if arg is not None: + return fn(arg) + + +def get_param_to_os(model: Module, optimizer: Optimizer) -> Dict[str, int]: + # ensure all params in optimizer are in model state dict + params_set = set(id(p) for p in model.parameters()) + for group in optimizer.param_groups: + for p in group['params']: + assert id(p) in params_set + param_mappings = {} + start_index = 0 + + def get_group_mapping(group): + nonlocal start_index + param_mappings.update( + {id(p): i for i, p in enumerate(group['params'], start_index) if id(p) not in param_mappings}) + start_index += len(group['params']) + + for g in optimizer.param_groups: + get_group_mapping(g) + return {k: param_mappings[id(p)] for k, p in model.named_parameters()} + + +def compute_optimizer_state_size(state: Dict[str, Any]) -> int: + size = 0 + for v in state.values(): + if isinstance(v, Tensor): + size += v.numel() * v.element_size() + return size + + +class ModelCheckpointSharder: + + def __init__(self, max_shard_size: int) -> None: + self.max_shard_size = max_shard_size + self.buffer: Dict[str, Tensor] = {} + self.buffer_size: int = 0 + + def append(self, key: str, tensor: Tensor) -> Optional[dict]: + retval = None + if self.max_shard_size > 0 and self.buffer_size >= self.max_shard_size: + retval = self.buffer + self.buffer = {} + self.buffer_size = 0 + self.buffer[key] = tensor + self.buffer_size += tensor.numel() * tensor.element_size() + return retval + + def extend(self, state_dict: Dict[str, Tensor]) -> List[dict]: + shards = [] + for key, tensor in state_dict.items(): + shard = self.append(key, tensor) + run_if_not_none(shards.append, shard) + return shards + + def complete(self) -> Optional[dict]: + return self.buffer if len(self.buffer) > 0 else None + + +class OptimizerCheckpointSharder: + + def __init__(self, max_shard_size: int, param_groups: dict) -> None: + self.max_shard_size = max_shard_size + self.buffer: Dict[str, dict] = {'state': {}, 'param_groups': param_groups} + self.buffer_size: int = 0 + self.returned_first: bool = False + + def append(self, key: int, state: dict) -> Optional[dict]: + retval = None + if self.max_shard_size > 0 and self.buffer_size >= self.max_shard_size: + retval = self.buffer + self.buffer = {'state': {}} + self.buffer_size = 0 + self.buffer['state'][key] = state + self.buffer_size += compute_optimizer_state_size(state) + return retval + + def extend(self, state_dict: Dict[str, dict]) -> List[dict]: + shards = [] + for key, state in state_dict['state'].items(): + shard = self.append(key, state) + run_if_not_none(shards.append, shard) + return shards + + def complete(self) -> Optional[dict]: + return self.buffer if len(self.buffer['state']) > 0 else None + + +def shard_checkpoint(max_shard_size: int, + model_state_dict: Dict[str, Tensor], + optimizer_state_dict: Optional[dict] = None, + param_to_os: Optional[dict] = None) -> Tuple[List[dict], List[dict]]: + has_optimizer: bool = False + if optimizer_state_dict is not None: + assert param_to_os is not None + os_to_param = {v: k for k, v in param_to_os.items()} + for os_key in optimizer_state_dict['state'].keys(): + assert os_key in os_to_param + assert os_to_param[os_key] in model_state_dict + has_optimizer = True + model_sharder = ModelCheckpointSharder(max_shard_size) + model_shards = model_sharder.extend(model_state_dict) + run_if_not_none(model_shards.append, model_sharder.complete()) + if not has_optimizer: + return model_shards, [] + optimizer_sharder = OptimizerCheckpointSharder(max_shard_size, optimizer_state_dict['param_groups']) + optimizer_shards = optimizer_sharder.extend(optimizer_state_dict) + run_if_not_none(optimizer_shards.append, optimizer_sharder.complete()) + return model_shards, optimizer_shards + + +def get_paired_os(model_state_dict: Dict[str, Tensor], optimizer_state_dict: dict, param_to_os: Dict[str, int]) -> dict: + os_to_param = {v: k for k, v in param_to_os.items()} + paired_os = {} + for idx, state in optimizer_state_dict['state'].items(): + paired_os[idx] = {} + p = model_state_dict[os_to_param[idx]] + for k, v in state.items(): + if isinstance(v, Tensor) and v.shape == p.shape: + paired_os[idx][k] = True + else: + paired_os[idx][k] = False + return paired_os + + +def build_checkpoints(max_size: int, + model: Module, + optimizer: Optional[Optimizer] = None, + param_to_os: Optional[Dict[str, int]] = None, + dist_meta: Optional[Dict[str, ParamDistMeta]] = None, + eliminate_replica: bool = False) -> Tuple[List[dict], List[dict], dict]: + save_global = dist_meta is None + model_state_dict = model.state_dict() + optimizer_state_dict = optimizer.state_dict() if optimizer else None + meta = {'dist_meta': dist_meta} + if optimizer: + param_to_os = param_to_os or get_param_to_os(model, optimizer) + paired_os = get_paired_os(model_state_dict, optimizer_state_dict, param_to_os) + meta['param_to_os'] = param_to_os + meta['paired_os'] = paired_os + if not save_global and eliminate_replica: + # filter dp replicated params + model_state_dict = { + k: v for k, v in model_state_dict.items() if dist_meta[k].used_zero or dist_meta[k].dp_rank == 0 + } + if optimizer: + optimizer_state_dict['state'] = { + param_to_os[k]: optimizer_state_dict['state'][param_to_os[k]] + for k in model_state_dict.keys() + if dist_meta[k].used_zero or dist_meta[k].dp_rank == 0 + } + meta['params'] = list(model_state_dict.keys()) + if len(model_state_dict) == 0: + warnings.warn('model state dict is empty, checkpoint is not saved') + return [], [], meta + model_checkpoints, optimizer_checkpoints = shard_checkpoint(max_size, model_state_dict, optimizer_state_dict, + param_to_os) + return model_checkpoints, optimizer_checkpoints, meta + + +def is_duplicated_list(list_: List[Any]) -> bool: + if len(list_) == 0: + return True + elem = list_[0] + for x in list_[1:]: + if x != elem: + return False + return True + + +def copy_optimizer_state(src_state: dict, dest_state: dict) -> None: + for k, v in src_state.items(): + if k in dest_state: + old_v = dest_state[k] + if isinstance(old_v, Tensor): + old_v.copy_(v) + else: + dest_state[k] = v + + +def optimizer_load_state_dict(optimizer: Optimizer, state_dict: dict, strict: bool = False) -> None: + assert optimizer.state_dict()['param_groups'] == state_dict['param_groups'] + state_dict = deepcopy(state_dict) + groups = optimizer.param_groups + saved_groups = state_dict['param_groups'] + idx_to_p: Dict[str, Parameter] = { + old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups + )), chain.from_iterable((g['params'] for g in groups))) + } + missing_keys = list(set(idx_to_p.keys()) - set(state_dict['state'].keys())) + unexpected_keys = [] + error_msgs = [] + for idx, state in state_dict['state'].items(): + if idx in idx_to_p: + old_state = optimizer.state[idx_to_p[idx]] + copy_optimizer_state(state, old_state) + else: + unexpected_keys.append(idx) + if strict: + if len(unexpected_keys) > 0: + error_msgs.insert( + 0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in unexpected_keys))) + if len(missing_keys) > 0: + error_msgs.insert( + 0, 'Missing key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in missing_keys))) + if len(error_msgs) > 0: + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(optimizer.__class__.__name__, + "\n\t".join(error_msgs))) diff --git a/colossalai/utils/checkpoint_io/writer.py b/colossalai/utils/checkpoint_io/writer.py new file mode 100644 index 0000000000000000000000000000000000000000..4552accde470d9e0b54b9d658fdcfb6a33e4bc78 --- /dev/null +++ b/colossalai/utils/checkpoint_io/writer.py @@ -0,0 +1,98 @@ +from abc import ABC, abstractmethod +from typing import Optional +from .constant import MODEL_CKPT_FILE_NAME, OPTIM_CKPT_FILE_NAME, META_CKPT_FILE_NAME, OTHER_CKPT_FILE_NAME, GLOBAL_META_FILE_NAME +import torch +import os + + +class CheckpointWriter(ABC): + + def __init__(self, base_name: str, overwrite: bool = False, rank: int = 0, world_size: int = 1) -> None: + super().__init__() + self.base_name = base_name + self.overwrite = overwrite + self.rank = rank + self.world_size = world_size + self.is_distributed = world_size > 1 + self.is_main_process = rank == 0 + + @abstractmethod + def write(self, name: str, state_dict: dict) -> None: + pass + + @abstractmethod + def save_model(self, model_checkpoint: dict) -> None: + pass + + @abstractmethod + def save_optimizer(self, optimizer_checkpoint: dict) -> None: + pass + + @abstractmethod + def save_meta(self, meta_checkpoint: dict) -> None: + pass + + @abstractmethod + def save_others(self, kwargs: dict) -> None: + pass + + +class DiskCheckpointWriter(CheckpointWriter): + + def __init__(self, base_name: str, overwrite: bool = False, rank: int = 0, world_size: int = 1) -> None: + super().__init__(base_name, overwrite, rank, world_size) + if not os.path.exists(base_name): + os.makedirs(base_name) + assert os.path.isdir(base_name), f'"{base_name}" is not a directory' + self.model_checkpoint_names = [] + self.optimizer_checkpoint_names = [] + self.is_meta_saved: bool = False + self._save_global_meta() + + def write(self, name: str, state_dict: dict) -> None: + path = os.path.join(self.base_name, name) + if os.path.exists(path) and not self.overwrite: + raise RuntimeError(f'Save error: Checkpoint "{path}" exists. (overwrite = False)') + torch.save(state_dict, path) + + def _save_global_meta(self) -> None: + if self.is_main_process: + global_meta = {'meta': []} + if self.is_distributed: + for i in range(self.world_size): + global_meta['meta'].append(META_CKPT_FILE_NAME.replace('.bin', f'-rank{i}.bin')) + else: + global_meta['meta'].append(META_CKPT_FILE_NAME) + self.write(GLOBAL_META_FILE_NAME, global_meta) + + def _get_checkpoint_name(self, base_name: str, shard_idx: Optional[int] = None) -> str: + checkpoint_name = base_name + if self.is_distributed: + checkpoint_name = checkpoint_name.replace('.bin', f'-rank{self.rank}.bin') + if shard_idx is not None: + checkpoint_name = checkpoint_name.replace('.bin', f'-shard{shard_idx}.bin') + return checkpoint_name + + def save_model(self, model_checkpoint: dict) -> None: + assert not self.is_meta_saved, 'Cannot save model after saving meta' + name = self._get_checkpoint_name(MODEL_CKPT_FILE_NAME, len(self.model_checkpoint_names)) + self.write(name, model_checkpoint) + self.model_checkpoint_names.append(name) + + def save_optimizer(self, optimizer_checkpoint: dict) -> None: + assert not self.is_meta_saved, 'Cannot save optimizer after saving meta' + name = self._get_checkpoint_name(OPTIM_CKPT_FILE_NAME, len(self.optimizer_checkpoint_names)) + self.write(name, optimizer_checkpoint) + self.optimizer_checkpoint_names.append(name) + + def save_meta(self, meta_checkpoint: dict) -> None: + if len(self.model_checkpoint_names) > 0: + meta_checkpoint['model'] = self.model_checkpoint_names + if len(self.optimizer_checkpoint_names) > 0: + meta_checkpoint['optimizer'] = self.optimizer_checkpoint_names + self.write(self._get_checkpoint_name(META_CKPT_FILE_NAME), meta_checkpoint) + self.is_meta_saved = True + + def save_others(self, kwargs: dict) -> None: + if self.is_main_process: + self.write(OTHER_CKPT_FILE_NAME, kwargs) diff --git a/colossalai/utils/checkpointing.py b/colossalai/utils/checkpointing.py new file mode 100644 index 0000000000000000000000000000000000000000..d1c6b6370ede4420191eba75e3ab7816f99ed499 --- /dev/null +++ b/colossalai/utils/checkpointing.py @@ -0,0 +1,266 @@ +from collections import OrderedDict +from itertools import chain + +import torch +import torch.distributed as dist +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.constants import IS_TENSOR_PARALLEL +try: + from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX +except ImportError: + _EXTRA_STATE_KEY_SUFFIX = '_extra_state' + +from .common import is_using_pp + +__all__ = ["save_checkpoint", "load_checkpoint"] + + +def broadcast_state_dict(state_dict, parallel_mode): + state_dict = [state_dict.copy() if isinstance(state_dict, dict) else state_dict] + src_rank = gpc.get_ranks_in_group(parallel_mode)[0] + dist.broadcast_object_list(state_dict, src=src_rank, group=gpc.get_cpu_group(parallel_mode)) + return state_dict[0] + + +def partition_tensor_parallel_state_dict(state_dict: OrderedDict, + parallel_mode: ParallelMode, + dims: dict = dict(), + partition_states: dict = dict()): + src_rank = gpc.get_ranks_in_group(parallel_mode)[0] + depth = gpc.get_world_size(parallel_mode) + group = gpc.get_cpu_group(parallel_mode) + is_rank0 = gpc.get_local_rank(parallel_mode) == 0 + partition_info = [None] + if is_rank0: + partition_info_dict = OrderedDict() + for key, param in state_dict.items(): + dim = dims[key] + is_partitioned = partition_states[key] + shape = list(param.shape) + if is_partitioned: + shape[dim] = shape[dim] // depth + partition_info_dict[key] = (is_partitioned, param.dtype, shape, dim) + partition_info[0] = partition_info_dict + dist.broadcast_object_list(partition_info, src_rank, group=group) + partitioned_state = OrderedDict() + for key, (is_partitioned, dtype, shape, dim) in partition_info[0].items(): + if is_partitioned: + output = torch.empty(shape, dtype=dtype) + if is_rank0: + scatter_list = [t.contiguous() for t in state_dict[key].chunk(depth, dim)] + else: + scatter_list = None + dist.scatter(output, scatter_list, src_rank, group=group) + else: + if is_rank0: + output = state_dict[key] + else: + output = torch.empty(shape, dtype=dtype) + dist.broadcast(output, src_rank, group=group) + partitioned_state[key] = output + return partitioned_state + + +def gather_tensor_parallel_state_dict( + state_dict: OrderedDict, + parallel_mode: ParallelMode, + dims: dict = dict(), + partition_states: dict = dict(), + keep_vars: bool = False, +): + dst_rank = gpc.get_ranks_in_group(parallel_mode)[0] + depth = gpc.get_world_size(parallel_mode) + + for key in list(state_dict.keys()): + param = state_dict.pop(key) + param = param if keep_vars else param.detach() + dim = dims.get(key, 0) + do_partition = partition_states.get(key, True) + if do_partition: + temp = param.transpose(0, dim).contiguous() + gather_list = None + if gpc.get_local_rank(parallel_mode) == 0: + shape = list(param.shape) + shape[0], shape[dim] = shape[dim], shape[0] + shape[0] *= depth + param = torch.empty(shape, dtype=param.dtype, device=param.device) + gather_list = list(torch.chunk(param, depth, dim=0)) + dist.gather(temp, gather_list, dst=dst_rank, group=gpc.get_cpu_group(parallel_mode)) + param = torch.transpose(param, 0, dim) + # update params in state_dict only on local rank 0 + if gpc.get_local_rank(parallel_mode) == 0: + state_dict[key] = param + + return state_dict + + +def _send_state_dict(state_dict, dst, parallel_mode): + state_tensor, state_size = dist.distributed_c10d._object_to_tensor(state_dict) + dist.send(state_size, dst, group=gpc.get_cpu_group(parallel_mode)) + dist.send(state_tensor, dst, group=gpc.get_cpu_group(parallel_mode)) + + +def _recv_state_dict(src, parallel_mode): + state_size = torch.tensor([0], dtype=torch.long) + dist.recv(state_size, src, group=gpc.get_cpu_group(parallel_mode)) + state_tensor = torch.empty(state_size.item(), dtype=torch.uint8) + dist.recv(state_tensor, src, group=gpc.get_cpu_group(parallel_mode)) + state_dict = dist.distributed_c10d._tensor_to_object(state_tensor, state_size) + return state_dict + + +def partition_pipeline_parallel_state_dict(model, state_dict): + pipeline_state = OrderedDict() + + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # receive all states from prev stage + if not gpc.is_first_rank(ParallelMode.PIPELINE): + state_dict = _recv_state_dict(gpc.get_prev_global_rank(ParallelMode.PIPELINE), ParallelMode.PIPELINE) + # move states to output + for name, _ in model.named_parameters(recurse=True): + if name in state_dict: + pipeline_state[name] = state_dict.pop(name) + for name, _ in model.named_buffers(recurse=True): + if name in state_dict: + pipeline_state[name] = state_dict.pop(name) + for name, _ in model.named_modules(): + extra_state_key = name + "." + _EXTRA_STATE_KEY_SUFFIX + if extra_state_key in state_dict: + pipeline_state[extra_state_key] = state_dict.pop(extra_state_key) + # send rest states to next stage + if not gpc.is_last_rank(ParallelMode.PIPELINE): + _send_state_dict(state_dict, gpc.get_next_global_rank(ParallelMode.PIPELINE), ParallelMode.PIPELINE) + + return pipeline_state + + +def gather_pipeline_parallel_state_dict(state_dict): + gathered_states = ([None for _ in range(gpc.get_world_size(ParallelMode.PIPELINE))] + if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 else None) + dist.gather_object( + state_dict, + gathered_states, + dst=gpc.get_ranks_in_group(ParallelMode.PIPELINE)[0], + group=gpc.get_cpu_group(ParallelMode.PIPELINE), + ) + + state_dict = (OrderedDict(chain.from_iterable(state.items() for state in gathered_states)) + if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 else OrderedDict()) + + return state_dict + + +def save_checkpoint(file, + epoch: int, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer = None, + lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, + **kwargs): + """Stores the checkpoint to disk. Saves all the training components' parameters or buffers, such as model, optimizer, + lr_scheduler etc. into a checkpoint dictionary. + + Args: + file: a file-like object (has to implement write and flush) or a string or os.PathLike object containing a + file name. + epoch (int): Epoch number (indicates how many epochs have you trained this model). + model (:class:`torch.nn.Module`): Model to be saved. + optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to be saved. + lr_scheduler (Union[:class:`torch.optim.lr_scheduler`, :class:`colossalai.nn.lr_scheduler`], optional): + lr_scheduler to be saved, defaults to None. + pickle_module: module used for pickling metadata and objects + pickle_protocol: can be specified to override the default protocol + """ + # ckpt container + checkpoint = {"epoch": epoch} + + model_state = model.state_dict() + if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0: + model_state = gather_pipeline_parallel_state_dict(model_state) + + if gpc.get_global_rank() == 0: + checkpoint["model"] = model_state + + # if optimizer is not None: + # checkpoint['optimizer'] = optimizer.state_dict() + + # if lr_scheduler is not None: + # checkpoint['lr_scheduler'] = lr_scheduler.state_dict() + + torch.save(checkpoint, file, **kwargs) + + +def broadcast_model(model: torch.nn.Module): + src_rank = gpc.get_ranks_in_group(ParallelMode.TENSOR)[0] + for p in model.parameters(): + if not getattr(p, IS_TENSOR_PARALLEL, False) and p.storage().size() > 0: + group = gpc.get_group(ParallelMode.TENSOR) if p.device.type == 'cuda' else gpc.get_cpu_group( + ParallelMode.TENSOR) + dist.broadcast(p, src_rank, group=group) + + +def load_checkpoint( + file, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer = None, + lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, + strict: bool = True, +): + """Loads training states from a checkpoint file. + + Args: + file: a file-like object (has to implement read(), readline(), tell(), and seek()), or a string or os.PathLike + object containing a file name. + model (:class:`torch.nn.Module`): Model to load saved weights and buffers. + optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to recuperate. + lr_scheduler (:class:`torch.optim.lr_scheduler._LRScheduler`, optional): + lr_scheduler to recuperate, defaults to None. + strict (bool, optional): Whether to strictly enforce that the keys in :attr:`state_dict` + of the checkpoint match the names of parameters and buffers in model, defaults to True. + + Returns: + int: The saved epoch number. + + Raises: + RuntimeError: Raise error if the model/optimizer cannot successfully be recuperated + """ + state_dict = (torch.load(file, map_location=torch.device("cpu")) + if gpc.get_local_rank(ParallelMode.MODEL) == 0 else None) + + # model states + model_state = state_dict.pop("model") if state_dict is not None else dict() + # pipeline + if is_using_pp(): + model_state = partition_pipeline_parallel_state_dict(model, model_state) + try: + model.load_state_dict(model_state, strict=strict) + broadcast_model(model) + except RuntimeError as e: + error_msgs = str(e) + if error_msgs.startswith("Error(s) in loading state_dict for "): + error_msgs = error_msgs.split("\n\t")[1:] + dst_rank = gpc.get_ranks_in_group(ParallelMode.MODEL)[0] + all_error_msgs = [None for _ in range(gpc.get_world_size(ParallelMode.MODEL))] + dist.gather_object(error_msgs, all_error_msgs, dst=dst_rank, group=gpc.get_cpu_group(ParallelMode.MODEL)) + if gpc.get_global_rank() == 0: + all_error_msgs = list(chain.from_iterable(all_error_msgs)) + raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format( + model.__class__.__name__, "\n\t".join(all_error_msgs))) + else: + raise e + + # broadcast the rest states + state_dict = broadcast_state_dict(state_dict, ParallelMode.MODEL) + + # # optimizer states + # if optimizer is not None and 'optimizer' in state_dict: + # optimizer.load_state_dict(state_dict['optimizer']) + + # # lr scheduler states + # if lr_scheduler is not None and 'lr_scheduler' in state_dict: + # lr_scheduler.load_state_dict(state_dict['lr_scheduler']) + + # last epoch + last_epoch = state_dict.pop("epoch", -1) + + return last_epoch diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py new file mode 100644 index 0000000000000000000000000000000000000000..95b3b8014af1ac2ad1f8c54bec701c45e5c8f900 --- /dev/null +++ b/colossalai/utils/common.py @@ -0,0 +1,472 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- +import functools +import os +import random +import socket +from collections import defaultdict +from contextlib import contextmanager +from pathlib import Path +from typing import Callable, Dict, List, Optional, Union + +import torch +import torch.distributed as dist +from torch import inf +from torch.nn.parameter import Parameter + +from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.global_variables import tensor_parallel_env as env +from colossalai.tensor import ColoParameter, ProcessGroup + +from .multi_tensor_apply import multi_tensor_applier + +try: + from colossalai._C import fused_optim +except: + fused_optim = None + + +def print_rank_0(msg: str, logger=None): + """Print messages and save logs(optional). This is executed only if you are the rank-0 gpu. + + Args: + msg (str): A string message to output. + logger (:class:`colossalai.logging.DistributedLogger`, optional): + The logger to record the message, defaults to None. + """ + if gpc.get_global_rank() == 0: + if logger is None: + print(msg, flush=True) + else: + logger.info(msg) + + +def ensure_path_exists(filename: str): + # ensure the path exists + dirpath = os.path.dirname(filename) + if not os.path.exists(dirpath): + Path(dirpath).mkdir(parents=True, exist_ok=True) + + +def sync_model_param(model, parallel_mode): + r"""Make sure data parameters are consistent during Data Parallel Mode. + + Args: + model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency. + parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel mode to be checked. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1: + for param in model.parameters(): + ranks = gpc.get_ranks_in_group(parallel_mode) + dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode)) + + +def is_dp_rank_0(): + return not gpc.is_initialized(ParallelMode.DATA) or gpc.is_first_rank(ParallelMode.DATA) + + +def is_tp_rank_0(): + return not gpc.is_initialized(ParallelMode.TENSOR) or gpc.is_first_rank(ParallelMode.TENSOR) + + +def is_no_pp_or_last_stage(): + return not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE) + + +def is_using_ddp(): + return gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1 + + +def is_using_pp(): + return gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1 + + +def is_using_sequence(): + return gpc.is_initialized(ParallelMode.SEQUENCE) and gpc.get_world_size(ParallelMode.SEQUENCE) > 1 + + +@contextmanager +def conditional_context(context_manager, enable=True): + if enable: + with context_manager: + yield + else: + yield + + +class model_branch_context(object): + + def __enter__(self): + self.env_status = env.save() + + def __exit__(self, *exc_info): + env.load(**self.env_status) + + +def is_model_parallel_parameter(p): + return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL) + + +def is_ddp_ignored(p): + return getattr(p, '_ddp_to_ignore', False) + + +def _calc_l2_norm(grads): + # we should not + global fused_optim + + if fused_optim is None: + from colossalai.kernel.op_builder import FusedOptimBuilder + fused_optim = FusedOptimBuilder().load() + + norm = 0.0 + if len(grads) > 0: + dummy_overflow_buf = torch.cuda.IntTensor([0]) + norm, _ = multi_tensor_applier( + fused_optim.multi_tensor_l2norm, + dummy_overflow_buf, + [grads], + False # no per-parameter norm + ) + return norm + + +def _calc_lp(grads, norm_type): + norm = 0.0 + for grad in grads: + grad_norm = torch.norm(grad, norm_type) + norm += grad_norm**norm_type + return norm + + +def _move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]: + if torch.is_tensor(norm) and norm.device.type != 'cuda': + norm = norm.to(torch.cuda.current_device()) + return norm + + +def _get_tensor_norm(norm: Union[float, torch.Tensor], move_to_cuda) -> torch.Tensor: + if isinstance(norm, float): + norm = torch.Tensor([norm]) + if move_to_cuda: + norm = norm.to(torch.cuda.current_device()) + return norm + + +# ======== Gradient Clipping ========= + + +def _compute_local_lp(params: List[ColoParameter], norm_type: float) -> float: + if len(params) == 0: + return 0.0 + grads = [p.grad for p in params] + use_cuda_kernel = grads[0].device.type == 'cuda' + if norm_type == inf: + local_lp = max([g.abs().max() for g in grads]) + elif norm_type == 2.0 and use_cuda_kernel: + local_lp = _calc_l2_norm(grads)**norm_type + else: + local_lp = _calc_lp(grads, norm_type) + if isinstance(local_lp, torch.Tensor): + return local_lp.item() + return local_lp + + +def _compute_buckets_lp(params: List[ColoParameter], norm_type: float) -> float: + if len(params) == 0: + return 0.0 + buckets: Dict[Optional[ProcessGroup], List[ColoParameter]] = defaultdict(list) + for p in params: + if p.is_replicate(): + buckets[None].append(p) + else: + buckets[p.get_process_group().tp_process_group()].append(p) + total_lp = 0.0 + for group, bucket in buckets.items(): + local_lp = _compute_local_lp(bucket, norm_type) + if group is not None: + local_lp_tensor = torch.tensor([local_lp], device=torch.cuda.current_device()) + if norm_type == inf: + dist.all_reduce(local_lp_tensor, op=dist.ReduceOp.MAX, group=group) + else: + dist.all_reduce(local_lp_tensor, group=group) + local_lp = local_lp_tensor.item() + if norm_type == inf: + total_lp = max(total_lp, local_lp) + else: + total_lp += local_lp + return total_lp + + +def _compute_pp_grad_lp(total_lp: float, norm_type: float) -> float: + if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1: + total_lp_tensor = torch.tensor([total_lp], device=torch.cuda.current_device()) + if norm_type == inf: + dist.all_reduce(total_lp_tensor, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.PIPELINE)) + else: + dist.all_reduce(total_lp_tensor, group=gpc.get_group(ParallelMode.PIPELINE)) + total_lp = total_lp_tensor.item() + return total_lp + + +def _compute_grad_lp(parameters, norm_type: float = 2.0) -> float: + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + grad_dtype = None + cpu_grad_params: List[ColoParameter] = [] + cuda_grad_params: List[ColoParameter] = [] + for p in parameters: + if p.grad is None: + continue + assert isinstance(p, ColoParameter) + if grad_dtype is None: + grad_dtype = p.grad.dtype + assert p.grad.dtype == grad_dtype, f'Expected all grads are {grad_dtype}, got {p.grad.dtype}' + if p.grad.device.type == 'cuda': + cuda_grad_params.append(p) + else: + cpu_grad_params.append(p) + norm_type = float(norm_type) + cpu_lp = _compute_buckets_lp(cpu_grad_params, norm_type) + cuda_lp = _compute_buckets_lp(cuda_grad_params, norm_type) + if norm_type == inf: + total_lp = max(cpu_lp, cuda_lp) + else: + total_lp = cpu_lp + cuda_lp + return _compute_pp_grad_lp(total_lp, norm_type) + + +def compute_grad_norm(parameters, norm_type: float = 2.0) -> float: + norm_type = float(norm_type) + total_norm = _compute_grad_lp(parameters, norm_type) + if norm_type != inf: + total_norm = total_norm**(1 / norm_type) + return total_norm + + +def _clip_grad_norm(parameters, max_norm: float, total_norm: float) -> None: + clip_coef = max_norm / (total_norm + 1e-6) + if clip_coef < 1.0: + cuda_grads: List[torch.Tensor] = [] + cpu_grads: List[torch.Tensor] = [] + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + for p in parameters: + if p.grad is None: + continue + if p.grad.device.type == 'cuda': + cuda_grads.append(p.grad.detach()) + else: + cpu_grads.append(p.grad.detach()) + if len(cuda_grads) > 0: + dummy_overflow_buf = torch.cuda.IntTensor([0]) + multi_tensor_applier(fused_optim.multi_tensor_scale, dummy_overflow_buf, [cuda_grads, cuda_grads], + clip_coef) + for g in cpu_grads: + g.mul_(clip_coef) + + +def clip_grad_norm(parameters, max_norm: float, norm_type: float = 2.0) -> float: + total_norm = compute_grad_norm(parameters, norm_type) + _clip_grad_norm(parameters, max_norm, total_norm) + return total_norm + + +def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): + """Clips gradient norm of an iterable of parameters whose gradients are in fp32. + + This is adapted from :func:`torch.nn.utils.clip_grad.clip_grad_norm_` and + added functionality to handle model parallel parameters. + + Note: + the gradients are modified in place. + + Args: + parameters (Iterable[:class:`torch.tensor`] or :class:`torch.tensor`): + An iterable of Tensors or a single Tensor that will have gradients normalized. + max_norm (Union[float, int]): Max norm of the gradients. + norm_type (Union[float, int, 'inf']): Type of the used p-norm. Can be ``'inf'`` for infinity norm. + + Returns: + float: Total norm of the parameters. + """ + + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + + # Filter parameters based on: + # - grad should not be none + # - parameter should not be shared + # - should not be a replica due to tensor model parallelism + params: List[Parameter] = [] + has_zero_shared_param: bool = False + for param in parameters: + if param.grad is not None: + # Make sure the grads are in fp32 + assert param.grad.dtype == torch.float, \ + f'expected gradient to be dtype torch.float, but got {param.grad.type()}' + if hasattr(param, 'colo_attr') and param.colo_attr.sharded_data_tensor.is_sharded: + has_zero_shared_param = True + params.append(param) + + if len(params) == 0: + enable_cuda_kernels = False + else: + enable_cuda_kernels = params[0].grad.device.type == 'cuda' + # Norm parameters. + max_norm = float(max_norm) + norm_type = float(norm_type) + + # Parameters can be on CPU or CUDA + # If parameters are on CPU, disable CUDA kernerls + + # Calculate norm. + if norm_type == inf: + total_norm = max(p.grad.data.abs().max() for p in params) + total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + # Take max across all model-parallel GPUs. + if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1: + dist.all_reduce(total_norm_cuda, + op=dist.ReduceOp.MAX, + group=gpc.get_group(ParallelMode.MODEL), + async_op=False) + if has_zero_shared_param: + dist.all_reduce(total_norm_cuda, + op=dist.ReduceOp.MAX, + group=gpc.get_group(ParallelMode.DATA), + async_op=False) + total_norm = total_norm_cuda[0].item() + else: + tensor_parallel_grads = [] + no_tensor_parallel_grads = [] + zero_sharded_grads = [] + for p in params: + if is_model_parallel_parameter(p): + reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS))**(1 / norm_type) + tensor_parallel_grads.append(p.grad.data / reductor) + elif hasattr(p, 'colo_attr') and p.colo_attr.sharded_data_tensor.is_sharded: + zero_sharded_grads.append(p.grad.data) + else: + no_tensor_parallel_grads.append(p.grad.data) + + if norm_type == 2.0 and enable_cuda_kernels: + tensor_parallel_norm = _calc_l2_norm(tensor_parallel_grads)**norm_type + no_tensor_parallel_norm = _calc_l2_norm(no_tensor_parallel_grads)**norm_type + zero_sharded_norm = _calc_l2_norm(zero_sharded_grads)**norm_type + else: + tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type) + no_tensor_parallel_norm = _calc_lp(no_tensor_parallel_grads, norm_type) + zero_sharded_norm = _calc_lp(zero_sharded_grads, norm_type) + # If norm is type of float, then we convert them into torch.Tensor. + tensor_parallel_norm = _get_tensor_norm(tensor_parallel_norm, enable_cuda_kernels) + no_tensor_parallel_norm = _get_tensor_norm(no_tensor_parallel_norm, enable_cuda_kernels) + zero_sharded_norm = _get_tensor_norm(zero_sharded_norm, enable_cuda_kernels) + # If grads are on CPU, the norms is also on CPU. Cast them to CUDA tensors + if not enable_cuda_kernels: + tensor_parallel_norm = _move_norm_to_cuda(tensor_parallel_norm) + no_tensor_parallel_norm = _move_norm_to_cuda(no_tensor_parallel_norm) + zero_sharded_norm = _move_norm_to_cuda(zero_sharded_norm) + + # Sum across all model-parallel GPUs. + if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0: + dist.all_reduce(tensor_parallel_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR)) + # Sum across all zero sharded GPUs + if len(zero_sharded_grads) > 0: + dist.all_reduce(zero_sharded_norm, group=gpc.get_group(ParallelMode.DATA)) + no_tensor_parallel_norm += zero_sharded_norm + total_norm = tensor_parallel_norm + no_tensor_parallel_norm + if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1: + dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PIPELINE)) + total_norm = total_norm**(1.0 / norm_type) + if torch.is_tensor(total_norm): + total_norm = total_norm.item() + + # Scale. + clip_coeff = max_norm / (total_norm + 1.0e-6) + if clip_coeff < 1.0: + if enable_cuda_kernels: + grads = [p.grad.detach() for p in params] + dummy_overflow_buf = torch.cuda.IntTensor([0]) + multi_tensor_applier(fused_optim.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff) + else: + for p in params: + p.grad.detach().mul_(clip_coeff) + return total_norm + + +def count_zeros_fp32(parameters): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + + # Filter parameters based on: + # - grad should not be none + # - parameter should not be shared + # - should not be a replica due to tensor model parallelism + total_num_zeros = 0.0 + for param in parameters: + grad_not_none = param.grad is not None + is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) + if grad_not_none and is_not_tp_duplicate: + grad = param.grad.detach() + num_zeros = grad.numel() - torch.count_nonzero(grad) + total_num_zeros = num_zeros + total_num_zeros + + total_num_zeros = torch.IntTensor([int(total_num_zeros)]).cuda() + + # Sum across all model-parallel GPUs. + ops = [] + ops.append( + dist.all_reduce(total_num_zeros, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR), async_op=True)) + if gpc.is_initialized(ParallelMode.PIPELINE): + ops.append( + dist.all_reduce(total_num_zeros, + op=dist.ReduceOp.SUM, + group=gpc.get_group(ParallelMode.PIPELINE), + async_op=True)) + + for req in ops: + req.wait() + total_num_zeros = total_num_zeros.item() + + return total_num_zeros + + +def copy_tensor_parallel_attributes(src_tensor, dst_tensor): + for attr in TENSOR_PARALLEL_ATTRIBUTES: + if hasattr(src_tensor, attr): + val = getattr(src_tensor, attr) + setattr(dst_tensor, attr, val) + + +def param_is_not_tensor_parallel_duplicate(param): + return (hasattr(param, IS_TENSOR_PARALLEL) and getattr(param, IS_TENSOR_PARALLEL)) or (gpc.get_local_rank( + ParallelMode.TENSOR) == 0) + + +@contextmanager +def switch_virtual_pipeline_parallel_rank(rank): + prev_rank = gpc.virtual_pipeline_parallel_rank + try: + gpc.set_virtual_pipeline_parallel_rank(rank) + yield + finally: + gpc.set_virtual_pipeline_parallel_rank(prev_rank) + + +def disposable(func: Callable) -> Callable: + executed = False + + @functools.wraps(func) + def wrapper(*args, **kwargs): + nonlocal executed + if not executed: + executed = True + return func(*args, **kwargs) + + return wrapper diff --git a/colossalai/utils/cuda.py b/colossalai/utils/cuda.py new file mode 100644 index 0000000000000000000000000000000000000000..60f3ccb60883e7af56da5f41e1b18c7e20cc098f --- /dev/null +++ b/colossalai/utils/cuda.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch + + +def set_to_cuda(models): + """Send model to gpu. + + :param models: nn.module or a list of module + """ + if isinstance(models, list) and len(models) > 1: + ret = [] + for model in models: + ret.append(model.to(get_current_device())) + return ret + elif isinstance(models, list): + return models[0].to(get_current_device()) + else: + return models.to(get_current_device()) + + +def get_current_device() -> torch.device: + """ + Returns currently selected device (gpu/cpu). + If cuda available, return gpu, otherwise return cpu. + """ + if torch.cuda.is_available(): + return torch.device(f'cuda:{torch.cuda.current_device()}') + else: + return torch.device('cpu') + + +def synchronize(): + """Similar to cuda.synchronize(). + Waits for all kernels in all streams on a CUDA device to complete. + """ + if torch.cuda.is_available(): + torch.cuda.synchronize() + + +def empty_cache(): + """Similar to cuda.empty_cache() + Releases all unoccupied cached memory currently held by the caching allocator. + """ + if torch.cuda.is_available(): + torch.cuda.empty_cache() diff --git a/colossalai/utils/data_sampler/__init__.py b/colossalai/utils/data_sampler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..12798a94c2d063bb120f805967e748c5a1059a3a --- /dev/null +++ b/colossalai/utils/data_sampler/__init__.py @@ -0,0 +1,4 @@ +from .base_sampler import BaseSampler +from .data_parallel_sampler import DataParallelSampler, get_dataloader + +__all__ = ['BaseSampler', 'DataParallelSampler', 'get_dataloader'] diff --git a/colossalai/utils/data_sampler/base_sampler.py b/colossalai/utils/data_sampler/base_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..89f3bca5b1b51925ef7b32e4a08f1df301776fcb --- /dev/null +++ b/colossalai/utils/data_sampler/base_sampler.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from abc import ABC, abstractmethod + + +class BaseSampler(ABC): + + def __init__(self, dataset, batch_size): + self.dataset = dataset + self.batch_size = batch_size + + @abstractmethod + def __len__(self): + pass + + @abstractmethod + def __iter__(self): + pass diff --git a/colossalai/utils/data_sampler/data_parallel_sampler.py b/colossalai/utils/data_sampler/data_parallel_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..945dc54b397a3869232827838d4596e95e188059 --- /dev/null +++ b/colossalai/utils/data_sampler/data_parallel_sampler.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- +# adpated from torch.utils.data.DistributedSampler + +import math +import random +import numpy as np +from typing import TypeVar, Iterator + +import torch +from torch.utils.data import Sampler, Dataset, DataLoader + +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.registry import DATA_SAMPLERS + +T_co = TypeVar('T_co', covariant=True) + + +@DATA_SAMPLERS.register_module +class DataParallelSampler(Sampler): + """A data sampler for distributed data parallelism. + + Args: + dataset (:class:`torch.utils.data.Dataset`): The Dataset for sampling. + shuffle (bool, optional): Whether to shuffle data, defaults to False. + seed (int, optional): The random seed used for sampling, defaults to 0. + drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size + is not divisible by the batch size. If False and the size of dataset is not divisible by + the batch size, then the last batch will be smaller, defaults to False. + """ + + def __init__(self, + dataset: Dataset, + shuffle: bool = False, + seed: int = 0, + drop_last: bool = False) -> None: + self.dataset = dataset + self.num_replicas = gpc.get_world_size(ParallelMode.DATA) + self.rank = gpc.get_local_rank(ParallelMode.DATA) + self.epoch = 0 + self.drop_last = drop_last + # If the dataset length is evenly divisible by # of replicas, then there + # is no need to drop any data, since the dataset will be split equally. + # type: ignore[arg-type] + if self.drop_last and len(self.dataset) % self.num_replicas != 0: + # Split to nearest available length that is evenly divisible. + # This is to ensure each rank receives the same amount of data when + # using this Sampler. + self.num_samples = math.ceil( + # `type:ignore` is required because Dataset cannot provide a default __len__ + # see NOTE in pytorch/torch/utils/data/sampler.py + (len(self.dataset) - self.num_replicas) / \ + self.num_replicas # type: ignore[arg-type] + ) + else: + self.num_samples = math.ceil( + len(self.dataset) / self.num_replicas) # type: ignore[arg-type] + self.total_size = self.num_samples * self.num_replicas + self.shuffle = shuffle + self.seed = seed + + def __iter__(self) -> Iterator[T_co]: + if self.shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + # type: ignore[arg-type] + indices = torch.randperm(len(self.dataset), generator=g).tolist() + + # update for next epoch so that there is no need to call + # set_epoch manually + self.epoch += 1 + else: + indices = list(range(len(self.dataset))) # type: ignore[arg-type] + + if not self.drop_last: + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / + len(indices)))[:padding_size] + else: + # remove tail of data to make it evenly divisible. + indices = indices[:self.total_size] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self) -> int: + return self.num_samples + + def set_epoch(self, epoch: int) -> None: + r"""Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas + use a different random ordering for each epoch. Otherwise, the next iteration of this + sampler will yield the same ordering. + + Args: + epoch (int): Epoch number. + """ + self.epoch = epoch + + +def get_dataloader(dataset, + shuffle=False, + seed=1024, + add_sampler=True, + drop_last=False, + pin_memory=False, + num_workers=0, + **kwargs): + r"""Set up a deterministic dataloader (also configure seed workers, samplers and whether shuffle or not) + + Note: + When pipeline parallel is enabled, shuffle cannot be True as it will result in mismatch between input data + on the 1st stage and label on the last stage. + + Args: + dataset (:class:`torch.utils.data.Dataset`): The dataset to be loaded. + shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. + seed (int, optional): Random worker seed for sampling, defaults to 1024. + add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. + drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size + is not divisible by the batch size. If False and the size of dataset is not divisible by + the batch size, then the last batch will be smaller, defaults to False. + pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. + num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. + kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in + `DataLoader `_. + + Returns: + :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. + """ + _kwargs = kwargs.copy() + + if add_sampler and gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1: + sampler = DataParallelSampler(dataset, shuffle=shuffle) + else: + sampler = None + + # Deterministic dataloader + def seed_worker(worker_id): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + if sampler is None: + return DataLoader(dataset, + worker_init_fn=seed_worker, + shuffle=shuffle, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs) + else: + return DataLoader(dataset, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs) diff --git a/colossalai/utils/memory.py b/colossalai/utils/memory.py new file mode 100644 index 0000000000000000000000000000000000000000..434e90edd3b98fb7f69c502d2d5ebf21e127d0bb --- /dev/null +++ b/colossalai/utils/memory.py @@ -0,0 +1,178 @@ +import torch +import gc +import psutil +from collections import namedtuple + +from colossalai.context.parallel_mode import ParallelMode +from colossalai.utils import get_current_device +from colossalai.core import global_context as gpc +from colossalai.context.parallel_mode import ParallelMode +from colossalai.logging import get_dist_logger +from packaging import version + +_GLOBAL_CUDA_MEM_FRACTION = 1.0 +_GLOBAL_CPU_MEM_CAPACITY = -1 + + +def _bytes_to_MB(val, decimal=2): + """A byte-to-Megabyte converter, default using binary notation. + + :param val: X bytes to convert + :return: X' MB + """ + return round(val / (1024 * 1024), decimal) + + +# copy from PatrickStar +def _get_cpu_memory_info(): + ps_mem_info = namedtuple("ps_mem_info", ["total", "free", "cached", "buffers", "used"]) + try: + # psutil reads the memory info from /proc/memory_info, + # which results in returning the host memory instead of + # that of container. + # Here we try to read the container memory with method in: + # https://stackoverflow.com/a/46213331/5163915 + mems = {} + with open("/sys/fs/cgroup/memory/memory.meminfo", "rb") as f: + for line in f: + fields = line.split() + mems[fields[0]] = int(fields[1]) * 1024 + total = mems[b"MemTotal:"] + free = mems[b"MemFree:"] + cached = mems[b"Cached:"] + buffers = mems[b"Buffers:"] + used = total - free - cached - buffers + if used < 0: + used = total - free + mem_info = ps_mem_info(total=total, free=free, cached=cached, buffers=buffers, used=used) + except FileNotFoundError: + mems = psutil.virtual_memory() + mem_info = ps_mem_info( + total=mems.total, + free=mems.free, + cached=mems.cached, + buffers=mems.buffers, + used=mems.used, + ) + return mem_info + + +def report_memory_usage(message, logger=None, report_cpu=False): + """Calculate and print RAM usage (in GB) + + Args: + message (str): A prefix message to add in the log. + logger (:class:`colossalai.logging.DistributedLogger`): The logger used to record memory information. + report_cpu (bool, optional): Whether to report CPU memory. + + Raises: + EnvironmentError: Raise error if no distributed environment has been initialized. + """ + if not gpc.is_initialized(ParallelMode.GLOBAL): + raise EnvironmentError("No distributed environment is initialized") + + gpu_allocated = _bytes_to_MB(torch.cuda.memory_allocated()) + gpu_max_allocated = _bytes_to_MB(torch.cuda.max_memory_allocated()) + gpu_cached = _bytes_to_MB(torch.cuda.memory_reserved()) + gpu_max_cached = _bytes_to_MB(torch.cuda.max_memory_reserved()) + + full_log = f"{message}: GPU: allocated {gpu_allocated} MB, max allocated {gpu_max_allocated} MB, " \ + + f"cached: {gpu_cached} MB, max cached: {gpu_max_cached} MB" + + if report_cpu: + # python doesn't do real-time garbage collection so do it explicitly to get the correct RAM reports + gc.collect() + vm_stats = psutil.virtual_memory() + vm_used = _bytes_to_MB(vm_stats.total - vm_stats.available) + full_log += f", CPU Virtual Memory: used = {vm_used} MB, percent = {vm_stats.percent}%" + + if logger is None: + logger = get_dist_logger() + logger.info(full_log) + + # get the peak memory to report correct data, so reset the counter for the next call + if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+ + torch.cuda.reset_peak_memory_stats() + + +def colo_device_memory_capacity(device: torch.device) -> int: + """ + Get the capacity of the memory of the device + + Args: + device (torch.device): a device + + Returns: + int: size in byte + """ + assert isinstance(device, torch.device) + if device.type == 'cpu': + # In the context of 1-CPU-N-GPU, the memory capacity of the current process is 1/N overall CPU memory. + return colo_get_cpu_memory_capacity() / gpc.num_processes_on_current_node + if device.type == 'cuda': + return torch.cuda.get_device_properties(get_current_device()).total_memory * _GLOBAL_CUDA_MEM_FRACTION + + +def colo_device_memory_used(device: torch.device) -> int: + """ + Get the device memory on device belonging to the current process. + + Args: + device (torch.device): a device + + Returns: + int: memory size in bytes + """ + if device.type == 'cpu': + mem_info = _get_cpu_memory_info() + # In the context of 1-CPU-N-GPU, the memory usage of the current process is 1/N CPU memory used. + # Each process consumes the same amount of memory. + ret = mem_info.used / gpc.num_processes_on_current_node + return ret + elif device.type == 'cuda': + ret: int = torch.cuda.memory_allocated(device) + # get the peak memory to report correct data, so reset the counter for the next call + if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+ + torch.cuda.reset_peak_memory_stats(device) + return ret + + +def colo_set_process_memory_fraction(ratio: float) -> None: + """colo_set_process_memory_fraction + + set how much cuda memory used on the gpu belonging to the current process. + + Args: + ratio (float): a ratio between 0. ~ 1. + """ + if version.parse(torch.__version__) < version.parse('1.8'): + logger = get_dist_logger('colo_set_process_memory_fraction') + logger.warning('colo_set_process_memory_fraction failed because torch version is less than 1.8') + return + global _GLOBAL_CUDA_MEM_FRACTION + _GLOBAL_CUDA_MEM_FRACTION = ratio + torch.cuda.set_per_process_memory_fraction(_GLOBAL_CUDA_MEM_FRACTION, get_current_device()) + + +def colo_set_cpu_memory_capacity(size: int) -> None: + global _GLOBAL_CPU_MEM_CAPACITY + mem_info = _get_cpu_memory_info() + total_size = mem_info.total + if size <= total_size: + _GLOBAL_CPU_MEM_CAPACITY = size + else: + _GLOBAL_CPU_MEM_CAPACITY = total_size + + +def colo_get_cpu_memory_capacity() -> int: + """ + Get the cpu memory capacity. We may not use all of it. + Returns: + int: _description_ + """ + global _GLOBAL_CPU_MEM_CAPACITY + if _GLOBAL_CPU_MEM_CAPACITY == -1: + mem_info = _get_cpu_memory_info() + return mem_info.total + else: + return _GLOBAL_CPU_MEM_CAPACITY diff --git a/colossalai/utils/model/__init__.py b/colossalai/utils/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/colossalai/utils/model/experimental.py b/colossalai/utils/model/experimental.py new file mode 100644 index 0000000000000000000000000000000000000000..bf3e3d05b99cdc6b8e08922a426d5f3e1096fd07 --- /dev/null +++ b/colossalai/utils/model/experimental.py @@ -0,0 +1,604 @@ +from types import MethodType +from typing import Callable, Optional, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch import Tensor +from torch.utils._pytree import tree_map + +from colossalai._analyzer._subclasses import MetaTensor +from colossalai.tensor.d_tensor.d_tensor import DTensor +from colossalai.tensor.d_tensor.layout import Layout + +# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html +_NORMAL_FACTORY = [ + "arange", + "full", + "empty", + "linspace", + "logspace", + "ones", + "rand", + "randn", + "randint", + "randperm", + "zeros", + "tensor", +] + +# factory function that does not support meta tensor backend +_NO_META_FACTORY = [ + "eye", +] + +_EARLY_MATERIALIZED_OPS = ['__getitem__', 'split'] + +# If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset) +# without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block. +# These ops cannot be unwrapped using .data +_CHANGE_META_OPS = ['_cudnn_rnn_flatten_weight', 'requires_grad_', '__get__', '__set__'] + +_LEGACY_TENSOR_CONSTRUCTOR = { + 'FloatTensor': torch.float, + 'DoubleTensor': torch.double, + 'HalfTensor': torch.half, + 'BFloat16Tensor': torch.bfloat16, + 'ByteTensor': torch.uint8, + 'CharTensor': torch.int8, + 'ShortTensor': torch.short, + 'IntTensor': torch.int, + 'LongTensor': torch.long, + 'BoolTensor': torch.bool, +} + +_EMPTY_DATA = torch.empty(0) + + +class _MyTensor(Tensor): + """This class is only for correctness verification. + """ + _pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None + + def __new__(cls, func, *args, concrete_data=None, **kwargs) -> '_MyTensor': + cls._pre_op_fn() + if concrete_data is not None: + # uniform api as LazyTensor + data = concrete_data + else: + data = func(*args, **kwargs) + return Tensor._make_subclass(cls, data, require_grad=data.requires_grad) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + cls._pre_op_fn() + return super().__torch_function__(func, types, args, kwargs) + + +def _data_tolist(tensor: torch.Tensor) -> list: + """tolist() method is not allowed for a subclass of tensor. Tensor.data returns a Tensor. + """ + return tensor.data.tolist() + + +def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor: + """Convert a lazy tensor's class to target's class, with target's data. + + The reason why we change the class of a lazy tensor in-place is that this can easily handle shared modules/parameters, which is common in huggingface models. + If we create a new tensor and update the module by ``setattr(module, name, param)``, the shared parameters will not be updated. And we have to track all shared parameters and update them manually. + + Args: + tensor (LazyTensor): the LazyTensor to be converted + target (torch.Tensor): target tensor + + Returns: + torch.Tensor: the converted tensor + """ + cls_to_become = nn.Parameter if isinstance(tensor, nn.Parameter) else torch.Tensor + tensor.__class__ = cls_to_become + tensor.data = target + tensor.requires_grad = target.requires_grad + # subclass of torch.Tensor does not have tolist() method + # overwrite this method after materialization or distribution + tensor.tolist = MethodType(_data_tolist, tensor) + return tensor + + +class LazyTensor(torch.Tensor): + """A naive implementation of LazyTensor (https://arxiv.org/pdf/2102.13267.pdf). + + Usage: + 1. Use ``LazyTensor`` instead of ``torch.Tensor``. + >>> x = LazyTensor(torch.zeros, 2, 3) + >>> x += 1 + >>> y = x * x + >>> y = y.cuda().half() + >>> y[0, 0] = 0 + >>> y = y.materialize() # materialize the tensor + >>> print(y) + tensor([[0., 1., 1.], + [1., 1., 1.]], device='cuda:0', dtype=torch.float16) + + Warnings: + 1. Cases that ``LazyTensor`` can't deal with. + >>> x = LazyTensor(torch.ones, 2, 3) + >>> x[0, 0] = -x[0, 0] # this will cause infinite recursion + >>> y = x.clone() + >>> x.add_(1) # modifying origin tensor after cloning leads to wrong materialization + >>> z = x.tolist() + >>> x.zeros_() # modifying origin tensor after cloning tolist is not allowed + >>> nn.utils.weight_norm(self.conv, name="weight", dim=2) # applying weight norm on a lazy tensor is not allowed + + + 2. Cases that ``LazyTensor`` becomes eager (early materialization). + >>> b = a[:, 2:] # get a slice of a lazy tensor triggers early materialization + >>> chunks = a.split(3) # this also triggers early materialization + >>> x.data = torch.rand(2, 3) # directly setting data of a lazy tensor triggers early materialization + + """ + + _repr = True + _meta_data: Optional[MetaTensor] = None # shape, dtype, device + _pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None + + @staticmethod + def __new__(cls, func, *args, meta_data=None, concrete_data=None, **kwargs): + if concrete_data is not None: + # some ops don't support meta backend and should have concrete data + elem = concrete_data + else: + if meta_data is None: + device = kwargs.get('device', 'cpu') + elem = func(*args, **{**kwargs, 'device': 'meta'}) + meta_data = MetaTensor(elem, device=device) + elem = meta_data._tensor + # As a meta tensor cannot be modified __class__ to torch.Tensor, we should use an empty real tensor here + r = torch.Tensor._make_subclass(cls, _EMPTY_DATA, require_grad=elem.requires_grad) + r._meta_data = meta_data + return r + + def __init__(self, func, *args, meta_data=None, concrete_data=None, **kwargs): + self._factory_method = (func, args, kwargs) # (func, args, kwargs) + self._op_buffer = [] # (func, args, kwargs, replace) + self._materialized_data: Optional[torch.Tensor] = concrete_data # materialized data + + def materialize(self) -> torch.Tensor: + """Materialize the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace). + + Returns: + torch.Tensor: The materialized tensor (self). + """ + target = self._materialize_data() + self.clean() + return _convert_cls(self, target) + + def distribute(self, layout: Layout) -> torch.Tensor: + """Distribute the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace), according to the layout. + + Args: + layout (Layout): Distribution layout. + + Returns: + torch.Tensor: The distributed tensor (self). + """ + target = self._materialize_data() + self.clean() + local_tensor = DTensor(target, layout).local_tensor + return _convert_cls(self, local_tensor) + + def clean(self) -> None: + """Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized. + """ + self._factory_method = None + self._op_buffer = None + self._materialized_data = None + self._meta_data = None + + @staticmethod + def _replace_with_materialized(x): + if isinstance(x, LazyTensor): + return x._materialize_data() + return x + + def _materialize_data(self) -> torch.Tensor: + # self._materialized_data should be generated after the first call of this function + if self._materialized_data is None: + # apply factory method + func, args, kwargs = self._factory_method + + # apply cached sequence + self._pre_op_fn() + + try: + init_val = func(*tree_map(self._replace_with_materialized, args), + **tree_map(self._replace_with_materialized, kwargs)) + except TypeError as e: + print(f'init fn: {func.__name__}') + raise e + + self._materialized_data = self._rerun_ops(init_val) + return self._materialized_data + + def _rerun_ops(self, target=None) -> torch.Tensor: + """Do lazy execution by rerunning all (stored) related operations. + + Args: + target (torc.Tensor, optional): Intial value of the target tensor (self). Defaults to None. + """ + + def replace(x): + if x is self: + return target + elif isinstance(x, LazyTensor): + return x._materialize_data() + return x + + packed = None + + for (func, args, kwargs) in self._op_buffer: + if func == torch.Tensor.requires_grad_: + packed = func, args, kwargs # requires grad should be set at last + else: + self._pre_op_fn() + o = func(*tree_map(replace, args), **tree_map(replace, kwargs)) + target = o if isinstance(o, torch.Tensor) else target # if func returns non-Tensor, discard the value + + # super-dainiu: set requires_grad after all inplace-ops are done + if packed is not None: + func, args, kwargs = packed + func(*tree_map(replace, args), **tree_map(replace, kwargs)) + + return target + + # cache everything with __torch_function__ + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + if func.__name__ in _EARLY_MATERIALIZED_OPS: + # These OPs cannot be lazy and related tensors should be early materialized + tree_map(cls._replace_with_materialized, args) + tree_map(cls._replace_with_materialized, kwargs) + is_inplace: bool = (func.__name__.endswith('_') and not (func.__name__.endswith('__')) + or func.__name__ in ('__setitem__', '__set__')) + + is_change_meta_op: bool = func.__name__ in _CHANGE_META_OPS + + if isinstance(func, torch._C.ScriptMethod): + # FIXME(ver217): torch script functions are not verified + + target = None + + def unwrap(x): + if isinstance(x, LazyTensor): + return x._meta_data + return x + + target: LazyTensor = args[0].clone() + target._op_buffer.append((func, args, kwargs)) + target._meta_data = getattr(target._meta_data, func.name)(*tree_map(unwrap, args[1:]), + **tree_map(unwrap, kwargs)) + return target + else: + + meta_to_lazy = {} + + def unwrap(x): + if isinstance(x, LazyTensor): + if x._materialized_data is not None: + # for early materialized tensor, use its materialized data directly + return x._materialized_data if is_change_meta_op else x._materialized_data.data + t = x if is_inplace else x.clone() + t._op_buffer.append((func, args, kwargs)) + meta = x._meta_data if is_change_meta_op else x._meta_data.data + meta_to_lazy[meta] = t + return meta + return x + + def wrap(y, i=None): + if isinstance(y, MetaTensor): + if y in meta_to_lazy: + # inplace op, just return origin lazy tensor + return meta_to_lazy[y] + else: + # out of place op, create new lazy tensor + fn = lambda *a, **kw: func(*a, **kw) if i is None else func(*a, **kw)[i] + lazy_y = LazyTensor(fn, *args, meta_data=y, **kwargs) + return lazy_y + elif type(y) is Tensor: + # for early materialized tensor + return LazyTensor(lambda: None, concrete_data=y) + return y + + cls._pre_op_fn() + o = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) + if isinstance(o, (tuple, list)): + return type(o)(wrap(y, i=i) for i, y in enumerate(o)) + return wrap(o) + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + pass # skip + + def clone(self) -> "LazyTensor": + + def factory_fn(): + # if self is materialized, return self + new_tensor = self.materialize() if type(self) is LazyTensor else self + return new_tensor.clone() + + target = LazyTensor(factory_fn, meta_data=self._meta_data) + + return target + + def detach(self) -> Tensor: + return self + + def __deepcopy__(self, memo): + if not self.is_leaf: + raise RuntimeError("Only Tensors created explicitly by the user " + "(graph leaves) support the deepcopy protocol at the moment") + if id(self) in memo: + return memo[id(self)] + + def factory_fn(): + # if self is materialized, return self + new_tensor = self.materialize() if type(self) is LazyTensor else self + copied = new_tensor.detach().clone() + if new_tensor.requires_grad: + copied.requires_grad_() + return copied + + target = LazyTensor(factory_fn, meta_data=self._meta_data) + + memo[id(self)] = target + return target + + @property + def data(self): + return self + + @data.setter + def data(self, other: 'LazyTensor'): + """This is sightly different from oringinal `data` setter. + + E.g.: + >>> a = torch.randn(3, 3) # a is a Tensor + >>> b = torch.rand(2, 2) + >>> a.data = b + >>> b.add_(1) # this will affect a + >>> x = torch.randn(3, 3) # x is a LazyTensor + >>> y = torch.rand(2, 2) # y is a LazyTensor + >>> x.data = y + >>> y.add_(1) # this will not affect x + + """ + if other is self: + return + + self._op_buffer.append(other._factory_method) + + def replace(x): + if x is other: + return self + return x + + for func, args, kwargs in other._op_buffer: + self._op_buffer.append((func, tree_map(replace, args), tree_map(replace, kwargs))) + + def tolist(self) -> list: + # Though self.__class__ is modified to torch.Tensor, in C++ side, it is still a subclass of torch.Tensor + # And subclass of torch.Tensor does not have tolist() method + t = self._materialize_data() + return t.tolist() + + def __hash__(self): + return id(self) + + +class LazyInitContext: + """Context manager for lazy initialization. Enables initializing the model without allocating real memory. + + Usage: + 1. The model is initialized, but no real memory is allocated. + >>> ctx = LazyInitContext() + >>> with ctx: + >>> model = MyModel().cuda() + + 2. The model is initialized with ``MetaTensor`` as weights, but still no real memory is allocated. + >>> with ctx.traceable(model): + >>> gm = symbolic_trace(model, meta_args=meta_args) + >>> # Solve the execution strategy and apply the strategy to the model + >>> strategy = StrategyAndSpec() + + 3. The model is initialized with ``torch.Tensor`` as weights, and real memory is allocated. (single device) + >>> model = ctx.materialize(model) + + 3. The model is initialized with sharded ``torch.Tensor`` as weights, and real memory is allocated. (distributed scenario) + >>> model = apply_strategy_to_all_params(model, strategy) + >>> model = ctx.distribute(model) + + Warnings: + This API is still experimental and further modifications can be made to it. + For example: + 1. Quantization strategies can be applied before allocating real memory. + 2. Lazy initialization seems slower than normal initialization. + """ + _replaced: bool = False + + def __init__(self, tensor_cls: Union[_MyTensor, LazyTensor] = LazyTensor): + self.overrides = {} + self.tensor_cls = tensor_cls + + def __enter__(self): + if LazyInitContext._replaced: + raise RuntimeError(f'LazyInitContext is not reentrant') + LazyInitContext._replaced = True + + def wrap_factory_method(target): + # factory functions (eg. torch.empty()) + def wrapper(*args, **kwargs): + return self.tensor_cls(target, *args, **kwargs) + + return wrapper, target + + def wrap_factory_like_method(orig_target, target): + # factory_like functions (eg. torch.empty_like()) + def wrapper(*args, **kwargs): + orig_t = args[0] + return self.tensor_cls(orig_target, *args[1:], device=orig_t.device, dtype=orig_t.dtype, **kwargs) + + return wrapper, target + + def wrap_legacy_constructor(target, dtype): + # legacy constructor (e.g. torch.LongTensor()) + def wrapper(*args, **kwargs): + if len(args) == 1 and isinstance(args[0], torch.Tensor): + # (Tensor other) + return args[0] + elif len(args) == 1: + # (object data, *, torch.device device) + kwargs = {**kwargs, 'dtype': dtype} + replaced, orig = self.overrides['tensor'] + return replaced(*args, **kwargs) + elif _is_int_tuple(args): + # (tuple of ints size, *, torch.device device) + kwargs = {**kwargs, 'dtype': dtype} + replaced, orig = self.overrides['empty'] + return replaced(*args, **kwargs) + else: + raise TypeError( + f'new() received an invalid combination of arguments - got {tuple(type(x) for x in args)}, but expected one of:\n * (Tensor other)\n * (tuple of ints size, *, torch.device device)\n * (object data, *, torch.device device)' + ) + + return wrapper, target + + def wrap_no_meta_factory(target): + # factory functions which don't support meta tensor backend + def wrapper(*args, **kwargs): + tensor = target(*args, **kwargs) + return self.tensor_cls(lambda: None, concrete_data=tensor) + + return wrapper, target + + self.overrides = { + target: wrap_factory_method(getattr(torch, target)) + for target in _NORMAL_FACTORY + if callable(getattr(torch, target, None)) + } + + self.overrides.update({ + target + '_like': wrap_factory_like_method(getattr(torch, target), getattr(torch, target + '_like')) + for target in _NORMAL_FACTORY + if callable(getattr(torch, target + '_like', None)) + }) + + self.overrides.update({ + target: wrap_legacy_constructor(getattr(torch, target), dtype) + for target, dtype in _LEGACY_TENSOR_CONSTRUCTOR.items() + if callable(getattr(torch, target, None)) + }) + + self.overrides.update({ + target: wrap_no_meta_factory(getattr(torch, target)) + for target in _NO_META_FACTORY + if callable(getattr(torch, target, None)) + }) + + for name, (wrapper, orig) in self.overrides.items(): + setattr(torch, name, wrapper) + + def __exit__(self, exc_type, exc_val, exc_tb): + LazyInitContext._replaced = False + for name, (wrapper, orig) in self.overrides.items(): + setattr(torch, name, orig) + + @staticmethod + def materialize(module: nn.Module, verbose: bool = False) -> nn.Module: + """Initialize all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place. + + Args: + module (nn.Module): Target ``nn.Module`` + verbose (bool): Whether to print lazy initialization rate. Defaults to False. + """ + + def apply_fn(name: str, p: LazyTensor): + p.materialize() + + return _apply_to_lazy_module(module, apply_fn, verbose) + + @staticmethod + def distribute(module: nn.Module, layout_dict: dict, verbose: bool = False) -> nn.Module: + """Distribute all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place. + + Args: + module (nn.Module): Target ``nn.Module`` + layout_dict (dict): Dict of layout for each parameter/buffer. The key is the parameter/buffer name, and the value is the layout. + verbose (bool, optional): Whether to print lazy initialization rate. Defaults to False. + """ + + def apply_fn(name: str, p: LazyTensor): + p.distribute(layout_dict[name]) + + return _apply_to_lazy_module(module, apply_fn, verbose) + + +def _apply_to_lazy_module(module: nn.Module, + apply_fn: Callable[[str, torch.Tensor], None], + verbose: bool = False) -> nn.Module: + if verbose: + # verbose info + param_cnt = 0 + param_lazy_cnt = 0 + buf_cnt = 0 + buf_lazy_cnt = 0 + total_numel = 0 + non_lazy_numel = 0 + + for name, p in module.named_parameters(): + if verbose: + param_cnt += 1 + total_numel += p.numel() + if getattr(p, '_materialized_data', False) is None: + # if no _materialized_data attr, the tensor is not lazy + param_lazy_cnt += 1 + else: + non_lazy_numel += p.numel() + if isinstance(p, LazyTensor): + apply_fn(name, p) + + for name, buf in module.named_buffers(): + if verbose: + buf_cnt += 1 + total_numel += buf.numel() + if getattr(buf, "_materialized_data", False) is None: + # if no _materialized_data attr, the tensor is not lazy + buf_lazy_cnt += 1 + else: + non_lazy_numel += buf.numel() + if isinstance(buf, LazyTensor): + apply_fn(name, buf) + + if verbose: + non_lazy_numel_ratio = non_lazy_numel / total_numel * 100 if non_lazy_numel != 0 else 0 + _print_rank_0(f'Param lazy rate: {param_lazy_cnt}/{param_cnt}') + _print_rank_0(f'Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}') + _print_rank_0( + f'Non lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M), ratio: {non_lazy_numel_ratio}%') + + return module + + +def _print_rank_0(*args, **kwargs): + if not dist.is_initialized() or dist.get_rank() == 0: + print(*args, **kwargs) + + +def _is_int_tuple(args) -> bool: + if not isinstance(args, tuple): + return False + for x in args: + if not isinstance(x, int): + return False + return True diff --git a/colossalai/utils/model/lazy_init_context.py b/colossalai/utils/model/lazy_init_context.py new file mode 100644 index 0000000000000000000000000000000000000000..cf05f966089d16884166469cff299e6991192097 --- /dev/null +++ b/colossalai/utils/model/lazy_init_context.py @@ -0,0 +1,242 @@ +#!/usr/bin/env python +# coding: utf-8 + +import inspect +import types +from typing import Callable, List + +import torch +import torch.nn as nn + +from colossalai.tensor import ColoParameter, ColoTensor +from colossalai.utils.model.utils import substitute_init_recursively + + +class LazyInitContext(): + """ + A context to allow for lazy weight initialization of PyTorch modules. It intercepts the tensor + initialization functions for lazy initialization + + Note: + This API is only experimental and subject to future changes. + + Usage: + with LazyInitContext() as ctx: + model = nn.Linear(10, 10) + model.weight.zero_() + + # make sure the weight is a meta tensor + assert model.weight.is_meta + + # initialize weights + ctx.lazy_init_parameters(model) + + # make sure the weight is not a meta tensor + # and initialized correctly + assert not model.weight.is_meta and torch.all(model.weight == 0) + + Args: + to_meta (bool): optional, whether to initialize the model with meta tensors, default is True. This + argument exists for now because some corner cases such as self.weight = torch.zeros(...) cannot be captured yet. + extra_torch_tensor_func (List[str]): extra torch tensor functions related + to value setting, such as `zero_` and `triu_`. `zero_` is pre-added by default. + """ + + tensor_set_value_func = ['zero_', 'fill_'] + + def __init__(self, to_meta: bool = True, extra_torch_tensor_func: List[str] = None): + # TODO: hijack the torch constructor functions as well + self._to_meta = to_meta + self._intercepted_nn_init_func_cache = {} + self._nn_init_methods = self._get_nn_init_methods() + self._torch_mod_cls = torch.nn.modules.module.Module + + if extra_torch_tensor_func: + # use tuple to remove duplicates + self._torch_tensor_funcs = tuple(self.tensor_set_value_func + extra_torch_tensor_func) + else: + self._torch_tensor_funcs = self.tensor_set_value_func + + @property + def to_meta(self): + return self._to_meta + + def _cache_init_func(self, func): + """ + This method wraps the ``torch.nn.init`` method and torch tensor value-setting functions + so that the function call is cached instead of being executed. + """ + + def wrapped_init_func(tensor, *args, **kwargs): + if tensor not in self._intercepted_nn_init_func_cache: + self._intercepted_nn_init_func_cache[tensor] = [] + self._intercepted_nn_init_func_cache[tensor].append((func, args, kwargs)) + + return wrapped_init_func + + def _get_nn_init_methods(self): + """ + This method looks for all available functions in the ``torch.nn.init`` + module. + """ + nn_init_method_names = dir(torch.nn.init) + nn_init_methods = [] + + # look for all methods in ``torch.nn.init`` module + for name in nn_init_method_names: + nn_init_methods.append((name, getattr(torch.nn.init, name))) + + def _is_init_method(item): + name, func = item + + if (not isinstance(func, types.FunctionType) or name.startswith('_') or not name.endswith('_')): + return False + else: + return True + + # remove methods which are not init functions + nn_init_methods = list(filter(_is_init_method, nn_init_methods)) + return nn_init_methods + + def _wrap_module_init(self, func): + """ + This method wraps the calls to the `__init__` of ``torch.nn.Module`` and replaces + the argument device with value 'meta' so that all modules are created as meta tensors. + """ + has_device = 'device' in inspect.signature(func).parameters + + def layer_lazy_init(module, *args, **kwargs): + # if this module contains device argument + # we set it to meta to initialize as meta backend + if has_device: + kwargs['device'] = 'meta' + func(module, *args, **kwargs) + + # if device is not found, we intialize it and convert to meta + if not has_device: + module.to('meta') + + return layer_lazy_init + + def _get_tmp_origin_func_ref(self, name): + """ + Generate a function name for consistency during caching and retrieving. + """ + return f'_orig_{name}' + + def _patch_nn_init_funcs(self): + # patch nn.init functions + for name, func in self._nn_init_methods: + setattr(torch.nn.init, name, self._cache_init_func(func)) + + def _unpatch_nn_init_funcs(self): + # unpatch nn.init functions + for name, func in self._nn_init_methods: + setattr(torch.nn.init, name, func) + + def _patch_submodule_init(self): + # patch classes __init__ methods + def _activate_wrap_init(cls): + cls.__orig_init__ = cls.__init__ + cls.__init__ = self._wrap_module_init(cls.__init__) + + substitute_init_recursively(self._torch_mod_cls, _activate_wrap_init, set()) + + def _unpatch_submodule_init(self): + + def _recover_orig_init(cls): + cls.__init__ = cls.__orig_init__ + + substitute_init_recursively(self._torch_mod_cls, _recover_orig_init, set()) + + def _patch_torch_tensor_funcs(self): + # patch tensor value-setting functions + for func_name in self._torch_tensor_funcs: + origin_func_name = self._get_tmp_origin_func_ref(func_name) + origin_func = getattr(torch.Tensor, func_name) + setattr(torch.Tensor, origin_func_name, origin_func) + setattr(torch.Tensor, func_name, self._cache_init_func(origin_func)) + + def _unpatch_torch_tensor_funcs(self): + for func_name in self._torch_tensor_funcs: + origin_func_name = self._get_tmp_origin_func_ref(func_name) + origin_func = getattr(torch.Tensor, origin_func_name) + setattr(torch.Tensor, func_name, origin_func) + + def __enter__(self): + self._patch_torch_tensor_funcs() + self._patch_nn_init_funcs() + + if self._to_meta: + self._patch_submodule_init() + return self + + def __exit__(self, *args, **kwargs): + if self._to_meta: + self._unpatch_submodule_init() + self._unpatch_nn_init_funcs() + self._unpatch_torch_tensor_funcs() + + def lazy_init_parameters(self, model: torch.nn.Module, device='cpu'): + """ + Initialize the weights of the meta-tensor model. + + Args: + model (`torch.nn.Module`): the model instantiated under the context. + device (str): the device on which weights are initialized + + """ + + def _init_recursively(module: nn.Module): + # recursively initialize the module + for mod in module.children(): + _init_recursively(mod) + + # initialize and shard tensors directly attached to the current module + for name, param in module.named_parameters(recurse=False): + _init_and_shard(module, name, param) + + for name, buf in module.named_buffers(recurse=False): + _init_and_shard(module, name, buf) + + @torch.no_grad() + def _init_and_shard(module, name, tensor): + # check whether the tensor is a buffer or parameter + is_param = isinstance(tensor, nn.parameter.Parameter) + + # get sharding spec + dist_spec = getattr(tensor, 'dist_spec', None) + pg = getattr(tensor, 'pg', None) + comp_spec = getattr(tensor, 'comp_spec', None) + + # convert the tensor from meta to materialized one + if tensor.is_meta: + materialized_tensor = torch.empty_like(tensor, device=device) + # if this tensor is a meta tensor, it must have an init function + assert tensor in self._intercepted_nn_init_func_cache + else: + materialized_tensor = tensor + + # apply init function + if tensor in self._intercepted_nn_init_func_cache: + init_func, args, kwargs = self._intercepted_nn_init_func_cache[tensor][-1] + init_func(materialized_tensor, *args, **kwargs) + + # convert it to ColoTensor or ColoParameter + if is_param: + tensor = ColoParameter.from_torch_tensor(materialized_tensor, requires_grad=tensor.requires_grad) + else: + tensor = ColoTensor.from_torch_tensor(materialized_tensor) + + # override the original tensor + with torch.no_grad(): + setattr(module, name, tensor) + + # apply sharding + if dist_spec: + tensor.process_group = pg + tensor.set_tensor_spec(dist_spec, comp_spec) + + _init_recursively(model) + + return model diff --git a/colossalai/utils/model/utils.py b/colossalai/utils/model/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f49607376439f6e48797a45c43afceb3fdd27224 --- /dev/null +++ b/colossalai/utils/model/utils.py @@ -0,0 +1,115 @@ +# This code has been adapted from the DeepSpeed library. +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import functools +from typing import Optional + +import torch + + +def substitute_init_recursively(cls, func, visited: set): + for subcls in cls.__subclasses__(): + substitute_init_recursively(subcls, func, visited) + if subcls not in visited: + func(subcls) + visited.add(subcls) + + +def call_to_str(base, *args, **kwargs): + """Construct a string representation of a call. + + Args: + base (str): name of the call + args (tuple, optional): args to ``base`` + kwargs (dict, optional): kwargs supplied to ``base`` + + Returns: + str: A string representation of base(*args, **kwargs) + """ + name = f'{base}(' + if args: + name += ', '.join(repr(arg) for arg in args) + if kwargs: + name += ', ' + if kwargs: + name += ', '.join(f'{key}={repr(arg)}' for key, arg in kwargs.items()) + name += ')' + return name + + +class InsertPostInitMethodToModuleSubClasses(object): + + def __init__(self, default_dtype: Optional[torch.dtype] = None): + self._old_default_dtype = None + self._default_dtype = default_dtype + + def __enter__(self): + r""" + Enter the context scope. + """ + if self._default_dtype is not None: + self._old_default_dtype = torch.get_default_dtype() + torch.set_default_dtype(self._default_dtype) + + def preprocess_after(f): + + @functools.wraps(f) + def wrapper(module: torch.nn.Module, *args, **kwargs): + f(module, *args, **kwargs) + self._post_init_method(module, *args, **kwargs) + + return wrapper + + def _enable_class(cls): + cls._old_init = cls.__init__ + cls.__init__ = preprocess_after(cls.__init__) + + # The function is called during init subclass. + def _init_subclass(cls, **kwargs): + cls.__init__ = preprocess_after(cls.__init__) + + # Replace .__init__() for all existing subclasses of torch.nn.Module + # Excution self._post_init_method after the default init function. + substitute_init_recursively(torch.nn.modules.module.Module, _enable_class, set()) + + # holding on to the current __init__subclass__ for exit + torch.nn.modules.module.Module._old_init_subclass = (torch.nn.modules.module.Module.__init_subclass__) + # Replace .__init__() for future subclasses of torch.nn.Module + torch.nn.modules.module.Module.__init_subclass__ = classmethod(_init_subclass) + + self._pre_context_exec() + return self + + def __exit__(self, exc_type, exc_value, traceback): + + if self._default_dtype is not None: + torch.set_default_dtype(self._old_default_dtype) + + def _disable_class(cls): + if not hasattr(cls, '_old_init'): + raise AttributeError( + f"_old_init is not found in the {cls.__name__}, please make sure that you have imported {cls.__name__} before entering the context." + ) + cls.__init__ = cls._old_init + + # Replace .__init__() for all existing subclasses of torch.nn.Module + substitute_init_recursively(torch.nn.modules.module.Module, _disable_class, set()) + + # Replace .__init__() for future subclasses of torch.nn.Module + torch.nn.modules.module.Module.__init_subclass__ = (torch.nn.modules.module.Module._old_init_subclass) + + self._post_context_exec() + # Now that we cleaned up the metaclass injection, raise the exception. + if exc_type is not None: + return False + + # To be implemented by inheriting classes + def _post_init_method(self, module, *args, **kwargs): + pass + + def _pre_context_exec(self): + pass + + def _post_context_exec(self): + pass diff --git a/colossalai/utils/moe.py b/colossalai/utils/moe.py new file mode 100644 index 0000000000000000000000000000000000000000..86d04c11958b0f03389cf5ae6e12da4e487fca98 --- /dev/null +++ b/colossalai/utils/moe.py @@ -0,0 +1,52 @@ +import torch.nn as nn +import torch.distributed as dist +from colossalai.core import global_context as gpc +from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.context import ParallelMode +from .common import is_using_ddp +from typing import Dict, List + + +def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]]: + """Returns a parameter dictionary, the key of which is the expert parallel + size of every parameter. Since the parameters in data parallelism is replicated + in each GPU, we set their ep_size to 1. + + Args: + model (:class:`torch.nn.Module`): A pyTorch `nn.Module` from which we get dict. + """ + epsize_param_dict = dict() + for param in model.parameters(): + if not hasattr(param, 'moe_info'): + ep_size = 1 # set ep_size to 1 for dp parameters + else: + ep_size = param.moe_info.ep_size + if ep_size not in epsize_param_dict: + epsize_param_dict[ep_size] = [] + epsize_param_dict[ep_size].append(param) + + return epsize_param_dict + + +def sync_moe_model_param(model: nn.Module): + """Make sure model parameters are consistent in MoE parallel context. + + Args: + model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency. + """ + if is_using_ddp(): + + param_dict = get_moe_epsize_param_dict(model) + + # synchronize the parameters whose dp_group is the whole world + if 1 in param_dict: + src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0] + for param in param_dict[1]: + dist.broadcast(param, src=src_rank, group=gpc.get_group(ParallelMode.DATA)) + + for ep_size in param_dict: + # When ep_size = world_size, communication is not needed + if ep_size != 1 and ep_size != MOE_CONTEXT.world_size: + src_rank = dist.get_rank(MOE_CONTEXT.parallel_info_dict[ep_size].ep_group) + for param in param_dict[ep_size]: + dist.broadcast(param, src=src_rank, group=param.moe_info.dp_group) diff --git a/colossalai/utils/multi_tensor_apply/__init__.py b/colossalai/utils/multi_tensor_apply/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..94d13b339a0de106334fee99a8b73ea6e70f60dd --- /dev/null +++ b/colossalai/utils/multi_tensor_apply/__init__.py @@ -0,0 +1,3 @@ +from .multi_tensor_apply import MultiTensorApply + +multi_tensor_applier = MultiTensorApply(2048 * 32) diff --git a/colossalai/utils/multi_tensor_apply/multi_tensor_apply.py b/colossalai/utils/multi_tensor_apply/multi_tensor_apply.py new file mode 100644 index 0000000000000000000000000000000000000000..2b6de5fe1f3c810acaf7877a6db05286cd47af3c --- /dev/null +++ b/colossalai/utils/multi_tensor_apply/multi_tensor_apply.py @@ -0,0 +1,33 @@ +# modified from https://github.com/NVIDIA/apex/blob/master/apex/multi_tensor_apply/multi_tensor_apply.py + + +class MultiTensorApply(object): + """ + Apply an operation to a list of tensors efficiently. + + Args: + chunk_size (int): Size of a chunk. + """ + + available = False + warned = False + + def __init__(self, chunk_size): + try: + MultiTensorApply.available = True + self.chunk_size = chunk_size + except ImportError as err: + MultiTensorApply.available = False + MultiTensorApply.import_err = err + + def check_avail(self): + if not MultiTensorApply.available: + raise RuntimeError( + "Attempted to call MultiTensorApply method, but MultiTensorApply " + "is not available, possibly because Apex was installed without " + "--cpp_ext --cuda_ext. Original import error message:", MultiTensorApply.import_err) + + def __call__(self, op, noop_flag_buffer, tensor_lists, *args): + self.check_avail() + + return op(self.chunk_size, noop_flag_buffer, tensor_lists, *args) diff --git a/colossalai/utils/profiler/__init__.py b/colossalai/utils/profiler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..90eab67c4d27d7b6e76316be097290c9fc7fa5ce --- /dev/null +++ b/colossalai/utils/profiler/__init__.py @@ -0,0 +1,2 @@ +from .legacy import * +from .profiler import profile diff --git a/colossalai/utils/profiler/extention.py b/colossalai/utils/profiler/extention.py new file mode 100644 index 0000000000000000000000000000000000000000..6726a683cc05ebb1ac5370d8c17750cd869d9ec2 --- /dev/null +++ b/colossalai/utils/profiler/extention.py @@ -0,0 +1,20 @@ +from abc import ABC, abstractmethod + + +class ProfilerExtension(ABC): + + @abstractmethod + def prepare_trace(self): + pass + + @abstractmethod + def start_trace(self): + pass + + @abstractmethod + def stop_trace(self): + pass + + @abstractmethod + def extend_chrome_trace(self, trace: dict) -> dict: + pass diff --git a/colossalai/utils/profiler/legacy/__init__.py b/colossalai/utils/profiler/legacy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..849c7fca305315c267dcc1bcbb014d44225d739f --- /dev/null +++ b/colossalai/utils/profiler/legacy/__init__.py @@ -0,0 +1,6 @@ +from .comm_profiler import CommProfiler +from .pcie_profiler import PcieProfiler +from .prof_utils import ProfilerContext, BaseProfiler +from .mem_profiler import MemProfiler + +__all__ = ['BaseProfiler', 'CommProfiler', 'PcieProfiler', 'MemProfiler', 'ProfilerContext'] diff --git a/colossalai/utils/profiler/legacy/comm_profiler.py b/colossalai/utils/profiler/legacy/comm_profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..a4f5729c97ec4b4365e40b69c9a0fda7de5055d9 --- /dev/null +++ b/colossalai/utils/profiler/legacy/comm_profiler.py @@ -0,0 +1,308 @@ +import inspect +from pathlib import Path +from functools import partial +import torch +from torch.autograd.profiler import profile +import torch.distributed as dist +from torch.distributed import ReduceOp +from colossalai.utils import get_current_device +from .prof_utils import BaseProfiler, _format_time, _format_memory, _format_bandwidth +from typing import List, Optional + + +def _get_code_location(depth: int): + ret = [] + length = min(len(inspect.stack()), depth + 1) + for i in range(3, length): + upper_frame = inspect.stack()[i] + function_name = inspect.stack()[i - 1].function + ret.append(upper_frame.filename) + ret.append('(') + ret.append(str(upper_frame.lineno)) + ret.append('): ') + ret.append(function_name) + if i != length - 1: + ret.append('\n') + + return ''.join(ret) + + +torch_all_reduce = dist.all_reduce +torch_all_gather = dist.all_gather +torch_reduce_scatter = dist.reduce_scatter +torch_broadcast = dist.broadcast +torch_reduce = dist.reduce + + +class CommEvent(object): + """Communication Event. Used for communication time and communication + volume recording. + """ + + def __init__(self, count: int = 0, comm_vol: float = 0., cuda_time: int = 0): + self.self_count = count + self.self_comm_vol = comm_vol + self.self_cuda_time = cuda_time + + def add(self, rhs): + self.self_count += rhs.self_count + self.self_comm_vol += rhs.self_comm_vol + self.self_cuda_time += rhs.self_cuda_time + + +class CommProfiler(BaseProfiler): + """Communication profiler. Records all communication events. + """ + + def __init__(self, depth: int = 0, total_count: int = 0, total_comm_vol: float = 0, total_cuda_time: int = 0): + super().__init__(profiler_name="Collective_Communication", priority=0) + self.depth = 3 + depth + self.total_count = total_count + self.total_comm_vol = total_comm_vol + self.total_cuda_time = total_cuda_time + + self.ops_record = dict() + self.profiler = None + self.pending_op = None + self.pending_metadata = None + self.warn_flag = False + + def reset(self): + self.total_count = 0 + self.total_comm_vol = 0 + self.total_cuda_time = 0 + + self.ops_record = dict() + self.profiler = None + self.pending_op = None + self.pending_metadata = None + self.warn_flag = False + + def enable(self): + dist.all_reduce = partial(all_reduce, profiler=self) + dist.all_gather = partial(all_gather, profiler=self) + dist.reduce_scatter = partial(reduce_scatter, profiler=self) + dist.broadcast = partial(broadcast, profiler=self) + dist.reduce = partial(reduce, profiler=self) + + def disable(self): + dist.all_reduce = torch_all_reduce + dist.all_gather = torch_all_gather + dist.reduce_scatter = torch_reduce_scatter + dist.broadcast = torch_broadcast + dist.reduce = torch_reduce + + def to_tensorboard(self, writer): + writer.add_text(tag="Collective Communication", text_string=self.result_str("\n\n")) + + def to_file(self, filename: Path): + with open(filename, "w") as f: + f.write(self.result_str()) + + def show(self): + print(self.result_str()) + + def result_str(self, sep: str = "\n"): + res = [] + + def append(s: str = None): + if s is not None: + res.append(s) + res.append(sep) + + if self.warn_flag: + append("Warnning: there exists multiple communication operations in the same time. As a result, " + "the profiling result is not accurate.") + + if self.total_cuda_time == 0: + return "No collective communication has been called yet!" + + append("Collective communication profiling result:") + append("total cuda time: {}".format(_format_time(self.total_cuda_time))) + append("average bandwidth: {}".format(_format_bandwidth(self.total_comm_vol, self.total_cuda_time))) + append("total number of calls: {}".format(self.total_count)) + append("All events:") + + seperation = '-' * 74 + row_format = '{:^10}' + '{:^12}' * 2 + '{:^16}' + '{:^12}' * 2 + + append(seperation) + append(row_format.format('Location', 'GPU time', 'Percentage', 'Comm volume', 'Bandwidth', 'Num of calls')) + append(seperation) + + show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].self_cuda_time) + for location, event in show_list: + append(location) + append( + row_format.format('', _format_time(event.self_cuda_time), + '{:.1f}%'.format(event.self_cuda_time / self.total_cuda_time * 100.0), + _format_memory(event.self_comm_vol), + _format_bandwidth(event.self_comm_vol, event.self_cuda_time), event.self_count)) + append() + + return ''.join(res) + + @property + def has_aync_op(self): + return self.pending_op is not None + + def activate_profiler(self, kn: str, vol: float): + self.pending_metadata = (kn, _get_code_location(self.depth), vol) + self.profiler = profile(enabled=True, use_cuda=True, use_cpu=True, use_kineto=True) + self.profiler.__enter__() + + def close_profiler(self, group=None): + assert self.profiler is not None, "There is no running dist op" + kernel_name, code_location, vol = self.pending_metadata + self.profiler.__exit__(None, None, None) + + if self.profiler.enabled and dist.get_world_size(group) > 1: + assert_flag = 0 + current_comm_event = None + events = self.profiler.function_events + for event in events: + if kernel_name in event.name: + assert assert_flag == 0, "Multiple dist ops has been called " + current_comm_event = CommEvent(1, vol, event.self_cuda_time_total) + assert_flag += 1 + + assert current_comm_event is not None, "dist op has not been found" + + buffer = torch.tensor([current_comm_event.self_cuda_time], device=get_current_device()) + torch_all_reduce(buffer, op=ReduceOp.MIN, group=group) + current_comm_event.self_cuda_time = buffer.item() + + self.total_count += current_comm_event.self_count + self.total_comm_vol += current_comm_event.self_comm_vol + self.total_cuda_time += current_comm_event.self_cuda_time + if code_location in self.ops_record: + self.ops_record[code_location].add(current_comm_event) + else: + self.ops_record[code_location] = current_comm_event + + self.profiler = None + self.pending_op = None + self.pending_metadata = None + + def wait_async_op(self): + if self.pending_op is not None: + op = self.pending_op + op.wait() + self.close_profiler() + + +class CommHandler(object): + """Communication handler. A dummy handler to wait aync operations. + """ + + def __init__(self, profiler: CommProfiler): + super().__init__() + self.prof = profiler + + def wait(self): + self.prof.wait_async_op() + + +def async_check(profiler: CommProfiler): + if profiler.pending_op is not None: + profiler.warn_flag = True + profiler.wait_async_op() + + +def all_reduce(tensor: torch.Tensor, + op: ReduceOp = ReduceOp.SUM, + group=None, + async_op: bool = False, + profiler: CommProfiler = None) -> Optional[CommHandler]: + async_check(profiler) + + comm_size = dist.get_world_size(group) + correction = 2 * (comm_size - 1) / comm_size + comm_vol = correction * tensor.element_size() * tensor.numel() + profiler.activate_profiler("ncclKernel_AllReduce_", comm_vol) + profiler.pending_op = torch_all_reduce(tensor, op, group, async_op) + + if async_op: + return CommHandler(profiler) + + profiler.close_profiler(group) + + +def reduce_scatter(output: torch.Tensor, + input_list: List[torch.Tensor], + op: ReduceOp = ReduceOp.SUM, + group=None, + async_op: bool = False, + profiler: CommProfiler = None) -> Optional[CommHandler]: + async_check(profiler) + + comm_size = dist.get_world_size(group) + correction = (comm_size - 1) / comm_size + comm_vol = 0 + for tensor in input_list: + comm_vol += tensor.element_size() * tensor.numel() + comm_vol *= correction + profiler.activate_profiler("ncclKernel_ReduceScatter_", comm_vol) + profiler.pending_op = torch_reduce_scatter(output, input_list, op, group, async_op) + + if async_op: + return CommHandler(profiler) + + profiler.close_profiler(group) + + +def all_gather(tensor_list: List[torch.Tensor], + tensor: torch.Tensor, + group=None, + async_op: bool = False, + profiler: CommProfiler = None) -> Optional[CommHandler]: + async_check(profiler) + + comm_size = dist.get_world_size(group) + correction = (comm_size - 1) / comm_size + comm_vol = 0 + for ten in tensor_list: + comm_vol += ten.element_size() * ten.numel() + comm_vol *= correction + profiler.activate_profiler("ncclKernel_AllGather_", comm_vol) + profiler.pending_op = torch_all_gather(tensor_list, tensor, group, async_op) + + if async_op: + return CommHandler(profiler) + + profiler.close_profiler(group) + + +def broadcast(tensor: torch.Tensor, + src: int, + group=None, + async_op: bool = False, + profiler: CommProfiler = None) -> Optional[CommHandler]: + async_check(profiler) + + comm_vol = 1.0 * tensor.element_size() * tensor.numel() + profiler.activate_profiler("ncclKernel_Broadcast_", comm_vol) + profiler.pending_op = torch_broadcast(tensor, src, group, async_op) + + if async_op: + return CommHandler(profiler) + + profiler.close_profiler(group) + + +def reduce(tensor: torch.Tensor, + dst: int, + op: ReduceOp = ReduceOp.SUM, + group=None, + async_op: bool = False, + profiler: CommProfiler = None) -> Optional[CommHandler]: + async_check(profiler) + + comm_vol = 1.0 * tensor.element_size() * tensor.numel() + profiler.activate_profiler("ncclKernel_Reduce_", comm_vol) + profiler.pending_op = torch_reduce(tensor, dst, op, group, async_op) + + if async_op: + return CommHandler(profiler) + + profiler.close_profiler(group) diff --git a/colossalai/utils/profiler/legacy/pcie_profiler.py b/colossalai/utils/profiler/legacy/pcie_profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..526222941ef979c1ff805349f40d51e9a6fdd569 --- /dev/null +++ b/colossalai/utils/profiler/legacy/pcie_profiler.py @@ -0,0 +1,148 @@ +from pathlib import Path +from torch.autograd.profiler import profile +from .prof_utils import BaseProfiler, _format_time, _format_memory, _format_bandwidth +from typing import List + + +def _get_size(dtype: str): + if dtype == "fp16": + return 2 + elif dtype == "fp32": + return 4 + else: + raise NotImplementedError + + +def _get_numel(my_list: List[int]) -> int: + from functools import reduce + from operator import mul + return reduce(mul, my_list) + + +def _reduce_location(locations: List[str]) -> str: + ret = [] + for lo in locations: + ret.append(lo) + ret.append("\n") + ret = ret[:-1] + return ''.join(ret) + + +class PcieEvent(object): + """Pcie Event. + """ + + def __init__(self, count: int = 0, pcie_vol: int = 0, cuda_time: int = 0): + self.count = count + self.pcie_vol = pcie_vol + self.cuda_time = cuda_time + + def add(self, rhs): + self.count += rhs.count + self.pcie_vol += rhs.pcie_vol + self.cuda_time += rhs.cuda_time + + +class PcieProfiler(BaseProfiler): + """Pcie profiler. Records all data transmission between CPU and GPU. + + TODO: Merge pcie profiler into communication profiler + """ + + def __init__(self, dtype: str = "fp32", depth: int = 1): + super().__init__(profiler_name="Pcie", priority=10) + self.depth = depth + self.data_size = _get_size(dtype) + self.h2d_count = 0 + self.h2d_time = 0 + self.d2h_count = 0 + self.d2h_time = 0 + + self.ops_record = dict() + self.profiler = None + + def reset(self): + self.h2d_count = 0 + self.h2d_time = 0 + self.d2h_count = 0 + self.d2h_time = 0 + + self.ops_record = dict() + self.profiler = None + + def enable(self): + self.profiler = profile(enabled=True, + use_cuda=True, + use_cpu=True, + use_kineto=True, + record_shapes=True, + with_stack=True) + self.profiler.__enter__() + + def disable(self): + self.profiler.__exit__(None, None, None) + + if self.profiler.enabled: + events = self.profiler.function_events + for event in events: + if event.name == "aten::copy_": + t_shape = event.input_shapes[0] + if len(t_shape) == 0 or event.cuda_time_total == 0 or len(event.stack) == 0: + continue + current_comm_event = PcieEvent(1, self.data_size * _get_numel(t_shape), event.cuda_time_total) + code_location = _reduce_location(event.stack[:self.depth]) + if code_location in self.ops_record: + self.ops_record[code_location].add(current_comm_event) + else: + self.ops_record[code_location] = current_comm_event + elif 'Memcpy HtoD' in event.name: + self.h2d_count += 1 + self.h2d_time += event.cuda_time_total + elif 'Memcpy DtoH' in event.name: + self.d2h_count += 1 + self.d2h_time += event.cuda_time_total + + self.profiler = None + + def to_tensorboard(self, writer): + writer.add_text(tag="Data Transmission", text_string=self.result_str("\n\n")) + + def to_file(self, filename: Path): + with open(filename, "w") as f: + f.write(self.result_str()) + + def show(self): + print(self.result_str()) + + def result_str(self, sep: str = "\n"): + res = [] + + def append(s: str = None): + if s is not None: + res.append(s) + res.append(sep) + + append("Pcie profiling result:") + append("time of data transmission (CPU -> GPU): {}".format(_format_time(self.h2d_time))) + append("number of transmission (CPU -> GPU): {}".format(self.h2d_count)) + append("time of data transmission (GPU -> CPU): {}".format(_format_time(self.d2h_time))) + append("number of transmission (GPU -> CPU): {}".format(self.d2h_count)) + + append("Possible data transmission events in PCIE:") + + seperation = '-' * 62 + row_format = '{:^10}' + '{:^12}' + '{:^16}' + '{:^12}' * 2 + + append(seperation) + append(row_format.format('Location', 'GPU time', 'Trans volume', 'Bandwidth', 'Num of calls')) + append(seperation) + + show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].cuda_time) + for location, event in show_list: + append(location) + append( + row_format.format('', _format_time(event.cuda_time), _format_memory(event.pcie_vol), + _format_bandwidth(event.pcie_vol, event.cuda_time), event.count)) + append() + + return ''.join(res) diff --git a/colossalai/utils/profiler/legacy/prof_utils.py b/colossalai/utils/profiler/legacy/prof_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..87ad644a7ecc989e56ecc581ab2a3ba8d609b94f --- /dev/null +++ b/colossalai/utils/profiler/legacy/prof_utils.py @@ -0,0 +1,131 @@ +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Union, List +from colossalai.core import global_context as gpc + + +# copied from high version pytorch to support low version +def _format_time(time_us): + """Defines how to format time in FunctionEvent""" + US_IN_SECOND = 1000.0 * 1000.0 + US_IN_MS = 1000.0 + if time_us >= US_IN_SECOND: + return '{:.3f}s'.format(time_us / US_IN_SECOND) + if time_us >= US_IN_MS: + return '{:.3f}ms'.format(time_us / US_IN_MS) + return '{:.3f}us'.format(time_us) + + +# copied from high version pytorch to support low version +def _format_memory(nbytes): + """Returns a formatted memory size string""" + KB = 1024 + MB = 1024 * KB + GB = 1024 * MB + if (abs(nbytes) >= GB): + return '{:.2f} GB'.format(nbytes * 1.0 / GB) + elif (abs(nbytes) >= MB): + return '{:.2f} MB'.format(nbytes * 1.0 / MB) + elif (abs(nbytes) >= KB): + return '{:.2f} KB'.format(nbytes * 1.0 / KB) + else: + return str(nbytes) + ' B' + + +def _format_bandwidth(volme: float or int, time_us: int): + sec_div_mb = (1000.0 / 1024.0)**2 + mb_per_sec = volme / time_us * sec_div_mb + + if mb_per_sec >= 1024.0: + return '{:.3f} GB/s'.format(mb_per_sec / 1024.0) + else: + return '{:.3f} MB/s'.format(mb_per_sec) + + +class BaseProfiler(ABC): + + def __init__(self, profiler_name: str, priority: int): + self.name = profiler_name + self.priority = priority + + @abstractmethod + def enable(self): + pass + + @abstractmethod + def disable(self): + pass + + @abstractmethod + def to_tensorboard(self, writer): + pass + + @abstractmethod + def to_file(self, filename: Path): + pass + + @abstractmethod + def show(self): + pass + + +class ProfilerContext(object): + """Profiler context manager + + Usage:: + + world_size = 4 + inputs = torch.randn(10, 10, dtype=torch.float32, device=get_current_device()) + outputs = torch.empty(world_size, 10, 10, dtype=torch.float32, device=get_current_device()) + outputs_list = list(torch.chunk(outputs, chunks=world_size, dim=0)) + + cc_prof = CommProfiler() + + with ProfilerContext([cc_prof]) as prof: + op = dist.all_reduce(inputs, async_op=True) + dist.all_gather(outputs_list, inputs) + op.wait() + dist.reduce_scatter(inputs, outputs_list) + dist.broadcast(inputs, 0) + dist.reduce(inputs, 0) + + prof.show() + """ + + def __init__(self, profilers: List[BaseProfiler] = None, enable: bool = True): + self.enable = enable + self.profilers = sorted(profilers, key=lambda prof: prof.priority) + + def __enter__(self): + if self.enable: + for prof in self.profilers: + prof.enable() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.enable: + for prof in self.profilers: + prof.disable() + + def to_tensorboard(self, writer): + from torch.utils.tensorboard import SummaryWriter + + assert isinstance(writer, SummaryWriter), \ + f'torch.utils.tensorboard.SummaryWriter is required, but found {type(writer)}.' + + for prof in self.profilers: + prof.to_tensorboard(writer) + + def to_file(self, log_dir: Union[str, Path]): + if isinstance(log_dir, str): + log_dir = Path(log_dir) + + if not log_dir.exists(): + log_dir.mkdir(parents=True, exist_ok=True) + for prof in self.profilers: + log_file = log_dir.joinpath(f'{prof.name}_rank_{gpc.get_global_rank()}.log') + prof.to_file(log_file) + + def show(self): + for prof in self.profilers: + prof.show() diff --git a/colossalai/utils/profiler/profiler.py b/colossalai/utils/profiler/profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..8f43a0b96de0a8deb4818644bb48df7ddaad896c --- /dev/null +++ b/colossalai/utils/profiler/profiler.py @@ -0,0 +1,201 @@ +import os +from typing import List +from colossalai.engine import Engine +from torch.profiler import profile as torch_profile +from torch.profiler.profiler import ProfilerAction +from typing import Any, Callable, Iterable, Optional +from torch.autograd import ProfilerActivity +import json +import os +import tempfile +import gzip +from colossalai.utils.profiler.extention import ProfilerExtension +from colossalai.utils.profiler.stateful_tensor_mem_extention import StatefulTensorMemoryProfilerExtention +from colossalai.logging import get_dist_logger + + +class profile(torch_profile): + """Profiler context manager. + + Args: + activities (iterable): list of activity groups (CPU, CUDA) to use in profiling, supported values: + ``torch.profiler.ProfilerActivity.CPU``, ``torch.profiler.ProfilerActivity.CUDA``. + Default value: ProfilerActivity.CPU and (when available) ProfilerActivity.CUDA. + schedule (callable): callable that takes step (int) as a single parameter and returns + ``ProfilerAction`` value that specifies the profiler action to perform at each step. + on_trace_ready (callable): callable that is called at each step when ``schedule`` + returns ``ProfilerAction.RECORD_AND_SAVE`` during the profiling. + engine (Optional[Engine], optional): An ``Engine`` instance. Defaults to None. + record_shapes (bool): save information about operator's input shapes. + profile_memory (bool): track tensor memory allocation/deallocation. + with_stack (bool): record source information (file and line number) for the ops. + with_flops (bool): use formula to estimate the FLOPs (floating point operations) of specific operators + (matrix multiplication and 2D convolution). + with_modules (bool): record module hierarchy (including function names) + corresponding to the callstack of the op. e.g. If module A's forward call's + module B's forward which contains an aten::add op, + then aten::add's module hierarchy is A.B + Note that this support exist, at the moment, only for TorchScript models + and not eager mode models. + profile_stateful_tensor_memory (bool): track stateful tensor memory usage. ``engine`` must not be None if you enable this. + + .. note:: + Use :func:`~torch.profiler.schedule` to generate the callable schedule. + Non-default schedules are useful when profiling long training jobs + and allow the user to obtain multiple traces at the different iterations + of the training process. + The default schedule simply records all the events continuously for the + duration of the context manager. + + .. note:: + Use :func:`~torch.profiler.tensorboard_trace_handler` to generate result files for TensorBoard: + + ``on_trace_ready=torch.profiler.tensorboard_trace_handler(dir_name)`` + + After profiling, result files can be found in the specified directory. Use the command: + + ``tensorboard --logdir dir_name`` + + to see the results in TensorBoard. + For more information, see + `PyTorch Profiler TensorBoard Plugin `__ + + .. note:: + Enabling shape and stack tracing results in additional overhead. + When record_shapes=True is specified, profiler will temporarily hold references to the tensors; + that may further prevent certain optimizations that depend on the reference count and introduce + extra tensor copies. + + Examples: + + .. code-block:: python + + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ] + ) as p: + code_to_profile() + print(p.key_averages().table( + sort_by="self_cuda_time_total", row_limit=-1)) + + Using the profiler's ``schedule``, ``on_trace_ready`` and ``step`` functions: + + .. code-block:: python + + # Non-default profiler schedule allows user to turn profiler on and off + # on different iterations of the training loop; + # trace_handler is called every time a new trace becomes available + def trace_handler(prof): + print(prof.key_averages().table( + sort_by="self_cuda_time_total", row_limit=-1)) + # prof.export_chrome_trace("/tmp/test_trace_" + str(prof.step_num) + ".json") + + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + + # In this example with wait=1, warmup=1, active=2, + # profiler will skip the first step/iteration, + # start warming up on the second, record + # the third and the forth iterations, + # after which the trace will become available + # and on_trace_ready (when set) is called; + # the cycle repeats starting with the next step + + schedule=torch.profiler.schedule( + wait=1, + warmup=1, + active=2), + on_trace_ready=trace_handler + # on_trace_ready=torch.profiler.tensorboard_trace_handler('./log') + # used when outputting for tensorboard + ) as p: + for iter in range(N): + code_iteration_to_profile(iter) + # send a signal to the profiler that the next iteration has started + p.step() + """ + + def __init__(self, + *, + activities: Optional[Iterable[ProfilerActivity]] = None, + schedule: Optional[Callable[[int], ProfilerAction]] = None, + on_trace_ready: Optional[Callable[..., Any]] = None, + engine: Optional[Engine] = None, + record_shapes: bool = False, + profile_memory: bool = False, + with_stack: bool = False, + with_flops: bool = False, + with_modules: bool = False, + profile_stateful_tensor_memory: bool = False) -> None: + super().__init__(activities=activities, + schedule=schedule, + on_trace_ready=on_trace_ready, + record_shapes=record_shapes, + profile_memory=profile_memory, + with_stack=with_stack, + with_flops=with_flops, + with_modules=with_modules) + self._logger = get_dist_logger() + self.extentions: List[ProfilerExtension] = [] + if profile_stateful_tensor_memory: + if engine is None: + self._logger.warning('Ignore "profile_model_data" since engine is None', ranks=[0]) + else: + self.extentions.append(StatefulTensorMemoryProfilerExtention(engine)) + + def prepare_trace(self) -> None: + if hasattr(super(), 'prepare_trace'): + super().prepare_trace() + elif hasattr(super(), '_start_warmup'): + super()._start_warmup() + for ext in self.extentions: + ext.prepare_trace() + + def _start_warmup(self): + self.prepare_trace() + + def start_trace(self): + if hasattr(super(), '_start_trace'): + super()._start_trace() + elif hasattr(super(), 'start_trace'): + super().start_trace() + for ext in self.extentions: + ext.start_trace() + + def _start_trace(self): + self.start_trace() + + def stop_trace(self): + if hasattr(super(), '_stop_trace'): + super()._stop_trace() + elif hasattr(super(), 'stop_trace'): + super().stop_trace() + for ext in self.extentions: + ext.stop_trace() + + def _stop_trace(self): + self.stop_trace() + + def export_chrome_trace(self, path: str): + """ + Exports the collected trace in Chrome JSON format. + """ + assert self.profiler + fp = tempfile.NamedTemporaryFile('w+t', suffix='.json', delete=False) + fp.close() + retvalue = self.profiler.export_chrome_trace(fp.name) + with open(fp.name) as fin: + trace = json.load(fin) + for ext in self.extentions: + trace = ext.extend_chrome_trace(trace) + open_func = gzip.open if path.endswith('.gz') else open + with open_func(path, 'wt') as fout: + json.dump(trace, fout) + + os.remove(fp.name) + return retvalue diff --git a/colossalai/utils/profiler/stateful_tensor_mem_extention.py b/colossalai/utils/profiler/stateful_tensor_mem_extention.py new file mode 100644 index 0000000000000000000000000000000000000000..127055c8c1efa5e193ee04cbf6467df35a5972a7 --- /dev/null +++ b/colossalai/utils/profiler/stateful_tensor_mem_extention.py @@ -0,0 +1,133 @@ +import os +import threading +import time +import torch +from enum import Enum +from typing import List +from colossalai.gemini.stateful_tensor import StatefulTensor +from colossalai.gemini.ophooks import BaseOpHook +from colossalai.engine import Engine +from colossalai.utils.profiler.extention import ProfilerExtension + + +class DeviceType(Enum): + CPU = 0 + CUDA = 1 + + +def get_timestamp_us(): + return int(time.time() * 1e6) + + +def generic_instant_event(name, pid, tid, timestamp, args): + return {'ph': 'i', 's': 't', 'name': name, 'pid': pid, 'tid': tid, 'ts': timestamp, 'args': args} + + +class StatefulTensorMemoryEvent: + EVENT_NAME = '[statefulTensorMemory]' + + def __init__(self, timestamp: int, device_type: DeviceType, bytes_: int) -> None: + self.pid = os.getpid() + self.tid = threading.get_ident() + self.timestamp = timestamp + self.device_type = device_type + self.device_id = torch.cuda.current_device() if device_type == DeviceType.CUDA else -1 + self.bytes = bytes_ + + def state_dict(self): + return generic_instant_event(StatefulTensorMemoryEvent.EVENT_NAME, self.pid, self.tid, self.timestamp, { + 'Device Type': self.device_type.value, + 'Device Id': self.device_id, + 'Bytes': self.bytes + }) + + +class StatefulTensorMemoryTracer: + + def __init__(self) -> None: + self.events: List[StatefulTensorMemoryEvent] = [] + self._tracing = False + + def sample(self): + cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda'] + cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu'] + timestamp = get_timestamp_us() + if self._tracing: + self.events.append(StatefulTensorMemoryEvent(timestamp, DeviceType.CUDA, cuda_mem)) + self.events.append(StatefulTensorMemoryEvent(timestamp, DeviceType.CPU, cpu_mem)) + + def start_trace(self): + self.events.clear() + self._tracing = True + + def stop_trace(self): + self._tracing = False + + def state_dict(self): + return [event.state_dict() for event in self.events] + + +class StatefulTensorMemoryTracerHook(BaseOpHook): + + def __init__(self, tracer: StatefulTensorMemoryTracer): + super().__init__() + self.tracer = tracer + self._enable = False + + def pre_fwd_exec(self, module: torch.nn.Module, *args): + if self._enable: + self.tracer.sample() + + def post_fwd_exec(self, module: torch.nn.Module, *args): + if self._enable: + self.tracer.sample() + + def pre_bwd_exec(self, module: torch.nn.Module, input_, output): + if self._enable: + self.tracer.sample() + + def post_bwd_exec(self, module: torch.nn.Module, input_): + if self._enable: + self.tracer.sample() + + def post_iter(self): + if self._enable: + self.tracer.sample() + + def enable(self): + self._enable = True + + def disable(self): + self._enable = False + + +class StatefulTensorMemoryProfilerExtention(ProfilerExtension): + + def __init__(self, engine: Engine) -> None: + self.engine = engine + self.tracer = StatefulTensorMemoryTracer() + self.hook = StatefulTensorMemoryTracerHook(self.tracer) + self.hook_registered = False + + def prepare_trace(self): + self.hook.enable() + if not self.hook_registered: + self.engine.add_hook(self.hook) + self.hook_registered = True + + def start_trace(self): + self.prepare_trace() + self.tracer.start_trace() + + def stop_trace(self): + self.tracer.stop_trace() + self.hook.disable() + if self.hook_registered: + self.engine.remove_hook(self.hook) + # remove_hook is not implemented now + # FIXME(ver217): uncomment below line when remove_hook is implemented + # self.hook_registered = False + + def extend_chrome_trace(self, trace: dict) -> dict: + trace['traceEvents'].extend(self.tracer.state_dict()) + return trace diff --git a/colossalai/utils/rank_recorder/README.md b/colossalai/utils/rank_recorder/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e30a925d2a9291de9d8eeb01119237d24a1bc38c --- /dev/null +++ b/colossalai/utils/rank_recorder/README.md @@ -0,0 +1,72 @@ +# Rank Recorder +This is a useful tool to get the records of certain functions in each rank. The records of each rank will dump into a json file after the end of multiple process program. You can parse and visualise the json file easily. + +Before using the tool, you should ensure dist.is_initialized() return true before exit of program. + +## Usage + +Is very simple: + +```python +from colossalai.utils.rank_recorder import recorder + +... +... + +with recorder(record_name, current_rank) as r: + """procedure to record + """ + +``` + +## Example +This is a demo to display kernel select in cuda and visualise the cost of several procedures in each rank. + +```python +import time +import os +import logging +logging.disable(logging.INFO) + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from colossalai.utils.rank_recorder import recorder + + +WORLD_SIZE = 4 + +# config the export image here +# If you want to dive into the detail, format 'svg' is recommended +recorder.export_format = 'png' +recorder.export_name = 'kernel_select' +recorder.dpi = 500 + +def calc(x, y): + a = torch.randn(x, y).cuda() + b = torch.randn(x, y).cuda() + c = sum(a * b) + return c + +def worker(rank): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '29020' + dist.init_process_group(backend='nccl', world_size=WORLD_SIZE, rank=rank) + print(dist.get_rank(), "enter") + time.sleep(0.1 * rank) + + with recorder("calc_1(x100)", rank) as r: + calc(100, 100) + + with recorder("calc_2(x400)", rank) as r: + calc(400, 400) + + with recorder("calc_2(x200)", rank) as r: + calc(200, 200) + +if __name__ == "__main__": + mp.spawn(worker, nprocs=WORLD_SIZE) +``` + +run the script directly and you will get `kernel_select.json` and `kernel_select.png` in your current folder. \ No newline at end of file diff --git a/colossalai/utils/rank_recorder/__init__.py b/colossalai/utils/rank_recorder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1274d0e7dbc5277a60fb4d6fcd16c972b85c305c --- /dev/null +++ b/colossalai/utils/rank_recorder/__init__.py @@ -0,0 +1,3 @@ +from colossalai.utils.rank_recorder.rank_recorder import recorder + +__all__ = ["recorder"] \ No newline at end of file diff --git a/colossalai/utils/rank_recorder/rank_recorder.py b/colossalai/utils/rank_recorder/rank_recorder.py new file mode 100644 index 0000000000000000000000000000000000000000..c088ceeb2e87727ca9f4e5ac4a71d0721405f4a1 --- /dev/null +++ b/colossalai/utils/rank_recorder/rank_recorder.py @@ -0,0 +1,178 @@ +import time +from typing import List, Dict +import json +import os +import time +import shutil +import atexit + +import torch +import torch.distributed as dist + +import json +import matplotlib.pyplot as plt +import matplotlib.colors as mcolors + +cmap = list(mcolors.TABLEAU_COLORS.values()) + +LOG_FOLDER = "record.log" +MAX_WAIT_TIME = 20 + + +class Event: + + def __init__(self, start: int, end: int, name: str, rank: int) -> None: + self.start = start + self.end = end + self.name = name + self.rank = rank + + +class Recorder: + + def __init__(self) -> None: + self.rank_to_history: Dict[int, List[Event]] = {} + self.base_time = time.time() + self.temp_event = None + + self.export_format = 'png' + self.export_name = 'test' + self.dpi = 500 + self.theme = 'dark_background' + self.figure_width = 30 + self.figure_height = 10 + self.legend_fontsize = 16 + self.device_fontsize = 20 + self.bar_height = 0.2 + + if not os.path.exists(LOG_FOLDER): + os.makedirs(LOG_FOLDER) + + def start(self, name: str, rank: int): + # TODO : add lock to prevent conflict + torch.cuda.synchronize() + start_time = time.time() + self.temp_event = Event(start_time, None, name, rank) + + def end(self): + assert self.temp_event is not None, "`start` before `end`" + torch.cuda.synchronize() + end_time = time.time() + self.temp_event.end = end_time + rank = self.temp_event.rank + if rank not in self.rank_to_history: + self.rank_to_history[rank] = [] + self.rank_to_history[rank].append(self.temp_event) + self.temp_event = None + + def get_history(self): + return self.history + + def __call__(self, name: str, rank: str): + self.temp_name = name + self.temp_rank = rank + return self + + def __enter__(self): + name = self.temp_name + rank = self.temp_rank + self.start(name, rank) + + def __exit__(self, *args): + self.end() + + def dump_record(self): + rank = dist.get_rank() + rank_to_history = self.rank_to_history + records = {'base_time': self.base_time, 'content': {}} + for record_rank in rank_to_history: + history = rank_to_history[record_rank] + recs = [] + for event in history: + rec = {'start': event.start, 'end': event.end, 'name': event.name} + recs.append(rec) + records['content'][record_rank] = recs + + dump_name = f'{rank}.json' + dump_path = os.path.join(LOG_FOLDER, dump_name) + with open(dump_path, 'w', encoding='utf-8') as f: + json.dump(records, f, ensure_ascii=False) + + def merge_recode(self): + base_time = self.base_time + world_size = dist.get_world_size() + + wait_time = 0 + while True: + time.sleep(0.1) + log_num = len(os.listdir(LOG_FOLDER)) + if log_num == world_size: + break + + wait_time += 1 + if wait_time >= MAX_WAIT_TIME: + break + + # merge + logs_path = [os.path.join(LOG_FOLDER, file) for file in os.listdir(LOG_FOLDER)] + recoders = {} + for path in logs_path: + with open(path, 'r', encoding='utf-8') as f: + recs = json.load(f) + for record_rank in recs['content']: + history = recs['content'][record_rank] + recoders[record_rank] = [] + for rec in history: + recoders[record_rank].append({ + 'start': rec['start'] - base_time, + 'end': rec['end'] - base_time, + 'name': rec['name'] + }) + + shutil.rmtree(LOG_FOLDER) + with open(self.export_name + '.json', 'w', encoding='utf-8') as f: + json.dump(recoders, f, ensure_ascii=False) + + def visualise_record(self): + with open(self.export_name + '.json', 'r', encoding='utf-8') as f: + records = json.load(f) + records = dict(records) + ranks = list(sorted(records.keys())) + + name_list = {} + plots = {} + plt.figure(dpi=self.dpi, figsize=[self.figure_width, self.figure_height]) + plt.style.use(self.theme) + + for rank in ranks: + rank_records = records[rank] + for rec in rank_records: + s = rec['start'] + e = rec['end'] + name = rec['name'] + if name not in name_list: + name_list[name] = len(name_list) + bar = plt.barh(rank, width=e - s, height=self.bar_height, left=s, color=cmap[name_list[name]]) + if name not in plots: + plots[name] = bar + + plt.legend(list(plots.values()), list(plots.keys()), loc="upper left", fontsize=self.legend_fontsize) + plt.yticks(ticks=ranks, labels=[f'Device:{rank}' for rank in ranks], fontsize=self.device_fontsize) + plt.grid(axis='x') + plt.savefig("{}.{}".format(self.export_name, self.export_format)) + + def exit_worker(self): + if len(self.rank_to_history) == 0: + return + self.dump_record() + # if this is rank 0, wait for merge + rank = dist.get_rank() + + if rank == 1: + # take the base time of rank 0 as standard + self.merge_recode() + self.visualise_record() + + +recorder = Recorder() +atexit.register(recorder.exit_worker) diff --git a/colossalai/utils/tensor_detector/__init__.py b/colossalai/utils/tensor_detector/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cafc19b67c5c3b5f585e3dbef49f6bdbb85d5755 --- /dev/null +++ b/colossalai/utils/tensor_detector/__init__.py @@ -0,0 +1 @@ +from .tensor_detector import TensorDetector diff --git a/colossalai/utils/tensor_detector/readme.md b/colossalai/utils/tensor_detector/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..840dc8f4eca648f2d8b4ffc286b430003169216e --- /dev/null +++ b/colossalai/utils/tensor_detector/readme.md @@ -0,0 +1,128 @@ +# Tensor Detector + +This tool supports you to detect tensors on both CPU and GPU. However, there will always be some strange tensors on CPU, including the rng state of PyTorch. + +## Example + +An example is worth than a thousand words. + +The code below defines a simple MLP module, with which we will show you how to use the tool. + +```python +class MLP(nn.Module): + def __init__(self): + super().__init__() + self.mlp = nn.Sequential(nn.Linear(64, 8), + nn.ReLU(), + nn.Linear(8, 32)) + def forward(self, x): + return self.mlp(x) +``` + +And here is how to use the tool. + +```python +from colossalai.utils import TensorDetector + +# create random data +data = torch.rand(64, requires_grad=True).cuda() +data.retain_grad() +# create the module +model = MLP().cuda() +# create the detector +# by passing the model to the detector, it can distinguish module parameters from common tensors +detector = TensorDetector(include_cpu=False, module=model) +detector.detect() + +out = model(data) + +detector.detect() + +loss = out.sum() +loss.backward() + +detector.detect() +``` + +I have made some comments on the right of the output for your understanding. + +Note that the total `Mem` of all the tensors and parameters is not equal to `Total GPU Memery Allocated`. PyTorch's memory management is really complicated, and for models of a large scale, it's impossible to figure out clearly. + +**The order of print is not equal to the order the tensor creates, but they are really close.** + +```bash +------------------------------------------------------------------------------------------------------------ + Tensor device shape grad dtype Mem +------------------------------------------------------------------------------------------------------------ ++ Tensor cuda:0 (64,) True torch.float32 256 B # data ++ mlp.0.weight cuda:0 (8, 64) True torch.float32 2.0 KB ++ mlp.0.bias cuda:0 (8,) True torch.float32 32 B ++ mlp.2.weight cuda:0 (32, 8) True torch.float32 1.0 KB ++ mlp.2.bias cuda:0 (32,) True torch.float32 128 B +------------------------------------------------------------------------------------------------------------ +Detect Location: "test_tensor_detector.py" line 27 +Totle GPU Memery Allocated on cuda:0 is 4.5 KB +------------------------------------------------------------------------------------------------------------ + + +------------------------------------------------------------------------------------------------------------ + Tensor device shape grad dtype Mem +------------------------------------------------------------------------------------------------------------ ++ Tensor cuda:0 (8,) True torch.float32 32 B # activation ++ Tensor cuda:0 (32,) True torch.float32 128 B # output +------------------------------------------------------------------------------------------------------------ +Detect Location: "test_tensor_detector.py" line 30 +Totle GPU Memery Allocated on cuda:0 is 5.5 KB +------------------------------------------------------------------------------------------------------------ + + +------------------------------------------------------------------------------------------------------------ + Tensor device shape grad dtype Mem +------------------------------------------------------------------------------------------------------------ ++ Tensor cuda:0 () True torch.float32 4 B # loss +------------------------------------------------------------------------------------------------------------ +Detect Location: "test_tensor_detector.py" line 32 +Totle GPU Memery Allocated on cuda:0 is 6.0 KB +------------------------------------------------------------------------------------------------------------ + + +------------------------------------------------------------------------------------------------------------ + Tensor device shape grad dtype Mem +------------------------------------------------------------------------------------------------------------ ++ Tensor (with grad) cuda:0 (64,) True torch.float32 512 B # data with grad ++ mlp.0.weight (with grad) cuda:0 (8, 64) True torch.float32 4.0 KB # for use data.retain_grad() ++ mlp.0.bias (with grad) cuda:0 (8,) True torch.float32 64 B ++ mlp.2.weight (with grad) cuda:0 (32, 8) True torch.float32 2.0 KB ++ mlp.2.bias (with grad) cuda:0 (32,) True torch.float32 256 B + +- mlp.0.weight cuda:0 (8, 64) True torch.float32 2.0 KB +- mlp.0.bias cuda:0 (8,) True torch.float32 32 B +- mlp.2.weight cuda:0 (32, 8) True torch.float32 1.0 KB +- mlp.2.bias cuda:0 (32,) True torch.float32 128 B +- Tensor cuda:0 (64,) True torch.float32 256 B +- Tensor cuda:0 (8,) True torch.float32 32 B # deleted activation +------------------------------------------------------------------------------------------------------------ +Detect Location: "test_tensor_detector.py" line 34 +Totle GPU Memery Allocated on cuda:0 is 10.0 KB +------------------------------------------------------------------------------------------------------------ + + +------------------------------------------------------------------------------------------------------------ + Tensor device shape grad dtype Mem +------------------------------------------------------------------------------------------------------------ ++ Tensor cuda:0 (64,) False torch.float32 256 B ++ Tensor cuda:0 (8, 64) False torch.float32 2.0 KB ++ Tensor cuda:0 (8,) False torch.float32 32 B ++ Tensor cuda:0 (32, 8) False torch.float32 1.0 KB ++ Tensor cuda:0 (32,) False torch.float32 128 B +------------------------------------------------------------------------------------------------------------ +Detect Location: "test_tensor_detector.py" line 36 +Totle GPU Memery Allocated on cuda:0 is 14.0 KB +------------------------------------------------------------------------------------------------------------ +``` + +## Reference + + This tool was inspired by https://github.com/Stonesjtu/pytorch_memlab/blob/master/pytorch_memlab/mem_reporter.py + and https://github.com/Oldpan/Pytorch-Memory-Utils + diff --git a/colossalai/utils/tensor_detector/tensor_detector.py b/colossalai/utils/tensor_detector/tensor_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..a8186f76834c1eec3fdb28c7b6e12dfd6e65260f --- /dev/null +++ b/colossalai/utils/tensor_detector/tensor_detector.py @@ -0,0 +1,179 @@ +import gc +import inspect +import torch +import torch.nn as nn +from typing import Optional +from collections import defaultdict + +LINE_WIDTH = 108 +LINE = '-' * LINE_WIDTH + '\n' + + +class TensorDetector(): + + def __init__(self, + show_info: bool = True, + log: str = None, + include_cpu: bool = False, + module: Optional[nn.Module] = None): + """This class is a detector to detect tensor on different devices. + + Args: + show_info (bool, optional): whether to print the info on screen, default True. + log (str, optional): the file name to save the log. Defaults to None. + include_cpu (bool, optional): whether to detect tensor on cpu, default False. + module (Optional[:class:`nn.Module`]): when sending an ``nn.Module`` object, + the detector can name the tensors detected better. + """ + self.show_info = show_info + self.log = log + self.include_cpu = include_cpu + self.tensor_info = defaultdict(list) + self.saved_tensor_info = defaultdict(list) + self.order = [] + self.detected = [] + self.devices = [] + self.info = "" + + self.module = module + if isinstance(module, nn.Module): + # if module is an instance of nn.Module, we can name the parameter with its real name + for name, param in module.named_parameters(): + self.tensor_info[id(param)].append(name) + self.tensor_info[id(param)].append(param.device) + self.tensor_info[id(param)].append(param.shape) + self.tensor_info[id(param)].append(param.requires_grad) + self.tensor_info[id(param)].append(param.dtype) + self.tensor_info[id(param)].append(self.get_tensor_mem(param)) + + def get_tensor_mem(self, tensor): + # calculate the memory occupied by a tensor + memory_size = tensor.element_size() * tensor.storage().size() + if (tensor.is_leaf or tensor.retains_grad) and tensor.grad is not None: + grad_memory_size = tensor.grad.element_size() * tensor.grad.storage().size() + memory_size += grad_memory_size + return self.mem_format(memory_size) + + def mem_format(self, real_memory_size): + # format the tensor memory into a reasonal magnitude + if real_memory_size >= 2**30: + return str(real_memory_size / (2**30)) + ' GB' + if real_memory_size >= 2**20: + return str(real_memory_size / (2**20)) + ' MB' + if real_memory_size >= 2**10: + return str(real_memory_size / (2**10)) + ' KB' + return str(real_memory_size) + ' B' + + def collect_tensors_state(self): + for obj in gc.get_objects(): + if torch.is_tensor(obj): + # skip cpu tensor when include_cpu is false and the tensor we have collected before + if (not self.include_cpu) and obj.device == torch.device('cpu'): + continue + self.detected.append(id(obj)) + # skip paramters we had added in __init__ when module is an instance of nn.Module for the first epoch + if id(obj) not in self.tensor_info: + + name = type(obj).__name__ + # after backward, we want to update the records, to show you the change + if isinstance(self.module, nn.Module) and name == 'Parameter': + if obj.grad is not None: + # with grad attached + for par_name, param in self.module.named_parameters(): + if param.requires_grad and param.grad.equal(obj.grad): + name = par_name + ' (with grad)' + else: + # with no grad attached + # there will be no new paramters created during running + # so it must be in saved_tensor_info + continue + # we can also marked common tensors as tensor(with grad) + if name == 'Tensor' and (obj.is_leaf or obj.retains_grad): + if obj.grad is not None: + name = name + ' (with grad)' + # in fact, common tensor have no grad + # unless you set retain_grad() + if id(obj) in self.saved_tensor_info.keys() and name == self.saved_tensor_info[id(obj)][0]: + continue + + self.tensor_info[id(obj)].append(name) + self.tensor_info[id(obj)].append(obj.device) + self.tensor_info[id(obj)].append(obj.shape) + self.tensor_info[id(obj)].append(obj.requires_grad) + self.tensor_info[id(obj)].append(obj.dtype) + self.tensor_info[id(obj)].append(self.get_tensor_mem(obj)) + # recorded the order we got the tensor + # by this we can guess the tensor easily + # it will record every tensor updated this turn + self.order.append(id(obj)) + # recorded all different devices + if obj.device not in self.devices: + self.devices.append(obj.device) + + def print_tensors_state(self): + template_format = '{:3s}{:<30s}{:>10s}{:>20s}{:>10s}{:>20s}{:>15s}' + self.info += LINE + self.info += template_format.format(' ', 'Tensor', 'device', 'shape', 'grad', 'dtype', 'Mem') + self.info += '\n' + self.info += LINE + + # if a tensor updates this turn, and was recorded before + # it should be updated in the saved_tensor_info as well + outdated = [x for x in self.saved_tensor_info.keys() if x in self.order] + minus = [x for x in self.saved_tensor_info.keys() if x not in self.detected] + minus = outdated + minus + if len(self.order) > 0: + for tensor_id in self.order: + self.info += template_format.format('+', str(self.tensor_info[tensor_id][0]), + str(self.tensor_info[tensor_id][1]), + str(tuple(self.tensor_info[tensor_id][2])), + str(self.tensor_info[tensor_id][3]), + str(self.tensor_info[tensor_id][4]), + str(self.tensor_info[tensor_id][5])) + self.info += '\n' + if len(self.order) > 0 and len(minus) > 0: + self.info += '\n' + if len(minus) > 0: + for tensor_id in minus: + self.info += template_format.format('-', str(self.saved_tensor_info[tensor_id][0]), + str(self.saved_tensor_info[tensor_id][1]), + str(tuple(self.saved_tensor_info[tensor_id][2])), + str(self.saved_tensor_info[tensor_id][3]), + str(self.saved_tensor_info[tensor_id][4]), + str(self.saved_tensor_info[tensor_id][5])) + self.info += '\n' + # deleted the updated tensor + self.saved_tensor_info.pop(tensor_id) + + # trace where is the detect() + locate_info = inspect.stack()[2] + locate_msg = '"' + locate_info.filename + '" line ' + str(locate_info.lineno) + + self.info += LINE + self.info += f"Detect Location: {locate_msg}\n" + for device in self.devices: + if device == torch.device('cpu'): + continue + gpu_mem_alloc = self.mem_format(torch.cuda.memory_allocated(device)) + self.info += f"Totle GPU Memery Allocated on {device} is {gpu_mem_alloc}\n" + self.info += LINE + self.info += '\n\n' + if self.show_info: + print(self.info) + if self.log is not None: + with open(self.log + '.log', 'a') as f: + f.write(self.info) + + def detect(self, include_cpu=False): + self.include_cpu = include_cpu + self.collect_tensors_state() + self.print_tensors_state() + self.saved_tensor_info.update(self.tensor_info) + self.tensor_info.clear() + self.order = [] + self.detected = [] + self.info = "" + + def close(self): + self.saved_tensor_info.clear() + self.module = None diff --git a/colossalai/utils/timer.py b/colossalai/utils/timer.py new file mode 100644 index 0000000000000000000000000000000000000000..4b61f4a5ef1148b74962f349cd24d852da899737 --- /dev/null +++ b/colossalai/utils/timer.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- +import time +from typing import Tuple +from .cuda import synchronize + + +class Timer: + """A timer object which helps to log the execution times, and provides different tools to assess the times. + """ + + def __init__(self): + self._started = False + self._start_time = time.time() + self._elapsed = 0 + self._history = [] + + @property + def has_history(self): + return len(self._history) != 0 + + @property + def current_time(self) -> float: + synchronize() + return time.time() + + def start(self): + """Firstly synchronize cuda, reset the clock and then start the timer. + """ + self._elapsed = 0 + synchronize() + self._start_time = time.time() + self._started = True + + def lap(self): + """lap time and return elapsed time + """ + return self.current_time - self._start_time + + def stop(self, keep_in_history: bool = False): + """Stop the timer and record the start-stop time interval. + + Args: + keep_in_history (bool, optional): Whether does it record into history + each start-stop interval, defaults to False. + Returns: + int: Start-stop interval. + """ + synchronize() + end_time = time.time() + elapsed = end_time - self._start_time + if keep_in_history: + self._history.append(elapsed) + self._elapsed = elapsed + self._started = False + return elapsed + + def get_history_mean(self): + """Mean of all history start-stop time intervals. + + Returns: + int: Mean of time intervals + """ + return sum(self._history) / len(self._history) + + def get_history_sum(self): + """Add up all the start-stop time intervals. + + Returns: + int: Sum of time intervals. + """ + return sum(self._history) + + def get_elapsed_time(self): + """Return the last start-stop time interval. + + Returns: + int: The last time interval. + + Note: + Use it only when timer is not in progress + """ + assert not self._started, 'Timer is still in progress' + return self._elapsed + + def reset(self): + """Clear up the timer and its history + """ + self._history = [] + self._started = False + self._elapsed = 0 + + +class MultiTimer: + """An object contains multiple timers. + + Args: + on (bool, optional): Whether the timer is enabled. Default is True. + """ + + def __init__(self, on: bool = True): + self._on = on + self._timers = dict() + + def start(self, name: str): + """Start namely one of the timers. + + Args: + name (str): Timer's key. + """ + if self._on: + if name not in self._timers: + self._timers[name] = Timer() + return self._timers[name].start() + + def stop(self, name: str, keep_in_history: bool): + """Stop namely one of the timers. + + Args: + name (str): Timer's key. + keep_in_history (bool): Whether does it record into history each start-stop interval. + """ + if self._on: + return self._timers[name].stop(keep_in_history) + else: + return None + + def get_timer(self, name): + """Get timer by its name (from multitimer) + + Args: + name (str): Timer's key. + Returns: + :class:`colossalai.utils.Timer`: Timer with the name you give correctly. + """ + return self._timers[name] + + def reset(self, name=None): + """Reset timers. + + Args: + name (str, optional): If name is designated, the named timer will be reset + and others will not, defaults to None. + """ + if self._on: + if name is not None: + self._timers[name].reset() + else: + for timer in self._timers: + timer.reset() + + def is_on(self): + return self._on + + def set_status(self, mode: bool): + self._on = mode + + def __iter__(self) -> Tuple[str, Timer]: + for name, timer in self._timers.items(): + yield name, timer diff --git a/colossalai/zero/__init__.py b/colossalai/zero/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3465079e4fbbf4dbd3460cddea68d4b818216957 --- /dev/null +++ b/colossalai/zero/__init__.py @@ -0,0 +1,16 @@ +from .gemini import ( + ColoInitContext, + GeminiAdamOptimizer, + GeminiDDP, + ZeroDDP, + ZeroOptimizer, + get_static_torch_model, + post_process_colo_init_ctx, +) +from .low_level import LowLevelZeroOptimizer +from .wrapper import zero_model_wrapper, zero_optim_wrapper + +__all__ = [ + 'ZeroDDP', 'GeminiDDP', 'ZeroOptimizer', 'GeminiAdamOptimizer', 'zero_model_wrapper', 'zero_optim_wrapper', + 'LowLevelZeroOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx', 'get_static_torch_model' +] diff --git a/colossalai/zero/gemini/__init__.py b/colossalai/zero/gemini/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..60f85ca2f540497fb9ba8e11c55b1239a8cb45d6 --- /dev/null +++ b/colossalai/zero/gemini/__init__.py @@ -0,0 +1,11 @@ +from .chunk import ChunkManager, TensorInfo, TensorState, search_chunk_configuration +from .colo_init_context import ColoInitContext, post_process_colo_init_ctx +from .gemini_ddp import GeminiDDP, ZeroDDP +from .gemini_mgr import GeminiManager +from .gemini_optimizer import GeminiAdamOptimizer, ZeroOptimizer +from .utils import get_static_torch_model + +__all__ = [ + 'GeminiManager', 'TensorInfo', 'TensorState', 'ChunkManager', 'search_chunk_configuration', 'ZeroDDP', 'GeminiDDP', + 'get_static_torch_model', 'GeminiAdamOptimizer', 'ZeroOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx' +] diff --git a/colossalai/zero/gemini/chunk/__init__.py b/colossalai/zero/gemini/chunk/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6914d2dbef4581dbf37610cfc7589a2c5be77406 --- /dev/null +++ b/colossalai/zero/gemini/chunk/__init__.py @@ -0,0 +1,6 @@ +from .chunk import Chunk, ChunkFullError, TensorInfo, TensorState +from .manager import ChunkManager +from .search_utils import classify_params_by_dp_degree, search_chunk_configuration +from .utils import init_chunk_manager + +__all__ = ['Chunk', 'ChunkManager', 'classify_params_by_dp_degree', 'search_chunk_configuration', 'init_chunk_manager'] diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..a7682eaf62e97c618d97897e27a88b02378a33b1 --- /dev/null +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -0,0 +1,576 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Dict, List, Optional + +import torch +import torch.distributed as dist + +from colossalai.tensor import ProcessGroup as ColoProcessGroup +from colossalai.utils import get_current_device + + +class TensorState(Enum): + FREE = 0 + COMPUTE = 1 + HOLD = 2 + HOLD_AFTER_BWD = 3 + READY_FOR_REDUCE = 4 + + +STATE_TRANS = ((TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE), + (TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE), (TensorState.COMPUTE, + TensorState.HOLD), + (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE), + (TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), (TensorState.READY_FOR_REDUCE, + TensorState.HOLD)) + + +@dataclass +class TensorInfo: + state: TensorState + offset: int + end: int + + +class ChunkFullError(Exception): + pass + + +def is_storage_empty(tensor: torch.Tensor) -> bool: + return tensor.storage().size() == 0 + + +def free_storage(tensor: torch.Tensor) -> None: + if not is_storage_empty(tensor): + tensor.storage().resize_(0) + + +def alloc_storage(tensor: torch.Tensor) -> None: + if is_storage_empty(tensor): + tensor.storage().resize_(tensor.numel()) + + +class Chunk: + _total_number = 0 + + def __init__(self, + chunk_size: int, + process_group: ColoProcessGroup, + dtype: torch.dtype, + init_device: Optional[torch.device] = None, + cpu_shard_init: bool = False, + keep_gathered: bool = False, + pin_memory: bool = False) -> None: + """ + Chunk: A container owning a piece of contiguous memory space for tensors + Here we use all-gather operation to gather the whole chunk. + Currently, Chunk is exclusively used for DDP and ZeRO DDP and it doesn't support unused parameters. + It is designed to make the full use of communication and PCIE bandwidth. + + Args: + chunk_size (int): the number of elements in the chunk + process_group (ColoProcessGroup): the process group of this chunk + dtype (torch.dtype): the data type of the chunk + init_device (torch.device): optional, During the chunk construction process, where the tensor is stored. + The default value is None, which is the current GPU + cpu_shard_init (bool): a flag indicates the local chunk shard is resident on CPU. + keep_gathered (bool): optional, if True, this chunk is always gathered in CUDA memory + pin_memory (bool): optional, if True, this chunk always has a shard copied in pinned CPU memory + """ + self.count_id = Chunk._total_number + Chunk._total_number += 1 + + self.chunk_size = chunk_size + self.utilized_size = 0 + + self.torch_pg = process_group.dp_process_group() + self.pg_size = dist.get_world_size(self.torch_pg) + self.pg_rank = dist.get_rank(self.torch_pg) + + # the chunk size should be divisible by the dp degree + if not keep_gathered: + assert chunk_size % self.pg_size == 0 + self.shard_size = chunk_size // self.pg_size + self.shard_begin = self.shard_size * self.pg_rank + self.shard_end = self.shard_begin + self.shard_size + self.valid_end = self.shard_size + + self.dtype = dtype + device = init_device or get_current_device() + + # chunk_temp is a global chunk, which only exists during building the chunks. + self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero + + self.cuda_global_chunk = None # we force cuda_global_chunk located in CUDA + + # cuda local chunk, which is sharded on GPUs + self.cuda_shard = None + # cpu local chunk, which is sharded on CPUs + self.cpu_shard = None + # is the chunks gathers, which means chunks are duplicated on each process, + # and we should use the cuda_global_chunk. + self.is_gathered = True + + # configure the init device of the shard + # no-offload default: fp16, fp32 -> CUDA + # offload default: fp16, fp32 -> CPU + self.shard_device = torch.device("cpu") if cpu_shard_init else get_current_device() + + self.chunk_mem = self.chunk_size * self.chunk_temp.element_size() + self.shard_mem = self.chunk_mem // self.pg_size + + # each tensor is associated with a TensorInfo to track its meta info + # (state, offset, end) + self.tensors_info: Dict[torch.Tensor, TensorInfo] = {} + # the total number of tensors in the chunk + self.num_tensors = 0 + + # Record the number of tensors in different states + self.tensor_state_cnter: Dict[TensorState, int] = dict() + for state in TensorState: + self.tensor_state_cnter[state] = 0 + + # If a chunk is kept gathered, + # they are treated the same as that of the parameters in DDP during training. + self.keep_gathered = keep_gathered + if self.keep_gathered: + pin_memory = False # since this chunk is gathered, it doesn't need to pin + + # if pin_memory is True, we allocate a piece of CPU pin-memory + # for it all the time + self.pin_memory = pin_memory + + # we introduce the paired chunk here + # it refers to another chunk having the same parameters + # but with different dtype(such as fp16_chunk.paired_chunk -> fp32_chunk + self.paired_chunk = None + # if this chunk is synchronized with the optimizer, the flag is True + self.optim_sync_flag = True + # if the cpu_shard has been visited during the training step, the flag is True + self.cpu_vis_flag = False + + # whether to record l2 norm for the gradient clipping calculation + self.l2_norm_flag = False + self.l2_norm = None + + @property + def memory_usage(self) -> Dict[str, int]: + cuda_memory = 0 + cpu_memory = 0 + + if self.chunk_temp is not None: + # this chunk is not closed + if self.chunk_temp.device.type == 'cuda': + cuda_memory += self.chunk_mem + else: + cpu_memory += self.chunk_mem + else: + if self.is_gathered: + cuda_memory += self.chunk_mem + if self.cuda_shard is not None: + cuda_memory += self.shard_mem + if self.cpu_shard is not None: + cpu_memory += self.shard_mem + + return dict(cuda=cuda_memory, cpu=cpu_memory) + + @property + def device_type(self) -> str: + if self.chunk_temp is not None: + return self.chunk_temp.device.type + else: + if self.is_gathered: + return 'cuda' + elif self.cuda_shard is not None: + return 'cuda' + else: + return 'cpu' + + @property + def payload(self) -> torch.Tensor: + # sanity check + assert self.chunk_temp is None + + if self.is_gathered: + return self.cuda_global_chunk + elif self.cuda_shard is not None: + return self.cuda_shard + else: + return self.cpu_shard + + @property + def payload_mem(self) -> int: + # sanity check + assert self.chunk_temp is None + + if self.is_gathered: + return self.chunk_mem + else: + return self.shard_mem + + @property + def can_move(self) -> bool: + return not self.is_gathered + + @property + def can_release(self) -> bool: + if self.keep_gathered: + return False + else: + return self.tensor_state_cnter[TensorState.HOLD] + \ + self.tensor_state_cnter[TensorState.HOLD_AFTER_BWD] == self.num_tensors + + @property + def can_reduce(self): + return self.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == self.num_tensors + + @property + def has_inf_or_nan(self) -> bool: + """Check if the chunk has inf or nan values on CUDA. + """ + if self.is_gathered: + valid_tensor = self.cuda_global_chunk[:self.utilized_size] + else: + assert self.cuda_shard is not None # only check on CUDA + valid_tensor = self.cuda_shard[:self.valid_end] + + return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item() + + def set_l2_norm(self) -> None: + """Record l2 norm of this chunks on CUDA. + """ + assert self.l2_norm is None, "you are calculating the l2 norm twice" + if self.is_gathered: + valid_tensor = self.cuda_global_chunk[:self.utilized_size] + else: + assert self.cuda_shard is not None # calculate on CUDA + valid_tensor = self.cuda_shard[:self.valid_end] + chunk_l2_norm = valid_tensor.data.float().norm(2) + self.l2_norm = chunk_l2_norm.item()**2 + + def append_tensor(self, tensor: torch.Tensor): + """Add a tensor to the chunk. + + Args: + tensor (torch.Tensor): a tensor to be added to the chunk + """ + # sanity check + assert self.chunk_temp is not None + assert tensor.dtype == self.dtype + + new_utilized_size = self.utilized_size + tensor.numel() + # raise exception when the chunk size is exceeded + if new_utilized_size > self.chunk_size: + raise ChunkFullError + + self.chunk_temp[self.utilized_size:new_utilized_size].copy_(tensor.data.flatten()) + assert type(self.chunk_temp) == torch.Tensor, "copy_tensor_to_chunk_slice must use a torch tensor" + tensor.data = self.chunk_temp[self.utilized_size:new_utilized_size].view(tensor.shape) + + # record all the information about the tensor + self.num_tensors += 1 + tensor_state = TensorState.HOLD + self.tensors_info[tensor] = TensorInfo(tensor_state, self.utilized_size, new_utilized_size) + self.tensor_state_cnter[tensor_state] += 1 + self.utilized_size = new_utilized_size + + def close_chunk(self): + """Close the chunk. Any tensor can't be appended to a closed chunk later. + """ + # sanity check + assert self.chunk_temp is not None + + # calculate the valid end for each shard + if self.utilized_size <= self.shard_begin: + self.valid_end = 0 + elif self.utilized_size < self.shard_end: + self.valid_end = self.utilized_size - self.shard_begin + + if self.chunk_temp.device.type == 'cpu': + self.cuda_global_chunk = self.chunk_temp.to(get_current_device()) + self.__update_tensors_ptr() + else: + self.cuda_global_chunk = self.chunk_temp + self.chunk_temp = None + + self.__scatter() + # gathered chunk never have shard attribute + if self.keep_gathered: + return + + if self.pin_memory or self.shard_device.type == 'cpu': + self.cpu_shard = torch.empty(self.shard_size, dtype=self.dtype, pin_memory=self.pin_memory) + self.cpu_shard.copy_(self.cuda_shard) + self.cpu_vis_flag = True # cpu_shard has been visited + + if self.shard_device.type == 'cpu': + self.cuda_shard = None + + def shard_move(self, device: torch.device, force_copy: bool = False): + """Move the shard tensor in the chunk. + + Args: + device: the device to which the shard will move + force_copy: if True, copy function is called mandatorily + """ + # sanity check + assert not self.is_gathered + # when the current chunk is not synchronized with the optimizer + # just use another way for the movement + if not self.optim_sync_flag: + assert device.type == 'cuda', "each chunk should first be moved to CUDA" + self.__paired_shard_move() + self.optim_sync_flag = True + return + + if device.type == 'cuda': + assert device == get_current_device(), "can't move chunk to another device" + + if self.cuda_shard: + return + + self.cuda_shard = self.cpu_shard.to(get_current_device()) + + if not self.pin_memory: + self.cpu_shard = None + elif device.type == 'cpu': + if self.cuda_shard is None: + return + + if self.pin_memory: + if force_copy or not self.cpu_vis_flag: + self.cpu_shard.copy_(self.cuda_shard) + # if cpu_shard has been visited + # copy operation is not need + else: + self.cpu_shard = self.cuda_shard.cpu() + self.cpu_vis_flag = True + self.cuda_shard = None + else: + raise NotImplementedError + + def access_chunk(self): + """Make the chunk usable for the parameters inside it. It's an operation done in CUDA. + """ + # sanity check + assert self.chunk_temp is None + + if not self.is_gathered: + self.__gather() + self.__update_tensors_ptr() + + def release_chunk(self): + """Release the usable chunk. It's an operation done in CUDA. + """ + # sanity check + assert self.chunk_temp is None + + if self.is_gathered: + self.__scatter() + + def reduce(self): + """Reduce scatter all the gradients. It's an operation done in CUDA. + """ + # sanity check + assert self.is_gathered + + if self.pg_size == 1: + # tricky code here + # just move cuda_global_chunk to cuda_shard + # the communication is not necessary + self.__scatter() + elif self.keep_gathered: + # we use all-reduce here + dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg) + else: + self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=get_current_device()) + + input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0)) + dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg) + + free_storage(self.cuda_global_chunk) + self.is_gathered = False + self.__update_tensors_state(TensorState.HOLD) + + def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None: + """ + Make a transition of the tensor into the next state. + + Args: + tensor (torch.Tensor): a torch Tensor object. + tensor_state (TensorState): the target state for transition. + """ + + # As the gradient hook can be triggered either before or after post-backward + # tensor's state can be compute -> hold_after_bwd -> ready_for_reduce + # or compute -> ready_for_reduce -> hold_after_bwd + # the second one is invalid, we just ignore ready_for_reduce -> hold_after_bwd + # this function only apply valid state transformation + # invalid calls will be ignored and nothing changes + if (self.tensors_info[tensor].state, tensor_state) not in STATE_TRANS: + return + self.__update_one_tensor_info(self.tensors_info[tensor], tensor_state) + + def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None: + """ + Copy data slice to the memory space indexed by the input tensor in the chunk. + + Args: + tensor (torch.Tensor): the tensor used to retrive meta information + data_slice (torch.Tensor): the tensor to be copied to the chunk + """ + # sanity check + assert self.is_gathered + + tensor_info = self.tensors_info[tensor] + self.cuda_global_chunk[tensor_info.offset:tensor_info.end].copy_(data_slice.data.flatten()) + tensor.data = self.cuda_global_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape) + + def get_valid_length(self) -> int: + """Get the valid length of the chunk's payload. + """ + if self.keep_gathered: + return self.utilized_size + else: + return self.valid_end + + def init_pair(self, friend_chunk: 'Chunk') -> None: + """Initialize the paired chunk. + """ + if self.paired_chunk is None and friend_chunk.paired_chunk is None: + self.paired_chunk = friend_chunk + friend_chunk.paired_chunk = self + else: + assert self.paired_chunk is friend_chunk + assert friend_chunk.paired_chunk is self + + def optim_update(self) -> None: + """Update the fp16 chunks via their fp32 chunks. It's used by the optimizer. + """ + # sanity check + assert self.paired_chunk is not None + + friend_chunk = self.paired_chunk + if self.is_gathered is True: + assert friend_chunk.is_gathered is True + self.cuda_global_chunk.copy_(friend_chunk.cuda_global_chunk) + self.optim_sync_flag = True + elif friend_chunk.device_type == 'cuda' and self.device_type == 'cuda': + self.cuda_shard.copy_(friend_chunk.cuda_shard) + self.optim_sync_flag = True + self.cpu_vis_flag = False + else: + # optim_sync_flag is set to False + # see shard_move function for more details + assert friend_chunk.device_type == 'cpu' + assert self.device_type == 'cpu' + self.optim_sync_flag = False + self.cpu_vis_flag = False + + def get_tensors(self) -> List[torch.Tensor]: + return list(self.tensors_info.keys()) + + def __gather(self): + if not self.is_gathered: + # sanity check + assert self.cuda_shard is not None + + alloc_storage(self.cuda_global_chunk) + gather_list = list(torch.chunk(input=self.cuda_global_chunk, chunks=self.pg_size, dim=0)) + dist.all_gather(gather_list, self.cuda_shard, self.torch_pg) + + self.cuda_shard = None + self.is_gathered = True + + def __scatter(self): + if self.keep_gathered: + return + + if self.is_gathered: + # sanity check + assert self.cuda_shard is None + + self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=self.cuda_global_chunk.device) + + self.cuda_shard.copy_(self.cuda_global_chunk[self.shard_begin:self.shard_end]) + + free_storage(self.cuda_global_chunk) + self.is_gathered = False + + def __paired_shard_move(self): + assert self.paired_chunk is not None, "chunks should be paired before training" + optim_chunk = self.paired_chunk + assert self.chunk_size == optim_chunk.chunk_size + + # only be called when optimizer state is in CPU memory + # the grad and param should be in the same device + assert self.cuda_shard is None + temp = optim_chunk.cpu_shard.to(get_current_device()) + # avoid to transform FP32 in CPU + self.cuda_shard = temp.to(self.dtype) + + if not self.pin_memory: + self.cpu_shard = None + + def __update_tensors_ptr(self) -> None: + # sanity check + assert self.is_gathered + assert type(self.cuda_global_chunk) == torch.Tensor + + for tensor, tensor_info in self.tensors_info.items(): + tensor.data = self.cuda_global_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape) + + def __update_one_tensor_info(self, tensor_info: TensorInfo, next_state: TensorState): + self.tensor_state_cnter[tensor_info.state] -= 1 + tensor_info.state = next_state + self.tensor_state_cnter[tensor_info.state] += 1 + + def __update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None): + for tensor_info in self.tensors_info.values(): + if prev_state is None or tensor_info.state == prev_state: + self.__update_one_tensor_info(tensor_info, next_state) + + def __hash__(self) -> int: + return hash(id(self)) + + def __eq__(self, __o: object) -> bool: + return self is __o + + def __repr__(self, detailed: bool = True): + output = [ + "Chunk Information:\n", + "\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format(self.chunk_size, self.dtype, + self.pg_size), + "\t# of tensors: {}, utilized size: {}, utilized percentage: {:.2f}\n".format( + self.num_tensors, self.utilized_size, self.utilized_size / self.chunk_size) + ] + + def print_tensor(tensor, prefix=''): + output.append("{}shape: {}, dtype: {}, device: {}\n".format(prefix, tensor.shape, tensor.dtype, + tensor.device)) + + if self.chunk_temp is not None: + output.append("\tchunk temp:\n") + print_tensor(tensor=self.chunk_temp, prefix='\t\t') + + if self.cuda_global_chunk is not None and self.cuda_global_chunk.storage().size() > 0: + output.append("\tchunk total:\n") + print_tensor(tensor=self.cuda_global_chunk, prefix='\t\t') + + if self.cuda_shard is not None: + output.append("\tcuda shard:\n") + print_tensor(tensor=self.cuda_shard, prefix='\t\t') + + if self.cpu_shard is not None: + output.append("\tcpu shard:\n") + print_tensor(tensor=self.cpu_shard, prefix='\t\t') + + memory_info = self.memory_usage + output.append("\tmemory usage: cuda {}, cpu {}\n".format(memory_info['cuda'], memory_info['cpu'])) + + if detailed: + output.append("\ttensor state monitor:\n") + for st in TensorState: + output.append("\t\t# of {}: {}\n".format(st, self.tensor_state_cnter[st])) + + return ''.join(output) diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py new file mode 100644 index 0000000000000000000000000000000000000000..d85df0b00476f88c4bf29943d88c9c9ea99ebbe8 --- /dev/null +++ b/colossalai/zero/gemini/chunk/manager.py @@ -0,0 +1,251 @@ +from collections import deque +from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple + +import torch + +from colossalai.tensor import ColoTensor +from colossalai.utils import get_current_device + +from .chunk import Chunk, ChunkFullError, TensorState + + +class ChunkManager: + """ + A manager class to manipulate the tensors in chunks. + + Args: + chunk_configuration (Dict[int, Dict]): the configuration dictionary of this chunk manager. + init_device (torch.device): optional, the device on which the chunk is initialized. The default is None. + """ + + def __init__(self, chunk_configuration, init_device: Optional[torch.device] = None) -> None: + + self.device = init_device or get_current_device() + self.dp_degree_chunk_size_dict: Dict[int, int] = dict() + self.kwargs_config = chunk_configuration + for k, v in self.kwargs_config.items(): + self.dp_degree_chunk_size_dict[k] = v.pop('chunk_size') + v['init_device'] = self.device + + self.chunk_groups: Dict[str, Deque] = dict() + self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = dict() + self.accessed_chunks: Set[Chunk] = set() + self.accessed_mem: int = 0 + self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0} + + def register_tensor(self, + tensor: ColoTensor, + group_type: str, + config_key: int, + cpu_offload: bool = False, + pin_memory: bool = False) -> None: + """ + Register a tensor to the chunk manager. + Then, the tensor should be accessed by `get_chunks`. + + Args: + tensor: the tensor appended to the chunk + group_type: the data type of the group. + config_key: the key of the group's name, the size of the dp world + cpu_offload: if True, the chunk will be closed on CPU + pin_memory: whether the chunk is pinned in the cpu memory + """ + assert tensor not in self.tensor_chunk_map + assert isinstance(tensor, ColoTensor), "Please feed ColoTensor to this ChunkManager" + assert config_key in self.dp_degree_chunk_size_dict + + chunk_size = self.dp_degree_chunk_size_dict[config_key] + chunk_kwargs = self.kwargs_config[config_key] + group_name = "{}_{}".format(group_type, config_key) + chunk_group = self.__get_chunk_group(group_name) + + try: + # append the tensor to the last chunk + chunk_group[-1].append_tensor(tensor) + except (IndexError, ChunkFullError): + # the except statement will be triggered when there is no chunk or + # the last chunk in the chunk group is full + # this will create a new chunk and allocate this chunk to its corresponding process + if chunk_group: + # the chunk group is not empty + # close the last chunk + self.__close_one_chunk(chunk_group[-1]) + + if tensor.numel() > chunk_size: + chunk_size = tensor.numel() + dp_size = tensor.get_dp_world_size() + chunk_size = chunk_size + (-chunk_size % dp_size) + + chunk = Chunk( + chunk_size=chunk_size, + process_group=tensor.process_group, + dtype=tensor.dtype, + cpu_shard_init=cpu_offload, + pin_memory=pin_memory, + **chunk_kwargs, + ) + + chunk_group.append(chunk) + chunk.append_tensor(tensor) + self.__add_memory_usage(chunk.memory_usage) + + self.tensor_chunk_map[tensor] = chunk_group[-1] + + def close_all_groups(self): + """Close all the chunks of all groups. + """ + for group_name in self.chunk_groups: + self.__close_one_chunk(self.chunk_groups[group_name][-1]) + + def access_chunk(self, chunk: Chunk) -> None: + """Make the chunk can be used for calculation. + """ + if chunk in self.accessed_chunks: + return + self.__sub_memroy_usage(chunk.memory_usage) + if chunk.device_type == 'cpu': + chunk.shard_move(get_current_device()) + self.__add_accessed_chunk(chunk) + self.__add_memory_usage(chunk.memory_usage) + + def release_chunk(self, chunk: Chunk) -> None: + """Scatter the chunk in CUDA. + """ + if chunk not in self.accessed_chunks: + return + if chunk.can_release: + self.__sub_memroy_usage(chunk.memory_usage) + self.__sub_accessed_chunk(chunk) + self.__add_memory_usage(chunk.memory_usage) + + def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False) -> None: + """Move the shard of the chunk to the target device. + """ + if not chunk.can_move or chunk.device_type == device.type: + return + self.__sub_memroy_usage(chunk.memory_usage) + chunk.shard_move(device, force_copy) + self.__add_memory_usage(chunk.memory_usage) + + def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None: + """Transit tensor state according to pre-defined state machine. + """ + chunk = self.tensor_chunk_map[tensor] + chunk.tensor_trans_state(tensor, state) + + def reduce_chunk(self, chunk: Chunk) -> bool: + """Reduce or all reduce the chunk. + """ + if not chunk.can_reduce: + return False + self.__sub_memroy_usage(chunk.memory_usage) + chunk.reduce() + self.__sub_accessed_chunk(chunk) + self.__add_memory_usage(chunk.memory_usage) + return True + + def fake_release_chunk(self, chunk: Chunk) -> None: + """Release gathered chunk in a fake mode. + This function is used for keep-gathered chunk in the inference mode. + """ + assert chunk.keep_gathered + assert chunk.tensor_state_cnter[TensorState.HOLD] == chunk.num_tensors + self.__sub_accessed_chunk(chunk) + + def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -> None: + """ + Copy data to the chunk. + + Args: + tensor (torch.Tensor): the tensor used to retrive meta information + data (torch.Tensor): the tensor to be copied to the chunk + """ + chunk = self.tensor_chunk_map[tensor] + chunk.copy_tensor_to_chunk_slice(tensor, data) + + def get_chunk(self, tensor: torch.Tensor) -> Chunk: + """ + Return the chunk owning the tensor. + + Args: + tensor (torch.Tensor): a torch tensor object + """ + return self.tensor_chunk_map[tensor] + + def get_cuda_movable_chunks(self) -> List[Chunk]: + """ + Get all chunks that can be moved. + """ + chunk_list = [] + for chunk in self.accessed_chunks: + if chunk.can_release: + chunk_list.append(chunk) + chunk_list.sort(key=lambda x: x.count_id) + return chunk_list + + def get_chunks(self, tensors: Iterable[torch.Tensor]) -> Tuple[Chunk, ...]: + """ + Get all chunks owning the input tensors. + + Args: + tensors (Iterable[torch.Tensor]): the tensors used to look for chunks + """ + chunks = [] + for tensor in tensors: + chunk = self.get_chunk(tensor) + if chunk not in chunks: + chunks.append(chunk) + return tuple(chunks) + + def add_extern_static_tensor(self, tensor: torch.Tensor) -> None: + """Add extern static tensor to chunk manager. + Those tensors won't be managed by chunk manager, but we want to monitor memory usage of them. + They are "static", which means their shape, dtype, device never change. + Thus, their memory usage never changes. + + Args: + tensor (torch.Tensor): An extern static tensor. E.g. optimizer state. + """ + assert tensor not in self.tensor_chunk_map + self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size() + + def __repr__(self) -> str: + msg = [ + 'Chunk Manager Information:\n', + 'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n' + ] + for group_name, group in self.chunk_groups.items(): + msg.append(f'Group {group_name}:\n') + for i, chunk in enumerate(group): + msg.append(f'[{i}] {chunk}\n') + return ''.join(msg) + + def __get_chunk_group(self, group_name: str) -> Deque: + """Register a chunk group. + """ + if group_name not in self.chunk_groups: + self.chunk_groups[group_name] = deque() + return self.chunk_groups[group_name] + + def __close_one_chunk(self, chunk: Chunk): + self.__sub_memroy_usage(chunk.memory_usage) + chunk.close_chunk() + self.__add_memory_usage(chunk.memory_usage) + + def __sub_memroy_usage(self, usage: Dict[str, int]): + for k, v in usage.items(): + self.total_mem[k] -= v + + def __add_memory_usage(self, usage: Dict[str, int]): + for k, v in usage.items(): + self.total_mem[k] += v + + def __add_accessed_chunk(self, chunk: Chunk): + chunk.access_chunk() + self.accessed_chunks.add(chunk) + self.accessed_mem += chunk.chunk_mem + + def __sub_accessed_chunk(self, chunk: Chunk): + chunk.release_chunk() + self.accessed_chunks.remove(chunk) + self.accessed_mem -= chunk.chunk_mem diff --git a/colossalai/zero/gemini/chunk/search_utils.py b/colossalai/zero/gemini/chunk/search_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..da58e038c8792bdc616250d0a4e512817438c34d --- /dev/null +++ b/colossalai/zero/gemini/chunk/search_utils.py @@ -0,0 +1,196 @@ +import math +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch.distributed as dist +import torch.nn as nn + +from colossalai.tensor import ColoParameter +from colossalai.utils import is_ddp_ignored +from colossalai.zero.gemini.memory_tracer import MemStats, OrderedParamGenerator + + +def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None: + """_filter_exlarge_params + + Filter those parameters whose size is too large (more than 3x standard deviations) from others. + + Args: + model (nn.Module): the model. + size_dict (Dict[int, List[int]]): the size dict of parameters. + """ + agg_size_list = [] + for key in size_dict: + agg_size_list.extend(size_dict[key]) + + if len(agg_size_list) == 0: + return + + params_size_arr = np.array(agg_size_list) + + std = np.std(params_size_arr) + mean = np.mean(params_size_arr) + upper_limit = mean + 3 * std + + for key in size_dict: + org_list = size_dict[key] + size_dict[key] = list(filter(lambda x: x <= upper_limit, org_list)) + + +def _get_unused_byte(size_list: List[int], chunk_size: int) -> int: + """_get_unused_byte + + Get unused byte for a certain chunk size. + + Args: + size_list (List[int]): the size list of parameters. + chunk_size (int): the chunk size. + + Returns: + int: the unused byte. + """ + acc = 0 + left = 0 + for s in size_list: + if s > left: + acc += left + left = chunk_size + left -= s + return left + acc + + +def _tensor_numel(local_param: ColoParameter, strict_ddp_flag: bool) -> int: + """_tensor_numel + + Get the number of elements of a tensor. + + Args: + local_param (ColoParameter): The local parameter. + strict_ddp_flag (bool): whether to enable the strict ddp mode. + + Returns: + int: the number of elements. + """ + if strict_ddp_flag and type(local_param) is ColoParameter: + return local_param.numel_global() + else: + # if local_param is not ColoParameter, we assume it's replicated + return local_param.numel() + + +def classify_params_by_dp_degree(param_order: OrderedParamGenerator, + strict_ddp_flag: bool = False) -> Dict[int, List[ColoParameter]]: + """classify_params_by_dp_degree + + Classify the parameters by their dp degree + + Args: + param_order (OrderedParamGenerator): the order of param be visied + strict_ddp_flag (bool, optional): whether to enable the strict ddp mode. Defaults to False. + + Returns: + Dict[int, List[ColoParameter]]: a dict contains the classification results. + The keys are dp_degrees and the values are parameters. + """ + params_dict: Dict[int, List[ColoParameter]] = dict() + for param in param_order.generate(): + # assert isinstance(param, ColoParameter), "please init model in the ColoInitContext" + if is_ddp_ignored(param): + continue + + if strict_ddp_flag or type(param) is not ColoParameter: + # if model is not initialized with ColoInitContext, we assume it's replicated + # TODO(ver217): integrate DTensor + param_key = dist.get_world_size() + else: + param_key = param.process_group.dp_world_size() + + if param_key not in params_dict: + params_dict[param_key] = [] + params_dict[param_key].append(param) + + return params_dict + + +def search_chunk_configuration( + model: nn.Module, + search_range_mb: float, + search_interval_byte: int, # hidden size is the best value for the interval + min_chunk_size_mb: float = 32, + filter_exlarge_params: bool = True, + strict_ddp_flag: bool = False, + memstas: Optional[MemStats] = None) -> Tuple[Dict, int, int]: + """search_chunk_configuration + + Search the chunk configuration for a model. + + Args: + model (nn.Module): torch module + search_range_mb (float): searching range in mega byte. + search_interval_byte (int): searching interval in byte. + min_chunk_size_mb (float, optional): the minimum size of a distributed chunk. + filter_exlarge_params (bool, optional): filter extreme large parameters. Defaults to True. + strict_ddp_flag (bool, optional): whether to enable the strict ddp mode. + all parameters keep replicated in this mode. + + Returns: + Tuple[Dict, int]: chunk config (a dict of dp_degree -> chunk init args) and its memory chunk waste in byte. + """ + + if memstas is not None: + param_order = memstas.param_order() + else: + # build the param visited order right now + param_order = OrderedParamGenerator() + for p in model.parameters(): + param_order.append(p) + + search_range_byte = round(search_range_mb * 1024**2) + min_chunk_size_byte = round(min_chunk_size_mb * 1024**2) + assert search_range_byte >= 0 + + params_dict = classify_params_by_dp_degree(param_order, strict_ddp_flag) + size_lcm = np.lcm.reduce(list(params_dict.keys())) + config_dict: Dict[int, Dict] = dict() + total_param_size = 0 + + size_dict: Dict[int, List[int]] = dict() + for dp_degree in params_dict: + params_list = params_dict[dp_degree] + size_list = [_tensor_numel(p, strict_ddp_flag) for p in params_list] + group_acc_size = sum(size_list) + total_param_size += group_acc_size + + # let small parameters keep gathered in CUDA all the time + if group_acc_size < min_chunk_size_byte: + config_dict[dp_degree] = dict(chunk_size=group_acc_size, keep_gathered=True) + else: + size_dict[dp_degree] = size_list + + if filter_exlarge_params: + _filter_exlarge_params(model, size_dict) + + max_size = min_chunk_size_byte + for key in size_dict: + max_size = max(max_size, max(size_dict[key])) + start_size = int(math.ceil(max_size / search_interval_byte) * search_interval_byte) + + min_chunk_waste = float('+inf') + best_chunk_size = start_size + + for chunk_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte): + temp_waste = 0 + for key in size_dict: + temp_waste += _get_unused_byte(size_dict[key], chunk_size) + if temp_waste < min_chunk_waste: + min_chunk_waste = temp_waste + best_chunk_size = chunk_size + + # the chunk size needs to be divided by each groups sizes + best_chunk_size = best_chunk_size + (-best_chunk_size % size_lcm) + for dp_degree in params_dict: + if dp_degree in config_dict: + continue + config_dict[dp_degree] = dict(chunk_size=best_chunk_size, keep_gathered=False) + + return config_dict, total_param_size, min_chunk_waste diff --git a/colossalai/zero/gemini/chunk/utils.py b/colossalai/zero/gemini/chunk/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..71242dcd6d498e537155abbb3d1f882d2b71da66 --- /dev/null +++ b/colossalai/zero/gemini/chunk/utils.py @@ -0,0 +1,52 @@ +from time import time +from typing import Optional + +import torch +import torch.distributed as dist +import torch.nn as nn + +from colossalai.utils import is_ddp_ignored + +from .manager import ChunkManager +from .search_utils import search_chunk_configuration + + +def safe_div(a, b): + if a == 0: + return 0 + return a / b + + +def init_chunk_manager(model: nn.Module, + init_device: Optional[torch.device] = None, + hidden_dim: Optional[int] = None, + verbose: bool = False, + **kwargs) -> ChunkManager: + if hidden_dim: + search_interval_byte = hidden_dim + else: + search_interval_byte = 1024 # defaults to 1kb + kwargs["search_interval_byte"] = search_interval_byte + + dist.barrier() + begin = time() + + config_dict, total_size, wasted_size = search_chunk_configuration(model, **kwargs) + + dist.barrier() + end = time() + span_s = end - begin + mb_size = 1024**2 + total_size /= mb_size + wasted_size /= mb_size + + if verbose and dist.get_rank() == 0: + print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s), + "used number: {:.2f} MB, wasted number: {:.2f} MB\n".format(total_size, wasted_size), + "total wasted percentage is {:.2f}%".format(100 * safe_div(wasted_size, total_size + wasted_size)), + sep='', + flush=True) + dist.barrier() + + chunk_manager = ChunkManager(config_dict, init_device) + return chunk_manager diff --git a/colossalai/zero/gemini/colo_init_context.py b/colossalai/zero/gemini/colo_init_context.py new file mode 100644 index 0000000000000000000000000000000000000000..75f8576ca477977e03e054eb195c49f5e0048c5f --- /dev/null +++ b/colossalai/zero/gemini/colo_init_context.py @@ -0,0 +1,191 @@ +from typing import Any, Dict, Iterator, Optional, Tuple, Union + +import torch +from torch import nn + +from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup +from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses + +# find named_params includes replica + + +def _named_params_with_replica( + module: nn.Module, + prefix: str = '', + recurse: bool = True, +) -> Iterator[Tuple[str, Union[nn.Parameter, ColoTensor]]]: + modules = module.named_modules(prefix=prefix) if recurse else [(prefix, module)] + + for mod_prefix, mod in modules: + for name, val in mod._parameters.items(): + if val is None: + continue + name = mod_prefix + ('.' if mod_prefix else '') + name + yield name, val + + +def _convert_to_coloparam(param: torch.nn.Parameter, + device: torch.device, + dtype=torch.float, + default_pg: Optional[ProcessGroup] = None, + default_dist_spec: Optional[Any] = None) -> ColoParameter: + + if type(param) is ColoParameter: + return param + # detaching tensor is necessary for optimizers. + requires_grad = param.requires_grad + # param is the global tensor. + + if param.device.type == "meta": + colo_param = ColoParameter(param, requires_grad=requires_grad) + else: + colo_param = ColoParameter(param.to(device=device, dtype=dtype), requires_grad=requires_grad) + + # if default_shard_plan exists, shard the param during initialization. + # This can reduce the model size after initialization. + # NOTE() embedding usually can not be correctly sharded. So I use except to handle + # the param that can not be sharded by the default plan + if default_pg is not None: + colo_param.set_process_group(default_pg) + + if default_dist_spec is not None: + try: + colo_param.set_dist_spec(default_dist_spec) + except: + pass + return colo_param + + +def ColoModulize(module): + """ + Replacing the parameters() and named_parameters() with our customized ones + """ + + module._colo_visited = True + + +class ColoInitContext(InsertPostInitMethodToModuleSubClasses): + + def __init__(self, + device: torch.device = torch.device('cpu'), + dtype: torch.dtype = torch.float, + default_pg: Optional[ProcessGroup] = None, + default_dist_spec=None): + """ + Args: + device (torch.device): the device where parameters initialized are resident. Defaults to torch.device('cpu'). + dtype (torch.dtype): the dtype of parameters initialized. Defaults to torch.float. + default_pg (ProcessGroup): the default process group for all initialized parameters. + default_dist_spec: the default distributed specifications. + """ + super().__init__() + self._device = device + self._dtype = dtype + + self._register_colo_modules() + self._default_pg = default_pg + self._default_dist_spec = default_dist_spec + + def _register_colo_modules(self): + from colossalai.nn.parallel.layers import ColoEmbedding, ColoLinear, register_colo_module + register_colo_module(torch.nn.Linear, ColoLinear()) + register_colo_module(torch.nn.Embedding, ColoEmbedding()) + + def _pre_context_exec(self): + pass + + def _post_init_method(self, module: torch.nn.Module, *args, **kwargs): + """ + The function to call at the end of the constructor of each module. + FIXME(fjr) The module may be passed to this function multiple times? + """ + name_list = [] + for name, param in _named_params_with_replica(module): + if type(param) is ColoParameter: + continue + + split = name.rfind('.') + if split >= 0: # param in submodule + module_name = name[:split] + param_name = name[split + 1:] + else: + module_name = '' # param in current module + param_name = name + name_list.append((module_name, param_name)) + + replaced_tensors = dict( + ) # record mapping between (torch.Tensor, ColoTensor) to distinguish the same reference + for module_name, param_name in name_list: + submodule = module.get_submodule(module_name) + param = submodule.get_parameter(param_name) + if param in replaced_tensors: + colo_param = replaced_tensors[param] + else: + colo_param = _convert_to_coloparam(param, self._device, self._dtype, self._default_pg, + self._default_dist_spec) + replaced_tensors[param] = colo_param + delattr(submodule, param_name) + setattr(submodule, param_name, colo_param) + colo_param.shared_param_modules.append(submodule) + + param_number = 0 + meta_param_number = 0 + buffer_number = 0 + meta_buffer_number = 0 + + for param in module.parameters(): + param_number += 1 + meta_param_number += (param.device.type == 'meta') + + for buffer in module.buffers(): + buffer_number += 1 + meta_buffer_number += (buffer.device.type == 'meta') + + if meta_param_number > 0 and meta_param_number != param_number: + raise ValueError("Meta parameters and valued parameters can not be in the same model") + if meta_buffer_number > 0 and meta_buffer_number != buffer_number: + raise ValueError("Meta buffers and valued buffers can not be in the same model") + + if meta_buffer_number == 0: + for buffer in module.buffers(): + buffer.data = buffer.data.to(device=self._device) + + +def post_process_colo_init_ctx(model: torch.nn.Module, + device: torch.device = torch.device('cpu'), + dtype: torch.dtype = torch.float, + default_pg: Optional[ProcessGroup] = None, + default_dist_spec=None): + """post_process_colo_init_ctx + + This function is called after `ColoInitContext`. + + Args: + model (torch.nn.module): the model + device (torch.device, optional): device type of the model params. Defaults to torch.device('cpu'). + dtype (torch.dtype, optional): dtype of the model params. Defaults to torch.float. + default_pg (Optional[ProcessGroup], optional): default process group. Defaults to None. Indicates a DP-only process group. + default_dist_spec (Any, optional): default dist spec of params. Defaults to None. + + Raises: + RuntimeError: raise error if + """ + + torch_params = [] + for n, p in model.named_parameters(): + if not isinstance(p, ColoParameter): + # print(f"{n} is not a ColoParameter. We are going to converting it to ColoParameter") + torch_params.append((n, p)) + + for (n, param) in torch_params: + name_list = n.split('.') + module = model + for i in range(len(name_list) - 1): + module = module._modules[name_list[i]] + delattr(module, name_list[-1]) + setattr(module, name_list[-1], _convert_to_coloparam(param, device, dtype, default_pg, default_dist_spec)) + + del torch_params + for n, p in model.named_parameters(): + if not isinstance(p, ColoTensor): + raise RuntimeError diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py new file mode 100644 index 0000000000000000000000000000000000000000..8a001b114e9a89fd65c45a5e2466b3d31627affc --- /dev/null +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -0,0 +1,747 @@ +import itertools +from collections import OrderedDict +from contextlib import nullcontext +from functools import partial +from typing import Dict, Iterator, List, Optional, Union + +import torch +import torch.distributed as dist +import torch.nn as nn + +from colossalai.checkpoint_io.utils import calculate_tensor_size +from colossalai.logging import get_dist_logger +from colossalai.nn.parallel.data_parallel import ColoDDP, _cast_float, free_storage +from colossalai.tensor import ProcessGroup as ColoProcessGroup +from colossalai.tensor import ReplicaSpec +from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec +from colossalai.tensor.param_op_hook import ColoParamOpHookManager +from colossalai.utils import get_current_device, is_ddp_ignored +from colossalai.utils.model.experimental import LazyTensor + +from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager +from .gemini_hook import GeminiZeROHook +from .gemini_mgr import GeminiManager +from .memory_tracer import MemStats, OrderedParamGenerator +from .utils import get_temp_total_chunk_on_cuda + +try: + from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys +except ImportError: + _EXTRA_STATE_KEY_SUFFIX = '_extra_state' + +__all__ = [ + 'ZeroDDP', + 'GeminiDDP', +] + + +class ZeroDDP(ColoDDP): + """ZeRO DDP for ColoTensor. + Warning: Nested ZeroDDP is not supported now. + It is designed to be used with ChunkManager and GeminiManager. + For more details, see the API reference of ``ChunkManager`` and ``GeminiManager``. + + Args: + module (torch.nn.Module): Module to apply ZeRO-DP. + gemini_manager (GeminiManager): Manages the chunk manager and heterogeneous memory space. + For more details, see the API reference of ``GeminiManager``. + pin_memory (bool): Chunks on CPU Memory use pin-memory. + force_outputs_fp32 (bool): If set to True, outputs will be fp32. Otherwise, outputs will be fp16. + Defaults to False. + strict_ddp_mode (bool): If set to True, there is no tensor sharding, each tensor is replicated. + Defaults to False. Users can set it to True, when they clearly know that they only need DDP. + scatter_after_inference (bool): If set to True, the model will be scattered after inference. This will save memory but slow down the consecutive inference. + """ + + def __init__(self, + module: torch.nn.Module, + gemini_manager: GeminiManager, + pin_memory: bool = False, + force_outputs_fp32: bool = False, + strict_ddp_mode: bool = False, + scatter_after_inference: bool = True) -> None: + self.gemini_manager = gemini_manager + self.chunk_manager: ChunkManager = gemini_manager.chunk_manager + self.force_outputs_fp32 = force_outputs_fp32 + self.param_op_hook = GeminiZeROHook(gemini_manager) + self.fp32_params: List[ColoTensor] = list() + self.fp16_params: List[ColoParameter] = list() + self.overflow_counter = 0 + self.grads_device: Dict[torch.Tensor, torch.device] = dict() + self.param2name: Dict[nn.Parameter, str] = dict() + self.name2param: Dict[str, nn.Parameter] = dict() + self.scatter_after_inference = scatter_after_inference + + self._logger = get_dist_logger() + + if self.gemini_manager._premade_memstats_: + # build chunk in param runtime visited order. + param_order = self.gemini_manager.memstats()._param_runtime_order + else: + # build chunk in param initialized order. + # Note: in this way, it can not get filter unused params during runtime. + param_order = OrderedParamGenerator() + for p in module.parameters(): + param_order.append(p) + + self._init_chunks(param_order=param_order, + strict_ddp_mode=strict_ddp_mode, + cpu_offload=self.gemini_manager.policy_name != 'cuda', + pin_memory=pin_memory) + + for name, param in module.named_parameters(): + self.param2name[param] = name + for m_name, m_var in module.named_modules(): + for p_name, p_var in m_var.named_parameters(recurse=False): + param_name = m_name + '.' + p_name if m_name else p_name + self.name2param[param_name] = p_var + super().__init__(module, process_group=ColoProcessGroup()) + self._cast_buffers() + + def _post_forward(self): + """This function is only triggered for inference. + """ + access_list = list(self.chunk_manager.accessed_chunks) + # we need to scatter all accessed chunks and move them to their original places + for chunk in access_list: + if chunk.keep_gathered: + self.chunk_manager.fake_release_chunk(chunk) + else: + assert chunk.can_release + self.chunk_manager.release_chunk(chunk) + first_param = next(iter(chunk.tensors_info)) + self.chunk_manager.move_chunk(chunk, self.grads_device[first_param]) + assert self.chunk_manager.accessed_mem == 0 + + def forward(self, *args, **kwargs): + # check whether we are in a inference mode + grad_flag = torch.is_grad_enabled() + if not grad_flag: + assert not self.gemini_manager.need_warmup or not self.gemini_manager.is_warmup( + ), "You should run a completed iteration as your warmup iter" + + args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half) + self.module.zero_grad(set_to_none=True) + if not grad_flag: + outputs = self._inference_forward(*args, **kwargs) + else: + self.gemini_manager.pre_iter(*args) + with ColoParamOpHookManager.use_hooks(self.param_op_hook): + outputs = self.module(*args, **kwargs) + + if self.force_outputs_fp32: + return _cast_float(outputs, torch.float) + return outputs + + def _inference_forward(self, *args, **kwargs): + """This function is only triggered for inference. + """ + fwd_ctx = ColoParamOpHookManager.use_hooks(self.param_op_hook) + if not self.scatter_after_inference: + # gather all chunks + for chunk in self.chunk_manager.get_chunks(self.fp16_params): + self.chunk_manager.access_chunk(chunk) + fwd_ctx = nullcontext() + with fwd_ctx: + outputs = self.module(*args, **kwargs) + if self.scatter_after_inference: + # scatter chunks + self._post_forward() + # reset all recorded attributes + self.gemini_manager.reset_attributes() + return outputs + + def _setup_grads_ptr(self): + for p in self.module.parameters(): + if is_ddp_ignored(p): + continue + p.grad = None + + def _pre_backward(self): + # set a visit label for all parameters + # the label is used to check whether the parameter is correctly reduced + for param in self.param2name: + if not is_ddp_ignored(param): + setattr(param, "_gemini_reduced", False) + + def _post_backward(self): + if self.chunk_manager.accessed_mem != 0: + error_params = ["Reduction failed at followed parameters:"] + for param in self.param2name: + if not is_ddp_ignored(param) and not getattr(param, "_gemini_reduced"): + error_params.append(self.param2name[param]) + error_str = "\n\t".join(error_params) + raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.", + "The most possible reason is that the model is not compatible with ZeroDDP.\n", + f"{error_str}") + self._setup_grads_ptr() + self._logger.debug( + f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}' + ) + self.gemini_manager.post_iter() + + def backward(self, loss: torch.Tensor): + self._pre_backward() + with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook): + loss.backward() + self._post_backward() + + def backward_by_grad(self, tensor, grad): + with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook): + torch.autograd.backward(tensor, grad) + self._post_backward() + + def grad_handle(self, p, grad): + empty_grad = torch.empty_like(grad) + free_storage(empty_grad) + with torch._C.DisableTorchFunction(): + chunk = self.chunk_manager.get_chunk(p) + if chunk.tensors_info[p].state != TensorState.HOLD_AFTER_BWD: + raise RuntimeError(f"Parameter `{self.param2name[p]}` failed at the gradient reduction. " + "Some unsupported torch function is operated upon this parameter.") + self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE) + chunk.copy_tensor_to_chunk_slice(p, grad) + reduced = self.chunk_manager.reduce_chunk(chunk) + if reduced: + if chunk.is_gathered: + chunk.cuda_global_chunk.div_(chunk.pg_size) + else: + chunk.cuda_shard.div_(chunk.pg_size) + # check overflow elements + self.overflow_counter += chunk.has_inf_or_nan + # record l2 norm for gradient clipping + if chunk.l2_norm_flag: + chunk.set_l2_norm() + self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True) + return empty_grad + + def zero_grad(self, set_to_none: bool = False) -> None: + self.module.zero_grad(set_to_none=True) + + def set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None: + for tensor in chunk.get_tensors(): + self.grads_device[tensor] = device + + def state_dict(self, + destination=None, + prefix='', + keep_vars=False, + only_rank_0: bool = True, + dtype: torch.dtype = torch.float16): + """Returns a dictionary containing a whole state of the module. + + Both parameters and persistent buffers (e.g. running averages) are included. + Keys are corresponding parameter and buffer names. + Parameters and buffers set to ``None`` are not included. + + Warning: The non strict state dict would ignore the parameters if the tensors of the parameters + are shared with other parameters which have been included in the dictionary. + When you need to load the state dict, you should set the argument `strict` to False. + + Returns: + dict: + a dictionary containing a whole state of the module + """ + if destination is None: + destination = OrderedDict() + destination._metadata = OrderedDict() + destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version) + self._save_to_state_dict(destination, prefix, keep_vars, only_rank_0, dtype) + + for hook in self._state_dict_hooks.values(): + hook_result = hook(self, destination, prefix, local_metadata) + if hook_result is not None: + destination = hook_result + return destination + + def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool, dtype: torch.dtype = torch.float16) -> Dict: + """ + get gathered chunk content. + + Args: + chunk (Chunk): a chunk + only_rank_0 (bool): whether to only save data on rank 0 + + Returns: + Dict: a dict whose key is param name and value is param with correct payload + """ + # save parameters + chunk_to_save_data = dict() + temp_chunk = get_temp_total_chunk_on_cuda(chunk) + if torch.is_floating_point(temp_chunk): + temp_chunk = temp_chunk.to(dtype) + for tensor, tensor_info in chunk.tensors_info.items(): + record_tensor = torch.empty([0]) + record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0) + if record_flag: + record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu() + + assert tensor not in chunk_to_save_data + chunk_to_save_data[tensor] = record_tensor + + del temp_chunk + return chunk_to_save_data + + def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool, + dtype: torch.dtype) -> Dict: + """ + get param content from chunks. + + Args: + param_list (_type_): a list of torch.nn.Parameters + only_rank_0 (_type_): _description_ + + Returns: + Dict: a dict whose key is param name and value is param with correct payload + """ + # save parameters + param_to_save_data = dict() + chunk_list = self.chunk_manager.get_chunks(param_list) + for chunk in chunk_list: + param_to_save_data.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype)) + return param_to_save_data + + def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True, dtype=torch.float16): + r"""Saves module state to `destination` dictionary, containing a state + of the module, but not its descendants. This is called on every + submodule in :meth:`~torch.nn.Module.state_dict`. + + In rare cases, subclasses can achieve class-specific behavior by + overriding this method with custom logic. + + Args: + destination (dict): a dict where state will be stored + prefix (str): the prefix for parameters and buffers used in this + module + """ + assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now." + + # get copies of fp32 parameters in CPU + # as memory of fp16_params may be reused by grad, it's not reliable, we should use fp32_params and convert to fp16 + param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0, dtype) + # get the mapping between copies and fp16 parameters + p_mapping = dict() + for p, fp32_p in zip(self.fp16_params, self.fp32_params): + name = self.param2name[p] + assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name) + record_parameter = param_to_save_data[fp32_p] + p_mapping[p] = record_parameter + for name, param in self.name2param.items(): + if param is not None: + if is_ddp_ignored(param): + # deal with ddp ignored parameters + destination[prefix + name] = param if keep_vars else param.detach() + else: + destination[prefix + name] = p_mapping[param] + del p_mapping + del param_to_save_data + + # save all buffers + for name, buf in self.named_buffers(): + if buf is not None and name not in self._non_persistent_buffers_set: + destination[prefix + name] = buf if keep_vars else buf.detach() + # save extra states + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(self.__class__, "get_extra_state", + torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: + destination[extra_state_key] = self.get_extra_state() + + def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True): + r"""Copies parameters and buffers from :attr:`state_dict` into + this module and its descendants. If :attr:`strict` is ``True``, then + the keys of :attr:`state_dict` must exactly match the keys returned + by this module's :meth:`~torch.nn.Module.state_dict` function. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + strict (bool, optional): whether to strictly enforce that the keys + in :attr:`state_dict` match the keys returned by this module's + :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` + + Returns: + ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: + * **missing_keys** is a list of str containing the missing keys + * **unexpected_keys** is a list of str containing the unexpected keys + + Note: + If a parameter or buffer is registered as ``None`` and its corresponding key + exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a + ``RuntimeError``. + """ + missing_keys: List[str] = [] + unexpected_keys: List[str] = [] + error_msgs: List[str] = [] + + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, '_metadata', None) + state_dict = state_dict.copy() + if metadata is not None: + # mypy isn't aware that "_metadata" exists in state_dict + state_dict._metadata = metadata # type: ignore[attr-defined] + + prefix = '' + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + self._load_from_state_dict(state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) + + if strict: + if len(unexpected_keys) > 0: + error_msgs.insert( + 0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join( + '"{}"'.format(k) for k in unexpected_keys))) + if len(missing_keys) > 0: + error_msgs.insert( + 0, 'Missing key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in missing_keys))) + + if len(error_msgs) > 0: + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( + self.__class__.__name__, "\n\t".join(error_msgs))) + return _IncompatibleKeys(missing_keys, unexpected_keys) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs): + r"""Copies parameters and buffers from :attr:`state_dict` into only + this module, but not its descendants. This is called on every submodule + in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this + module in input :attr:`state_dict` is provided as :attr:`local_metadata`. + For state dicts without metadata, :attr:`local_metadata` is empty. + Subclasses can achieve class-specific backward compatible loading using + the version number at `local_metadata.get("version", None)`. + + .. note:: + :attr:`state_dict` is not the same object as the input + :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So + it can be modified. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + prefix (str): the prefix for parameters and buffers used in this + module + local_metadata (dict): a dict containing the metadata for this module. + See + strict (bool): whether to strictly enforce that the keys in + :attr:`state_dict` with :attr:`prefix` match the names of + parameters and buffers in this module + missing_keys (list of str): if ``strict=True``, add missing keys to + this list + unexpected_keys (list of str): if ``strict=True``, add unexpected + keys to this list + error_msgs (list of str): error messages should be added to this + list, and will be reported together in + :meth:`~torch.nn.Module.load_state_dict` + """ + for hook in self._load_state_dict_pre_hooks.values(): + hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + persistent_buffers = {k: v for k, v in self.named_buffers() if k not in self._non_persistent_buffers_set} + local_name_params = itertools.chain(self.named_parameters(), persistent_buffers.items()) + local_state = {k: v for k, v in local_name_params if v is not None} + + def load(param_name, dest_tensor, copy_func): + state_key = prefix + param_name + if state_key in state_dict: + input_param = state_dict[state_key] + # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ + if len(dest_tensor.shape) == 0 and len(input_param.shape) == 1: + input_param = input_param[0] + if input_param.shape != dest_tensor.shape: + # local shape should match the one in checkpoint + error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' + 'the shape in current model is {}.'.format(state_key, input_param.shape, + dest_tensor.shape)) + return + try: + with torch.no_grad(): + copy_func(input_param) + except Exception as ex: + error_msgs.append('While copying the parameter named "{}", ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}, ' + 'an exception occurred : {}.'.format(state_key, dest_tensor.size(), + input_param.size(), ex.args)) + elif strict: + missing_keys.append(state_key) + + def load_fp32_parameter(chunk_slice, data): + chunk_slice.copy_(data.flatten()) + + for name, param in self.named_parameters(): + if is_ddp_ignored(param): + # deal with ddp ignored parameters + load(name, param, param.copy_) + + fp32_to_name = dict() + for p, fp32_p in zip(self.fp16_params, self.fp32_params): + if p is not None: + name = self.param2name[p] + fp32_to_name[fp32_p] = name + + chunk_list = self.chunk_manager.get_chunks(self.fp32_params) + for chunk in chunk_list: + temp_chunk = get_temp_total_chunk_on_cuda(chunk) + + for tensor, tensor_info in chunk.tensors_info.items(): + parameter_name = fp32_to_name[tensor] + parameter_slice = temp_chunk[tensor_info.offset:tensor_info.end] + load(parameter_name, tensor, partial(load_fp32_parameter, parameter_slice)) + + if chunk.is_gathered: + chunk.cuda_global_chunk.copy_(temp_chunk) + elif chunk.cuda_shard is not None: + chunk.cuda_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end]) + else: + chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end]) + + del temp_chunk + + for chunk_32 in chunk_list: + chunk_16 = chunk_32.paired_chunk + assert chunk_16 is not None + chunk_16.optim_update() + + for name, buf in persistent_buffers.items(): + if buf is not None: + load(name, buf, buf.copy_) + + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(self.__class__, "set_extra_state", + torch.nn.Module.set_extra_state) is not torch.nn.Module.set_extra_state: + if extra_state_key in state_dict: + self.set_extra_state(state_dict[extra_state_key]) + elif strict: + missing_keys.append(extra_state_key) + elif strict and (extra_state_key in state_dict): + unexpected_keys.append(extra_state_key) + + if strict: + for key in state_dict.keys(): + if key.startswith(prefix) and key != extra_state_key: + input_name = key[len(prefix):] + if input_name not in local_state: + unexpected_keys.append(key) + + def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool): + ddp_pg = ColoProcessGroup() + for p in param_order.generate(): + self._preprocess_param(p) + assert type(p) is ColoParameter + + # gather sharded parameters in the strict ddp mode + if strict_ddp_mode: + if not p.is_replicate(): + p.set_dist_spec(ReplicaSpec()) + p.set_process_group(pg=ddp_pg) + + # ignore the parameters with no gradient + if not p.requires_grad: + self.set_params_to_ignore([p]) + + # move ignored parameters to CUDA + if is_ddp_ignored(p): + p.data = p.data.to(device=get_current_device(), dtype=torch.float16) + continue + + # create a fp32 parameter + fp32_data = p.data.float() + fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group)) + # create a fp16 parameter + p.data = p.data.half() + + # register the fp16 parameter and fp32 parameter in the chunk manager + dp_world_size = p.process_group.dp_world_size() + self.chunk_manager.register_tensor(tensor=p, + group_type='fp16_param', + config_key=dp_world_size, + cpu_offload=cpu_offload, + pin_memory=pin_memory) + self.chunk_manager.register_tensor(tensor=fp32_p, + group_type='fp32_param', + config_key=dp_world_size, + cpu_offload=cpu_offload, + pin_memory=pin_memory) + + self.fp16_params.append(p) + self.fp32_params.append(fp32_p) + self.grads_device[p] = self.gemini_manager.default_device + + self.chunk_manager.close_all_groups() + + for p, fp32_p in zip(self.fp16_params, self.fp32_params): + chunk_16 = self.chunk_manager.get_chunk(p) + chunk_32 = self.chunk_manager.get_chunk(fp32_p) + chunk_32.init_pair(chunk_16) + + # keep gathered chunks are in CUDA + if chunk_16.keep_gathered: + self.grads_device[p] = get_current_device() + + def _cast_buffers(self): + for buffer in self.module.buffers(): + if isinstance(buffer, LazyTensor): + buffer.materialize() + buffer.data = buffer.cuda() + if torch.is_floating_point(buffer): + buffer.data = buffer.half() + + def _preprocess_param(self, p: Union[nn.Parameter, ColoParameter, 'LazyTensor']) -> None: + """Convert parameter to ColoParameter in-place. + Args: + p (Union[nn.Parameter, ColoParameter, LazyTensor]): parameter to be converted + """ + if type(p) is ColoParameter: + # model is initialized with ColoInitContext + return + requires_grad = p.requires_grad + if isinstance(p, LazyTensor): + # model is initialized with LazyInitContext + p.materialize() + p.__class__ = ColoParameter + p.__init__(p, requires_grad=requires_grad) + + def state_dict_shard(self, + prefix: str = '', + keep_vars: bool = False, + max_shard_size: int = 1024, + only_rank_0: bool = True, + dtype: torch.dtype = torch.float16) -> Iterator[OrderedDict]: + """Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``. + + Both parameters and persistent buffers (e.g. running averages) are included. + Keys are corresponding parameter and buffer names. + Parameters and buffers set to ``None`` are not included. + + Args: + prefix (str, optional): the prefix for parameters and buffers used in this + module. Defaults to ''. + keep_vars (bool, optional): whether to keep variables. Defaults to False. + max_shard_size (int, optional): max size of state dict shard (in MB). Defaults to 1024. + only_rank_0 (bool, optional): only get data on rank0. Defaults to True. + + + Yields: + Iterator[OrderedDict]: A generator of state dict shard + """ + sharder = _StateDictSharder(max_shard_size) + + # get the mapping between copies and fp16 parameters + fp16_to_fp32 = dict() + for p, fp32_p in zip(self.fp16_params, self.fp32_params): + fp16_to_fp32[p] = fp32_p + + # key is fp32 param, and value is gathered param on CPU + gathered_param_buffer = dict() + for name, param in self.name2param.items(): + if param is not None: + if is_ddp_ignored(param): + # deal with ddp ignored parameters + gathered_param = param if keep_vars else param.detach() + else: + # as memory of fp16 param may be reused, we should use fp32 param and then convert to fp16 + fp32_param = fp16_to_fp32[param] + if fp32_param not in gathered_param_buffer: + chunk = self.chunk_manager.get_chunk(fp32_param) + gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype)) + gathered_param = gathered_param_buffer.pop(fp32_param) + + block = sharder.append(prefix + name, gathered_param) + if block is not None: + yield block + + del fp16_to_fp32 + del gathered_param_buffer + + # save all buffers + for name, buf in self.named_buffers(): + if buf is not None and name not in self._non_persistent_buffers_set: + buffer = buf if keep_vars else buf.detach() + block = sharder.append(prefix + name, buffer) + if block is not None: + yield block + # save extra states + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(self.__class__, "get_extra_state", + torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: + extra_state = self.get_extra_state() + block = sharder.append(extra_state_key, extra_state) + if block is not None: + yield block + + yield sharder.current_block + + +class _StateDictSharder: + + def __init__(self, max_shard_size: int) -> None: + self.max_shard_size = max_shard_size + self.current_block = OrderedDict() + self.current_block_size = 0 + + def append(self, name: str, tensor: torch.Tensor) -> Optional[OrderedDict]: + tensor_size = calculate_tensor_size(tensor) + ret_block = None + if self.current_block_size + tensor_size > self.max_shard_size: + ret_block = self.current_block + self.current_block = OrderedDict() + self.current_block_size = 0 + self.current_block[name] = tensor + self.current_block_size += tensor_size + return ret_block + + +class GeminiDDP(ZeroDDP): + + def __init__(self, + module: torch.nn.Module, + device: torch.device, + placement_policy: str = "cpu", + pin_memory: bool = False, + force_outputs_fp32: bool = False, + strict_ddp_mode: bool = False, + scatter_after_inference: bool = True, + search_range_mb: int = 32, + hidden_dim: Optional[int] = None, + min_chunk_size_mb: float = 32, + memstats: Optional[MemStats] = None, + verbose: bool = False) -> None: + """ + A torch.Module wrapper using ZeRO-DP and Gemini. + ZeRO is for parallel. Gemini is for memory management. + WARNING: The class will modify the module inline! + + Example: + model is initialized under the context of ColoInitContext + >>> model = GeminiDDP(model, torch.cuda.current_device(), "cuda") + >>> logits = model(x) + >>> loss = criterion(logits, labels) + >>> model.backward(loss) + + Args: + module (torch.nn.Module): the model to be wrapped. + device (torch.device): device to place the model. + placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu". + pin_memory (bool, optional): use pin memory on CPU. Defaults to False. + force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False. + search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32. + hidden_dim (int, optional): the hidden dimension of DNN. + Users can provide this argument to speed up searching. + If users do not know this argument before training, it is ok. We will use a default value 1024. + min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte. + If the aggregate size of parameters is still smaller than the minimum chunk size, + all parameters will be compacted into one small chunk. + memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer. + """ + # some ugly hotfix for the compatibility with Lightning + if search_range_mb is None: + search_range_mb = 32 + + chunk_manager = init_chunk_manager(model=module, + init_device=device, + hidden_dim=hidden_dim, + search_range_mb=search_range_mb, + min_chunk_size_mb=min_chunk_size_mb, + strict_ddp_flag=strict_ddp_mode, + verbose=verbose) + gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats) + super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode, + scatter_after_inference) diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..dbc2924858e6371de61fafe0160ace91b4ff182f --- /dev/null +++ b/colossalai/zero/gemini/gemini_hook.py @@ -0,0 +1,68 @@ +from contextlib import contextmanager +from enum import Enum +from functools import partial +from typing import List + +import torch + +from colossalai.tensor.param_op_hook import ColoParamOpHook +from colossalai.utils import is_ddp_ignored +from colossalai.zero.gemini import TensorState +from colossalai.zero.gemini.gemini_mgr import GeminiManager + + +class TrainingPhase(Enum): + FORWARD = 0 + BACKWARD = 1 + + +class GeminiZeROHook(ColoParamOpHook): + + def __init__(self, gemini_manager: GeminiManager) -> None: + super().__init__() + self._gemini_manager = gemini_manager + self._chunk_manager = gemini_manager.chunk_manager + self._training_phase = TrainingPhase.FORWARD + + def pre_op(self, params): + params = [p for p in params if not is_ddp_ignored(p)] + chunks = self._chunk_manager.get_chunks(params) + for p in params: + self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE) + self._gemini_manager.sample_overall_data() + self._gemini_manager.adjust_layout(chunks) + for chunk in chunks: + self._chunk_manager.access_chunk(chunk) + + # record cuda model data of the current OP + self._gemini_manager.record_model_data_volume() + + def post_op(self, params): + params = [p for p in params if not is_ddp_ignored(p)] + for p in params: + tensor_state = TensorState.HOLD if self._training_phase == TrainingPhase.FORWARD or not p.requires_grad else TensorState.HOLD_AFTER_BWD + self._chunk_manager.trans_tensor_state(p, tensor_state) + + def pre_forward(self, params: List[torch.Tensor]) -> None: + self.pre_op(params) + + def post_forward(self, params: List[torch.Tensor]) -> None: + self.post_op(params) + + def pre_backward(self, params: List[torch.Tensor]) -> None: + self.pre_op(params) + + def post_backward(self, params: List[torch.Tensor]) -> None: + self.post_op(params) + + @contextmanager + def switch_training_phase(self, training_phase: TrainingPhase = TrainingPhase.BACKWARD): + old_training_phase = self._training_phase + try: + self._training_phase = training_phase + yield + finally: + self._training_phase = old_training_phase + + switch_to_backward = switch_training_phase + switch_to_forward = partial(switch_to_backward, training_phase=TrainingPhase.FORWARD) diff --git a/colossalai/zero/gemini/gemini_mgr.py b/colossalai/zero/gemini/gemini_mgr.py new file mode 100644 index 0000000000000000000000000000000000000000..c38e6eff840dd8f9cb2681b2524609432bc47a57 --- /dev/null +++ b/colossalai/zero/gemini/gemini_mgr.py @@ -0,0 +1,164 @@ +import functools +from time import time +from typing import List, Optional, Tuple + +import torch + +from .chunk import Chunk, ChunkManager +from .memory_tracer import ChunkMemStatsCollector, MemStats +from .placement_policy import PlacementPolicyFactory + + +class GeminiManager: + """ + Stateful Tensor Manager, inspired from PatrickStar + + PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management + https://arxiv.org/abs/2108.05818 + + Args: + placement_policy (str): Which device to place *held* tensors. It can be 'cpu', 'cuda' and 'auto'. + If it's 'cpu', parameters, gradients and optimizer states will be offloaded to CPU, which means min CUDA memory will be used. + If it's 'cuda', they won't be offloaded, which means max CUDA memory will be used. + If it's 'auto', they are moving dynamically based on CPU and CUDA memory usage. It will utilize heterogeneous memory space evenly and well. + Note that 'auto' policy can only work well when no other processes use CUDA during your training. + chunk_manager (ChunkManager): A ``ChunkManager`` instance. + memstats (MemStats, optional): a mem stats collected by a runtime mem tracer. if None then GeminiManager will collect it during a warmup iteration. + """ + + def __init__(self, placement_policy: str, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None: + + assert placement_policy in PlacementPolicyFactory.get_policy_names() + self.policy_name = placement_policy + policy_cls = PlacementPolicyFactory.create(placement_policy) + self._chunk_manager = chunk_manager + + self._premade_memstats_ = memstats is not None + self._memstats = memstats + self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager, + self._memstats) if policy_cls.need_mem_stats else None + self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector) + self._compute_list: List[Tuple[Chunk, ...]] = [] + self._compute_idx: int = -1 + + self._h2d_volume = 0 + self._d2h_volume = 0 + self._layout_time = 0 + self._evict_time = 0 + self._warmup = True + self._comp_cuda_demand_time = 0 + + def reset_attributes(self): + self._compute_idx = -1 + self._h2d_volume = 0 + self._d2h_volume = 0 + self._layout_time = 0 + self._evict_time = 0 + self._comp_cuda_demand_time = 0 + + @property + def need_warmup(self) -> bool: + return self.policy_name in ('auto', 'const') + + def is_warmup(self): + return self._warmup + + def memstats(self): + """memstats + + get the memory statistics during training. + The stats could be collected by a runtime memory tracer, or collected by the GeminiManager. + Note, for the latter, you can not access the memstats before warmup iteration finishes. + """ + if self._premade_memstats_: + return self._memstats + else: + assert not self._warmup, "Gemini Manager has memstats after warm up! Now is during warmup." + return self._mem_stats_collector._memstats + + def pre_iter(self, *args): + if self._mem_stats_collector and self._warmup: + self._mem_stats_collector.start_collection() + + def post_iter(self): + """This function must be called when each iteration finishes + """ + if self._mem_stats_collector and self._warmup: + self._mem_stats_collector.finish_collection() + self._warmup = False + self.reset_attributes() + + def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None: + """ Adjust the layout of stateful tensors according to the information provided + by mem_stats_collector, which should belongs to a Sharded Model. + """ + # find stateful tensor in state COMPUTE + start = time() + self._record_chunks_order(chunks) + cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks) + self._layout_time += time() - start + + vol, evict_time = self._placement_policy.evict_tensors(can_evict_chunks=hold_cuda_tensor_list, + cuda_demand=cuda_demand, + warmup=self._warmup, + compute_list=self._compute_list, + compute_idx=self._compute_idx) + + self._d2h_volume += vol + self._evict_time += evict_time + # move COMPUTE tensors to CUDA + self._h2d_volume += cuda_demand + + @functools.lru_cache(maxsize=None) + def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk, ...]): + start = time() + cuda_demand = 0 + for chunk in chunks: + if chunk.device_type == 'cuda': + if chunk.is_gathered: + pass + else: + cuda_demand += chunk.chunk_mem - chunk.shard_mem + elif chunk.device_type == 'cpu': + cuda_demand += chunk.chunk_mem + else: + raise RuntimeError + self._comp_cuda_demand_time += time() - start + + can_evict_chunks = self._chunk_manager.get_cuda_movable_chunks() + return cuda_demand, can_evict_chunks + + def _record_chunks_order(self, chunks: Tuple[Chunk, ...]) -> None: + self._compute_idx += 1 + if self._warmup and self._placement_policy.need_mem_stats: + self._compute_list.append(chunks) + + @property + def default_device(self): + return self._placement_policy.get_default_device() + + def sample_overall_data(self): + if self._mem_stats_collector: + self._mem_stats_collector.sample_overall_data() + + def record_model_data_volume(self): + if self._mem_stats_collector: + self._mem_stats_collector.record_model_data_volume() + + @property + def chunk_manager(self): + return self._chunk_manager + + @property + def cuda_margin_mem(self) -> Optional[float]: + if self._mem_stats_collector: + return self._mem_stats_collector.cuda_margin_mem + return None + + @property + def is_cuda_margin_mem_avail(self) -> bool: + return self._placement_policy.need_mem_stats + + @staticmethod + def get_default_device(policy_name: str) -> torch.device: + return PlacementPolicyFactory.get_default_device(policy_name) diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..71c4f65cb8d2942e9d0022628f84234c1842c11b --- /dev/null +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -0,0 +1,335 @@ +# this code is inspired by the DeepSpeed library and implemented with our own design from scratch +import math +import warnings +from enum import Enum +from typing import Any, Dict, Set, Tuple + +import torch +import torch.distributed as dist +from torch.nn import Parameter +from torch.optim import Optimizer + +from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler +from colossalai.logging import get_dist_logger +from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam +from colossalai.utils import disposable, get_current_device, is_ddp_ignored + +from .chunk import Chunk, ChunkManager +from .gemini_ddp import ZeroDDP + +__all__ = ['ZeroOptimizer', 'GeminiAdamOptimizer'] + +_AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam} + + +class OptimState(Enum): + SCALED = 0 + UNSCALED = 1 + + +class ZeroOptimizer(ColossalaiOptimizer): + """A wrapper for optimizer. ``ZeroDDP`` and ``ZeroOptimizer`` implement Zero Redundancy Optimizer (ZeRO state-3). + + Note: + You must use ``ZeroDDP`` with ``ZeroOptimizer``. + + Note: + Make sure you set ``placement_policy`` of ``GeminiManager`` to `"auto"`, + if you set ``gpu_margin_mem_ratio > 0``. + + Args: + optim (Optimizer): An Optimizer instance. + module (ZeroDDP): A ``ZeroDDP`` instance. + gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward) + which will be used when using hybrid CPU optimizer. + This argument is meaningless when `placement_policy` of `GeminiManager` is not "auto". + Defaults to 0.0. + initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32. + min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1. + growth_factor (float, optional): Growth_factor used by DynamicGradScaler. Defaults to 2. + backoff_factor (float, optional): Backoff_factor used by DynamicGradScaler. Defaults to 0.5. + growth_interval (float, optional): Growth_interval used by DynamicGradScaler. Defaults to 1000. + hysteresis (float, optional): Hysteresis used by DynamicGradScaler. Defaults to 2. + max_scale (int, optional): Max_scale used by DynamicGradScaler. Defaults to 2**32. + clipping_norm (float, optional): The norm value used to clip gradient. Defaults to 0.0. + norm_type (float, optional): The type of norm used for gradient clipping. Currently, only L2-norm (norm_type=2.0) + is supported in ZeroOptimizer. Defaults to 2.0. + verbose (bool, optional): Whether to print verbose information, including grad overflow info. Defaults to False. + """ + + def __init__(self, + optim: Optimizer, + module: ZeroDDP, + gpu_margin_mem_ratio: float = 0.0, + initial_scale: float = 2**32, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + clipping_norm: float = 0.0, + norm_type: float = 2.0, + verbose: bool = False, + **defaults: Any): + super().__init__(optim) + assert isinstance(module, ZeroDDP) + assert type(optim) in _AVAIL_OPTIM_LIST, "You should use an optimizer in the available list:\n" \ + f"{_AVAIL_OPTIM_LIST}" + self.module = module + self.gemini_manager = module.gemini_manager + self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager + self.optim_state = OptimState.UNSCALED + self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict() + self.param_to_chunk32: Dict[Parameter, Chunk] = dict() + self.chunk16_set: Set[Chunk] = set() + self.clipping_flag = clipping_norm > 0.0 + self.max_norm = clipping_norm + self.verbose = verbose + + if self.clipping_flag: + assert norm_type == 2.0, "ZeroOptimizer only supports L2 norm now" + + ddp_param_list = [] + for name, param in module.named_parameters(): + if is_ddp_ignored(param): + if param.requires_grad: + warnings.warn(f"Parameter `{name}` is ignored by DDP but requires gradient! " + "You should handle its optimizer update by yourself!") + else: + ddp_param_list.append(param) + + for p, fp32_p in zip(ddp_param_list, module.fp32_params): + chunk_16 = self.chunk_manager.get_chunk(p) + if chunk_16 not in self.chunk16_set: + chunk_16.l2_norm_flag = self.clipping_flag + self.chunk16_set.add(chunk_16) + + self.__init__optimizer() + + # Grad scaler + self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale) + self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device()) + self._logger = get_dist_logger() + + self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio) + assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f'gpu_margin_mem_ratio must >=0.0 and <=1.0' + # Only move fp32 shards from CPU to GPU when user allows and inner optimizer is valid + # Inner optimizer must support optimizing hybrid (CPU and CUDA) tensors, + # and it must set `num_fp32_shards_per_param` correctly + self._should_move_fp32_params_h2d: bool = self.gemini_manager.is_cuda_margin_mem_avail and self.gpu_margin_mem_ratio > 0.0 and getattr( + optim, 'num_fp32_shards_per_param', 0) >= 2 + if self.gpu_margin_mem_ratio > 0.0 and not self.gemini_manager.is_cuda_margin_mem_avail: + self._logger.warning(f'gpu_margin_mem_ratio is meaningless when placement_policy is not "auto"', ranks=[0]) + + self._register_states = disposable(self._register_states_) + + def _set_grad_ptr(self): + for group in self.param_groups: + for fake_param in group['params']: + chunk32 = self.param_to_chunk32[fake_param] + begin, end = self.param_to_range[fake_param] + chunk16 = chunk32.paired_chunk + + fake_param.data = chunk16.payload[begin:end] + fake_param.grad = fake_param.data + fake_param.data = chunk32.payload[begin:end] + + def _update_fp16_params(self): + none_tensor = torch.empty([0]) + for group in self.param_groups: + for fake_param in group['params']: + assert fake_param.grad is None + fake_param.data = none_tensor.to(fake_param.device) + + for chunk16 in self.chunk16_set: + chunk16.optim_update() + + def _check_overflow(self): + # clear previous overflow record + self._found_overflow.fill_(self.module.overflow_counter) + + # all-reduce across global group + dist.all_reduce(self._found_overflow) + + return self._found_overflow.item() > 0 + + def _clear_global_norm(self) -> None: + for c16 in self.chunk16_set: + c16.l2_norm = None + + def _calc_global_norm(self) -> float: + norm_sqr: float = 0.0 + group_to_norm = dict() + for c16 in self.chunk16_set: + assert c16.l2_norm is not None + + if c16.is_gathered: + norm_sqr += c16.l2_norm + else: + # this chunk is sharded, use communication to collect total norm + if c16.torch_pg not in group_to_norm: + group_to_norm[c16.torch_pg] = 0.0 + group_to_norm[c16.torch_pg] += c16.l2_norm + + c16.l2_norm = None # clear l2 norm + + comm_buffer = torch.zeros(1, dtype=torch.float, device=get_current_device()) + for group, part_norm in group_to_norm.items(): + comm_buffer.fill_(part_norm) + dist.all_reduce(comm_buffer, group=group) + norm_sqr += comm_buffer.item() + + global_norm = math.sqrt(norm_sqr) + return global_norm + + def _get_combined_scale(self): + loss_scale = 1 + + if self.optim_state == OptimState.SCALED: + loss_scale = self.loss_scale + self.optim_state = OptimState.UNSCALED + + combined_scale = loss_scale + if self.clipping_flag: + total_norm = self._calc_global_norm() + clip = ((total_norm / loss_scale) + 1e-6) / self.max_norm + if clip > 1: + combined_scale = clip * loss_scale + + if combined_scale == 1: + return -1 + else: + return combined_scale + + @property + def loss_scale(self): + return self.grad_scaler.scale.item() + + def zero_grad(self, *args, **kwargs): + self.module.overflow_counter = 0 + return self.optim.zero_grad(set_to_none=True) + + def step(self, *args, **kwargs): + self._maybe_move_fp32_params() + self._set_grad_ptr() + + found_inf = self._check_overflow() + if found_inf: + self.optim_state = OptimState.UNSCALED # no need to unscale grad + self.grad_scaler.update(found_inf) # update gradient scaler + if self.verbose: + self._logger.info(f'Found overflow. Skip step') + self._clear_global_norm() # clear recorded norm + self.zero_grad() # reset all gradients + self._update_fp16_params() + return + + # get combined scale. combined scale = loss scale * clipping norm + # so that gradient = gradient / combined scale + combined_scale = self._get_combined_scale() + self.grad_scaler.update(found_inf) + + ret = self.optim.step(div_scale=combined_scale, *args, **kwargs) + self._register_states() + self.zero_grad() + self._update_fp16_params() + return ret + + def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0): + raise NotImplementedError + + def backward(self, loss: torch.Tensor): + loss = self.loss_scale * loss + self.optim_state = OptimState.SCALED + self.module.backward(loss) + + def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor): + # This function is called except the last stage of pipeline parallel + # It receives the scaled grad from the previous rank + # No need to scale the grad again + # Need to unscale when optimizing + self.optim_state = OptimState.SCALED + self.module.backward_by_grad(tensor, grad) + + def _maybe_move_fp32_params(self): + if self._should_move_fp32_params_h2d: + self._should_move_fp32_params_h2d = False + available_cuda_margin_mem = self.gemini_manager.cuda_margin_mem * self.gpu_margin_mem_ratio + fp32_params_available_cuda_margin_mem = available_cuda_margin_mem / self.optim.num_fp32_shards_per_param + fp32_params_used_cuda_margin_mem = 0 + + for group in self.param_groups: + for fake_param in group['params']: + chunk32 = self.param_to_chunk32[fake_param] + chunk16 = chunk32.paired_chunk + + if chunk32.device_type == 'cuda': + continue + + if fp32_params_used_cuda_margin_mem + chunk32.payload_mem < fp32_params_available_cuda_margin_mem: + self.chunk_manager.move_chunk(chunk32, get_current_device()) + # stores grad now + self.chunk_manager.move_chunk(chunk16, get_current_device()) + self.module.set_chunk_grad_device(chunk16, get_current_device()) + fp32_params_used_cuda_margin_mem += chunk32.payload_mem + + for group in self.param_groups: + for fake_param in group['params']: + chunk32 = self.param_to_chunk32[fake_param] + if chunk32.device_type == 'cuda': + state = self.optim.state[fake_param] + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.to(get_current_device()) + + def _register_states_(self): + for group in self.optim.param_groups: + for p in group['params']: + state = self.optim.state[p] + for val in state.values(): + if isinstance(val, torch.Tensor): + self.chunk_manager.add_extern_static_tensor(val) + + def __init__optimizer(self): + + def get_range_pair(local_chunk: Chunk, local_param: Parameter): + param_info = local_chunk.tensors_info[local_param] + if local_chunk.keep_gathered: + return param_info.offset, param_info.end + begin = max(0, param_info.offset - local_chunk.shard_begin) + end = min(local_chunk.shard_size, param_info.end - local_chunk.shard_begin) + return begin, end + + for group in self.optim.param_groups: + fake_params_list = list() + + for param in group['params']: + if is_ddp_ignored(param): + continue + chunk16 = self.chunk_manager.get_chunk(param) + range_pair = get_range_pair(chunk16, param) + if range_pair[0] >= range_pair[1]: + continue + + grad_device = self.module.grads_device[param] + fake_param = torch.nn.Parameter(torch.empty([0], device=grad_device)) + self.param_to_chunk32[fake_param] = chunk16.paired_chunk + self.param_to_range[fake_param] = range_pair + + fake_params_list.append(fake_param) + + group['params'] = fake_params_list + + +class GeminiAdamOptimizer(ZeroOptimizer): + + def __init__(self, model: torch.nn.Module, **defaults: Any) -> None: + optimizer = HybridAdam(model.parameters(), **defaults) + super().__init__(optimizer, model, **defaults) diff --git a/colossalai/zero/gemini/memory_tracer/__init__.py b/colossalai/zero/gemini/memory_tracer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..02c9d5754ec9a34c11531111a8fd6ca5e6698c96 --- /dev/null +++ b/colossalai/zero/gemini/memory_tracer/__init__.py @@ -0,0 +1,11 @@ +from .param_runtime_order import OrderedParamGenerator # isort:skip +from .memory_stats import MemStats # isort:skip +from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor # isort:skip +from .memstats_collector import MemStatsCollector # isort:skip +from .chunk_memstats_collector import ChunkMemStatsCollector # isort:skip +from .static_memstats_collector import StaticMemStatsCollector # isort:skip + +__all__ = [ + 'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector', + 'StaticMemStatsCollector', 'MemStats', 'OrderedParamGenerator' +] diff --git a/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py b/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py new file mode 100644 index 0000000000000000000000000000000000000000..f5eb05b4f22ac71d2bb61550d64992affd67e411 --- /dev/null +++ b/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py @@ -0,0 +1,36 @@ +from typing import Optional + +from colossalai.utils import get_current_device +from colossalai.utils.memory import colo_device_memory_capacity +from colossalai.zero.gemini.chunk import ChunkManager + +from .memory_stats import MemStats +from .memstats_collector import MemStatsCollector + + +class ChunkMemStatsCollector(MemStatsCollector): + + def __init__(self, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None: + """ + + Memory Statistic Collector for Chunks. + + Args: + chunk_manager (ChunkManager): the chunk manager. + memstats (Optional[MemStats], optional): memory statistics collected by RMT. Defaults to None. + """ + super().__init__(memstats) + self._chunk_manager = chunk_manager + + # override + def record_model_data_volume(self) -> None: + """ + record model data volumn on cuda and cpu. + """ + if self._start_flag and not self.use_outside_memstats: + cuda_mem = self._chunk_manager.total_mem['cuda'] + self._memstats.record_max_cuda_model_data(cuda_mem) + + @property + def cuda_margin_mem(self) -> float: + return colo_device_memory_capacity(get_current_device()) - self._memstats.max_overall_cuda diff --git a/colossalai/zero/gemini/memory_tracer/memory_monitor.py b/colossalai/zero/gemini/memory_tracer/memory_monitor.py new file mode 100644 index 0000000000000000000000000000000000000000..f8d99dbce7a43a8089dd1ebddd9bce6979a17f40 --- /dev/null +++ b/colossalai/zero/gemini/memory_tracer/memory_monitor.py @@ -0,0 +1,147 @@ +import json +from abc import abstractmethod +from concurrent.futures import ThreadPoolExecutor +from time import sleep, time + +import torch + +from colossalai.utils import colo_device_memory_used, get_current_device + + +class MemoryMonitor: + """Base class for all types of memory monitor. + All monitors should have a list called `time_stamps` and a list called `mem_stats`. + """ + + def __init__(self): + self.time_stamps = [] + self.mem_stats = [] + + def __len__(self): + return len(self.mem_stats) + + @abstractmethod + def start(self): + pass + + @abstractmethod + def finish(self): + pass + + def state_dict(self): + return { + "time_stamps": self.time_stamps, + "mem_stats": self.mem_stats, + } + + def save(self, filename): + with open(filename, "w") as f: + json.dump(self.state_dict(), f) + + def clear(self): + self.mem_stats.clear() + self.time_stamps.clear() + + +class AsyncMemoryMonitor(MemoryMonitor): + """ + An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU + at interval of `1/(10**power)` sec. + + The idea comes from Runtime Memory Tracer of PatrickStar + `PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management`_ + + Usage:: + + async_mem_monitor = AsyncMemoryMonitor() + input = torch.randn(2, 20).cuda() + OP1 = torch.nn.Linear(20, 30).cuda() + OP2 = torch.nn.Linear(30, 40).cuda() + + async_mem_monitor.start() + output = OP1(input) + async_mem_monitor.finish() + async_mem_monitor.start() + output = OP2(output) + async_mem_monitor.finish() + async_mem_monitor.save('log.pkl') + + Args: + power (int, optional): the power of time interva. Defaults to 10. + + .. _PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management: + https://arxiv.org/abs/2108.05818 + """ + + def __init__(self, power: int = 10): + super().__init__() + self.keep_measuring = False + + current_device = get_current_device() + + def _set_cuda_device(): + torch.cuda.set_device(current_device) + + self.executor = ThreadPoolExecutor(max_workers=1, initializer=_set_cuda_device) + self.monitor_thread = None + self.interval = 1 / (10**power) + + def set_interval(self, power: int): + self.clear() + self.interval = 1 / (10**power) + + def is_measuring(self): + return self.keep_measuring + + def start(self): + self.keep_measuring = True + self.monitor_thread = self.executor.submit(self._measure_usage) + + def finish(self): + if self.keep_measuring is False: + return 0 + + self.keep_measuring = False + max_usage = self.monitor_thread.result() + + self.monitor_thread = None + self.time_stamps.append(time()) + self.mem_stats.append(max_usage) + return max_usage + + def _measure_usage(self): + max_usage = 0 + while self.keep_measuring: + max_usage = max( + max_usage, + colo_device_memory_used(get_current_device()), + ) + sleep(self.interval) + return max_usage + + +class SyncCudaMemoryMonitor(MemoryMonitor): + """ + A synchronized cuda memory monitor. + It only record the maximum allocated cuda memory from start point to finish point. + """ + + def __init__(self, power: int = 10): + super().__init__() + + def start(self): + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + + def finish(self) -> int: + """ + return max gpu memory used since latest `start()`. + + Returns: + int: max GPU memory + """ + torch.cuda.synchronize() + self.time_stamps.append(time()) + max_usage = torch.cuda.max_memory_allocated() + self.mem_stats.append(max_usage) + return max_usage diff --git a/colossalai/zero/gemini/memory_tracer/memory_stats.py b/colossalai/zero/gemini/memory_tracer/memory_stats.py new file mode 100644 index 0000000000000000000000000000000000000000..9a45034ee27e2cc0aeea2b100cc69fc2e2df71e7 --- /dev/null +++ b/colossalai/zero/gemini/memory_tracer/memory_stats.py @@ -0,0 +1,127 @@ +from typing import Any, Dict, List, Optional + +import torch + +from .param_runtime_order import OrderedParamGenerator + + +class MemStats(object): + + def __init__(self) -> None: + """ + Store the non model data statistics used for Gemini and ZeroOptimizer. + """ + # (preop_step, List[param]) + self._step_param_dict = dict() + # (param, List[preop_step]) + self._param_step_dict = dict() + # (preop_step, non_model_data) non model data used during preop_step ~ (preop_step+1) + self._step_nmd_dict = dict() + self._param_runtime_order = OrderedParamGenerator() + + self._preop_step = 0 + + self._prev_overall_cuda = -1 + self._max_overall_cuda = 0 + self._prev_md_cuda = -1 + + # old version + self._model_data_cuda_list = [] + self._model_data_cpu_list = [] + + self._overall_cuda_list = [] + self._overall_cpu_list = [] + + self._non_model_data_cuda_list = [] + self._non_model_data_cpu_list = [] + + def calc_max_cuda_non_model_data(self): + if self._prev_overall_cuda != -1 and self._prev_md_cuda != -1: + max_cuda_non_model_data = self._prev_overall_cuda - self._prev_md_cuda + self._step_nmd_dict[self._preop_step - 1] = max_cuda_non_model_data + # compatibility of the old version. + self._non_model_data_cuda_list.append(max_cuda_non_model_data) + + def record_max_cuda_model_data(self, val): + self._prev_md_cuda = val + + def record_max_cuda_overall_data(self, val): + self._prev_overall_cuda = val + self._max_overall_cuda = max(self._max_overall_cuda, val) + + @property + def max_overall_cuda(self): + return self._max_overall_cuda + + def increase_preop_step(self, param_list: List[torch.nn.Parameter]): + """ + the time step is increased. param list is used between current and the next + time step. + + Args: + param_list (List[torch.nn.Parameter]): a list of torch paramters. + """ + for p in param_list: + if p not in self._param_step_dict: + self._param_step_dict[p] = [self._preop_step] + else: + self._param_step_dict[p].append(self._preop_step) + self._param_runtime_order.append(p) + self._step_param_dict[self._preop_step] = param_list + self._preop_step += 1 + + def param_used_step(self, param: torch.nn.Parameter) -> Optional[List[int]]: + """param_used_step + get the timestep list using the param + + Args: + param (torch.nn.Parameter): a torch param + + Returns: + Optional[List[int]]: a list of int indicates the time step of preop hook. + """ + if param not in self._param_step_dict: + return None + else: + return self._param_step_dict[param] + + def param_order(self): + if self._param_runtime_order.is_empty(): + raise RuntimeError + else: + return self._param_runtime_order + + def non_model_data_list(self, device_type: str) -> List[int]: + if device_type == 'cuda': + return self._non_model_data_cuda_list + elif device_type == 'cpu': + return self._non_model_data_cpu_list + else: + raise TypeError + + def max_non_model_data(self, device_type: str) -> float: + if device_type == 'cuda': + return max(self._non_model_data_cuda_list) + elif device_type == 'cpu': + return max(self._non_model_data_cpu_list) + else: + raise TypeError + + def clear(self): + self._model_data_cuda_list = [] + self._overall_cuda_list = [] + + self._model_data_cpu_list = [] + self._overall_cpu_list = [] + + self._non_model_data_cpu_list = [] + self._non_model_data_cuda_list = [] + + self._param_runtime_order.clear() + self._step_param_dict.clear() + self._param_step_dict.clear() + self._step_nmd_dict.clear() + self._preop_step = 0 + + self._prev_overall_cuda = -1 + self._prev_md_cuda = -1 diff --git a/colossalai/zero/gemini/memory_tracer/memstats_collector.py b/colossalai/zero/gemini/memory_tracer/memstats_collector.py new file mode 100644 index 0000000000000000000000000000000000000000..0694be48550aac735b3456e30edf4a0ddcd24e26 --- /dev/null +++ b/colossalai/zero/gemini/memory_tracer/memstats_collector.py @@ -0,0 +1,97 @@ +import time +from typing import Optional + +from .memory_monitor import SyncCudaMemoryMonitor +from .memory_stats import MemStats + + +class MemStatsCollector: + """ + A Memory statistic collector. + It works in two phases. + Phase 1. Collection Phase: collect memory usage statistics of CPU and GPU. + The first iteration of DNN training. + Phase 2. Runtime Phase: use the read-only collected stats + The rest iterations of DNN training. + + It has a Sampling counter which is reset after DNN training iteration. + """ + + def __init__(self, memstats: Optional[MemStats] = None) -> None: + self._mem_monitor = SyncCudaMemoryMonitor() + self._sampling_time = [] + + self._start_flag = False + self._step_idx = 0 + self._step_total = 0 + if memstats is not None: + self.use_outside_memstats = True + self._memstats = memstats + else: + self.use_outside_memstats = False + self._memstats = MemStats() + + def next_period_non_model_data_usage(self, device_type: str) -> int: + """Maximum non model data memory usage during the next Op run + + Args: + device_type (str): device type, can be 'cpu' or 'cuda'. + + Returns: + int: max non model data memory usage of current sampling period + """ + assert not self._start_flag, 'Cannot get mem stats info during collection phase.' + assert self._step_total > 0, 'Cannot get mem stats info before collection phase.' + assert len(self._memstats.non_model_data_list(device_type)) > self._step_idx, \ + f"{len(self._memstats.non_model_data_list(device_type))} should be > than step idx {self._step_idx}, "\ + f"step total {self._step_total}" + next_non_model_data = self._memstats.non_model_data_list(device_type)[self._step_idx] + self._step_idx = (self._step_idx + 1) % self._step_total + return next_non_model_data + + @property + def sampling_time(self): + return [t - self._sampling_time[0] for t in self._sampling_time] + + def start_collection(self): + self._start_flag = True + self._mem_monitor.start() + + def finish_collection(self): + self.sample_overall_data() + # self._step_total = len(self._sampling_time) + self._step_total = len(self._memstats.non_model_data_list('cuda')) + self._start_flag = False + print(f'finish_collection {self._step_total}') + + # deprecated + def record_model_data_volume(self) -> None: + """ + Sampling model data statistics. + """ + if self._start_flag and not self.use_outside_memstats: + from colossalai.zero.legacy.gemini import StatefulTensor + + # The following code work for ZeroInitContext, which is deprecated in v0.1.12 + cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda'] + self._memstats.record_max_cuda_model_data(cuda_mem) + + def sample_overall_data(self) -> None: + """ + Sampling overall and non model data cuda memory statistics. + """ + if self._start_flag and not self.use_outside_memstats: + cuda_overall = self._mem_monitor.finish() + self._memstats.record_max_cuda_overall_data(cuda_overall) + self._memstats.calc_max_cuda_non_model_data() + + self._mem_monitor.start() + + if self._start_flag: + self._sampling_time.append(time.time()) + + def clear(self) -> None: + self._memstats.clear() + self._start_flag = False + self._step_idx = 0 + self._step_total = 0 diff --git a/colossalai/zero/gemini/memory_tracer/param_runtime_order.py b/colossalai/zero/gemini/memory_tracer/param_runtime_order.py new file mode 100644 index 0000000000000000000000000000000000000000..638c0533ce926b6629906d8b113161345017295d --- /dev/null +++ b/colossalai/zero/gemini/memory_tracer/param_runtime_order.py @@ -0,0 +1,42 @@ +from abc import ABC + +import torch + + +class ParamGenerator(ABC): + + def append(self, param: torch.nn.Parameter): + pass + + def generate(self): + pass + + def clear(self): + pass + + +class OrderedParamGenerator(ParamGenerator): + """OrderedParamGenerator + + Contain the order of parameters visited during runtime. + """ + + def __init__(self) -> None: + self.param_visited_order = [] + + def append(self, param: torch.nn.Parameter): + self.param_visited_order.append(param) + + def generate(self): + visited_set = set() + for p in self.param_visited_order: + if p not in visited_set: + yield p + visited_set.add(p) + del visited_set + + def is_empty(self): + return len(self.param_visited_order) == 0 + + def clear(self): + self.param_visited_order = [] diff --git a/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py b/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py new file mode 100644 index 0000000000000000000000000000000000000000..0c9eac8b63e3662c10e358c6653fdd524b7d5e6e --- /dev/null +++ b/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py @@ -0,0 +1,104 @@ +import torch.nn + +from colossalai.nn.parallel.data_parallel import _cast_float +from colossalai.tensor.param_op_hook import ColoParamOpHookManager +from colossalai.zero.legacy.gemini.ophooks.runtime_mem_tracer_hook import ( + GradMemStats, + GradMemTracerHook, + ParamMemTracerHook, +) + +from .memory_stats import MemStats + +__all__ = ['RuntimeMemTracer'] + + +class RuntimeMemTracer(): + """RuntimeMemTracer for the module training using ColoParameter. + + Trace non-model memory usage during fwd+bwd process. + It is obtained by using a tensor with the same shape as the training process as the inputs + and running an single fwd+bwd to trace the statistics. + + NOTE() + 1. The premise to use this tracer is that the target DNN execute the same operations at each iterations, + 2. Module buffers are viewed as non-model data. + """ + + def __init__(self, module: torch.nn.Module, dtype: torch.dtype = torch.half): + super().__init__() + self.module = module + self.dtype = dtype + self._gradstat = GradMemStats() + self._memstats = MemStats() + self.param_op_hook = ParamMemTracerHook(self._memstats, self._gradstat) + self.grad_hook = GradMemTracerHook(self._gradstat) + self.cpu_param_data_dict = {} + + for p in module.parameters(): + p.data = p.data.to(dtype) + + self._cast_buffers_to_cuda_dtype() + + def parameters_in_runtime_order(self): + return self._memstats._param_runtime_order.generate() + + def memstats(self): + return self._memstats + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def _backup_params(self): + """ + The function is called before forward. Backup model params on cpu. + """ + for p in self.module.parameters(): + self.cpu_param_data_dict[p] = torch.empty(p.data.shape, dtype=self.dtype, device="cpu") + self.cpu_param_data_dict[p].copy_(p.data) + + def _restore_params(self): + """ + This function is called after backward. Restore model params. + """ + for p in self.module.parameters(): + p.data = torch.empty(p.data.shape, dtype=self.dtype, device="cpu", requires_grad=p.data.requires_grad) + p.data.copy_(self.cpu_param_data_dict[p]) + self.cpu_param_data_dict.clear() + + def _pre_forward(self): + self._clear_cuda_mem_info() + self._backup_params() + self.grad_hook.register_grad_hook(self.module) + self.param_op_hook.mem_monitor.start() + + def forward(self, *args, **kwargs): + args, kwargs = _cast_float(args, self.dtype), _cast_float(kwargs, self.dtype) + self.module.zero_grad(set_to_none=True) + self._pre_forward() + with ColoParamOpHookManager.use_hooks(self.param_op_hook): + outputs = self.module(*args, **kwargs) + return outputs + + def backward(self, loss): + with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook): + loss.backward() + self._post_backward() + + def _post_backward(self): + cuda_volume = self.param_op_hook.mem_monitor.finish() + self._memstats.record_max_cuda_overall_data(cuda_volume) + # calc the last Op non model data + self._memstats.calc_max_cuda_non_model_data() + self.grad_hook.remove_grad_hook() + self._restore_params() + + def _clear_cuda_mem_info(self): + self._memstats.clear() + self._gradstat.clear() + + def _cast_buffers_to_cuda_dtype(self): + for buffer in self.module.buffers(): + buffer.data = buffer.cuda() + if torch.is_floating_point(buffer): + buffer.data = buffer.data.to(self.dtype) diff --git a/colossalai/zero/gemini/memory_tracer/static_memstats_collector.py b/colossalai/zero/gemini/memory_tracer/static_memstats_collector.py new file mode 100644 index 0000000000000000000000000000000000000000..b8f9a095f4224a500cbf630f227b9c257d2db3b2 --- /dev/null +++ b/colossalai/zero/gemini/memory_tracer/static_memstats_collector.py @@ -0,0 +1,105 @@ +from typing import Optional + +import torch +import torch.nn as nn +from torch.fx import symbolic_trace + +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta +from colossalai.zero.gemini.chunk import ChunkManager + +if is_compatible_with_meta(): + from colossalai.fx.profiler import MetaTensor + +from .chunk_memstats_collector import ChunkMemStatsCollector + + +class ModuleInfos: + + def __init__(self, module: torch.nn.Module, module_name: str, module_full_name: str, + parent_module: torch.nn.Module): + self.module = module + self.module_name = module_name + self.module_full_name = module_full_name + self.parent_module = parent_module + + +class StaticMemStatsCollector(ChunkMemStatsCollector): + """ + A Static Memory statistic collector. + """ + + def __init__(self, module: nn.Module, chunk_manager: ChunkManager) -> None: + super().__init__(chunk_manager) + self.module = module + self.module_info_list = [] + + def init_mem_stats(self, *inputs): + + self.register_opnodes_recursively(self.module) + self.refactor_module() + + self.module = self.module.cpu() + self.module.train() + + data = [MetaTensor(torch.rand(inp.shape, device='meta'), fake_device='cpu') for inp in inputs] + gm = symbolic_trace(self.module) + interp = MetaInfoProp(gm) + interp.propagate(*data) + + total_mem = 0 + for inp in inputs: + total_mem += inp.numel() * inp.element_size() + last_node = None + module_name_list = [mInfo.module_full_name for mInfo in self.module_info_list] + for node in gm.graph.nodes: + total_mem = total_mem + calculate_fwd_tmp(node) + calculate_fwd_out(node) + if node.op == "call_module": + if node.name.endswith("_0") and node.name[:-2] in module_name_list: + self._non_model_data_cuda_list.append(total_mem) + last_node = node + self._non_model_data_cuda_list.append(total_mem) + self._non_model_data_cuda_list = self._non_model_data_cuda_list[1:] + + cur_module_mem_fwd = 0 + cur_module_mem_bwd = 0 + grad_module_out = last_node.meta["fwd_mem_out"] + for node in gm.graph.nodes.__reversed__(): + cur_module_mem_fwd = cur_module_mem_fwd + calculate_fwd_tmp(node) + calculate_fwd_out(node) + cur_module_mem_bwd = cur_module_mem_bwd + node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"] + if node.op == "call_module": + if node.name.endswith("_0") and node.name[:-2] in module_name_list: + self._non_model_data_cuda_list.append(total_mem + grad_module_out + cur_module_mem_bwd) + total_mem = total_mem - cur_module_mem_fwd + cur_module_mem_fwd = 0 + cur_module_mem_bwd = 0 + grad_module_out = node.meta["bwd_mem_out"] + + self._step_total = len(self._non_model_data_cuda_list) + self.recover_module() + + def refactor_module(self): + for modInfo in self.module_info_list: + temp_node = nn.Sequential(nn.ReLU(), modInfo.module) + modInfo.parent_module.__setattr__(modInfo.module_name, temp_node) + + def recover_module(self): + for modInfo in self.module_info_list: + modInfo.parent_module.__setattr__(modInfo.module_name, modInfo.module) + + def register_opnodes_recursively(self, + module: torch.nn.Module, + name: str = "", + full_name: str = "", + parent_module: Optional[torch.nn.Module] = None): + + assert isinstance(module, torch.nn.Module) + + for child_name, child in module.named_children(): + self.register_opnodes_recursively(child, child_name, full_name + "_" + child_name, module) + + # Early return on modules with no parameters. + if len(list(module.parameters(recurse=False))) == 0: + return + + self.module_info_list.append(ModuleInfos(module, name, full_name[1:], parent_module)) diff --git a/colossalai/zero/gemini/memory_tracer/utils.py b/colossalai/zero/gemini/memory_tracer/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6962c058110e245a6bbd2470d75c137b54202aae --- /dev/null +++ b/colossalai/zero/gemini/memory_tracer/utils.py @@ -0,0 +1,59 @@ +from typing import Optional, Tuple + +import torch + + +def colo_model_optimizer_usage(optim) -> Tuple[int, int]: + """Trace the optimizer memory usage + + Args: + optim (ShardedOptimV2): an instance of ShardedOptimver + + Returns: + Tuple[int, int]: cuda/cpu memory usage in Byte + """ + if optim is None: + return 0, 0 + assert hasattr(optim, 'get_memory_usage'), f"{type(optim)} has no attr get_memory_usage()" + return optim.get_memory_usage() + + +def colo_model_mem_usage(model: torch.nn.Module) -> Tuple[int, int]: + """ + Trace the model memory usage. + Args: + model (torch.nn.Module): a torch model + + Returns: + Tuple[int, int]: cuda memory usage in Byte, cpu memory usage in Byte + """ + if model is None: + return 0, 0 + + def _get_tensor_mem_use(t: Optional[torch.Tensor]): + if t is None: + return 0, 0 + assert isinstance(t, torch.Tensor) + _cpu_mem_usage, _cuda_mem_usage = 0, 0 + if t.device.type == 'cpu': + _cpu_mem_usage += t.numel() * t.element_size() + elif t.device.type == 'cuda': + _cuda_mem_usage += t.numel() * t.element_size() + return _cuda_mem_usage, _cpu_mem_usage + + cuda_mem_usage = 0 + cpu_mem_usage = 0 + for param in model.parameters(): + if hasattr(param, 'colo_attr'): + t_cuda, t_cpu = param.colo_attr.get_memory_usage() + cuda_mem_usage += t_cuda + cpu_mem_usage += t_cpu + else: + t_cuda, t_cpu = _get_tensor_mem_use(param.data) + cuda_mem_usage += t_cuda + cpu_mem_usage += t_cpu + t_cuda, t_cpu = _get_tensor_mem_use(param.grad) + cuda_mem_usage += t_cuda + cpu_mem_usage += t_cpu + + return cuda_mem_usage, cpu_mem_usage diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..84a868872f887a571858a44901dbd82d022dd18d --- /dev/null +++ b/colossalai/zero/gemini/placement_policy.py @@ -0,0 +1,246 @@ +import functools +from abc import ABC, abstractmethod +from time import time +from typing import Dict, List, Optional, Tuple, Type + +import torch + +from colossalai.utils import get_current_device +from colossalai.utils.memory import colo_device_memory_capacity + +from .chunk import Chunk, ChunkManager +from .memory_tracer import ChunkMemStatsCollector + + +class PlacementPolicy(ABC): + need_mem_stats: bool = False + + def __init__(self, + chunk_manager: ChunkManager, + mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None: + self.chunk_manager = chunk_manager + self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector + + @abstractmethod + def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]: + raise NotImplementedError + + @staticmethod + def get_default_device() -> torch.device: + return torch.device('cpu') + + +class CPUPlacementPolicy(PlacementPolicy): + + def __init__(self, + chunk_manager: ChunkManager, + mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None: + super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) + + def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]: + volume = 0 + start = time() + for chunk in can_evict_chunks: + self.chunk_manager.release_chunk(chunk) + self.chunk_manager.move_chunk(chunk, torch.device('cpu')) + volume += chunk.chunk_mem + return volume, time() - start + + +class CUDAPlacementPolicy(PlacementPolicy): + + def __init__(self, + chunk_manager: ChunkManager, + mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None: + assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available' + super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) + + def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]: + return 0, 0 + + @staticmethod + def get_default_device() -> torch.device: + return get_current_device() + + +class AutoPlacementPolicy(PlacementPolicy): + + need_mem_stats: bool = True + # model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase + # you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio() + # and AutoPlacementPolicy.set_steady_cuda_cap_ratio() + _warmup_non_model_data_ratio: float = 0.8 + _steady_cuda_cap_ratio: float = 0.9 + + def __init__(self, + chunk_manager: ChunkManager, + mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None: + super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) + + def evict_tensors(self, + can_evict_chunks: List[Chunk], + cuda_demand: int = 0, + warmup: bool = True, + compute_list: Optional[List[Tuple[Chunk, ...]]] = None, + compute_idx: int = 0, + **kwargs) -> Tuple[int, float]: + """ + Evict tensors from CUDA device. + + Args: + can_evict_chunks (List[StatefulTensor]): the list of tensors that can be evicted. + cuda_demand (int, optional): the volume of data needed on cuda device. Defaults to 0. + warmup (bool, optional): a flag indicates whether in the phase of warmup. Defaults to True. + compute_list (List[StatefulTensor], optional): TODO. Defaults to []. + compute_idx (int, optional): the idx of computing device. Defaults to 0. + + Raises: + RuntimeError: + + Returns: + int: the volume of memory that is evicted + """ + start = time() + cuda_capacity = colo_device_memory_capacity(get_current_device()) + used_cuda_model_data = self.chunk_manager.total_mem['cuda'] + if warmup: + # We designate a part of CUDA memory for model data in warmup iterations. + max_cuda_non_model_data_per_period = cuda_capacity * AutoPlacementPolicy._warmup_non_model_data_ratio + else: + # max non-model-data cuda memory consumption of this sampling moment and the next sampling moment. + max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage('cuda') + cuda_capacity *= AutoPlacementPolicy._steady_cuda_cap_ratio + total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period + avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data + freed_cuda_model_data = 0 + + if avail_cuda_model_data < cuda_demand: + # Move cuda_demand - avail_cuda_model_data volume of tensors + # to_free_cuda_model_data = cuda_demand - avail_cuda_model_data + to_free_cuda_model_data = cuda_demand - avail_cuda_model_data + to_free_chunks = can_evict_chunks + if not warmup: + to_free_chunks = self._sort_can_evict_chunks(tuple(to_free_chunks), compute_idx, tuple(compute_list)) + # print(self._sort_can_evict_chunks.cache_info()) + for chunk in to_free_chunks: + if freed_cuda_model_data >= to_free_cuda_model_data: + break + + self.chunk_manager.release_chunk(chunk) + self.chunk_manager.move_chunk(chunk, torch.device('cpu')) + freed_cuda_model_data += chunk.chunk_mem + if freed_cuda_model_data < to_free_cuda_model_data: + raise RuntimeError(f"Adjust layout failed! No enough CUDA memory! " + f"Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}") + return freed_cuda_model_data, time() - start + + @staticmethod + @functools.lru_cache(maxsize=None) + def _sort_can_evict_chunks(can_evict_chunks: tuple, compute_idx: int, compute_list: tuple) -> list: + next_compute_idx = {chunk: len(compute_list) for chunk in can_evict_chunks} + for i in range(len(compute_list) - 1, compute_idx, -1): + for chunk in compute_list[i]: + if chunk in next_compute_idx: + next_compute_idx[chunk] = i + next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True) + return [t for (t, idx) in next_compute_idx] + + @staticmethod + def set_warmup_non_model_data_ratio(ratio: float) -> None: + ratio = float(ratio) + assert 0.0 < ratio < 1.0 + AutoPlacementPolicy._warmup_non_model_data_ratio = ratio + + @staticmethod + def set_steady_cuda_cap_ratio(ratio: float) -> None: + ratio = float(ratio) + assert 0.0 < ratio < 1.0 + AutoPlacementPolicy._steady_cuda_cap_ratio = ratio + + +class ConstPlacementPolicy(PlacementPolicy): + + need_mem_stats: bool = False + _accessed_memory_boundary = 512 * 1024**2 + + def __init__(self, + chunk_manager: ChunkManager, + mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None: + super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) + + def evict_tensors(self, + can_evict_chunks: List[Chunk], + cuda_demand: int = 0, + warmup: bool = True, + compute_list: Optional[List[Tuple[Chunk, ...]]] = None, + compute_idx: int = 0, + **kwargs) -> Tuple[int, float]: + """ + See the docstrings in the class `AutoPlacementPolicy`. + """ + start = time() + used_accessed_memory = self.chunk_manager.accessed_mem + avail_accessed_memory = ConstPlacementPolicy._accessed_memory_boundary - used_accessed_memory + freed_accessed_memory = 0 + + if avail_accessed_memory < cuda_demand: + to_free_memory = cuda_demand - avail_accessed_memory + to_free_chunks = can_evict_chunks + + if not warmup: + # sort all chunks + to_free_chunks = self._sort_can_evict_chunks(tuple(to_free_chunks), compute_idx, tuple(compute_list)) + + for chunk in to_free_chunks: + if freed_accessed_memory >= to_free_memory: + break + + self.chunk_manager.release_chunk(chunk) + self.chunk_manager.move_chunk(chunk, torch.device('cpu')) + freed_accessed_memory += chunk.chunk_mem + + if freed_accessed_memory < to_free_memory: + raise RuntimeError(f"Adjust layout failed! No enough CUDA memory! " + f"Need {to_free_memory}, freed {freed_accessed_memory}") + return freed_accessed_memory, time() - start + + @staticmethod + @functools.lru_cache(maxsize=None) + def _sort_can_evict_chunks(can_evict_chunks: tuple, compute_idx: int, compute_list: tuple) -> list: + next_compute_idx = {chunk: len(compute_list) for chunk in can_evict_chunks} + for i in range(len(compute_list) - 1, compute_idx, -1): + for chunk in compute_list[i]: + if chunk in next_compute_idx: + next_compute_idx[chunk] = i + next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True) + return [t for (t, idx) in next_compute_idx] + + @staticmethod + def set_const_memory_boundary(cuda_memory_mb: int) -> None: + boundary = int(cuda_memory_mb * 1024**2) + assert boundary > 0 + ConstPlacementPolicy._accessed_memory_boundary = boundary + + +class PlacementPolicyFactory: + policies: Dict[str, Type[PlacementPolicy]] = { + 'cpu': CPUPlacementPolicy, + 'cuda': CUDAPlacementPolicy, + 'auto': AutoPlacementPolicy, + 'const': ConstPlacementPolicy + } + + @staticmethod + def create(policy_name: str) -> Type[PlacementPolicy]: + if policy_name not in PlacementPolicyFactory.policies: + raise TypeError(f"Unknown tensor placement policy {policy_name}") + return PlacementPolicyFactory.policies[policy_name] + + @staticmethod + def get_policy_names(): + return tuple(PlacementPolicyFactory.policies.keys()) + + @staticmethod + def get_default_device(policy_name: str) -> torch.device: + policy_cls = PlacementPolicyFactory.create(policy_name) + return policy_cls.get_default_device() diff --git a/colossalai/zero/gemini/utils.py b/colossalai/zero/gemini/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e52b5b836b0bec568e8a74d369969c0dcec0763a --- /dev/null +++ b/colossalai/zero/gemini/utils.py @@ -0,0 +1,104 @@ +from collections import OrderedDict +from copy import copy +from typing import Optional, Set + +import torch +import torch.distributed as dist +import torch.nn as nn + +from colossalai.utils import get_current_device + +from .chunk import Chunk + + +def get_temp_total_chunk_on_cuda(chunk: Chunk): + if chunk.is_gathered: + return chunk.cuda_global_chunk + + if chunk.cuda_shard is not None: + shard_temp = chunk.cuda_shard + else: + shard_temp = chunk.cpu_shard.to(get_current_device()) + + total_temp = torch.zeros(chunk.chunk_size, dtype=chunk.dtype, device=get_current_device()) + gather_list = list(torch.chunk(input=total_temp, chunks=chunk.pg_size, dim=0)) + dist.all_gather(tensor_list=gather_list, tensor=shard_temp, group=chunk.torch_pg) + + return total_temp + + +def _get_dfs_module_list(module: nn.Module, memo: Optional[Set[nn.Module]] = None, prefix: str = ''): + """Get a dfs module list of the given module. Its order is same as the order of creations of modules. + """ + if memo is None: + memo = set() + if module not in memo: + for name, submodule in module._modules.items(): + if submodule is None: + continue + submodule_prefix = prefix + ('.' if prefix else '') + name + for m in _get_dfs_module_list(submodule, memo, submodule_prefix): + yield m + + memo.add(module) + yield prefix, module + + +def _get_shallow_copy_model(model: nn.Module): + """Get a shallow copy of the given model. Each submodule is different from the original submodule. + But the new submodule and the old submodule share all attributes. + """ + old_to_new = dict() + for name, module in _get_dfs_module_list(model): + new_module = copy(module) + new_module._modules = OrderedDict() + for subname, submodule in module._modules.items(): + if submodule is None: + continue + setattr(new_module, subname, old_to_new[submodule]) + old_to_new[module] = new_module + return old_to_new[model] + + +def get_static_torch_model(zero_ddp_model, + device=torch.device("cpu"), + dtype=torch.float32, + only_rank_0=True) -> torch.nn.Module: + """Get a static torch.nn.Module model from the given ZeroDDP module. + You should notice that the original ZeroDDP model is not modified. + Thus, you can use the original model in further training. + But you should not use the returned torch model to train, this can cause unexpected errors. + + Args: + zero_ddp_model (ZeroDDP): a zero ddp model + device (torch.device): the device of the final torch model + dtype (torch.dtype): the dtype of the final torch model + only_rank_0 (bool): if True, only rank0 has the coverted torch model + + Returns: + torch.nn.Module: a static torch model used for saving checkpoints or numeric checks + """ + from colossalai.zero.gemini.gemini_ddp import ZeroDDP + assert isinstance(zero_ddp_model, ZeroDDP) + + state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0) + colo_model = zero_ddp_model.module + torch_model = _get_shallow_copy_model(colo_model) + + if not only_rank_0 or dist.get_rank() == 0: + for (name, colo_module), (_, torch_module) in \ + zip(_get_dfs_module_list(colo_model), _get_dfs_module_list(torch_model)): + # clean the parameter list of the new torch module + torch_module._parameters = OrderedDict() + for sufix_param_name, param in colo_module.named_parameters(recurse=False): + # get the full name of the parameter + full_param_name = name + ('.' if name else '') + sufix_param_name + assert full_param_name in state_dict, \ + f"Can not find parameter `{full_param_name}` in the GeminiDDP module" + state_param = state_dict[full_param_name] + torch_param = torch.nn.Parameter(state_param.data.to(device=device, dtype=dtype)) + + setattr(torch_module, sufix_param_name, torch_param) + dist.barrier() + + return torch_model diff --git a/colossalai/zero/legacy/__init__.py b/colossalai/zero/legacy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3783d38e61b27cb97e3709a0db882acfa5667821 --- /dev/null +++ b/colossalai/zero/legacy/__init__.py @@ -0,0 +1,45 @@ +from typing import Tuple + +import torch +import torch.nn as nn + +from colossalai.logging import get_dist_logger + +from .init_ctx import ZeroInitContext, no_shard_zero_context, no_shard_zero_decrator +from .shard_utils import BucketTensorShardStrategy, TensorShardStrategy +from .sharded_model import ShardedModelV2 +from .sharded_optim import ShardedOptimizerV2 + + +def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model_config, + optimizer_config) -> Tuple[ShardedModelV2, ShardedOptimizerV2]: + """ + A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading + + :param model: Your model object + :type model: :class:`torch.nn.Module` + :param optimizer_config: Your optimizer object + :type optimizer_config: :class:`dict` + + :return: (model, optimizer) + :rtype: Tuple + """ + + logger = get_dist_logger('convert_to_zero_v2') + + logger.info(f'optimizer_config is {optimizer_config}', ranks=[0]) + if optimizer_config is None: + optimizer_config = dict() + logger.info(f'model_config is {model_config}', ranks=[0]) + if model_config is None: + model_config = dict() + + zero_model = ShardedModelV2(model, **model_config) + zero_optimizer = ShardedOptimizerV2(zero_model, optimizer, **optimizer_config) + return zero_model, zero_optimizer + + +__all__ = [ + 'convert_to_zero_v2', 'ShardedModelV2', 'ShardedOptimizerV2', 'ZeroInitContext', 'no_shard_zero_context', + 'no_shard_zero_decrator', 'TensorShardStrategy', 'BucketTensorShardStrategy' +] diff --git a/colossalai/zero/legacy/gemini/__init__.py b/colossalai/zero/legacy/gemini/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..754ae9bc004431a793a01e1ff074fd8a1a972236 --- /dev/null +++ b/colossalai/zero/legacy/gemini/__init__.py @@ -0,0 +1,9 @@ +from .ophooks import BaseOpHook, register_ophooks_recursively +from .stateful_tensor import StatefulTensor +from .stateful_tensor_mgr import StatefulTensorMgr +from .tensor_placement_policy import AutoTensorPlacementPolicy, CPUTensorPlacementPolicy, CUDATensorPlacementPolicy + +__all__ = [ + 'StatefulTensorMgr', 'StatefulTensor', 'CPUTensorPlacementPolicy', 'CUDATensorPlacementPolicy', + 'AutoTensorPlacementPolicy', 'register_ophooks_recursively', 'BaseOpHook' +] diff --git a/colossalai/zero/legacy/gemini/gemini_context.py b/colossalai/zero/legacy/gemini/gemini_context.py new file mode 100644 index 0000000000000000000000000000000000000000..9a7da6b80fbaddc43074d3599bdd0fd18548f94b --- /dev/null +++ b/colossalai/zero/legacy/gemini/gemini_context.py @@ -0,0 +1,48 @@ +from enum import EnumMeta + + +class GeminiMemoryManager(object): + + def __init__(self, states_cls: EnumMeta): + super().__init__() + self.states_cls = states_cls + self._cnter = 0 # the counter of instances + + self.total_mem = dict() + self.state_mem = dict() + self.state_mem['cpu'] = dict() + self.state_mem['cuda'] = dict() + + self.reset() + + @property + def total_number(self): + return self._cnter + + def reset(self): + self._cnter = 0 # the counter of instances + + self.total_mem['cpu'] = 0 # memory occupation of instances in cpu + self.total_mem['cuda'] = 0 # memory of occupation of instances in cuda + + # memory conditions for all states + for state in self.states_cls: + self.state_mem['cpu'][state] = 0 + self.state_mem['cuda'][state] = 0 + + def register_new_instance(self): + self._cnter += 1 + + def delete_instance(self): + self._cnter -= 1 + + def print_info(self): + print(f"Total number: {self.total_number}", + f"Total CPU memory occupation: {self.total_mem['cpu']}", + f"Total CUDA memory occupation: {self.total_mem['cuda']}\n", + sep='\n') + + for state in self.states_cls: + print(f"{state}: CPU memory occupation: {self.state_mem['cpu'][state]}", + f"{state}: CUDA memory occupation: {self.state_mem['cuda'][state]}\n", + sep='\n') diff --git a/colossalai/zero/legacy/gemini/ophooks/__init__.py b/colossalai/zero/legacy/gemini/ophooks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b65726166644f05b1972b88a7c0358fb7eeb737f --- /dev/null +++ b/colossalai/zero/legacy/gemini/ophooks/__init__.py @@ -0,0 +1,3 @@ +from .utils import BaseOpHook, register_ophooks_recursively + +__all__ = ["BaseOpHook", "register_ophooks_recursively"] diff --git a/colossalai/zero/legacy/gemini/ophooks/_shard_grad_ophook.py b/colossalai/zero/legacy/gemini/ophooks/_shard_grad_ophook.py new file mode 100644 index 0000000000000000000000000000000000000000..8f8fec64924ea325f88cf282cdf98dba8cf731f1 --- /dev/null +++ b/colossalai/zero/legacy/gemini/ophooks/_shard_grad_ophook.py @@ -0,0 +1,32 @@ +import torch + +from colossalai.registry import OPHOOKS + +from . import BaseOpHook + + +@OPHOOKS.register_module +class ShardGradMemTracerHook(BaseOpHook): + """ + A hook to process sharded param before and after FWD and BWD operator executing. + """ + + def __init__(self): + super().__init__() + + def pre_fwd_exec(self, module: torch.nn.Module, *args): + pass + + def post_fwd_exec(self, module: torch.nn.Module, *args): + pass + + def pre_bwd_exec(self, module: torch.nn.Module, input, output): + for param in module.parameters(): + assert hasattr(param, '_sharded_grad') + param._sharded_grad.setup() + + def post_bwd_exec(self, module: torch.nn.Module, input): + pass + + def post_iter(self): + pass diff --git a/colossalai/zero/legacy/gemini/ophooks/_shard_param_ophook.py b/colossalai/zero/legacy/gemini/ophooks/_shard_param_ophook.py new file mode 100644 index 0000000000000000000000000000000000000000..a2a62fb9788a9de973ae28e535b81eb542248c3a --- /dev/null +++ b/colossalai/zero/legacy/gemini/ophooks/_shard_param_ophook.py @@ -0,0 +1,48 @@ +import torch + +from colossalai.registry import OPHOOKS + +from . import BaseOpHook + + +@OPHOOKS.register_module +class ShardParamHook(BaseOpHook): + """ + A hook to process sharded param before and after FWD and BWD operator executing. + """ + + def __init__(self): + super().__init__() + + def niter(self): + return self._niter + + def pre_fwd_exec(self, module: torch.nn.Module, *args): + for param in module.parameters(): + assert hasattr(param, 'ca_attr') + param.ca_attr.gather() + param.data = param.ca_attr.payload() + + def post_fwd_exec(self, module: torch.nn.Module, *args): + for param in module.parameters(): + assert hasattr(param, 'ca_attr') + param.ca_attr.shard() + param.data = param.ca_attr.payload() + + def pre_bwd_exec(self, module: torch.nn.Module, input, output): + for param in module.parameters(): + assert hasattr(param, 'ca_attr') + param.ca_attr.gather() + param.data = param.ca_attr.payload() + + def post_bwd_exec(self, module: torch.nn.Module, input): + for param in module.parameters(): + assert hasattr(param, 'ca_attr') + param.ca_attr.shard() + param.data = param.ca_attr.payload() + + def pre_iter(self): + pass + + def post_iter(self): + pass diff --git a/colossalai/zero/legacy/gemini/ophooks/runtime_mem_tracer_hook.py b/colossalai/zero/legacy/gemini/ophooks/runtime_mem_tracer_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..f40d6ced1ee09a4aa2edca2dd5302111647937ce --- /dev/null +++ b/colossalai/zero/legacy/gemini/ophooks/runtime_mem_tracer_hook.py @@ -0,0 +1,145 @@ +from contextlib import contextmanager +from enum import Enum +from functools import partial +from typing import List + +import torch + +from colossalai.tensor.param_op_hook import ColoParamOpHook +from colossalai.zero.gemini.memory_tracer import MemStats, SyncCudaMemoryMonitor +from colossalai.zero.legacy.gemini.tensor_utils import alloc_storage, free_storage + + +class TrainingPhase(Enum): + FORWARD = 0 + BACKWARD = 1 + + +class GradMemStats(): + + def __init__(self) -> None: + self.unreleased_grad_flag = {} + self.unreleased_grad_volume = 0 + + def clear(self): + self.unreleased_grad_flag.clear() + self.unreleased_grad_volume = 0 + + +class GradMemTracerHook(): + + def __init__(self, grad_stats: GradMemStats): + self.grad_hook_list = [] + self._grad_stats = grad_stats + + def grad_handle(self, p, grad): + assert self._grad_stats.unreleased_grad_flag[p] + free_storage(grad) + self._grad_stats.unreleased_grad_volume -= grad.numel() * grad.element_size() + self._grad_stats.unreleased_grad_flag[p] = False + + def register_grad_hook(self, module: torch.nn.Module): + for p in module.parameters(): + if p.requires_grad: + self.grad_hook_list.append(p.register_hook(partial(self.grad_handle, p))) + self._grad_stats.unreleased_grad_flag[p] = False + + def remove_grad_hook(self): + for hook in self.grad_hook_list: + hook.remove() + + +class ParamMemTracerHook(ColoParamOpHook): + + def __init__(self, memstats: MemStats, gradstats: GradMemStats) -> None: + super().__init__() + self._training_phase = TrainingPhase.FORWARD + self._memstats = memstats + self._grad_stats = gradstats + self.mem_monitor = SyncCudaMemoryMonitor() + + def _free_cuda_params(self, params): + for p in params: + if p.data.device.type == "cpu": + raise NotImplementedError("Only free cuda memory") + free_storage(p.data) + + def _allocate_params_on_cuda(self, params: List[torch.nn.Parameter]): + """ + move params to cuda + + Args: + params (List[torch.nn.Parameter]): target params + + Raises: + NotImplementedError: raise error when param has cpu grad + """ + for p in params: + cur_dev = p.data.device.type + if cur_dev == "cpu": + if p.grad is not None and p.grad.device.type == "cpu": + raise NotImplementedError("Only run in forward propagation") + p.data = torch.empty(p.data.shape, + device="cuda", + dtype=p.data.dtype, + requires_grad=p.data.requires_grad) + elif cur_dev == "cuda": + alloc_storage(p.data) + + def record_model_data_volume(self, params): + """ + get cuda model data used by params + """ + data_volume = self._grad_stats.unreleased_grad_volume + for p in params: + cur_model_data_volume = p.data.numel() * p.data.element_size() + data_volume += cur_model_data_volume + if self._training_phase == TrainingPhase.BACKWARD and p.requires_grad: + # add param.grad, actually param.grad is None in this time + data_volume += cur_model_data_volume + if not self._grad_stats.unreleased_grad_flag[p]: + self._grad_stats.unreleased_grad_volume += cur_model_data_volume + self._grad_stats.unreleased_grad_flag[p] = True + # record max non model data used for this Op + self._memstats.record_max_cuda_model_data(data_volume) + + def pre_op(self, params): + max_cuda_used_pre_op = self.mem_monitor.finish() + # record max cuda overall data for prev OP. + self._memstats.record_max_cuda_overall_data(max_cuda_used_pre_op) + # record max cuda non model data for prev OP. + self._memstats.calc_max_cuda_non_model_data() + + self._allocate_params_on_cuda(params) + # record max cuda model data for current OP + self.record_model_data_volume(params) + + self.mem_monitor.start() + self._memstats.increase_preop_step(params) + + def post_op(self, params): + self._free_cuda_params(params) + + def pre_forward(self, params: List[torch.Tensor]) -> None: + self.pre_op(params) + + def post_forward(self, params: List[torch.Tensor]) -> None: + self.post_op(params) + + def pre_backward(self, params: List[torch.Tensor]) -> None: + self.pre_op(params) + + def post_backward(self, params: List[torch.Tensor]) -> None: + self.post_op(params) + + @contextmanager + def switch_training_phase(self, training_phase: TrainingPhase = TrainingPhase.BACKWARD): + old_training_phase = self._training_phase + try: + self._training_phase = training_phase + yield + finally: + self._training_phase = old_training_phase + + switch_to_backward = switch_training_phase + switch_to_forward = partial(switch_to_backward, training_phase=TrainingPhase.FORWARD) diff --git a/colossalai/zero/legacy/gemini/ophooks/utils.py b/colossalai/zero/legacy/gemini/ophooks/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..84e8298c1d5186c5c292a68126499a674f31a593 --- /dev/null +++ b/colossalai/zero/legacy/gemini/ophooks/utils.py @@ -0,0 +1,142 @@ +# this code is inspired by the DeepSpeed library and implemented with our own design from scratch +from abc import ABC, abstractmethod +from typing import Callable, List, Optional + +import torch + + +class BaseOpHook(ABC): + """This class allows users to add customized operations + before and after the execution of a PyTorch submodule""" + + def __init__(self): + pass + + @abstractmethod + def pre_fwd_exec(self, module: torch.nn.Module, *args): + pass + + @abstractmethod + def post_fwd_exec(self, module: torch.nn.Module, *args): + pass + + @abstractmethod + def pre_bwd_exec(self, module: torch.nn.Module, input, output): + pass + + @abstractmethod + def post_bwd_exec(self, module: torch.nn.Module, input): + pass + + @abstractmethod + def post_iter(self): + pass + + +# apply torch.autograd.Function that calls a backward_function to tensors in output +def _apply_to_tensors_only(module, functional, backward_function, outputs): + if type(outputs) is tuple: + touched_outputs = [] + for output in outputs: + touched_output = _apply_to_tensors_only(module, functional, backward_function, output) + touched_outputs.append(touched_output) + return tuple(touched_outputs) + elif type(outputs) is torch.Tensor: + return functional.apply(module, backward_function, outputs) + else: + return outputs + + +class PreBackwardFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, module, pre_backward_function, outputs): + ctx.module = module + ctx.pre_backward_function = pre_backward_function + module.applied_pre_backward = False + outputs = outputs.detach() + return outputs + + @staticmethod + def backward(ctx, *args): + ctx.pre_backward_function(ctx.module) + return (None, None) + args + + +class PostBackwardFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, module, pre_backward_function, output): + ctx.module = module + output = output.detach() + ctx.pre_backward_function = pre_backward_function + return output + + @staticmethod + def backward(ctx, *args): + """ + Args: + activation_grad of the next layer. + Returns: + grad of the input activation. + """ + ctx.pre_backward_function(ctx.module) + return (None, None) + args + + +def register_ophooks_recursively(module: torch.nn.Module, + ophook_list: List[BaseOpHook], + name: str = "", + filter_fn: Optional[Callable] = None): + r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD.""" + assert isinstance(module, torch.nn.Module) + assert isinstance(ophook_list, (list, tuple)) + assert len(ophook_list) > 0, 'expected at least 1 hook in the argument ophook_list but found 0' + for hook in ophook_list: + assert (isinstance(hook, BaseOpHook)) + + # Add hooks for submodules + for child_name, child in module.named_children(): + register_ophooks_recursively(child, ophook_list, name + child_name, filter_fn) + + # Early return on modules with no parameters. + if len(list(module.parameters(recurse=False))) == 0: + return + + # return from flitered module + if filter_fn is not None and filter_fn(module): + return + + def _pre_forward_module_hook(submodule, *args): + for hook in ophook_list: + assert isinstance(submodule, torch.nn.Module) + hook.pre_fwd_exec(submodule, *args) + + def _post_forward_module_hook(submodule, *args): + for hook in ophook_list: + assert isinstance(submodule, torch.nn.Module) + hook.post_fwd_exec(submodule, *args) + + def _pre_backward_module_hook(submodule, inputs, output): + + def _run_before_backward_function(submodule): + for hook in ophook_list: + assert isinstance(submodule, torch.nn.Module) + hook.pre_bwd_exec(submodule, inputs, output) + + return _apply_to_tensors_only(submodule, PreBackwardFunction, _run_before_backward_function, output) + + def _post_backward_module_hook(submodule, inputs): + + def _run_after_backward_function(submodule): + for hook in ophook_list: + assert isinstance(submodule, torch.nn.Module) + hook.post_bwd_exec(submodule, inputs) + + return _apply_to_tensors_only(submodule, PostBackwardFunction, _run_after_backward_function, inputs) + + module.register_forward_pre_hook(_pre_forward_module_hook) + module.register_forward_hook(_post_forward_module_hook) + + module.register_forward_hook(_pre_backward_module_hook) + module.register_forward_pre_hook(_post_backward_module_hook) diff --git a/colossalai/zero/legacy/gemini/paramhooks/__init__.py b/colossalai/zero/legacy/gemini/paramhooks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7e423993784afb0b25e0b9d068e4968763919d03 --- /dev/null +++ b/colossalai/zero/legacy/gemini/paramhooks/__init__.py @@ -0,0 +1,3 @@ +from ._param_hookmgr import BaseParamHookMgr + +__all__ = ["BaseParamHookMgr"] diff --git a/colossalai/zero/legacy/gemini/paramhooks/_param_hookmgr.py b/colossalai/zero/legacy/gemini/paramhooks/_param_hookmgr.py new file mode 100644 index 0000000000000000000000000000000000000000..84f32be358e3b844ae7db394022d5b9a077352a7 --- /dev/null +++ b/colossalai/zero/legacy/gemini/paramhooks/_param_hookmgr.py @@ -0,0 +1,39 @@ +import functools +from typing import Callable, List + +import torch + + +class BaseParamHookMgr(object): + + def __init__(self, param_list: List[torch.nn.Parameter]) -> None: + r""" + register backward hook on every parameters of module + """ + self._param_list = param_list + self._hook_list = [] + + def register_backward_hooks(self, hook_call: Callable) -> None: + r""" + The hook_call will be called every time a gradient with respect to the a param in self.param_list + is computed. + The hook should have the following signature: + ``` + hook(param, grad) -> Tensor or None + ``` + """ + if not torch.is_grad_enabled(): + return # don't register grad hooks if grad isn't enabled + for p in self._param_list: + if p.requires_grad and not hasattr(p, '_base_param_hook'): + handle = p.register_hook(functools.partial(hook_call, p)) + p._base_param_hook = handle + + def remove_hooks(self) -> None: + """ + Remove hooks from model parameters. + """ + + for p in self._param_list: + if p.requires_grad and hasattr(p, '_base_param_hook'): + p._base_param_hook.remove() diff --git a/colossalai/zero/legacy/gemini/stateful_tensor.py b/colossalai/zero/legacy/gemini/stateful_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..1619ae40798d17e1fbff351e5678017dd54cf049 --- /dev/null +++ b/colossalai/zero/legacy/gemini/stateful_tensor.py @@ -0,0 +1,209 @@ +from enum import Enum +from typing import Optional, Union + +import torch + +from .gemini_context import GeminiMemoryManager + + +def sizeof_tensor(tensor: torch.Tensor): + return tensor.numel() * tensor.element_size() + + +class TensorState(Enum): + FREE = 0 + HOLD = 1 + HOLD_AFTER_FWD = 2 + HOLD_AFTER_BWD = 3 + COMPUTE = 4 + + +class StatefulTensor(object): + """A Structure stores a Torch Tensor and labeled states. + Inspired from the paper: + PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management + + https://arxiv.org/abs/2108.05818 + """ + # Global Stateful Tensor Manager + GST_MGR = GeminiMemoryManager(TensorState) + + def __init__(self, maybe_tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None: + self._state = state + self._payload = None + self._payload_size = 0 # byte size of current payload + + StatefulTensor.GST_MGR.register_new_instance() + + if self._state == TensorState.FREE: + # when the state is free, payload should be None + assert maybe_tensor is None, f"payload has to None if state is {self._state}" + else: + # otherwise, payload should not be None + assert maybe_tensor is not None, f"payload can't be None if state is {self._state}" + self._payload = maybe_tensor + self._payload_size = sizeof_tensor(maybe_tensor) + self.__trans_state_update(TensorState.FREE, state) + + def data_ptr(self): + if self._payload is None: + return 0 # if a tensor has no storage, 0 should be returned + return self._payload.data_ptr() + + def set_null(self) -> None: + # notice that free stateful tensor do not need to become null again + if self.state != TensorState.FREE: + self.__trans_state_update(self.state, TensorState.FREE) + self.__release() + + def is_null(self) -> bool: + if self.state == TensorState.FREE: + # check sanity here + assert self.payload is None + return True + return False + + def trans_state(self, state: TensorState) -> None: + if self.state == TensorState.FREE: + # free stateful tensor can't change state + assert state == TensorState.FREE, "Free stateful tensor can't change to other states" + return + + self.__trans_state_update(self.state, state) + + if state == TensorState.FREE: + self.__release() + else: + self._state = state + + def move_to(self, device: Union[torch.device, int]): + assert self.state is not TensorState.FREE, "Can't move free stateful tensor" + + if not isinstance(device, torch.device): + to_device = torch.device('cuda', device) + else: + to_device = device + + from_device_type = self.device.type + if from_device_type == to_device.type: + # from device == to device + return + + # update manager's information + self.__trans_device_update(from_device_type, to_device.type) + self.payload.data = self.payload.data.to(to_device) + + def payload_copy(self, tensor) -> None: + self._payload.view(-1).copy_(tensor.view(-1)) + + def payload_reset(self, tensor) -> None: + + assert tensor is not None, "Can't reset None for stateful tensors, please use set_null() instead" + + if self.payload is not None: + # release old payload + self.__trans_state_update(self.state, TensorState.FREE) + else: + # otherwise, set the state to HOLD for new payload + self._state = TensorState.HOLD + del self._payload + + self._payload = tensor + self._payload_size = sizeof_tensor(tensor) + # record new payload + self.__trans_state_update(TensorState.FREE, self.state) + + def payload_relay(self, rhs): + # relay the payload of rhs to current stateful tensor + # can't support null relay right now + assert not rhs.is_null() + + # now this function only support stateful tensor that has zero-length payload + # because it doesn't require memory manager updating + # you can extend this function by yourself + assert self.payload_size == 0 + + self._payload = rhs.payload + self._payload_size = rhs.payload_size + self._state = TensorState.HOLD + self.__trans_state_update(rhs.state, TensorState.HOLD) + + rhs.__release() + + @property + def payload(self) -> Optional[torch.Tensor]: + return self._payload + + @property + def payload_size(self) -> int: + return self._payload_size + + @property + def state(self) -> TensorState: + return self._state + + @property + def device(self) -> torch.device: + return self._payload.device + + @property + def dtype(self) -> torch.dtype: + return self._payload.dtype + + @property + def shape(self): + return self._payload.shape + + def to(self, device: torch.device): + raise RuntimeError("Use move_to(...) instead of call .to() on StatefulTensor") + + def to_(self, device: torch.device): + raise RuntimeError("Use move_to(...) instead of call .to_() on StatefulTensor") + + def __release(self): + # release current payload + # shouldn't be visible to users + self._state = TensorState.FREE + self._payload = None + self._payload_size = 0 + + def __trans_state_update(self, from_state: TensorState, to_state: TensorState): + """Update global manager when changing the state of a tensor + """ + manager = StatefulTensor.GST_MGR + size = self.payload_size + device_type = self.device.type + + if from_state != TensorState.FREE: + manager.state_mem[device_type][from_state] -= size + else: + # when from_state is FREE, the tensor is new to manager + # we should add its memory + manager.total_mem[device_type] += size + + if to_state != TensorState.FREE: + manager.state_mem[device_type][to_state] += size + else: + # when to_state is FREE, the tensor will be deleted soon + # we should sub its memory + manager.total_mem[device_type] -= size + + def __trans_device_update(self, from_type: str, to_type: str): + """Update global manager when changing the device of a tensor + """ + manager = StatefulTensor.GST_MGR + size = self.payload_size + state = self.state + + # update aggregated information + manager.total_mem[from_type] -= size + manager.total_mem[to_type] += size + + # update the information of each state + manager.state_mem[from_type][state] -= size + manager.state_mem[to_type][state] += size + + def __del__(self): + self.set_null() + StatefulTensor.GST_MGR.delete_instance() + del self diff --git a/colossalai/zero/legacy/gemini/stateful_tensor_mgr.py b/colossalai/zero/legacy/gemini/stateful_tensor_mgr.py new file mode 100644 index 0000000000000000000000000000000000000000..4f9ea7c6d5202238def5fc5be6d8a38724932860 --- /dev/null +++ b/colossalai/zero/legacy/gemini/stateful_tensor_mgr.py @@ -0,0 +1,103 @@ +import functools +import types +from time import time +from typing import List + +import torch + +from colossalai.logging import get_dist_logger +from colossalai.utils.cuda import get_current_device + +from .stateful_tensor import StatefulTensor, TensorState +from .tensor_placement_policy import TensorPlacementPolicy +from .tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage + + +class StatefulTensorMgr(object): + """ + Stateful Tensor Manager, inspired from PatrickStar + + PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management + https://arxiv.org/abs/2108.05818 + """ + + def __init__(self, tensor_placement_policy: TensorPlacementPolicy) -> None: + self._tensor_placement_policy: TensorPlacementPolicy = tensor_placement_policy + self._stateful_tensor_list: List[StatefulTensor] = [] + + self._compute_list: List[StatefulTensor] = [] + self._compute_idx: int = -1 + + self._cpu_gpu_move_volume = 0 + self._layout_time = 0 + self._evict_time = 0 + self._warmup = True + + def register_stateful_tensor_list(self, tensor_list: List[StatefulTensor]) -> None: + assert self._stateful_tensor_list == [], "Can't register stateful tensors for manager twice" + self._stateful_tensor_list = tensor_list + for t in self._stateful_tensor_list: + assert isinstance(t, StatefulTensor) + t.trans_state = types.MethodType(functools.partial(self._trans_state, t.trans_state), t) + + def start_iter(self): + pass + + def finish_iter(self): + """This function must be called when each iteration finishes + """ + self._warmup = False + self._compute_idx = -1 + self._cpu_gpu_move_volume = 0 + self._layout_time = 0 + self._evict_time = 0 + + def adjust_layout(self) -> None: + """ Adjust the layout of stateful tensor according to the information provided + by mem_stats_collector, which should belongs to a Sharded Model. + """ + # find stateful tensor in state COMPUTE + cuda_demand = StatefulTensor.GST_MGR.state_mem['cpu'][TensorState.COMPUTE] + start = time() + move_to_cuda_tensor_list, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup) + self._layout_time += time() - start + vol, evict_time = self._tensor_placement_policy.evict_tensors(hold_cuda_tensor_list, + cuda_demand=cuda_demand, + warmup=self._warmup, + compute_list=self._compute_list, + compute_idx=self._compute_idx) + self._cpu_gpu_move_volume += vol + self._evict_time += evict_time + # move COMPUTE tensors to CUDA + self._cpu_gpu_move_volume += cuda_demand + for t in move_to_cuda_tensor_list: + colo_model_data_tensor_move_inline(t, get_current_device()) + + @property + def cpu_gpu_move_volume(self): + return self._cpu_gpu_move_volume + + def _trans_state(self, trans_state_func, stateful_tensor, state): + trans_state_func(state) + if state == TensorState.COMPUTE: + self._compute_idx += 1 + if self._warmup: + self._compute_list.append(stateful_tensor) + + @functools.lru_cache(maxsize=None) + def _get_layout_info(self, compute_idx: int, warmup: bool): + move_to_cuda_tensor_list = [] + hold_cuda_tensor_list = [] + for tensor in self._stateful_tensor_list: + if tensor.state == TensorState.FREE: + continue + + if tensor.device.type == 'cuda': + if tensor.state in [TensorState.HOLD, TensorState.HOLD_AFTER_BWD, TensorState.HOLD_AFTER_FWD]: + hold_cuda_tensor_list.append(tensor) + elif tensor.device.type == 'cpu': + if tensor.state == TensorState.COMPUTE: + move_to_cuda_tensor_list.append(tensor) + else: + raise RuntimeError + return move_to_cuda_tensor_list, hold_cuda_tensor_list diff --git a/colossalai/zero/legacy/gemini/tensor_placement_policy.py b/colossalai/zero/legacy/gemini/tensor_placement_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..165ae51fee60e2292a909ef4ffe8646fdae67353 --- /dev/null +++ b/colossalai/zero/legacy/gemini/tensor_placement_policy.py @@ -0,0 +1,139 @@ +import functools +from abc import ABC, abstractmethod +from time import time +from typing import List, Optional, Type + +import torch + +from colossalai.utils import get_current_device +from colossalai.utils.memory import colo_device_memory_capacity +from colossalai.zero.gemini.memory_tracer import MemStatsCollector + +from .stateful_tensor import StatefulTensor +from .tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage + + +class TensorPlacementPolicy(ABC): + + def __init__(self, device: Optional[torch.device], mem_stats_collector: Optional[MemStatsCollector] = None) -> None: + self.device: Optional[torch.device] = device + self.mem_stats_collector: Optional[MemStatsCollector] = mem_stats_collector + + @abstractmethod + def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> None: + raise NotImplementedError + + +class CPUTensorPlacementPolicy(TensorPlacementPolicy): + + def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None: + super().__init__(torch.device('cpu'), mem_stats_collector=mem_stats_collector) + + def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> int: + volume = 0 + for t in hold_cuda_tensor_list: + colo_model_data_tensor_move_inline(t, self.device) + volume += t.payload.numel() * t.payload.element_size() + return volume, 0 + + +class CUDATensorPlacementPolicy(TensorPlacementPolicy): + + def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None: + assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available' + super().__init__(get_current_device(), mem_stats_collector=mem_stats_collector) + + def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> int: + return 0, 0 + + +class AutoTensorPlacementPolicy(TensorPlacementPolicy): + + def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None: + super().__init__(None, mem_stats_collector=mem_stats_collector) + # model data will use 1-self._warmup_non_model_data_ratio CUDA memory in warmup phase + # TODO(ver217): make these args configurable + self._warmup_non_model_data_ratio: float = 0.8 + self._steady_cuda_cap_ratio: float = 0.9 + + def evict_tensors(self, + hold_cuda_tensor_list: List[StatefulTensor], + cuda_demand: int = 0, + warmup: bool = True, + compute_list: List[StatefulTensor] = [], + compute_idx: int = 0, + **kwargs) -> int: + """ + Evict tensors from CUDA device. + + Args: + hold_cuda_tensor_list (List[StatefulTensor]): the list of tensor in state of HOLD-like + cuda_demand (int, optional): the volume of data needed on cuda device. Defaults to 0. + warmup (bool, optional): a flag indicates whether in the phase of warmup. Defaults to True. + compute_list (List[StatefulTensor], optional): TODO. Defaults to []. + compute_idx (int, optional): the idx of computing device. Defaults to 0. + + Raises: + RuntimeError: + + Returns: + int: the volume of memory that is evicted + """ + start = time() + cuda_capacity = colo_device_memory_capacity(get_current_device()) + used_cuda_model_data = StatefulTensor.GST_MGR.total_mem['cuda'] + if warmup: + # We designate a part of CUDA memory for model data in warmup iterations. + max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_non_model_data_ratio + else: + # max non-model-data cuda memory consumption of this sampling moment and the next sampling moment. + max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage('cuda') + cuda_capacity *= self._steady_cuda_cap_ratio + total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period + avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data + freed_cuda_model_data = 0 + end = time() + if avail_cuda_model_data < cuda_demand: + # Move cuda_demand - avail_cuda_model_data volume of tensors + # to_free_cuda_model_data = cuda_demand - avail_cuda_model_data + to_free_cuda_model_data = cuda_demand - avail_cuda_model_data + to_free_tensor_list = hold_cuda_tensor_list + if not warmup: + to_free_tensor_list = self._sort_hold_cuda_tensors(tuple(hold_cuda_tensor_list), compute_idx, + tuple(compute_list)) + # print(self._sort_hold_cuda_tensors.cache_info()) + end = time() + for t in to_free_tensor_list: + if freed_cuda_model_data >= to_free_cuda_model_data: + break + freed_cuda_model_data += t.payload_size + colo_model_data_tensor_move_inline(t, torch.device('cpu')) + if freed_cuda_model_data < to_free_cuda_model_data: + raise RuntimeError( + f"Adjust layout failed! No enough CUDA memory! Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}" + ) + return freed_cuda_model_data, end - start + + @staticmethod + @functools.lru_cache(maxsize=None) + def _sort_hold_cuda_tensors(hold_cuda_tensors: tuple, compute_idx: int, compute_list: tuple) -> list: + next_compute_idx = {t: len(compute_list) for t in hold_cuda_tensors} + for i in range(len(compute_list) - 1, compute_idx, -1): + if compute_list[i] in next_compute_idx: + next_compute_idx[compute_list[i]] = i + next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True) + return [t for (t, idx) in next_compute_idx] + + +class TensorPlacementPolicyFactory: + + @staticmethod + def create(policy_name: str) -> Type[TensorPlacementPolicy]: + if policy_name == 'cpu': + return CPUTensorPlacementPolicy + elif policy_name == 'cuda': + return CUDATensorPlacementPolicy + elif policy_name == 'auto': + return AutoTensorPlacementPolicy + else: + raise TypeError(f"Unknown tensor placement policy {policy_name}") diff --git a/colossalai/zero/legacy/gemini/tensor_utils.py b/colossalai/zero/legacy/gemini/tensor_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f23e0253fdd1123306a32801d1cbc32884bc73 --- /dev/null +++ b/colossalai/zero/legacy/gemini/tensor_utils.py @@ -0,0 +1,120 @@ +from typing import Tuple, Union + +import torch + +from .stateful_tensor import StatefulTensor + + +def is_storage_empty(tensor: torch.Tensor) -> bool: + return tensor.storage().size() == 0 + + +def free_storage(tensor: torch.Tensor) -> None: + if not is_storage_empty(tensor): + tensor.storage().resize_(0) + + +def alloc_storage(tensor: torch.Tensor) -> None: + if is_storage_empty(tensor): + tensor.storage().resize_(tensor.numel()) + + +def colo_tensor_mem_usage(tensor: Union[torch.Tensor, StatefulTensor]) -> Tuple[int, int]: + if isinstance(tensor, StatefulTensor): + t = tensor.payload + elif isinstance(tensor, torch.Tensor): + t = tensor + else: + return 0, 0 + + cuda_use, cpu_use = 0, 0 + + mem_use = t.storage().size() * t.element_size() + if t.device.type == 'cuda': + cuda_use += mem_use + elif t.device.type == 'cpu': + cpu_use += mem_use + + return cuda_use, cpu_use + + +def colo_model_data_tensor_move(src_t: Union[StatefulTensor, torch.Tensor], tgt_t: Union[StatefulTensor, + torch.Tensor]) -> None: + """ + A colossal API for model data tensor move. + The src and target tensors could be resident on both CPU and GPU. + + NOTE() The source tensor payload will be removed after this function. + + The function will record the communication volume between CPU and GPU. + Args: + src_t (Union[StatefulTensor, torch.Tensor]): source tensor + tgt_t (Union[StatefulTensor, torch.Tensor]): target tensor + """ + if isinstance(src_t, StatefulTensor): + src_t_payload = src_t.payload + else: + src_t_payload = src_t.data + src_dev = src_t_payload.device + + if isinstance(tgt_t, StatefulTensor): + tgt_t_payload = tgt_t.payload + else: + tgt_t_payload = tgt_t.data + + tgt_t_payload.copy_(src_t_payload) + + # remove payload of src_t + if isinstance(src_t, StatefulTensor): + src_t.set_null() + else: + src_t.data = torch.empty(0, device=src_dev, dtype=src_t_payload.dtype) + + +def colo_model_data_tensor_move_inline(t: Union[StatefulTensor, torch.Tensor], target_device: Union[torch.device, + int]) -> None: + """ + move a tensor to the target_device + Args: + t (Union[StatefulTensor, torch.Tensor]): the tensor be moved + target_device: a traget device, if type is int, it the index of cuda card. + """ + if not isinstance(target_device, torch.device): + target_device = torch.device(f'cuda:{target_device}') + + if isinstance(t, torch.Tensor): + t.data = t.data.to(target_device) + elif isinstance(t, StatefulTensor): + t.move_to(target_device) + else: + raise TypeError(f'colo_model_data_tensor_move_inline dose not accept type {type(t)}') + + +def colo_model_data_move_to_cpu(t: Union[StatefulTensor, torch.Tensor]) -> None: + """colo_model_data_move_to_cpu + move a model data tensor from gpu to cpu + Args: + t (Union[StatefulTensor, torch.Tensor]): _description_ + """ + # TODO() optimize the tensor moving with non-blocking + if isinstance(t, torch.Tensor): + t.data = t.data.cpu() + elif isinstance(t, StatefulTensor): + t.move_to(torch.device('cpu')) + else: + raise TypeError(f'colo_model_data_move_to_cpu dose not accept type {type(t)}') + + +def colo_model_tensor_clone(t: Union[StatefulTensor, torch.Tensor], target_device: torch.device) -> torch.Tensor: + """ + Clone a model data tensor + Args: + t (Union[StatefulTensor, torch.Tensor]): a model data tensor + target_device (torch.device): the target device + Returns: + torch.Tensor: a cloned torch tensor + """ + # TODO() rename this function + colo_model_data_tensor_move_inline(t, target_device) + t_payload = t.payload if isinstance(t, StatefulTensor) else t + return t_payload diff --git a/colossalai/zero/legacy/init_ctx/__init__.py b/colossalai/zero/legacy/init_ctx/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0a6f81566a9de2d83561fe7d91f9052244b286b8 --- /dev/null +++ b/colossalai/zero/legacy/init_ctx/__init__.py @@ -0,0 +1,3 @@ +from .init_context import ZeroInitContext, no_shard_zero_context, no_shard_zero_decrator + +__all__ = ['ZeroInitContext', 'no_shard_zero_context', 'no_shard_zero_decrator'] diff --git a/colossalai/zero/legacy/init_ctx/init_context.py b/colossalai/zero/legacy/init_ctx/init_context.py new file mode 100644 index 0000000000000000000000000000000000000000..a921ca0aa83a5d588c5339b799f5f3ee05386feb --- /dev/null +++ b/colossalai/zero/legacy/init_ctx/init_context.py @@ -0,0 +1,265 @@ +import contextlib +import functools +from contextlib import AbstractContextManager +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.distributed as dist +import torch.nn as nn + +from colossalai.context.parallel_mode import ParallelMode +from colossalai.context.singleton_meta import SingletonMeta +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses +from colossalai.zero.legacy.shard_utils import BaseShardStrategy +from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp16 +from colossalai.zero.legacy.sharded_model.sharded_model_v2 import ShardedModelV2 +from colossalai.zero.legacy.sharded_param import ShardedParamV2 + + +@dataclass +class ZeroContextConfig: + """The configuration used to control zero context initialization. + + Args: + target_device (torch.device): The device where param data are after exiting the context. + is_replicated (bool, optional): Whether the param is replicated across data parallel group. + Some parameters are not replicated, e.g. parameters in MOE experts. + shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False. + """ + + target_device: torch.device + is_replicated: bool = True + shard_param: bool = False + + def __post_init__(self): + if self.shard_param: + assert self.is_replicated, "Non-replicated parameters can't be sharded." + + if self.is_replicated and not self.shard_param: + assert self.target_device.type == 'cuda', "Replicated no-shard parameters should be located in cuda." + + +class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): + """A context to initialize model. + + 1. Convert the model to fp16. + 2. The paramaters of the module are adapted to type ShardedParameter. + 3. Shard the param and grad according to flags. + + Args: + target_device (torch.device): The device where param data are after exiting the context. + shard_strategy (BaseShardStrategy): Shard strategy instance. + seed (int, optional): Random seed for weight initialization + shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False. + default_dtype (torch.dtype, optional): If it's not None, parameters will be initialized as ``default_dtype`` then converted to fp16. + model_numel_tensor (torch.Tensor, optional): A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int). + """ + + def __init__(self, + target_device: torch.device, + shard_strategy: BaseShardStrategy, + seed: int = 2**10 - 1, + shard_param: bool = False, + default_dtype: Optional[torch.dtype] = None, + model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.long)): + + super().__init__(default_dtype=default_dtype) + self.shard_strategy = shard_strategy + self.param_list = [] + self.model_numel_tensor = model_numel_tensor + self.seed = seed + self.dp_process_group = gpc.get_group(ParallelMode.DATA) + + self.config = ZeroContextConfig(target_device=target_device, is_replicated=True, shard_param=shard_param) + + ZeroContextMgr().current_context = self + + self.param_numel = {} + self.top_module = None + + @property + def target_device(self): + return self.config.target_device + + @property + def is_replicated(self): + return self.config.is_replicated + + @property + def shard_param(self): + return self.config.shard_param + + @staticmethod + def calc_fanin_fanout(tensor: torch.Tensor): + """We use this function to substitute fan-in and fan-out calculation in torch.nn.init. + This can help us get correct fan-in and fan-out for sharded tensor. + """ + assert isinstance(tensor, nn.Parameter), "Sharded tensor initialization is only allowed for parameters" + + # get correct shape of input tensor + if not hasattr(tensor, 'colo_attr') or not tensor.colo_attr.param_is_sharded: + tensor_shape = tensor.shape + else: + tensor_shape = tensor.colo_attr.sharded_data_tensor.origin_shape + + dimensions = len(tensor_shape) + if dimensions < 2: + raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions") + + num_input_fmaps = tensor_shape[1] + num_output_fmaps = tensor_shape[0] + receptive_field_size = 1 + if dimensions > 2: + # math.prod is not always available, accumulate the product manually + # we could use functools.reduce but that is not supported by TorchScript + for s in tensor_shape[2:]: + receptive_field_size *= s + fan_in = num_input_fmaps * receptive_field_size + fan_out = num_output_fmaps * receptive_field_size + + return fan_in, fan_out + + def _pre_context_exec(self): + """ + The Callback function when entering the context + """ + self.logger = get_dist_logger("ZeroInitContext") + + # substitute fan-in and fan-out calculation + self.nn_fanin_fanout = nn.init._calculate_fan_in_and_fan_out + nn.init._calculate_fan_in_and_fan_out = self.calc_fanin_fanout + + self.module_load_from_state_dict = nn.Module._load_from_state_dict + shard_strategy = self.shard_strategy if self.config.shard_param else None + nn.Module._load_from_state_dict = functools.partialmethod(ShardedModelV2._colo_load_from_state_dict, + shard_strategy=shard_strategy) + self.module_state_dict = nn.Module.state_dict + nn.Module.state_dict = functools.partialmethod(ShardedModelV2._colo_state_dict, + shard_strategy=shard_strategy, + state_dict_func=self.module_state_dict, + process_group=self.dp_process_group) + + # reserve rng states + self.cpu_rng_state = torch.get_rng_state() + self.cuda_rng_state = torch.cuda.get_rng_state() + + # set new seed for initialization, since we initialize sharded tensor separately + # we don't want all processes have the same seed + # otherwise all sharded tensors are same after init + offset = self.seed + 1 # we want to have more 1 in binary format seed + torch.manual_seed(self.seed + offset * dist.get_rank()) + + def _post_context_exec(self): + """The callback function when exiting context. + """ + # broadcast replicated no-shard parameters + src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0] + for param in self.param_list: + assert hasattr(param, 'colo_attr') + if not param.colo_attr.param_is_sharded and param.colo_attr.is_replicated: + dist.broadcast(tensor=param.data, src=src_rank, group=self.dp_process_group) + param.colo_attr.set_data_none() + + del self.param_list + + nn.init._calculate_fan_in_and_fan_out = self.nn_fanin_fanout + nn.Module.load_state_dict = self.module_load_from_state_dict + nn.Module.state_dict = self.module_state_dict + torch.set_rng_state(self.cpu_rng_state) + torch.cuda.set_rng_state(self.cuda_rng_state) + + params = frozenset(self.top_module.parameters()) + for param in self.param_numel.keys(): + if param not in params: + self.param_numel[param] = 0 + self.model_numel_tensor.fill_(sum(self.param_numel.values())) + + def _post_init_method(self, module: torch.nn.Module, *args, **kwargs): + """ + The function to call at the end of the constructor of each module. + NOTE() The module may be passed to this function multiple times. + """ + self.top_module = module + + def half_fn(t: torch.Tensor): + return t.half() if t.is_floating_point() else t + + for param in module.parameters(recurse=False): + # avoid adapting a param to ShardedParam twice + if hasattr(param, 'colo_attr'): + continue + + self.param_numel[param] = param.numel() + + # convert parameters to half + param_half = half_fn(param) + param.data = param_half + if param.grad is not None: + grad_half = half_fn(param.grad) + param.grad.data = grad_half + + # move torch parameters to the target device + target_device = self.target_device + param.data = param.data.to(target_device) + if param.grad is not None: + param.grad = param.grad.to(target_device) + + param.colo_attr = ShardedParamV2(param, set_data_none=True) + + if self.shard_param: + self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group) + + param.data = param.colo_attr.data_payload # set param.data to payload + + # mark whether the param is replicated + param.colo_attr.is_replicated = self.is_replicated + + # mark whether the param should keep not sharded + # if True, the param is used as Zero stage 2 + param.colo_attr.keep_not_shard = not self.shard_param + + self.param_list.append(param) + + # We must cast buffers + # If we use BN, buffers may be on CPU and Float + # We must cast them + for buffer in module.buffers(recurse=False): + buffer.data = buffer.data.to(device=torch.cuda.current_device()) + buffer.data = cast_tensor_to_fp16(buffer.data) + + +class ZeroContextMgr(metaclass=SingletonMeta): + current_context: Optional[ZeroInitContext] = None + + @contextlib.contextmanager + def hijack_context_config(self, **kwargs): + if self.current_context is None: + yield + else: + old_config = self.current_context.config + self.current_context.config = ZeroContextConfig(**kwargs) + yield + self.current_context.config = old_config + + +def no_shard_zero_context(is_replicated: bool = True) -> AbstractContextManager: + return ZeroContextMgr().hijack_context_config(target_device=torch.device('cuda', torch.cuda.current_device()), + is_replicated=is_replicated, + shard_param=False) + + +def no_shard_zero_decrator(is_replicated: bool = True): + + def _wrapper(init_func): + + def _no_shard(*args, **kwargs): + with no_shard_zero_context(is_replicated): + ret = init_func(*args, **kwargs) + return ret + + return _no_shard + + return _wrapper diff --git a/colossalai/zero/legacy/shard_utils/__init__.py b/colossalai/zero/legacy/shard_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5e5d63a7e768a470b609ccd185012864752cb432 --- /dev/null +++ b/colossalai/zero/legacy/shard_utils/__init__.py @@ -0,0 +1,5 @@ +from .base_shard_strategy import BaseShardStrategy +from .bucket_tensor_shard_strategy import BucketTensorShardStrategy +from .tensor_shard_strategy import TensorShardStrategy + +__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy'] diff --git a/colossalai/zero/legacy/shard_utils/base_shard_strategy.py b/colossalai/zero/legacy/shard_utils/base_shard_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..7ca95109164028f1a1389ae0fa7547883af1e441 --- /dev/null +++ b/colossalai/zero/legacy/shard_utils/base_shard_strategy.py @@ -0,0 +1,22 @@ +from abc import ABC, abstractmethod +from typing import List, Optional + +import torch.distributed as dist + +from colossalai.zero.legacy.sharded_param.sharded_tensor import ShardedTensor + + +class BaseShardStrategy(ABC): + + def __init__(self) -> None: + """Abstract Shard Strategy. Use to shard a tensors on multiple GPUs. + """ + super().__init__() + + @abstractmethod + def shard(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None): + pass + + @abstractmethod + def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None): + pass diff --git a/colossalai/zero/legacy/shard_utils/bucket_tensor_shard_strategy.py b/colossalai/zero/legacy/shard_utils/bucket_tensor_shard_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..d663104831ce79062fb6083e19401705efa64fd0 --- /dev/null +++ b/colossalai/zero/legacy/shard_utils/bucket_tensor_shard_strategy.py @@ -0,0 +1,47 @@ +from typing import List, Optional + +import torch +import torch.distributed as dist +from torch._utils import _flatten_dense_tensors as flatten + +from colossalai.utils import get_current_device +from colossalai.zero.legacy.sharded_param.sharded_tensor import ShardedTensor + +from .tensor_shard_strategy import TensorShardStrategy + + +class BucketTensorShardStrategy(TensorShardStrategy): + """Use the same shard scheme as `TensorShardStrategy`'s, but it gathers tensors of a sub-module together, + which will fully utilize network bandwidth. + It is especially useful when sub-module contains bias, + since we cannot utilize network bandwidth well if we only gather a bias tensor (bias is usually small). + """ + + def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None): + + tensor_list: List[ShardedTensor] = [t for t in tensor_list if t.is_sharded] + if len(tensor_list) == 0: + return + target_device = tensor_list[0].device + dtype = tensor_list[0].dtype + buffer_list: List[torch.Tensor] = [] + tensor_numels = [t.payload.numel() for t in tensor_list] + buffer_size = sum(tensor_numels) + world_size = dist.get_world_size(process_group) + rank = dist.get_rank(process_group) + for i in range(world_size): + if i == rank: + buffer_list.append(flatten([t.payload for t in tensor_list]).cuda(get_current_device())) + else: + buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_current_device())) + dist.all_gather(buffer_list, buffer_list[rank], group=process_group) + # Move to target device before splitting buffer + # Ensure we utilize maximum PCIE bandwidth + buffer_list = [buffer.to(target_device) for buffer in buffer_list] + offset = 0 + for i, t in enumerate(tensor_list): + gathered_payload = [buffer[offset:offset + tensor_numels[i]] for buffer in buffer_list] + gathered_payload = torch.cat(gathered_payload)[:t.origin_numel].view(t.origin_shape) + t.payload_reset(gathered_payload) + t.is_sharded = False + offset += tensor_numels[i] diff --git a/colossalai/zero/legacy/shard_utils/commons.py b/colossalai/zero/legacy/shard_utils/commons.py new file mode 100644 index 0000000000000000000000000000000000000000..bf5ae325caf4978cf3c3dcd6c9384444d8a4a8aa --- /dev/null +++ b/colossalai/zero/legacy/shard_utils/commons.py @@ -0,0 +1,22 @@ +from typing import Tuple + +import torch + + +def get_shard(tensor: torch.Tensor, rank: int, world_size: int) -> Tuple[torch.Tensor, int]: + """Return the local shard of a full tensor.""" + # Shard using torch.chunk to match all-gather/reduce-scatter. + chunks = list(torch.flatten(tensor).chunk(world_size)) + while len(chunks) < world_size: + chunks.append(chunks[0].new_empty(0)) + + # Determine number of padding elements. + num_to_pad = chunks[0].numel() - chunks[rank].numel() + assert num_to_pad >= 0, num_to_pad + + shard = torch.zeros_like(chunks[0]) + length = chunks[rank].size(0) + shard_temp = shard[:length] + shard_temp.copy_(chunks[rank]) + + return shard, num_to_pad diff --git a/colossalai/zero/legacy/shard_utils/tensor_shard_strategy.py b/colossalai/zero/legacy/shard_utils/tensor_shard_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..d1df4803b820988d20e06ac37e653cc3bb7759c9 --- /dev/null +++ b/colossalai/zero/legacy/shard_utils/tensor_shard_strategy.py @@ -0,0 +1,59 @@ +from typing import List, Optional + +import torch +import torch.distributed as dist + +from colossalai.utils import get_current_device +from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_tensor_move_inline +from colossalai.zero.legacy.shard_utils import BaseShardStrategy +from colossalai.zero.legacy.shard_utils.commons import get_shard +from colossalai.zero.legacy.sharded_param.sharded_tensor import ShardedTensor + + +class TensorShardStrategy(BaseShardStrategy): + """ + A naive implementation which shard each tensor evenly over all ranks + """ + + def shard(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None): + for t in tensor_list: + self._shard_tensor(t, process_group) + + def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None): + for t in tensor_list: + self._gather_tensor(t, process_group) + + def _shard_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None): + """ Shard tensor among processes. + + Args: + t (ShardedTensor): a tensor to be sharded. + process_group (Optional[dist.ProcessGroup], optional): the process group among which tensor shards. + Defaults to None. + """ + if t.is_sharded: + return + if t.payload.device.type == 'cuda': + assert t.payload.device == get_current_device(), f"shard tensor on cuda device index {t.payload.device.index},"\ + f" but current cuda device is {get_current_device()}" + sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group)) + t.payload_reset(sharded_payload) + t.is_sharded = True + + def _gather_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None): + if not t.is_sharded: + return + target_device = t.device + payload_numel = t.payload.numel() + world_size = dist.get_world_size(process_group) + rank = dist.get_rank(process_group) + + buffer = torch.empty(payload_numel * world_size, dtype=t.payload.dtype, device=get_current_device()) + buffer_list = list(torch.chunk(buffer, chunks=world_size, dim=0)) + buffer_list[rank].copy_(t.payload) + + dist.all_gather(buffer_list, buffer_list[rank], group=process_group, async_op=False) + gathered_payload = torch.narrow(buffer, 0, 0, t.origin_numel).reshape(t.origin_shape) + t.payload_reset(gathered_payload) + colo_model_data_tensor_move_inline(t, target_device) + t.is_sharded = False diff --git a/colossalai/zero/legacy/sharded_model/__init__.py b/colossalai/zero/legacy/sharded_model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..93120bdc34b4f18f48f7975e79d956c90d9ec50c --- /dev/null +++ b/colossalai/zero/legacy/sharded_model/__init__.py @@ -0,0 +1,3 @@ +from .sharded_model_v2 import ShardedModelV2 + +__all__ = ['ShardedModelV2'] diff --git a/colossalai/zero/legacy/sharded_model/_utils.py b/colossalai/zero/legacy/sharded_model/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2bd01531a78f517c5a431b4cb4d7e7af65024613 --- /dev/null +++ b/colossalai/zero/legacy/sharded_model/_utils.py @@ -0,0 +1,77 @@ +from typing import Any, Callable, List, Tuple, Union + +import torch +import torch.nn.functional as F + +from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor + + +def get_gradient_predivide_factor(world_size: int) -> float: + factor: int = 1 + while world_size % factor == 0 and world_size / factor > factor: + factor *= 2 + return float(factor) + + +def free_storage(data: torch.Tensor) -> None: + """Free underlying storage of a Tensor.""" + if data.storage().size() > 0: + # Since we're modifying the Tensor's Storage directly, make sure the Tensor + # is the sole occupant of the Storage. + assert data.storage_offset() == 0 + data.storage().resize_(0) + + +@torch.no_grad() +def alloc_storage(data: torch.Tensor, size: torch.Size) -> None: + """Allocate storage for a tensor.""" + if data.storage().size() == size.numel(): # no need to reallocate + return + assert data.storage().size() == 0 + data.storage().resize_(size.numel()) + + +def cast_tensor_to_fp16(tensor: torch.Tensor) -> torch.Tensor: + if isinstance(tensor, StatefulTensor): + tensor = tensor.payload + if torch.is_floating_point(tensor) and tensor.dtype is torch.float32: + return tensor.half() + return tensor + + +def cast_tensor_to_fp32(tensor: Union[torch.Tensor, StatefulTensor]) -> torch.Tensor: + if isinstance(tensor, StatefulTensor): + tensor = tensor.payload + + if torch.is_floating_point(tensor) and tensor.dtype is torch.float16: + return tensor.float() + return tensor + + +def apply_to_tensors(x: Any, fn: Callable): + if torch.is_tensor(x): + return fn(x) + elif isinstance(x, list): + return [apply_to_tensors(t, fn) for t in x] + elif isinstance(x, tuple): + return tuple(apply_to_tensors(t, fn) for t in x) + elif isinstance(x, dict): + return {key: apply_to_tensors(val, fn) for key, val in x.items()} + else: + return x + + +def cast_float_arguments(fn: Callable, *args: Any, **kwargs: Any) -> Tuple[Any, Any]: + return apply_to_tensors(args, fn), apply_to_tensors(kwargs, fn) + + +def chunk_and_pad(tensor: torch.Tensor, num_chunks: int) -> List[torch.Tensor]: + """Chunk a given Tensor into num_chunks parts and add any necessary padding.""" + chunks = list(torch.flatten(tensor).chunk(num_chunks)) + # torch.chunk may return fewer than num_chunks chunks, pad accordingly. + num_pad_for_partial_chunk = chunks[0].numel() - chunks[-1].numel() + if num_pad_for_partial_chunk > 0: + chunks[-1] = F.pad(chunks[-1], [0, num_pad_for_partial_chunk]) + if len(chunks) < num_chunks: + chunks.extend([torch.zeros_like(chunks[0]) for _ in range(num_chunks - len(chunks))]) + return chunks diff --git a/colossalai/zero/legacy/sharded_model/reduce_scatter.py b/colossalai/zero/legacy/sharded_model/reduce_scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..4fb507382df9eae2d3efa35fdcdcb2704a9256dc --- /dev/null +++ b/colossalai/zero/legacy/sharded_model/reduce_scatter.py @@ -0,0 +1,200 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import functools +import os +from typing import Callable, Dict, List, Optional, Tuple + +import torch +import torch.distributed as dist +from torch import Tensor +from torch.distributed import ProcessGroup + +# TODO: Remove the toggle-enable_nccl_base_collectives when github open issue #801 is resolved. +if os.getenv("ENABLE_NCCL_BASE_COLLECTIVES", "1") == "0": + enable_nccl_base_collectives = False +else: + enable_nccl_base_collectives = True + + +class Bucket: + + def __init__(self, shard_size: int, dtype: torch.dtype, device: torch.device, group: ProcessGroup): + self.buffer = torch.zeros((group.size(), shard_size), dtype=dtype, device=device) + self.group = group + self.offset = 0 + self.callbacks: List[Callable] = [] + self.output_shard = torch.zeros_like(self.buffer[0]) + + def flush(self) -> None: + """Flush content of the bucket.""" + if self.offset == 0: + assert len(self.callbacks) == 0 + return + # reduce-scatter bucket + if hasattr(dist, "_reduce_scatter_base") and enable_nccl_base_collectives: + dist._reduce_scatter_base(self.output_shard[:self.offset], + self.buffer[:, :self.offset].contiguous(), + group=self.group) + else: + dist.reduce_scatter(self.output_shard[:self.offset], + list(self.buffer[:, :self.offset].unbind(0)), + group=self.group) + # execute post-reduction callbacks + for callback_fn in self.callbacks: + callback_fn() + # reuse input bucket but allocate a fresh output shard + self.buffer[:, :self.offset].zero_() + self.offset = 0 + self.callbacks.clear() + self.output_shard = torch.zeros_like(self.buffer[0]) + + def alloc(self) -> None: + """Setup the buffers if they are not allocated. + + Using ``setup`` and ``teardown``, we can ensure that the bucket + buffers are only allocated during the backward pass, hence saving more + memory to other parts of the training process, such as the forward pass + for activation memory. + """ + for tensor in [self.buffer, self.output_shard]: + if tensor.storage().size() == 0: + tensor.storage().resize_(tensor.size().numel()) + + def free(self) -> None: + """Tear down the bucket by freeing the memory""" + assert self.offset == 0 and self.callbacks == [], "Incorrect call of teardown" + for tensor in [self.buffer, self.output_shard]: + tensor.storage().resize_(0) + + def append(self, tensor_list: List[Tensor], callback_fn: Callable): + # copy data from input_list into bucket + tensor_size = tensor_list[0].numel() + stacked_input = torch.stack(tensor_list).view(self.group.size(), tensor_size) + offset = self.offset + self.buffer[:, offset:offset + tensor_size].copy_(stacked_input) + self.offset += tensor_size + + # callback will be given the reduced result + if callback_fn is not None: + result_view = self.output_shard[offset:offset + tensor_size].view_as(tensor_list[0]) + self.callbacks.append(functools.partial(callback_fn, result_view)) + + +class ReduceScatterBucketer: + """ + Helper for bucketing multiple reduce-scatter operations on small tensors + into larger reduce-scatter ops to improve communication efficiency. + + Usage:: + + bucketer = ReduceScatterBucketer() + bucketer.reduce_scatter_async( + small_tensors, callback_fn=lambda result: print("small") + ) + bucketer.reduce_scatter_async( + big_tensors, callback_fn=lambda result: print("big") + ) + bucketer.reduce_scatter_async( + more_small_tensors, callback_fn=lambda result: print("small2") + ) + bucketer.flush() # callbacks only guaranteed to be called after flush() + # Example output (note that it is out of order, due to bucketing): + # big + # small + # small2 + + Args: + bucket_size_mb (int, Optional): bucket size for communicating. Buckets + are sub-divided based on world_size. Values <= 0 disable bucketing. + """ + + def __init__(self, bucket_size_mb: int = 25): + self.bucket_size_mb = bucket_size_mb + self.buckets: Dict[Tuple[torch.dtype, torch.device, ProcessGroup], Bucket] = {} + + @torch.no_grad() + def reduce_scatter_async( + self, + input_list: List[Tensor], + group: ProcessGroup, + callback_fn: Optional[Callable] = None, + ) -> None: + """ + Reduce-scatter a list of tensors asynchronously, so smaller reductions + can be bucketed together. The given callback (``callback_fn``) will be + called with the reduced result at some later time. Call ``flush()`` to + force all queued ops and callbacks to be executed. + + Note that large inputs will be reduced immediately, and this function + may also flush the relevant bucket to make room for ``input_list``. + + Args: + input_list (List[Tensor]): list of tensors to reduce-scatter. List + should contain ``group.size()`` tensors and each tensor should + have identical shape, dtype and device. + group (ProcessGroup): process group for reduction + callback_fn (Callable, Optional): callback function to call after + the reduction executes. Function will be called with a single + argument corresponding to the reduced result. + """ + world_size = group.size() + + assert (len(input_list) == world_size + ), f"reduce_scatter received {len(input_list)} inputs, expected group.size() ({world_size})" + + first_input = input_list[0] + first_input_size = first_input.numel() + + bucket_shard_size = self._get_shard_size(first_input.element_size(), world_size) + if first_input_size > bucket_shard_size: + # TODO: investigate how to avoid using torch.cat (because it seems to be slow for CPU tensors) + # input is too big to fit in the bucket, reduce-scatter directly + output = torch.zeros_like(input_list[0]) + if hasattr(dist, "_reduce_scatter_base") and enable_nccl_base_collectives: + input_flattened = torch.cat(input_list) + dist._reduce_scatter_base(output, input_flattened, group=group) + else: + # fallback + dist.reduce_scatter(output, input_list, group=group) + if callback_fn is not None: + callback_fn(output) + return + + bucket = self._get_bucket(first_input, group) + if first_input_size > bucket.buffer.size(1) - bucket.offset: + # not enough space remaining in bucket, flush it now + bucket.flush() + bucket.append(input_list, callback_fn) + + @torch.no_grad() + def flush(self) -> None: + """Reduce-scatter any partial buckets.""" + for bucket in self.buckets.values(): + bucket.flush() + + @torch.no_grad() + def free(self) -> None: + """Free buffers from all buckets.""" + for bucket in self.buckets.values(): + bucket.free() + + @functools.lru_cache() + def _get_shard_size(self, element_size: int, num_shards: int) -> int: + if self.bucket_size_mb <= 0: # Values <= 0 disable bucketing. + return 0 + MB = 1024 * 1024 + bucket_size = self.bucket_size_mb * MB / element_size + return int(bucket_size // num_shards) + + def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket: + key = (tensor.dtype, tensor.device, group) + if key not in self.buckets: + # buckets are divided into world_size pieces, bucket.data shaped (world_size, shard_size) + world_size = group.size() + shard_size = self._get_shard_size(tensor.element_size(), world_size) + self.buckets[key] = Bucket(shard_size, tensor.dtype, tensor.device, group) + self.buckets[key].alloc() + return self.buckets[key] diff --git a/colossalai/zero/legacy/sharded_model/sharded_model_v2.py b/colossalai/zero/legacy/sharded_model/sharded_model_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..b3a83b7418250216c4b275c17fd88cb4ec779641 --- /dev/null +++ b/colossalai/zero/legacy/sharded_model/sharded_model_v2.py @@ -0,0 +1,572 @@ +# this code is inspired by the DeepSpeed library and implemented with our own design from scratch +import functools +import itertools +from collections import OrderedDict +from copy import deepcopy +from typing import Any, Iterator, Optional, Tuple + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed import ProcessGroup +from torch.nn.parameter import Parameter + +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.utils import disposable, get_current_device +from colossalai.utils.memory import colo_device_memory_capacity +from colossalai.zero.gemini.memory_tracer import MemStatsCollector, StaticMemStatsCollector +from colossalai.zero.legacy.gemini.ophooks import register_ophooks_recursively +from colossalai.zero.legacy.gemini.paramhooks import BaseParamHookMgr +from colossalai.zero.legacy.gemini.stateful_tensor import TensorState +from colossalai.zero.legacy.gemini.stateful_tensor_mgr import StatefulTensorMgr +from colossalai.zero.legacy.gemini.tensor_placement_policy import TensorPlacementPolicy, TensorPlacementPolicyFactory +from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_move_to_cpu +from colossalai.zero.legacy.shard_utils import BaseShardStrategy +from colossalai.zero.legacy.sharded_model.reduce_scatter import ReduceScatterBucketer + +from ._utils import ( + cast_float_arguments, + cast_tensor_to_fp16, + cast_tensor_to_fp32, + chunk_and_pad, + free_storage, + get_gradient_predivide_factor, +) +from .zero_hook import ZeroHook + +try: + from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX +except ImportError: + _EXTRA_STATE_KEY_SUFFIX = '_extra_state' + + +class ShardedModelV2(nn.Module): + """ + A wrapper for the PyTorch module shards the model parameters among multiple GPU memory. + Only `1/#nproc` of parameters, gradients are stored in local CUDA memory, so forward and backward + passes can be executed with limited CUDA memory budget. + + Note: + You must use ``ShardedModelV2`` with ``ShardedOptimizerV2``. + Note: + Make sure you don't use gradient accumulation and your optimizer can work with fp16 gradient and fp32 parameter, + if you enable ``reuse_fp16_shard``. + + Args: + module (nn.Module): A sharded module, which must be initialized by `ZeroInitContext`. + shard_strategy (BaseShardStrategy): A shard strategy to manage shard behavior. + process_group (Optional[ProcessGroup], optional): Data parallel process group. Defaults to None. + reduce_scatter_process_group (Optional[ProcessGroup], optional): Reduce-scatter process group. + Generally, it should be `None`, and it's the same as `process_group`. Defaults to None. + reduce_scatter_bucket_size_mb (int, optional): Reduce-scatter bucket size in *MB*. Defaults to 25. + fp32_reduce_scatter (bool, optional): If set to `True`, gradients are forced to FP32 before reduce-scatter. Defaults to False. + tensor_placement_policy (str): Which device to place *held* tensors. It can be 'cpu', 'cuda' and 'auto'. + If it's 'cpu', parameters, gradients and optimizer states will be offloaded to CPU, which means min CUDA memory will be used. + If it's 'cuda', they won't be offloaded, which means max CUDA memory will be used. + If it's 'auto', they are moving dynamically based on CPU and CUDA memory usage. It will utilize heterogeneous memory space evenly and well. + Note that 'auto' policy can only work well when no other processes use CUDA during your training. + Defaults to 'cuda'. + gradient_predivide_factor (Optional[float], optional): Gradient is divived by this value before reduce-scatter. Defaults to 1.0. + reuse_fp16_shard (bool, optional): Whether to reuse fp16 shard for param and grad. + Enabling this can reduce GPU memory usage, but you have to make sure you disable it when using gradient accumulation. + In this mode, grad will be fp16. Make sure your optimizer supports mixed precision (fp32 param and fp16 grad). + We find that PyTorch's optimizers don't support mixed precision, + so we recommend you enable this only when using our CPUAdam with CPU offload. Defaults to False. + """ + + def __init__(self, + module: nn.Module, + shard_strategy: BaseShardStrategy, + process_group: Optional[ProcessGroup] = None, + reduce_scatter_process_group: Optional[ProcessGroup] = None, + reduce_scatter_bucket_size_mb: int = 25, + fp32_reduce_scatter: bool = False, + tensor_placement_policy: str = 'cuda', + gradient_predivide_factor: Optional[float] = 1.0, + reuse_fp16_shard: bool = False, + *args, + **kwargs): + assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.' + super().__init__() + self.logger = get_dist_logger() + + # We force users to use ZeroInitContext + for submodule in module.modules(): + sharded_cnt = 0 + unshard_cnt = 0 + for param in submodule.parameters(recurse=False): + assert hasattr(param, 'colo_attr'), 'You must use ZeroInitContext to init your module first.' + if param.colo_attr.param_is_sharded: + sharded_cnt += 1 + else: + unshard_cnt += 1 + assert (not sharded_cnt) or (not unshard_cnt), 'nn.Module can not both have shard param and unshard param' + submodule.param_is_sharded = (sharded_cnt > 0) + + self.sharded_params = [] + self.unshard_params = [] + for param in module.parameters(): + if param.colo_attr.param_is_sharded: + self.sharded_params.append(param) + else: + self.unshard_params.append(param) + + self.module = module + self.process_group = process_group or gpc.get_group(ParallelMode.DATA) + self.reduce_scatter_process_group = reduce_scatter_process_group or self.process_group + self.world_size = dist.get_world_size(self.process_group) + self.rank = dist.get_rank(self.process_group) + self.shard_strategy = shard_strategy + + self._use_memory_tracer = tensor_placement_policy == 'auto' + if self._use_memory_tracer: + self._memstats_collector = MemStatsCollector() + self._start_collect_memstats = disposable(self._memstats_collector.start_collection) + self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection) + else: + self._memstats_collector = None + self._tensor_placement_policy: TensorPlacementPolicy = TensorPlacementPolicyFactory.create( + tensor_placement_policy)(mem_stats_collector=self._memstats_collector) + + if 'warmup_non_model_data_ratio' in kwargs: + if tensor_placement_policy != 'auto': + self.logger.warning('setting warmup_non_model_data_ratio is useless if not use auto placement') + else: + ratio = kwargs['warmup_non_model_data_ratio'] + self._tensor_placement_policy._warmup_non_model_data_ratio = ratio + self.logger.info(f'setting warmup_non_model_data_ratio as {ratio} for auto placement') + + self._stateful_tensor_mgr = StatefulTensorMgr(self._tensor_placement_policy) + param_tensor_list = [p.colo_attr.sharded_data_tensor for p in module.parameters() if hasattr(p, 'colo_attr')] + self._stateful_tensor_mgr.register_stateful_tensor_list(param_tensor_list) + + # Register hooks + self._ophook_list = [ + ZeroHook(self.shard_strategy, self._memstats_collector, self._stateful_tensor_mgr, self.process_group) + ] + register_ophooks_recursively(self.module, self._ophook_list) + self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters())) + self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook) + + self.fp32_reduce_scatter = fp32_reduce_scatter + self._cpu_offload: bool = tensor_placement_policy != 'cuda' + for param in module.parameters(): + # Init `offload_grad` + param.colo_attr.offload_grad = self._cpu_offload + + # We find if gradient_predivide_factor != 1.0, there may be wrong precision problem + # So we use 1.0 as the default gradient_predivide_factor + # However, if you set gradient_predivide_factor to None, we will set + # gradient_predivide_factor to a value >= 1.0 automatically + self.gradient_predivide_factor: float = gradient_predivide_factor if \ + gradient_predivide_factor is not None else \ + get_gradient_predivide_factor(self.world_size) + self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor + + self.comm_stream: torch.cuda.Stream = torch.cuda.Stream() + self.reducer = ReduceScatterBucketer(reduce_scatter_bucket_size_mb) + self._require_backward_grad_sync: bool = True + + self._cuda_margin_space = 0 + self.reuse_fp16_shard = reuse_fp16_shard + + # record whether gradients have inf or nan + self.overflow_counter = 0 + + def adjust_stateful_tensor_layout(self) -> None: + self._stateful_tensor_mgr.adjust_layout() + + @property + def use_memory_tracer(self): + return self._use_memory_tracer + + @property + def cuda_margin_space(self): + return self._cuda_margin_space + + @property + def cpu_offload(self): + return self._cpu_offload + + def dump_memory_stats(self, filename: Optional[str] = 'dump_mem_stats.log') -> None: + """ + dummy memory tracer collected information to a file. + try: + # forward: model(inputs) + # backward: optimizer.backward() + except Exception as e: + model.dump_memory_stats() + exit(0) + """ + if self._use_memory_tracer: + self.logger.error(f'dump memort tracer collected information to a {filename}', ranks=[0]) + if gpc.get_global_rank() == 0: + with open(filename, 'w+') as f: + f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n') + f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n') + f.write('CUDA model data (GB)\n') + f.write('\n') + f.write('CUDA non model data (GB)\n') + f.write(str(self._memstats_collector._memstats.non_model_data_list('cuda'))) + f.write('CPU non model data (GB)\n') + f.write(str(self._memstats_collector._memstats.non_model_data_list('cpu'))) + f.write('\n') + + def _pre_forward_operations(self, *args): + # the operation will affect the memory tracer behavior in ZeroHook + if self._memstats_collector: + self._start_collect_memstats() + + for p in self.module.parameters(): + if hasattr(p, 'colo_attr'): + p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD) + + self._stateful_tensor_mgr.start_iter() + + def _post_forward_operations(self): + for p in self.module.parameters(): + if hasattr(p, 'colo_attr'): + p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD) + + def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: + self._pre_forward_operations(*args) + args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs) + outputs = self.module(*args, **kwargs) + self._post_forward_operations() + return outputs + + def backward(self, loss): + loss.backward() + self._post_backward_operations() + for ophook in self._ophook_list: + ophook.post_iter() + + def backward_by_grad(self, tensor, grad): + torch.autograd.backward(tensors=tensor, grad_tensors=grad) + self._post_backward_operations() + for ophook in self._ophook_list: + ophook.post_iter() + + def _update_memstats(self): + if self._memstats_collector: + self._finish_collect_memstats() + # cuda margin space = cuda mem capacity - max fwd/bwd cuda mem used. + # the way to calculate margin space is based on the assumption that + # model data is fixed in cuda during training. + # cuda margin space can be used to store OS. + self._cuda_margin_space = colo_device_memory_capacity( + get_current_device()) - self._memstats_collector._memstats.max_overall_cuda + + @torch.no_grad() + def _post_backward_operations(self) -> None: + """ + The method includes operations required to be processed after backward + 1. update memory tracer. + 2. flush the gradient in buckets. Reducing partial gradients in each process. + 3. shard tensors not dealed in the zero hook + 4. move sharded param grad payload to param.grad + """ + # 1. update memory tracer. + self._update_memstats() + + # 2. flush the gradient in buckets. Reducing partial gradients in each process. + if self._require_backward_grad_sync: + # Flush any unreduced buckets in the post_backward stream. + with torch.cuda.stream(self.comm_stream): + self.reducer.flush() + torch.cuda.current_stream().wait_stream(self.comm_stream) + self.reducer.free() + + # 3. shard tensors not dealed in the zero hook + tensor_list = [] + for p in self.sharded_params: + if not p.colo_attr.param_is_sharded: + tensor_list.append(p.colo_attr.sharded_data_tensor) + p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD) + p.colo_attr.set_data_none() + self.shard_strategy.shard(tensor_list, self.process_group) + + # 4. set all parameters' grad to None + for p in self.module.parameters(): + if not p.requires_grad: + continue + # Leave the gradient accumulation state (_require_backward_grad_sync) as-is if not synchronizing this pass. + # NOTE() (no-sync)/sync pass: (not conduct)/conduct gradient all reducing between process group. + # If _require_backward_grad_sync is True, + # p.grad remains the accumulated unsharded gradient from prior no-sync passes. + # We also allows to interleave no-sync pass with sync passes, if desired. + if not self._require_backward_grad_sync: + continue + + p.grad = None + + @torch.no_grad() + def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]: + """ + At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the + full gradient for the local batch. The reduce-scatter op will save + a single shard of the summed gradient across all + GPUs to param.colo_attr.grad. This shard will align with the current GPU rank. For example:: + + before reduce_scatter: + param.grad (GPU #0): [1, 2, 3, 4] + param.grad (GPU #1): [5, 6, 7, 8] + + after reduce_scatter: + param.grad (GPU #0): [6, 8] # 1+5, 2+6 + param.grad (GPU #1): [10, 12] # 3+7, 4+8 + + The local GPU's ``optim.step`` is responsible for updating a single + shard of params, also corresponding to the current GPU's rank. This + alignment is created by `param.colo_attr.grad`, which ensures that + the local optimizer only sees the relevant parameter shard. + """ + if grad is None: + return + assert not grad.requires_grad, 'ShardedModel only works with gradients that don\'t require gradients' + if not self._require_backward_grad_sync: + return + # used to cheat Pytorch, since we can't return None + empty_grad = torch.empty_like(grad) + free_storage(empty_grad) + # As torch didn't allow modifying grad in hook, we make a copy + grad = grad.clone() + if param.colo_attr.is_replicated: + self._reduce_scatter_handler(param, grad) + else: + self._save_grad(param, grad) + return empty_grad + + def _reduce_scatter_handler(self, param: Parameter, grad: torch.Tensor) -> None: + self.comm_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self.comm_stream): + if self.fp32_reduce_scatter: + grad.data = grad.data.to(param.dtype) + if self.gradient_predivide_factor > 1.0: + # Average grad by world_size for consistency with PyTorch DDP. + grad.data.div_(self.gradient_predivide_factor) + if self.world_size > 1: + grad_chunks = chunk_and_pad(grad, self.reduce_scatter_process_group.size()) + self.reducer.reduce_scatter_async(grad_chunks, + group=self.reduce_scatter_process_group, + callback_fn=functools.partial(self._reduce_scatter_callback, param)) + else: + self._reduce_scatter_callback(param, grad) + torch.cuda.current_stream().wait_stream(self.comm_stream) + + def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None: + assert isinstance(reduced_grad, + torch.Tensor), f"_reduce_scatter_callback accept reduced_grad as {type(reduced_grad)}" + reduced_grad.data = reduced_grad.data.contiguous().view(-1) + if self.gradient_postdivide_factor > 1: + # Average grad by world_size for consistency with PyTorch DDP. + reduced_grad.data.div_(self.gradient_postdivide_factor) + self._save_grad(param, reduced_grad) + + # FIXME(ver217): refactor the below line when impl eviction policy + def _save_grad(self, param: Parameter, grad: torch.Tensor): + + # record whether we have overflow + self.overflow_counter += torch.isinf(grad).any().item() + self.overflow_counter += torch.isnan(grad).any().item() + + # move gradient to cpu + if param.colo_attr.offload_grad: + colo_model_data_move_to_cpu(grad) + + if self.reuse_fp16_shard: + # make parameters point to gradient + + assert param.colo_attr.saved_grad.is_null( + ), 'Gradien accumulation is not supported when reuse_fp16_shard=True' + + param.colo_attr.grad_payload_reset(grad.data) + # release the memory of param + # we set a false None for parameter's payload + # so we can get parameter's device and dtype later in optimizer + param.colo_attr.data_payload_reset(torch.empty(0, device=grad.device, dtype=grad.dtype)) + + if param.colo_attr.is_replicated: + param.colo_attr.sharded_data_tensor.is_sharded = True + else: + + fp32_grad = cast_tensor_to_fp32(grad) + + if param.colo_attr.saved_grad.is_null(): + param.colo_attr.grad_payload_reset(fp32_grad) + else: + param.colo_attr.grad_payload.add_(fp32_grad.view_as(param.colo_attr.grad_payload)) + + # keep saved_grad in HOLD state + param.colo_attr.saved_grad.trans_state(TensorState.HOLD) + + def parameters(self, recurse: bool = True) -> Iterator[Parameter]: + return self.module.parameters(recurse=recurse) + + def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]: + return self.module.named_parameters(prefix, recurse) + + def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]': + return self._colo_state_dict(destination, + prefix, + keep_vars, + shard_strategy=self.shard_strategy, + state_dict_func=nn.Module.state_dict, + module_to_load=self.module, + sharded_params=self.sharded_params, + process_group=self.process_group) + + def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True) -> None: + for name, p in self.named_parameters(): + if name in state_dict: + p.colo_attr.data_payload_reset(state_dict[name].to(dtype=p.colo_attr.data_payload.dtype, + device=p.colo_attr.data_payload.device)) + # Force re-shard + p.colo_attr.sharded_data_tensor.is_sharded = False + self.shard_strategy.shard([p.colo_attr.sharded_data_tensor]) + elif strict: + raise RuntimeError(f'Missing key in state_dict: {name}') + + def _colo_state_dict(self, + destination=None, + prefix='', + keep_vars=False, + shard_strategy: Optional[BaseShardStrategy] = None, + state_dict_func=None, + module_to_load=None, + sharded_params=[], + process_group=None) -> 'OrderedDict[str, torch.Tensor]': + if len(sharded_params) == 0: + for param in self.parameters(): + if param.colo_attr.param_is_sharded: + sharded_params.append(param) + if shard_strategy is not None: + shard_strategy.gather([p.colo_attr.sharded_data_tensor for p in sharded_params], process_group) + for p in sharded_params: + p.data = p.colo_attr.data_payload + module_to_load = module_to_load or self + gathered_state_dict = state_dict_func(module_to_load, destination, prefix, keep_vars) + gathered_state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in gathered_state_dict.items()} + if shard_strategy is not None: + shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in sharded_params], process_group) + for p in sharded_params: + p.colo_attr.set_data_none() + return gathered_state_dict + + def _colo_load_from_state_dict(self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + shard_strategy=None): + r"""Copies parameters and buffers from :attr:`state_dict` into only + this module, but not its descendants. This is called on every submodule + in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this + module in input :attr:`state_dict` is provided as :attr:`local_metadata`. + For state dicts without metadata, :attr:`local_metadata` is empty. + Subclasses can achieve class-specific backward compatible loading using + the version number at `local_metadata.get("version", None)`. + + .. note:: + :attr:`state_dict` is not the same object as the input + :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So + it can be modified. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + prefix (str): the prefix for parameters and buffers used in this + module + local_metadata (dict): a dict containing the metadata for this module. + See + strict (bool): whether to strictly enforce that the keys in + :attr:`state_dict` with :attr:`prefix` match the names of + parameters and buffers in this module + missing_keys (list of str): if ``strict=True``, add missing keys to + this list + unexpected_keys (list of str): if ``strict=True``, add unexpected + keys to this list + error_msgs (list of str): error messages should be added to this + list, and will be reported together in + :meth:`~torch.nn.Module.load_state_dict` + shard_strategy (Optional[BaseShardStrategy], optional): A shard strategy to manage shard behavior. Defaults to None. + """ + for hook in self._load_state_dict_pre_hooks.values(): + hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} + local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) + local_state = {k: v for k, v in local_name_params if v is not None} + + for name, param in local_state.items(): + key = prefix + name + if key in state_dict: + input_param = state_dict[key] + if hasattr(param, 'colo_attr'): + param.colo_attr.data_payload_reset( + input_param.to(dtype=param.colo_attr.data_payload.dtype, + device=param.colo_attr.data_payload.device)) + if shard_strategy is not None: + # Force re-shard + param.colo_attr.sharded_data_tensor.is_sharded = False + shard_strategy.shard([param.colo_attr.sharded_data_tensor]) + else: + # This is used to avoid copying uninitialized parameters into + # non-lazy modules, since they dont have the hook to do the checks + # in such case, it will error when accessing the .shape attribute. + is_param_lazy = torch.nn.parameter.is_lazy(param) + # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ + if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1: + input_param = input_param[0] + + if not is_param_lazy and input_param.shape != param.shape: + # local shape should match the one in checkpoint + error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' + 'the shape in current model is {}.'.format( + key, input_param.shape, param.shape)) + continue + try: + with torch.no_grad(): + param.copy_(input_param) + except Exception as ex: + error_msgs.append('While copying the parameter named "{}", ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}, ' + 'an exception occurred : {}.'.format(key, param.size(), input_param.size(), + ex.args)) + elif strict: + missing_keys.append(key) + + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(self.__class__, "set_extra_state", nn.Module.set_extra_state) is not nn.Module.set_extra_state: + if extra_state_key in state_dict: + self.set_extra_state(state_dict[extra_state_key]) + elif strict: + missing_keys.append(extra_state_key) + elif strict and (extra_state_key in state_dict): + unexpected_keys.append(extra_state_key) + + if strict: + for key in state_dict.keys(): + if key.startswith(prefix) and key != extra_state_key: + input_name = key[len(prefix):] + input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child + if input_name not in self._modules and input_name not in local_state: + unexpected_keys.append(key) + + def __getitem__(self, idx: int): + assert isinstance(self.module, nn.ModuleList) + return self.module[idx] + + def __len__(self): + assert isinstance(self.module, nn.ModuleList) + return len(self.module) + + def __iter__(self): + assert isinstance(self.module, nn.ModuleList) + return iter(self.module) diff --git a/colossalai/zero/legacy/sharded_model/utils.py b/colossalai/zero/legacy/sharded_model/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..08806e78ea3bf245e80d220e30f35029227bf144 --- /dev/null +++ b/colossalai/zero/legacy/sharded_model/utils.py @@ -0,0 +1,20 @@ +import copy + +import torch + +from colossalai.zero.legacy.sharded_model import ShardedModelV2 + + +def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Module): + """ + copy param of the ShardedModelV2 to other_model. + Note the other_model has to be the same as self. + """ + for zero_param, param in zip(sharded_model.parameters(), other_model.parameters()): + assert hasattr(zero_param, 'colo_attr') + shard_flag = zero_param.colo_attr.sharded_data_tensor.is_sharded + if shard_flag: + sharded_model.shard_strategy.gather([zero_param.colo_attr.sharded_data_tensor]) + param.data = copy.deepcopy(zero_param.colo_attr.data_payload) + if shard_flag: + sharded_model.shard_strategy.shard([zero_param.colo_attr.sharded_data_tensor]) diff --git a/colossalai/zero/legacy/sharded_model/zero_hook.py b/colossalai/zero/legacy/sharded_model/zero_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..50f4bdfc775d77a938d51442e2966b720a519b94 --- /dev/null +++ b/colossalai/zero/legacy/sharded_model/zero_hook.py @@ -0,0 +1,118 @@ +from typing import Optional + +import torch +import torch.distributed as dist + +from colossalai.logging import get_dist_logger +from colossalai.registry import OPHOOKS +from colossalai.utils import get_current_device +from colossalai.zero.gemini.memory_tracer import MemStatsCollector +from colossalai.zero.legacy.gemini.ophooks import BaseOpHook +from colossalai.zero.legacy.gemini.stateful_tensor import TensorState +from colossalai.zero.legacy.gemini.stateful_tensor_mgr import StatefulTensorMgr +from colossalai.zero.legacy.shard_utils import BaseShardStrategy + + +@OPHOOKS.register_module +class ZeroHook(BaseOpHook): + """ + A hook to process sharded param for ZeRO method. + Warning: this class has been deprecated after version 0.1.12 + """ + + def __init__(self, + shard_strategy: BaseShardStrategy, + memstarts_collector: Optional[MemStatsCollector] = None, + stateful_tensor_mgr: Optional[StatefulTensorMgr] = None, + process_group: Optional[dist.ProcessGroup] = None): + super().__init__() + self.logger = get_dist_logger("ZeROHook") + self.shard_strategy = shard_strategy + self.process_group = process_group + + # NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU + self.computing_device = get_current_device() + + self._memstarts_collector = memstarts_collector + self._stateful_tensor_mgr = stateful_tensor_mgr + + def gather_parameters(self, module: torch.nn.Module): + # gather sharded parameters + if module.param_is_sharded: + tensor_list = [] + for param in module.parameters(recurse=False): + assert hasattr(param, 'colo_attr') + tensor_list.append(param.colo_attr.sharded_data_tensor) + self.shard_strategy.gather(tensor_list, self.process_group) + + def shard_parameters(self, module: torch.nn.Module): + # shard gathered parameters + if module.param_is_sharded: + tensor_list = [] + for param in module.parameters(recurse=False): + assert hasattr(param, 'colo_attr') + tensor_list.append(param.colo_attr.sharded_data_tensor) + self.shard_strategy.shard(tensor_list, self.process_group) + + def adjust_module_data(self, module: torch.nn.Module): + # record overall data statistics + if self._memstarts_collector: + self._memstarts_collector.sample_overall_data() + + for param in module.parameters(recurse=False): + param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE) + + # adjust stateful tensor to get enough CUDA memory + self._stateful_tensor_mgr.adjust_layout() + + # record model data statistics + if self._memstarts_collector: + self._memstarts_collector.record_model_data_volume() + + def pre_fwd_exec(self, module: torch.nn.Module, *args): + self.adjust_module_data(module) + self.gather_parameters(module) + for param in module.parameters(recurse=False): + param.data = param.colo_attr.data_payload + assert param.data.device.type == 'cuda', f"PRE FWD param.data must be on CUDA" + + def post_fwd_exec(self, module: torch.nn.Module, *args): + + # change tensor state to HOLD_AFTER_FWD + for param in module.parameters(recurse=False): + param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_FWD) + + self.shard_parameters(module) + + # remove torch payload + for param in module.parameters(recurse=False): + param.colo_attr.set_data_none() + + def pre_bwd_exec(self, module: torch.nn.Module, input, output): + self.adjust_module_data(module) + self.gather_parameters(module) + for param in module.parameters(recurse=False): + param.data = param.colo_attr.data_payload + assert param.data.device.type == 'cuda', f"PRE BWD param.data must be on CUDA" + + def post_bwd_exec(self, module: torch.nn.Module, input): + + # change tensor state to HOLD_AFTER_BWD + for param in module.parameters(recurse=False): + param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD) + + self.shard_parameters(module) + + # remove torch payload + for param in module.parameters(recurse=False): + param.colo_attr.set_data_none() + + def pre_iter(self): + pass + + def post_iter(self): + if self._stateful_tensor_mgr: + self.logger.debug( + f"CPU-GPU data moving this iteration {self._stateful_tensor_mgr.cpu_gpu_move_volume/1e9} GB, get layout info time: {self._stateful_tensor_mgr._layout_time}, evict cpu time: {self._stateful_tensor_mgr._evict_time}", + ranks=[0]) + self._stateful_tensor_mgr.finish_iter() diff --git a/colossalai/zero/legacy/sharded_optim/__init__.py b/colossalai/zero/legacy/sharded_optim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b71a70aeffa44d9f4f539bf7ee4d497a55ebd3d5 --- /dev/null +++ b/colossalai/zero/legacy/sharded_optim/__init__.py @@ -0,0 +1,3 @@ +from .sharded_optim_v2 import ShardedOptimizerV2 + +__all__ = ['ShardedOptimizerV2'] diff --git a/colossalai/zero/legacy/sharded_optim/sharded_optim_v2.py b/colossalai/zero/legacy/sharded_optim/sharded_optim_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..be60209af434ea80da71f76299217f0c9fd4340e --- /dev/null +++ b/colossalai/zero/legacy/sharded_optim/sharded_optim_v2.py @@ -0,0 +1,388 @@ +# this code is inspired by the DeepSpeed library and implemented with our own design from scratch +from enum import Enum +from os import stat +from typing import Dict, Optional, Tuple + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch import Tensor +from torch.distributed import ProcessGroup +from torch.nn.parameter import Parameter +from torch.optim import Optimizer + +from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.nn.optimizer import ColossalaiOptimizer +from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState +from colossalai.zero.legacy.gemini.tensor_placement_policy import AutoTensorPlacementPolicy +from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage +from colossalai.zero.legacy.sharded_model import ShardedModelV2 +from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp32 + + +class OptimState(Enum): + SCALED = 1 + UNSCALED = 2 + + +class ShardedOptimizerV2(ColossalaiOptimizer): + """A wrapper for optimizer. ``ShardedOptimizerV2`` and ``ShardedModelV2`` implement Zero Redundancy Optimizer (ZeRO). + + By default the ZeRO optimizer stage 3 offload Optimizer States on CPU. + + We apply the Device-aware Operator Placement technique for OS placement from the following paper. + + `PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management`_ + + GPU margin space is the remaining space after removing peak non-model data from the overall GPU memory, + which is detected by a runtime memory tracer. + + We place as many OS chunks in the margin space as possible. + + The size of margin space can be controlled by ``gpu_margin_mem_ratio``. + If it is set as ``0.0``, it is the same as classical ZeRO optimizer. + + Note: + You must use ``ShardedOptimizerV2`` with ``ShardedModelV2``. + + Note: + Make sure you set ``tensor_placement_policy`` in ``ShardedModelV2`` to `"auto"`, + if you set ``gpu_margin_mem_ratio > 0``. + + Args: + sharded_model (ShardedModelV2): A sharded model initialized by class ShardedModelV2. The optimizer will use the + shard strategy provided by sharded model to shard param fp32 tensors. + optimizer (Optimizer): An Optimizer instance. + gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward) + which will be used when using hybrid CPU optimizer. + This argument is meaningless when `tensor_placement_policy` of `ShardedModelV2` is not "auto". + Defaults to 0.0. + initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32. + min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1. + growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2. + backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5. + growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000. + hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2. + max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32. + dp_process_group (Optional[ProcessGroup], optional): data parallel process group. Defaults to None. + mp_process_group (Optional[ProcessGroup], optional): model parallel process group. Defaults to None. + + .. _PatrickStar\: Parallel Training of Pre-trained Models via Chunk-based Memory Management: + https://arxiv.org/abs/2108.05818 + """ + + def __init__(self, + sharded_model: ShardedModelV2, + optimizer: Optimizer, + gpu_margin_mem_ratio: float = 0.0, + initial_scale: float = 2**32, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + dp_process_group: Optional[ProcessGroup] = None, + mp_process_group: Optional[ProcessGroup] = None, + verbose: bool = False) -> None: + assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel' + assert not isinstance(optimizer, ShardedOptimizerV2), 'Nested ShardedOptimizerV2 is not supported.' + + super().__init__(optimizer) + self.shard_strategy = sharded_model.shard_strategy + self.model: ShardedModelV2 = sharded_model + + self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio) + assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f'gpu_margin_mem_ratio must >=0.0 and <=1.0' + # Only move fp32 shards from CPU to GPU when user allows and inner optimizer is valid + # Inner optimizer must support optimizing hybrid (CPU and CUDA) tensors, + # and it must set `num_fp32_shards_per_param` correctly + self._should_move_fp32_shards_h2d: bool = sharded_model.cpu_offload and self.gpu_margin_mem_ratio > 0.0 and getattr( + optimizer, 'num_fp32_shards_per_param', 0) >= 2 + self.device = sharded_model._tensor_placement_policy.device or torch.device('cpu') + self.optim_state: OptimState = OptimState.UNSCALED + self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA) + self.mp_process_group = mp_process_group or gpc.get_group(ParallelMode.MODEL) + # Grad scaler + self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale) + self._found_overflow: Tensor = torch.IntTensor([0]).to(torch.cuda.current_device()) + self._logger = get_dist_logger("ShardedOptimizerV2") + self._verbose = verbose + + # Store fp32 param shards + self._register_master_weight() + if self.gpu_margin_mem_ratio != 0.0 and not isinstance(sharded_model._tensor_placement_policy, + AutoTensorPlacementPolicy): + self._logger.warning(f'gpu_margin_mem_ratio is meaningless when tensor_placement_policy is not "auto"', + ranks=[0]) + + if self._verbose: + self._logger.debug( + f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0] / 1e6} MB CUDA Memory!", ranks=[0]) + + self._use_memory_tracer = self.model.use_memory_tracer + + @property + def loss_scale(self): + return self.grad_scaler.scale.item() + + def get_memory_usage(self) -> Tuple[int, int]: + """ Get the memory usage of the optimizer. Including master_params (param fp32), + momentum (``self.state[p]['exp_avg']``) variance (``self.state[p]['exp_avg_sq']``) + + Returns: + Tuple[int, int]: cuda/cpu memory usage in Byte. + """ + cuda_use = 0 + cpu_use = 0 + + def update_mem_use(t): + nonlocal cuda_use + nonlocal cpu_use + t_cuda_use, t_cpu_use = colo_tensor_mem_usage(t) + cuda_use += t_cuda_use + cpu_use += t_cpu_use + + for _, p_fp32 in self.master_params.items(): + update_mem_use(p_fp32) + for group in self.optim.param_groups: + for p in group['params']: + state = self.optim.state[p] + for k, v in state.items(): + update_mem_use(v) + + return cuda_use, cpu_use + + def zero_grad(self, *args, **kwargs): + self._zero_grad() + + def backward(self, loss: Tensor) -> None: + loss = self.loss_scale * loss + self.optim_state = OptimState.SCALED + self.model.backward(loss) + + def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None: + # This function is called except the last stage of pipeline parallel + # It receives the scaled grad from the previous rank + # No need to scale the grad again + # Need to unscale when optimizing + self.optim_state = OptimState.SCALED + self.model.backward_by_grad(tensor, grad) + + def clip_grad_norm(self, model: nn.Module, max_norm: float): + if self.optim_state == OptimState.SCALED: + self._prepare_grads() + self._unscale_grads() + return super().clip_grad_norm(model, max_norm) + + def step(self, *args, **kwargs): + + # unscale grads if scaled + if self.optim_state == OptimState.SCALED: + self._prepare_grads() + self._unscale_grads() + + self._maybe_move_fp32_shards() + found_inf = self._check_overflow() + self.grad_scaler.update(found_inf) + + if found_inf: + self._logger.warning('found inf during ShardedOptimV2 step') + self._zero_grad(recover_data=True) + return + + self._point_param_fp16_to_master_param() + + if self._verbose: + gpu_mem, cpu_mem = self.get_memory_usage() + self._logger.debug( + f"Before step ShardedOptimizerV2 consumes {gpu_mem / 1e6} MB CUDA Memory, {cpu_mem / 1e6} MB CUDA Memory!", + ranks=[0]) + ret = self.optim.step(*args, **kwargs) + + if self._verbose: + gpu_mem, cpu_mem = self.get_memory_usage() + self._logger.debug( + f"After step ShardedOptimizerV2 consumes {gpu_mem / 1e6} MB CUDA Memory, {cpu_mem / 1e6} MB CUDA Memory!", + ranks=[0]) + + self._copy_master_model_to_model_fp16() + return ret + + def _check_overflow(self): + # clear previous overflow record + self._found_overflow.fill_(self.model.overflow_counter) + + # all-reduce across dp group + dist.all_reduce(self._found_overflow, group=self.dp_process_group) + + # all-reduce over model parallel group + dist.all_reduce(self._found_overflow, group=self.mp_process_group) + + return self._found_overflow.item() > 0 + + def _unscale_grads(self): + assert self.optim_state == OptimState.SCALED + for group in self.optim.param_groups: + for p in group['params']: + if p.grad is not None: + p.grad.data.div_(self.loss_scale) + self.optim_state = OptimState.UNSCALED + + def _zero_grad(self, recover_data: bool = False): + """zero grad and maybe recover fp16 params + When `reuse_fp16_shard` is enabled, + p.colo_attr.sharded_data_tensor stores grad here. + We have to recover them from fp32 params. + + Args: + recover_data (bool, optional): Whether to recover fp16 param from fp32 param. Defaults to False. + """ + # We must set grad to None + # Because grad here is sharded + # But next backward pass will create a full grad first + # Which leads to wrong accumulation + self.optim.zero_grad(set_to_none=True) + for group in self.optim.param_groups: + for p in group['params']: + # p.colo_attr.sharded_data_tensor stores grad now + # we have to recover fp16 param + reuse_fp16_shard = (p.colo_attr.sharded_data_tensor.payload_size == 0) + if recover_data and reuse_fp16_shard: + self._copy_master_param_to_param_fp16(p) + else: + # release saved gradient + p.colo_attr.saved_grad.set_null() + self.model.overflow_counter = 0 # set overflow counter to zero + + def sync_grad(self): + pass + + def _register_master_weight(self): + self.master_params: Dict[Parameter, StatefulTensor] = {} + for group in self.optim.param_groups: + for p in group['params']: + assert hasattr(p, 'colo_attr'), 'The parameter must be wrapped with ShardedParam' + shard_flag = not p.colo_attr.sharded_data_tensor.is_sharded and p.colo_attr.is_replicated + if shard_flag: + # we always shard replicated parameters + self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group) + self.master_params[p] = StatefulTensor(cast_tensor_to_fp32(p.colo_attr.data_payload.to(self.device))) + if shard_flag: + # In this branch, there's no need to shard param + # So we gather here + self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group) + + def _maybe_move_fp32_shards(self): + if self._should_move_fp32_shards_h2d: + self._should_move_fp32_shards_h2d = False + available_cuda_margin_mem = self.model.cuda_margin_space * self.gpu_margin_mem_ratio + fp32_shards_available_cuda_margin_mem = available_cuda_margin_mem / self.optim.num_fp32_shards_per_param + fp32_shards_used_cuda_margin_mem = 0 + for group in self.optim.param_groups: + for p in group['params']: + if p.colo_attr.saved_grad.is_null(): + continue + shard_mem = self.master_params[p].payload.numel() * self.master_params[p].payload.element_size() + if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem: + colo_model_data_tensor_move_inline(self.master_params[p], torch.cuda.current_device()) + colo_model_data_tensor_move_inline(p.colo_attr.saved_grad, torch.cuda.current_device()) + p.colo_attr.offload_grad = False + fp32_shards_used_cuda_margin_mem += shard_mem + state = self.optim.state[p] + for k, v in state.items(): + if isinstance(v, Tensor): + state[k] = v.cuda() + + def _prepare_grads(self): + for group in self.optim.param_groups: + for p in group['params']: + if p.colo_attr.saved_grad.is_null(): + continue + p.colo_attr.saved_grad.trans_state(TensorState.COMPUTE) + # If reuse_fp16_shard, grad fp16 which wasn't be offloaded may be evicted to CPU + if not p.colo_attr.offload_grad: + colo_model_data_tensor_move_inline(p.colo_attr.saved_grad, torch.cuda.current_device()) + # FIXME(ver217): p.data here is an empty tensor on CUDA and has no useful information + # If we change p.grad directly + # it may raise error because of different shape/dtype/device of p.data and p.grad + # We just set p.data = p.colo_attr.saved_grad.payload here + p.data = p.colo_attr.grad_payload + p.grad = p.colo_attr.grad_payload + # Set p.data to empty tensor, in case of memory leaking + p.colo_attr.set_data_none() + + def _point_param_fp16_to_master_param(self): + # assign master param pointers to p.data. + # We will not trigger data copy here. + for group in self.optim.param_groups: + for p in group['params']: + self.master_params[p].trans_state(TensorState.COMPUTE) + p.data = self.master_params[p].payload + # Now p.data is sharded + # So optimizer states are sharded naturally + + def _copy_master_model_to_model_fp16(self): + # Copy master param data (fp32) to payload of colo_attr (fp16) + # TODO() improve efficiency by gathering tensors into a chunk and transferring + # a chunk. + for group in self.optim.param_groups: + for p in group['params']: + self._copy_master_param_to_param_fp16(p) + + def _copy_master_param_to_param_fp16(self, p): + # flush gradient + if p.colo_attr.sharded_data_tensor.payload_size == 0: + # here reuse_fp16_shard is True + # in order to use copy below, we should give sharded data tensor a payload + p.colo_attr.sharded_data_tensor.payload_relay(p.colo_attr.saved_grad) + else: + p.colo_attr.saved_grad.set_null() + + p.data = self.master_params[p].payload + + # we need to allocate new memory for keep_not_shard parameters + # in order to use copy, otherwise, the sizes of tensor is not compatible + if p.colo_attr.data_payload.numel() != p.data.numel(): + p.colo_attr.data_payload_reset( + torch.empty(p.data.shape, dtype=p.colo_attr.data_payload.dtype, device=p.colo_attr.data_payload.device)) + + # TODO() optimize this line CPU (fp32) -> GPU (fp16) + p.colo_attr.sharded_data_tensor.payload_copy(p.half().detach()) + p.colo_attr.set_data_none() + + if p.colo_attr.keep_not_shard and p.colo_attr.is_replicated: + # We gather full fp16 param here + p.colo_attr.sharded_data_tensor.is_sharded = True # since only gradient is sharded, we should set to True + self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group) + + self.master_params[p].trans_state(TensorState.HOLD) + + def state_dict(self): + optim_state_dict = super().state_dict() + scaler_state_dict = self.grad_scaler.state_dict() + optim_state_dict['scaler'] = scaler_state_dict + return optim_state_dict + + def load_state_dict(self, *args, **kwargs): + if 'scaler' not in args[0]: + self._logger.warning('Missing scaler when loading optimizer state dict', ranks=[0]) + else: + scaler_state_dict = args[0].pop('scaler') + self.grad_scaler.load_state_dict(scaler_state_dict) + super().load_state_dict(*args, **kwargs) + for group in self.optim.param_groups: + for p in group['params']: + state = self.optim.state[p] + for k, v in state.items(): + if isinstance(v, Tensor): + state[k] = v.to(dtype=self.master_params[p].dtype, device=self.master_params[p].device) diff --git a/colossalai/zero/legacy/sharded_param/__init__.py b/colossalai/zero/legacy/sharded_param/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..47e2ce2fa0e015c978b726c1fa77bb66fca91c42 --- /dev/null +++ b/colossalai/zero/legacy/sharded_param/__init__.py @@ -0,0 +1,4 @@ +from .sharded_param import ShardedParamV2 +from .sharded_tensor import ShardedTensor + +__all__ = ['ShardedTensor', 'ShardedParamV2'] diff --git a/colossalai/zero/legacy/sharded_param/sharded_param.py b/colossalai/zero/legacy/sharded_param/sharded_param.py new file mode 100644 index 0000000000000000000000000000000000000000..4bcc4b62104ab23b0f86683bea1bf34294c07a70 --- /dev/null +++ b/colossalai/zero/legacy/sharded_param/sharded_param.py @@ -0,0 +1,110 @@ +from typing import List, Optional, Tuple + +import torch + +from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState +from colossalai.zero.legacy.gemini.tensor_utils import colo_tensor_mem_usage + +from .sharded_tensor import ShardedTensor + +EMPTY_TENSOR_DICT = {} + + +def get_empty_tensor(device: torch.device, dtype: torch.dtype): + key = (device, dtype) + if key not in EMPTY_TENSOR_DICT: + EMPTY_TENSOR_DICT[key] = torch.empty(0, dtype=dtype, device=device) + + return EMPTY_TENSOR_DICT[key] + + +class ShardedParamV2(object): + + def __init__(self, param: torch.nn.Parameter, set_data_none: bool = False) -> None: + self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data) + self.saved_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE) + # This attribute must be initialized in ShardedModel + self.offload_grad: bool = False + + # make sure the shared param is the only owner of payload + # The param.data maybe used to init the other part of the model. + # For example: File "resnet.py", line 190, in __init__ + # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + # So we can not empty the .data at this time + self.param = param + if set_data_none: + self.set_data_none() + + def get_payload_tensors(self) -> List[StatefulTensor]: + """returns stateful tensors kept by this class. + """ + return [self._sharded_data_tensor] + + def set_data_none(self): + self.param.data = get_empty_tensor(self.sharded_data_tensor.device, self.sharded_data_tensor.dtype) + + def set_grad_none(self): + self.saved_grad.set_null() + + @property + def sharded_data_tensor(self): + return self._sharded_data_tensor + + @property + def data_payload(self): + assert not self.sharded_data_tensor.is_null() + return self.sharded_data_tensor.payload + + @property + def grad_payload(self): + assert not self.saved_grad.is_null() + return self.saved_grad.payload + + @property + def param_is_sharded(self): + return self.sharded_data_tensor.is_sharded + + def data_payload_reset(self, tensor: torch.Tensor): + assert type(tensor) is torch.Tensor + assert tensor.requires_grad is False + self.sharded_data_tensor.payload_reset(tensor) + + def grad_payload_reset(self, tensor: torch.Tensor): + assert type(tensor) is torch.Tensor + assert tensor.requires_grad is False + self.saved_grad.payload_reset(tensor) + + def get_memory_usage(self) -> Tuple[int, int]: + """ + get the memory usage of the param, including data and grad + Returns: + Tuple[int, int]: cuda mem usage in Byte, cpu memory usage in Byte + """ + cuda_mem_use, cpu_mem_use = 0, 0 + + def _update_mem_use(t: Optional[torch.Tensor]): + if t is None: + return + assert isinstance(t, torch.Tensor) + nonlocal cuda_mem_use + nonlocal cpu_mem_use + t_cuda, t_cpu = colo_tensor_mem_usage(t) + cuda_mem_use += t_cuda + cpu_mem_use += t_cpu + + address_set = set() + _update_mem_use(self.data_payload) + address_set.add(self.data_payload.data_ptr()) + + if not self.saved_grad.is_null() and self.saved_grad.data_ptr() not in address_set: + _update_mem_use(self.grad_payload) + address_set.add(self.saved_grad.data_ptr()) + + if self.param.data is not None and self.param.data.data_ptr() not in address_set: + _update_mem_use(self.param.data) + address_set.add(self.param.data.data_ptr()) + + if self.param.grad is not None and self.param.grad.data_ptr() not in address_set: + _update_mem_use(self.param.grad) + + return cuda_mem_use, cpu_mem_use diff --git a/colossalai/zero/legacy/sharded_param/sharded_tensor.py b/colossalai/zero/legacy/sharded_param/sharded_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..af60312600f22554450218ce82d4890e93f926fe --- /dev/null +++ b/colossalai/zero/legacy/sharded_param/sharded_tensor.py @@ -0,0 +1,40 @@ +import torch + +from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState + + +class ShardedTensor(StatefulTensor): + + def __init__(self, tensor: torch.Tensor, state: TensorState = TensorState.HOLD) -> None: + r""" + A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance. + """ + assert tensor.requires_grad is False + super().__init__(tensor, state) + + # kept the shape, numel and dtype of the init tensor. + self._origin_shape = tensor.shape + self._origin_numel = tensor.numel() + self._origin_dtype = tensor.dtype + self._is_sharded = False + + @property + def dtype(self) -> torch.dtype: + assert self._payload.dtype == self._origin_dtype + return self._payload.dtype + + @property + def origin_numel(self) -> int: + return self._origin_numel + + @property + def origin_shape(self) -> int: + return self._origin_shape + + @property + def is_sharded(self): + return self._is_sharded + + @is_sharded.setter + def is_sharded(self, flag: bool): + self._is_sharded = flag diff --git a/colossalai/zero/low_level/__init__.py b/colossalai/zero/low_level/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ae3c1de3a5bc68bc178240729f51fde0ba8a6222 --- /dev/null +++ b/colossalai/zero/low_level/__init__.py @@ -0,0 +1,3 @@ +from .low_level_optim import LowLevelZeroOptimizer + +__all__ = ['LowLevelZeroOptimizer'] diff --git a/colossalai/zero/low_level/_utils.py b/colossalai/zero/low_level/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..afc98e7a7f54a7b4125eb989790c521fad3c86a7 --- /dev/null +++ b/colossalai/zero/low_level/_utils.py @@ -0,0 +1,273 @@ +import math +from typing import Optional + +import torch +import torch.distributed as dist +from torch import inf +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + +from colossalai.tensor import ColoParameter +from colossalai.utils import is_model_parallel_parameter + + +def flatten(input_): + return _flatten_dense_tensors(input_) + + +def unflatten(flat, tensors): + return _unflatten_dense_tensors(flat, tensors) + + +def count_numel(tensor_list): + res = 0 + for tensor in tensor_list: + res += tensor.numel() + return res + + +def calculate_padding(numel, unit_size): + remainder = numel % unit_size + return unit_size - remainder if remainder else remainder + + +def shuffle_by_round_robin(tensor_list, num_partitions): + partitions = dict() + + for tensor_idx, tensor in enumerate(tensor_list): + partition_to_go = tensor_idx % num_partitions + if partition_to_go not in partitions: + partitions[partition_to_go] = [] + partitions[partition_to_go].append(dict(tensor=tensor, index=tensor_idx)) + + partitions_count = len(partitions) + new_tensor_list = [] + tensor_index_mapping = dict() + + for partition_id in range(partitions_count): + partition_tensors = partitions[partition_id] + for item in partition_tensors: + tensor_index_mapping[item['index']] = len(new_tensor_list) + new_tensor_list.append(item['tensor']) + + return new_tensor_list, tensor_index_mapping + + +# create a flat tensor aligned at the alignment boundary +def flatten_dense_tensors_with_padding(tensor_list, unit_size): + num_elements = count_numel(tensor_list) + padding = calculate_padding(num_elements, unit_size=unit_size) + + if padding > 0: + pad_tensor = torch.zeros(padding, device=tensor_list[0].device, dtype=tensor_list[0].dtype) + padded_tensor_list = tensor_list + [pad_tensor] + else: + padded_tensor_list = tensor_list + + return flatten(padded_tensor_list) + + +def is_nccl_aligned(tensor): + return tensor.data_ptr() % 4 == 0 + + +def get_grad_accumulate_object(tensor): + """ + Return the AccumulateGrad of the input tensor + """ + + # grad_fn reference: + # https://discuss.pytorch.org/t/in-the-grad-fn-i-find-a-next-functions-but-i-dont-understand-the-meaning-of-the-attribute/24463 + # expand_as reference: https://pytorch.org/docs/stable/generated/torch.Tensor.expand.html#torch.Tensor.expand + # + # `next_functions` will return the backward graph where + # the first element is the AccumulateGrad of the leaf nodes. + # we want to get the AccumulateGrad of the input tensor instead of the leaf + # node in the whole computation graph. + # Therefore, we call expand_as to create a dummy graph + # where tensor_tmp and tensor indeed point to the same object. + # You can check this by print(tensor.data_ptr() == tensor_tmp.data_ptr()) + tensor_tmp = tensor.expand_as(tensor) + grad_acc_obj = tensor_tmp.grad_fn.next_functions[0][0] + return grad_acc_obj + + +def split_by_dtype(tensor_list): + """ + Splits a list of PyTorch tensors into sublists based on their data type. + + :param tensor_list: A list of PyTorch tensors. + :type tensor_list: list[torch.Tensor] + :return: A list of sublists, where each sublist contains tensors of a specific data type. + :rtype: list[list[torch.Tensor]] + """ + dtypes = ["torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor", "torch.cuda.BFloat16Tensor"] + buckets = [] + for _, dtype in enumerate(dtypes): + bucket = [t for t in tensor_list if t.type() == dtype] + if bucket: + buckets.append(bucket) + return buckets + + +def reduce_tensor_dp_group(tensor: torch.Tensor, + dtype: Optional[torch.dtype] = None, + dst_local_rank: Optional[int] = None, + dst_global_rank: Optional[int] = None, + group: Optional[dist.ProcessGroup] = None): + """ + Reduce the tensor in the data parallel process group + + :param tensor: A tensor object to reduce/all-reduce + :param dtype: The data type used in communication + :param dst_rank: The source rank for reduce. If dst_rank is None, + :param parallel_mode: Communication parallel mode + all-reduce will be used instead of reduce. Default is None. + + :type tensor: torch.Tensor + :type dtype: torch.dtype, optional + :type dst_rank: int, optional + :type pg: ProcessGroup, optional + """ + # use the original dtype + if dtype is None: + dtype = tensor.dtype + + # cast the data to specified dtype for reduce/all-reduce + if tensor.dtype != dtype: + tensor_to_reduce = tensor.to(dtype) + else: + tensor_to_reduce = tensor + + world_size = dist.get_world_size(group=group) + tensor_to_reduce.div_(world_size) + + # if rank is None, all reduce will be used + # else, reduce is used + use_all_reduce = dst_local_rank is None + + if use_all_reduce: + dist.all_reduce(tensor_to_reduce, group=group) + else: + dist.reduce(tensor=tensor_to_reduce, dst=dst_global_rank, group=group) + + # recover the original dtype + if tensor.dtype != dtype and tensor is not tensor_to_reduce: + local_rank = dist.get_rank(group=group) + if use_all_reduce or dst_local_rank == local_rank: + tensor.copy_(tensor_to_reduce) + + return tensor + + +def has_inf_or_nan(tensor): + try: + # if tensor is half, the .float() incurs an additional deep copy, but it's necessary if + # Pytorch's .sum() creates a one-element tensor of the same type as tensor + # (which is true for some recent version of pytorch). + tensor_sum = float(tensor.float().sum()) + # More efficient version that can be used if .sum() returns a Python scalar + # tensor_sum = float(tensor.sum()) + except RuntimeError as instance: + # We want to check if inst is actually an overflow exception. + # RuntimeError could come from a different error. + # If so, we still want the exception to propagate. + if "value cannot be converted" not in instance.args[0]: + raise + return True + else: + if tensor_sum == float('inf') or tensor_sum == -float('inf') or tensor_sum != tensor_sum: + return True + return False + + +def release_param_grad(tensor_list): + for tensor in tensor_list: + tensor.grad = None + + +def calculate_global_norm_from_list(norm_list): + """ Compute total from a list of norms + """ + total_norm = 0.0 + for norm in norm_list: + total_norm += norm**2.0 + return math.sqrt(total_norm) + + +def compute_norm(gradients, params, dp_group, mp_group, norm_type=2): + """Clips gradient norm of an iterable of parameters. + This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and + added functionality to handle model parallel parameters. Note that + the gradients are modified in place. + Arguments: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + Returns: + Total norm of the parameters (viewed as a single vector). + """ + + if mp_group is None: + mp_rank = 0 + else: + mp_rank = dist.get_rank(mp_group) + + norm_type = float(norm_type) + if norm_type == inf: + total_norm = max(g.data.abs().max() for g in gradients) + total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_group) + + # Take max across all GPUs. + if mp_group is not None: + dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.MAX) + total_norm = total_norm_cuda[0].item() + else: + total_norm = 0.0 + # if dist.get_rank() == 0: + # logger.info(f"Total Norm beginning {total_norm}") + + for g, p in zip(gradients, params): + # Pipeline parallelism may replicate parameters. Avoid multi-counting. + tp_param_flag = False + if is_model_parallel_parameter(p) or (isinstance(p, ColoParameter) and not p.is_replicate()): + tp_param_flag = True + if tp_param_flag or mp_rank == 0: + param_norm = g.data.double().norm(2) + total_norm += param_norm.item()**2 + + # Sum across all model parallel GPUs. + total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=dp_group) + + if mp_group is not None: + dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=mp_group) + + total_norm = total_norm_cuda[0].item()**(1. / norm_type) + + if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm: + total_norm = -1 + + return total_norm + + +def sync_param(flat_tensor, tensor_list): + """ + Synchronize the flattened tensor and unflattened tensor list. When + a list of tensor are flattened with `torch._utils._unflatten_dense_tensors`, + a new tensor is created. Thus, the flat tensor and original tensor list do not + share the same memory space. This function will update the tensor list so that + they point to the same value. + + :param flat_tensor: A flat tensor obtained by calling `torch._utils._unflatten_dense_tensors` on a tensor lsit + :param tensor_list: A list of tensors corresponding to the flattened tensor + :type flat_tensor: torch.Tensor + :type tensor_list: List[torch.Tensor] + """ + updated_params = unflatten(flat_tensor, tensor_list) + + # update the tensor data + for p, q in zip(tensor_list, updated_params): + p.data = q.data diff --git a/colossalai/zero/low_level/bookkeeping/__init__.py b/colossalai/zero/low_level/bookkeeping/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7bcacfabfded39972babff0536cb75b0c2c65506 --- /dev/null +++ b/colossalai/zero/low_level/bookkeeping/__init__.py @@ -0,0 +1,6 @@ +from .bucket_store import BucketStore +from .gradient_store import GradientStore +from .parameter_store import ParameterStore +from .tensor_bucket import TensorBucket + +__all__ = ['GradientStore', 'ParameterStore', 'BucketStore', 'TensorBucket'] diff --git a/colossalai/zero/low_level/bookkeeping/base_store.py b/colossalai/zero/low_level/bookkeeping/base_store.py new file mode 100644 index 0000000000000000000000000000000000000000..2ebd122464f4ef46976d7b7e7e45323aea322b2d --- /dev/null +++ b/colossalai/zero/low_level/bookkeeping/base_store.py @@ -0,0 +1,17 @@ +import torch.distributed as dist +from torch.distributed import ProcessGroup + + +class BaseStore: + + def __init__(self, torch_pg: ProcessGroup): + self._world_size = dist.get_world_size(group=torch_pg) + self._local_rank = dist.get_rank(group=torch_pg) + + @property + def world_size(self): + return self._world_size + + @property + def local_rank(self): + return self._local_rank diff --git a/colossalai/zero/low_level/bookkeeping/bucket_store.py b/colossalai/zero/low_level/bookkeeping/bucket_store.py new file mode 100644 index 0000000000000000000000000000000000000000..ec322a78bf81a21d9b5e08b6c5bd9c44ca06f3f0 --- /dev/null +++ b/colossalai/zero/low_level/bookkeeping/bucket_store.py @@ -0,0 +1,41 @@ +from torch.distributed import ProcessGroup + +from .base_store import BaseStore + + +class BucketStore(BaseStore): + + def __init__(self, torch_pg: ProcessGroup): + super().__init__(torch_pg) + self._params = dict() + self._num_elements_in_bucket = dict() + + self.reset() + + def num_elements_in_bucket(self, reduce_rank: int = None): + return self._num_elements_in_bucket[reduce_rank] + + def add_num_elements_in_bucket(self, num_elements, reduce_rank: int = None): + self._num_elements_in_bucket[reduce_rank] += num_elements + + def add_param(self, tensor, reduce_rank: int = None): + self._params[reduce_rank].append(tensor) + + def reset(self): + keys = [None] + list(range(self._world_size)) + self._params = {rank: [] for rank in keys} + self._num_elements_in_bucket = {rank: 0 for rank in keys} + + def reset_by_rank(self, reduce_rank=None): + self._params[reduce_rank] = [] + self._num_elements_in_bucket[reduce_rank] = 0 + + def get_grad(self, reduce_rank: int = None): + param_list = self.get_param(reduce_rank) + for param in param_list: + # the param must have grad for reduction + assert param.grad is not None, f'Parameter of size ({param.size()}) has None grad, cannot be reduced' + return [param.grad for param in param_list] + + def get_param(self, reduce_rank: int = None): + return self._params[reduce_rank] diff --git a/colossalai/zero/low_level/bookkeeping/gradient_store.py b/colossalai/zero/low_level/bookkeeping/gradient_store.py new file mode 100644 index 0000000000000000000000000000000000000000..942d7186e55f53f97211491025c759a7bed18cf9 --- /dev/null +++ b/colossalai/zero/low_level/bookkeeping/gradient_store.py @@ -0,0 +1,88 @@ +from typing import List + +from torch import Tensor + +from .base_store import BaseStore + + +class GradientStore(BaseStore): + + def __init__(self, *args): + super().__init__(*args) + # bookkeeping data structures + self._averaged_gradients = dict() + + # for backward reduction hooks + self._grad_acc_objs = [] + + def append_accumulate_grad_object(self, obj): + """ + Keep :class:`AccumulateGrad` objects. If these objects are not kept, reduction hooks may not + be attached successfully. + + :param obj: An object of :class:`AccumulateGrad` class + :type obj: :class:`AccumulateGrad` + """ + + self._grad_acc_objs.append(obj) + + def get_averaged_gradients_by_group(self, group_id: int) -> List[Tensor]: + """ + Return average gradients of a parameter group + + :param group_id: The index of parameter group + :type group_id: int + + :return: Return the list of averaged gradients of a parameter group. Each element is a gradient, not a parameter. + :rtype: List[torch.Tensor] + """ + if group_id not in self._averaged_gradients: + self._averaged_gradients[group_id] = [] + + return self._averaged_gradients[group_id] + + def append_average_gradient_by_group(self, group_id: int, tensor: Tensor) -> None: + """ + Append an average gradient to the list of averaged gradients of a parameter group + + :param group_id: The index of a parameter group + :param tensor: A :class:`torch.Tensor` object + :type group_id: int + :type tensor: torch.Tensor + + """ + + if group_id in self._averaged_gradients: + self._averaged_gradients[group_id].append(tensor) + else: + self._averaged_gradients[group_id] = [tensor] + + def add_average_gradient_by_group(self, group_id: int, tensor_idx: int, tensor: Tensor) -> None: + """ + Add an average gradient to the list of averaged gradients of a parameter group + + :param group_id: The index of a parameter group + :param tensor_idx: The index of a tensor in the list of averaged gradients + :param tensor: A :class:`torch.Tensor` object + :type group_id: int + :type tensor_idx: int + :type tensor: torch.Tensor + + """ + self._averaged_gradients[group_id][tensor_idx].add_(tensor) + + def reset_average_gradients_by_group(self, group_id: int) -> None: + """ + Reset the bookkeeping data structure for averaged gradients to an empty list + + :param group_id: The index of a parameter group + :type group_id: int + """ + + self._averaged_gradients[group_id] = [] + + def reset_all_average_gradients(self) -> None: + """ + Reset the bookkeeping data structure for averaged gradients to an empty list + """ + self._averaged_gradients = dict() diff --git a/colossalai/zero/low_level/bookkeeping/parameter_store.py b/colossalai/zero/low_level/bookkeeping/parameter_store.py new file mode 100644 index 0000000000000000000000000000000000000000..1f3ba7cbc3bc75c029072a11648628e7b866aa2d --- /dev/null +++ b/colossalai/zero/low_level/bookkeeping/parameter_store.py @@ -0,0 +1,97 @@ +from typing import List + +from torch import Tensor +from torch.distributed import ProcessGroup + +from .base_store import BaseStore + + +class ParameterStore(BaseStore): + + def __init__(self, torch_pg: ProcessGroup): + super().__init__(torch_pg) + # param partitioning data structures + self._param_to_rank = dict() + self._rank_group_id_to_param_list = dict() + self._rank_group_id_to_flat_param = dict() + + # param reduction data structures + self._is_param_reduced = dict() + self._reduced_param = [] + + def set_param_to_rank(self, tensor: Tensor, rank: int) -> None: + """ + Set the mapping between parameter to rank, each parameter should be owned by a rank. + + :param tensor: A :class:`torch.Tensor` object + :type tensor: torch.Tensor + :param rank: The rank of which the process is responsible for updating the parameter + :type rank: int + """ + + self._param_to_rank[tensor] = rank + + def get_param_rank(self, tensor: Tensor) -> int: + """ + Gives the rank which the parameter belongs to + + :param tensor: A :class:`torch.Tensor` object + :type tensor: torch.Tensor + """ + return self._param_to_rank[tensor] + + def belongs_to_current_rank(self, tensor) -> bool: + """ + Check whether a parameter is supposed to be updated by the process of the current rank + + :param tensor: A :class:`torch.Tensor` object + :type tensor: torch.Tensor + + :return: True if the parameter should be updated by the current rank. Otherwise false. + :rtype: bool + """ + + tensor_rank = self._param_to_rank[tensor] + return tensor_rank == self._local_rank + + def add_param_list_by_rank_group(self, rank, group_id, tensor_list) -> None: + if rank not in self._rank_group_id_to_param_list: + self._rank_group_id_to_param_list[rank] = dict() + + if group_id not in self._rank_group_id_to_param_list[rank]: + self._rank_group_id_to_param_list[rank][group_id] = [] + + self._rank_group_id_to_param_list[rank][group_id].extend(tensor_list) + + def get_params_by_rank_group(self, rank, group_id) -> List[Tensor]: + return self._rank_group_id_to_param_list[rank][group_id] + + def add_flat_param_by_rank_group(self, rank, group_id, tensor) -> None: + if rank not in self._rank_group_id_to_flat_param: + self._rank_group_id_to_flat_param[rank] = dict() + + self._rank_group_id_to_flat_param[rank][group_id] = tensor + + def get_flat_param_by_rank_group(self, rank, group_id) -> Tensor: + return self._rank_group_id_to_flat_param[rank][group_id] + + def is_param_reduced(self, tensor): + return self._is_param_reduced[tensor] + + def set_param_reduction_state(self, tensor, state): + self._is_param_reduced[tensor] = state + + def get_param_reduction_states(self): + return self._is_param_reduced + + def reset_previous_reduced_params(self): + self._reduced_param = [] + + def add_previous_reduced_param(self, tensor): + self._reduced_param.append(tensor) + + def clear_grads_of_previous_reduced_params(self): + if len(self._reduced_param) > 0: + for param in self._reduced_param: + param.grad = None + self.reset_previous_reduced_params() diff --git a/colossalai/zero/low_level/bookkeeping/tensor_bucket.py b/colossalai/zero/low_level/bookkeeping/tensor_bucket.py new file mode 100644 index 0000000000000000000000000000000000000000..b32816a046cd6a156196e84e957e070c4401d555 --- /dev/null +++ b/colossalai/zero/low_level/bookkeeping/tensor_bucket.py @@ -0,0 +1,53 @@ +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + + +class TensorBucket: + + def __init__(self, size): + self._max_size = size + self._current_size = 0 + self._bucket = [] + + @property + def max_size(self): + return self._max_size + + @property + def current_size(self): + return self._current_size + + def is_full_or_oversized(self): + return self._current_size >= self._max_size + + def is_empty(self): + return len(self._bucket) == 0 + + def add_to_bucket(self, tensor, allow_oversize=False): + tensor_size = tensor.numel() + + if not allow_oversize and self.will_exceed_max_size(tensor_size): + msg = f"The param bucket max size {self._max_size} is exceeded" \ + + f"by tensor (size {tensor_size})" + raise RuntimeError(msg) + + self._bucket.append(tensor) + self._current_size += tensor_size + + def will_exceed_max_size(self, tensor_size): + expected_size = self._current_size + tensor_size + return expected_size > self._max_size + + def get_bucket(self): + return self._bucket + + def empty(self): + self._bucket = [] + self._size = 0 + + def flatten(self): + return _flatten_dense_tensors(self._bucket) + + def unflatten_and_copy(self, flat_tensor): + unflattened_tensor_list = _unflatten_dense_tensors(flat_tensor, self._bucket) + for old, new in zip(self._bucket, unflattened_tensor_list): + old.copy_(new) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py new file mode 100644 index 0000000000000000000000000000000000000000..3e7661ecab769768c5ba0be00f9462d57b696f39 --- /dev/null +++ b/colossalai/zero/low_level/low_level_optim.py @@ -0,0 +1,595 @@ +# this code is inspired by the DeepSpeed library and implemented with our own design from scratch +from functools import partial +from typing import Optional + +import torch +import torch.distributed as dist +from torch.optim import Optimizer + +from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.nn.optimizer import ColossalaiOptimizer +from colossalai.tensor import ColoParameter, ProcessGroup +from colossalai.utils.cuda import get_current_device + +from ._utils import ( + calculate_global_norm_from_list, + compute_norm, + flatten, + has_inf_or_nan, + reduce_tensor_dp_group, + release_param_grad, + split_by_dtype, + sync_param, +) +from .bookkeeping import BucketStore, GradientStore, ParameterStore, TensorBucket + + +class LowLevelZeroOptimizer(ColossalaiOptimizer): + """Optimizer used for ZeRO-1 and ZeRO-2. + """ + + def __init__( + self, + optimizer: Optimizer, + initial_scale: int = 2**16, # grad scaler config + min_scale: int = 1, + growth_factor: float = 2., + backoff_factor: float = .5, + growth_interval: int = 2000, + hysteresis: int = 2, + max_scale: int = 2**24, + clip_grad_norm: float = 0.0, # grad clipping + verbose: bool = False, + reduce_bucket_size: int = 1024 * 1024, # communication + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = False, + partition_grad: bool = False, # stage 2 flag + cpu_offload: bool = False, # cpu offload + forced_dtype: Optional[torch.dtype] = None): + + # TODO: add support for + # 1. fp16 master weights + # 2. contiguous gradients + # 3. cpu offload + # 4. support when some parameters requires_grad = False + # 5. support layer drop + super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) + self._dtype = self.optim.param_groups[0]['params'][0].dtype + self._logger = get_dist_logger() + self._verbose = verbose + + # stage 2 + self._partition_grads = partition_grad + + self._cpu_offload = cpu_offload + + colo_pg = self._search_colo_process_group() + if isinstance(colo_pg, ProcessGroup): + self._local_rank = colo_pg.dp_local_rank() + self._world_size = colo_pg.dp_world_size() + self._dp_global_ranks = colo_pg.get_ranks_in_dp() + self._dp_torch_group = colo_pg.dp_process_group() + self._mp_torch_group = None + if colo_pg.tp_world_size() > 1: + self._mp_torch_group = colo_pg.tp_process_group() + elif colo_pg is None: + dp_parallel_mode = ParallelMode.DATA + mp_parallel_mode = ParallelMode.MODEL + + self._dp_parallel_mode = dp_parallel_mode + self._mp_parallel_mode = mp_parallel_mode + self._local_rank = gpc.get_local_rank(dp_parallel_mode) + self._world_size = gpc.get_world_size(dp_parallel_mode) + self._dp_global_ranks = gpc.get_ranks_in_group(dp_parallel_mode) + self._dp_torch_group = gpc.get_group(dp_parallel_mode) + self._mp_torch_group = None + if gpc.is_initialized(mp_parallel_mode) and gpc.get_world_size(mp_parallel_mode) > 1: + self._mp_torch_group = gpc.get_group(mp_parallel_mode) + else: + raise NotImplementedError + + # working and master params for mixed precision training + self._working_param_groups = dict() + self._master_flat_param_groups_of_current_rank = dict() + + # communication params + self._overlap_communication = overlap_communication + self._reduce_bucket_size = reduce_bucket_size + self._communication_dtype = communication_dtype + + # gradient scaler + self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + verbose=verbose) + self._found_overflow = torch.FloatTensor([0]).to(get_current_device()) + + # gradient clipping + self._clip_grad_norm = clip_grad_norm + + if forced_dtype: + for group in self.optim.param_groups: + group_params = group['params'] + for param in group_params: + param.data = param.data.to(forced_dtype) + self._dtype = forced_dtype + + # check argument conflict + self._sanity_checks() + + # ParameterStore will manage the tensor buffers used for zero + # it will not manage the tensors used by mixed precision training + self._param_store = ParameterStore(self._dp_torch_group) + self._grad_store = GradientStore(self._dp_torch_group) + self._bucket_store = BucketStore(self._dp_torch_group) + + # iterate over the param group in the optimizer + # partition these param groups for data parallel training + # and add buffers to parameter store for future access + for group_id, param_group in enumerate(self.optim.param_groups): + group_params = list() + for param in param_group['params']: + if param.requires_grad: + group_params.append(param) + + # add the working params to working_param_groups for bookkeeping + self._working_param_groups[group_id] = group_params + + # assign parameters to ranks + # the params in the list are sorted + params_per_rank = self._partition_param_list(group_params) + + # store the mapping between param to rank + # each param should belong to only one rank + for rank, params in enumerate(params_per_rank): + self._param_store.add_param_list_by_rank_group(rank, group_id, params) + for param in params: + self._param_store.set_param_to_rank(param, rank) + + # move to cpu to make room to create the flat tensor + # move_tensor(params, device='cpu') + for param in group_params: + param.data = param.data.cpu() + + # flatten the reordered tensors + for rank in range(self._world_size): + tensor_list = self._param_store.get_params_by_rank_group(rank, group_id) + with torch.no_grad(): + flat_tensor = flatten(tensor_list) + flat_tensor = flat_tensor.data.cuda() + self._param_store.add_flat_param_by_rank_group(rank, group_id, flat_tensor) + + # sync parameters + for rank in range(self._world_size): + flat_tensor = self._param_store.get_flat_param_by_rank_group(rank, group_id) + tensor_list = self._param_store.get_params_by_rank_group(rank, group_id) + sync_param(flat_tensor=flat_tensor, tensor_list=tensor_list) + + # create a copy of fp32 master weights of the parameters for which this rank is responsible + working_flat_current_rank = self._param_store.get_flat_param_by_rank_group(self._local_rank, group_id) + master_flat_current_rank = working_flat_current_rank.float() + device = 'cpu' if self._cpu_offload else get_current_device() + master_flat_current_rank = master_flat_current_rank.to(device) + master_flat_current_rank.requires_grad = True + self._master_flat_param_groups_of_current_rank[group_id] = master_flat_current_rank + + # need to replace the params in the `params` field in the optimizer + # so that when the optimizer calls step(), it only updates the tensors + # managed by this data parallel rank + param_group['params'] = [master_flat_current_rank] + + # set reduction state + for param in self._working_param_groups[group_id]: + self._param_store.set_param_reduction_state(param, False) + + # intialize communication stream for + # communication-compuation overlapping + if self._overlap_communication: + self._comm_stream = torch.cuda.Stream() + + # reduction hook is only used if overlapping communication + # or stage 2 is used + # if it is stage 1 without overlapping, no hook will be attached + if self._overlap_communication or self._partition_grads: + self._attach_reduction_hook() + + @property + def dtype(self): + return self._dtype + + @property + def loss_scale(self): + return self.grad_scaler.scale + + @property + def num_param_groups(self): + return len(self._working_param_groups) + + def _sanity_checks(self): + assert torch.cuda.is_available(), 'CUDA is required' + for param_group in self.optim.param_groups: + group_params = param_group['params'] + for param in group_params: + assert param.dtype == self._dtype, \ + f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`" + + def _search_colo_process_group(self): + colo_flag = False + colo_pg = None + for param_group in self.optim.param_groups: + group_params = param_group['params'] + for param in group_params: + if isinstance(param, ColoParameter): + colo_flag = True + if colo_pg is None: + colo_pg = param.get_process_group() + else: + assert colo_pg == param.get_process_group(), "All parameters should be in a same process group" + elif colo_flag: + raise RuntimeError("All parameters should be ColoParameter if you use ColoParameter.") + return colo_pg + + def _partition_param_list(self, param_list): + params_per_rank = [[] for _ in range(self._world_size)] + numel_per_rank = [0 for _ in range(self._world_size)] + + # partititon the parameters in a greedy fashion + sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True) + for param in sorted_params: + # allocate this parameter to the rank with + # the smallest numel for load balancing purpose + rank_to_go = numel_per_rank.index(min(numel_per_rank)) + params_per_rank[rank_to_go].append(param) + numel_per_rank[rank_to_go] += param.numel() + + if self._verbose: + self._logger.info(f'Number of elements on ranks: {numel_per_rank}', ranks=[0]) + return params_per_rank + + ########################### + # Backward Reduction Hook # + ########################### + + def _grad_handler(self, param, grad, reduce_rank): + self._add_to_reduction_bucket(param, reduce_rank) + return grad + + def _attach_reduction_hook(self): + # we iterate over the working params + # on each param, we register a hook to its AccumulateGrad object + for group_id in range(self.num_param_groups): + param_group = self._working_param_groups[group_id] + for param in param_group: + if param.requires_grad: + # determines the reduction destionation rank + # this is only valid for stage 2 + # dst_rank = None means using all-reduce + # else using reduce + if self._partition_grads: + reduce_rank = self._param_store.get_param_rank(param) + else: + reduce_rank = None + + param.register_hook(partial(self._grad_handler, param, reduce_rank=reduce_rank)) + + def _reduce_tensor_bucket(self, bucket: TensorBucket, reduce_rank): + if self._overlap_communication: + torch.cuda.synchronize() + self._param_store.clear_grads_of_previous_reduced_params() + stream = self._comm_stream + else: + stream = torch.cuda.current_stream() + + with torch.cuda.stream(stream): + flat = bucket.flatten() + reduce_global_rank = None + if reduce_rank is not None: + reduce_global_rank = self._dp_global_ranks[reduce_rank] + reduced_flat = reduce_tensor_dp_group(tensor=flat, + dtype=self._communication_dtype, + dst_local_rank=reduce_rank, + dst_global_rank=reduce_global_rank, + group=self._dp_torch_group) + + # update the reduced tensor + if reduce_rank is None or reduce_rank == self._local_rank: + bucket.unflatten_and_copy(reduced_flat) + + def _reduce_tensor_list_with_one_dtype(self, tensor_list, bucket_size, reduce_rank): + param_bucket = TensorBucket(size=bucket_size) + + for tensor in tensor_list: + param_bucket.add_to_bucket(tensor, allow_oversize=True) + + if param_bucket.is_full_or_oversized(): + self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank) + param_bucket.empty() + + if not param_bucket.is_empty(): + self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank) + + def _reduce_grads(self, reduce_rank, grads, bucket_size): + grad_buckets_by_dtype = split_by_dtype(grads) + + for tensor_list in grad_buckets_by_dtype: + self._reduce_tensor_list_with_one_dtype(tensor_list=tensor_list, + bucket_size=bucket_size, + reduce_rank=reduce_rank) + + ####################### + # Reduction Functions # + ####################### + + def _run_reduction(self, reduce_rank=None): + # reduce grads + self._reduce_grads(reduce_rank=reduce_rank, + grads=self._bucket_store.get_grad(reduce_rank=reduce_rank), + bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank)) + + # use communication stream if overlapping + # communication with computation + if self._overlap_communication: + stream = self._comm_stream + else: + stream = torch.cuda.current_stream() + + with torch.cuda.stream(stream): + params_in_bucket = self._bucket_store.get_param(reduce_rank=reduce_rank) + + for param in params_in_bucket: + # the is_param_reduced flag should be False showing that + # this param is not reduced before calling self._reduce_grads_by_rank + is_param_reduced = self._param_store.is_param_reduced(param) + + if is_param_reduced: + msg = f'Parameter of size ({param.size()}) has been reduced, ' + \ + 'duplicate reduction will lead to arithmetic incorrectness' + raise RuntimeError(msg) + + # update the flag + self._param_store.set_param_reduction_state(param, True) + + # if partition grads = True + # we do not keep the gradient after reduction + if self._partition_grads and not self._param_store.belongs_to_current_rank(param): + if self._overlap_communication: + # we need to keep this gradient for now as reduction may + # be completed yet since it is using a different cuda stream + self._param_store.add_previous_reduced_param(param) + else: + param.grad = None + + self._bucket_store.reset_by_rank(reduce_rank) + + def _add_to_reduction_bucket(self, param, reduce_rank=None): + param_size = param.numel() + + # check if the bucket is full + # if full, will reduce the grads already in the bucket + # after reduction, the bucket will be empty + if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size: + self._run_reduction(reduce_rank) + + # the param must not be reduced to ensure correctness + is_param_reduced = self._param_store.is_param_reduced(param) + if is_param_reduced: + msg = f'Parameter of size ({param.size()}) has already been reduced, ' \ + + 'duplicate reduction will lead to arithmetic incorrectness' + raise RuntimeError(msg) + + self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank) + self._bucket_store.add_param(param, reduce_rank) + + ################################ + # torch.optim.Optimizer methods + ################################ + + def backward(self, loss, retain_graph=False, sync_grad=True): + loss = self.loss_scale * loss + loss.backward(retain_graph=retain_graph) + + # finish gradient reduction + if not self._partition_grads: + self._reduce_grad_stage1() + else: + # TODO: support async comm in reduce + self._reduce_grad_stage2() + + # clear reduced grads + if self._overlap_communication: + torch.cuda.synchronize() + self._param_store.clear_grads_of_previous_reduced_params() + + # gradient synchronization + if sync_grad: + self._sync_grad() + + def zero_grad(self, set_to_none=True): + """ + Set parameter gradients to zero. If set_to_none = True, gradient + will be set to None to save memory. + + :param set_to_none: Whether set the gradient to None. Default value is True. + :type set_to_none: bool + """ + for _, param_group in self._working_param_groups.items(): + for param in param_group: + if set_to_none: + param.grad = None + else: + if param.grad is not None: + param.grad.detach() + param.grad.zero_() + + #################### + # Update Parameter # + #################### + + def step(self, closure=None): + assert closure is None, 'closure is not supported by step()' + + # check for overflow + found_inf = self._check_overflow() + self.grad_scaler.update(found_inf) + + # update loss scale if overflow occurs + if found_inf: + self._grad_store.reset_all_average_gradients() + if self._verbose: + self._logger.info(f'Found overflow. Skip step') + self.zero_grad() + return + + # copy the grad of working param to master param + single_grad_partition_groups = [] + norm_groups = [] + + for group_id in range(self.num_param_groups): + # compute norm + norm_group = compute_norm(gradients=self._grad_store.get_averaged_gradients_by_group(group_id), + params=self._param_store.get_params_by_rank_group(group_id=group_id, + rank=self._local_rank), + dp_group=self._dp_torch_group, + mp_group=self._mp_torch_group) + norm_groups.append(norm_group) + + # create flat gradient for the flat fp32 master params + working_avg_grads = self._grad_store.get_averaged_gradients_by_group(group_id) + flat_working_avg_grads = flatten(working_avg_grads) + + dtype = self._master_flat_param_groups_of_current_rank[group_id].dtype + flat_master_avg_grads = flat_working_avg_grads.to(dtype) + + param_shape = self._master_flat_param_groups_of_current_rank[group_id].shape + assert param_shape == flat_master_avg_grads.shape, \ + f'fp32 param and grad have different shape {param_shape} vs {flat_master_avg_grads.shape}' + + single_grad_partition_groups.append(flat_master_avg_grads) + device = self._master_flat_param_groups_of_current_rank[group_id].device + self._master_flat_param_groups_of_current_rank[group_id].grad = flat_master_avg_grads.to(device) + self._grad_store.reset_average_gradients_by_group(group_id) + + # unscale and clip grads + global_norm = calculate_global_norm_from_list(norm_list=norm_groups) + self._unscale_and_clip_grads(single_grad_partition_groups, global_norm) + + # update the parameters + self.optim.step() + # release the master grad + release_param_grad(self._master_flat_param_groups_of_current_rank.values()) + + # update working partition updated by the current rank + for group_id in range(len(self._working_param_groups)): + working_param = self._param_store.get_flat_param_by_rank_group(rank=self._local_rank, group_id=group_id) + master_param = self._master_flat_param_groups_of_current_rank[group_id] + working_param.data.copy_(master_param) + + # broadcast the updated model weights + handles = [] + for group_id in range(self.num_param_groups): + for index in range(self._world_size): + rank = self._dp_global_ranks[index] + working_param = self._param_store.get_flat_param_by_rank_group(rank=index, group_id=group_id) + handle = dist.broadcast(working_param, src=rank, group=self._dp_torch_group, async_op=True) + handles.append(handle) + + for handle in handles: + handle.wait() + + ############################# + # Mixed Precision Utilities # + ############################# + + def _check_overflow(self): + # clear previous overflow record + self._found_overflow.fill_(0.0) + + # check for overflow + for group_id in range(len(self._working_param_groups)): + for avg_grad in self._grad_store.get_averaged_gradients_by_group(group_id): + if avg_grad is not None and has_inf_or_nan(avg_grad): + self._found_overflow.fill_(1.0) + break + + # all-reduce across dp group + dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_torch_group) + + # all-reduce over model parallel group + if self._mp_torch_group: + dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_torch_group) + + if self._found_overflow.item() > 0: + return True + else: + return False + + def _unscale_and_clip_grads(self, grad_groups_flat, total_norm): + # compute combined scale factor for this group + combined_scale = self.loss_scale + + if self._clip_grad_norm > 0.: + # norm is in fact norm*scale + clip = ((total_norm / self.loss_scale) + 1e-6) / self._clip_grad_norm + if clip > 1: + combined_scale = clip * self.loss_scale + + for grad in grad_groups_flat: + grad.data.mul_(1. / combined_scale) + + ############################ + # Gradient Synchronization # + ############################ + + def _sync_grad(self): + # update param already reduced flag + reduction_states = self._param_store.get_param_reduction_states() + for tensor, _ in reduction_states.items(): + reduction_states[tensor] = False + + # accumulate gradient + for group_id in range(self.num_param_groups): + param_group = self._param_store.get_params_by_rank_group(self._local_rank, group_id) + + avg_gradients_group = self._grad_store.get_averaged_gradients_by_group(group_id) + + param_idx = 0 + for param in param_group: + if param.grad is not None: + if len(avg_gradients_group) == param_idx: + self._grad_store.append_average_gradient_by_group(group_id, param.grad) + else: + self._grad_store.add_average_gradient_by_group(group_id, param_idx, param.grad) + param_idx += 1 + + # the gradients needed are stored in the avg_gradients buffer + # thus, can clear this + self.zero_grad() + + def _reduce_grad_stage1(self): + # if not overlapping communication (no reduction hook is attached) + # we need to manually reduce these gradients + if not self._overlap_communication: + for group_id in range(len(self._working_param_groups)): + param_group = self._working_param_groups[group_id] + for param in param_group: + if param.grad is not None: + self._add_to_reduction_bucket(param) + + # we need to reduce the gradients + # left in the communication bucket + self._run_reduction() + + def _reduce_grad_stage2(self): + # when partition_grads is True, reduction hooks + # are attached in the __init__ function, so we + # only need to reduce the gradients + # left in the communication bucket + for reduce_rank in range(self._world_size): + self._run_reduction(reduce_rank) diff --git a/colossalai/zero/wrapper.py b/colossalai/zero/wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..3e48f49fa305fb59e45e66775e930ea36eb1c7db --- /dev/null +++ b/colossalai/zero/wrapper.py @@ -0,0 +1,114 @@ +from copy import copy +from typing import Dict, Optional + +import torch +import torch.nn as nn + +from .gemini import GeminiDDP + + +def zero_model_wrapper(model: nn.Module, + zero_stage: int = 1, + gemini_config: Optional[Dict] = None, + verbose: bool = False): + """This wrapper function is used to wrap your training model for ZeRO DDP. + + Example: + + >>> with ColoInitContext(): + >>> my_model = Bert() + >>> my_optim = SGD(my_model.parameters(), lr = 1e-3) + >>> zero_model = zero_model_wrapper(my_model, zero_stage=1) + >>> zero_optim = zero_optim_wrapper(zero_model, my_optim) + + Args: + model (nn.Module): The model used in ZeRO DDP. + zero_stage (int, optional): The stage of ZeRO DDP. You can find more information in ZeRO's paper. + https://arxiv.org/abs/1910.02054 + gemini_config (dict, optional): The configuration dictionary of `GeminiDDP`. `GeminiDDP` is enabled + when the stage is set to 3. You can set the arguments of `GeminiDDP` in the gemini_config. + Here is an example where we set the device of the model, the placement policy of Gemini, and the + size of hidden dimension to help Gemini find out a unified chunk size. + + Example: + + >>> config_dict = dict(device=torch.cuda.current_device(), hidden_dim=1024, placement_policy='auto') + >>> model = zero_model_wrapper(model, zero_stage=3, gemini_config=config_dict) + """ + assert zero_stage in [1, 2, 3], "The stage of ZeRO should be 1, 2 or 3" + + if gemini_config is None: + gemini_config = dict() + + if zero_stage in [1, 2]: + wrapped_model = model + else: + wrapped_model = GeminiDDP(model, **gemini_config, verbose=verbose) + + setattr(wrapped_model, "_colo_zero_stage", zero_stage) + + return wrapped_model + + +def zero_optim_wrapper(model: nn.Module, + optimizer: torch.optim.Optimizer, + initial_scale: float = 2**16, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + min_scale: float = 1, + max_scale: float = 2**32, + max_norm: float = 0.0, + norm_type: float = 2.0, + optim_config: Optional[Dict] = None, + verbose: bool = False): + """This wrapper function is used to wrap your training optimizer for ZeRO DDP. + + Args: + model (nn.Module): Your model wrapped by `zero_model_wrapper` + optimizer (torch.optim.Optimizer): Your initialized optimizer + initial_scale (float, optional): initial_scale used by DynamicGradScaler. + min_scale (float, optional): min_scale used by DynamicGradScaler. + growth_factor (float, optional): growth_factor used by DynamicGradScaler. + backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. + growth_interval (float, optional): growth_interval used by DynamicGradScaler. + hysteresis (float, optional): hysteresis used by DynamicGradScaler. + max_scale (int, optional): max_scale used by DynamicGradScaler. + max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do + clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm. + norm_type (float, optional): norm_type used for `clip_grad_norm`. + optim_config (dict, optional): The configuration used for the ZeRO optimizer. + Example: + + >>> zero2_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True) + >>> optim = zero_optim_wrapper(model, optim, optim_config=zero2_config) + verbose (bool, optional): Whether to print the verbose info. + """ + assert hasattr(model, "_colo_zero_stage"), "You should use `zero_ddp_wrapper` first" + zero_stage = getattr(model, "_colo_zero_stage") + + assert norm_type == 2.0, "Current ZeRO optimizers only support 'norm_type=2'" + + if optim_config is None: + config_dict = dict() + else: + config_dict = copy(optim_config) + + config_dict['initial_scale'] = initial_scale + config_dict['growth_factor'] = growth_factor + config_dict['backoff_factor'] = backoff_factor + config_dict['growth_interval'] = growth_interval + config_dict['hysteresis'] = hysteresis + config_dict['min_scale'] = min_scale + config_dict['max_scale'] = max_scale + + if zero_stage in [1, 2]: + from colossalai.zero.low_level import LowLevelZeroOptimizer + config_dict['partition_grad'] = zero_stage == 2 + config_dict['clip_grad_norm'] = max_norm + return LowLevelZeroOptimizer(optimizer, **config_dict, verbose=verbose) + else: + from colossalai.zero.gemini.gemini_optimizer import ZeroOptimizer + config_dict['clipping_norm'] = max_norm + return ZeroOptimizer(optimizer, model, **config_dict, verbose=verbose) diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..49ff9b344268935ed6cddc42f62ec1081aa2512a --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,30 @@ +FROM hpcaitech/cuda-conda:11.3 + +# metainformation +LABEL org.opencontainers.image.source = "https://github.com/hpcaitech/ColossalAI" +LABEL org.opencontainers.image.licenses = "Apache License 2.0" +LABEL org.opencontainers.image.base.name = "docker.io/library/hpcaitech/cuda-conda:11.3" + +# install torch +RUN conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch + +# install apex +RUN git clone https://github.com/NVIDIA/apex && \ + cd apex && \ + pip install packaging && \ + pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_layer_norm" ./ + +# install colossalai +RUN git clone https://github.com/hpcaitech/ColossalAI.git \ + && cd ./ColossalAI \ + && CUDA_EXT=1 pip install -v --no-cache-dir . + +# install titans +RUN pip install --no-cache-dir titans + +# install tensornvme +RUN conda install cmake && \ + git clone https://github.com/hpcaitech/TensorNVMe.git && \ + cd TensorNVMe && \ + pip install -r requirements.txt && \ + pip install -v --no-cache-dir . diff --git a/docs/README-zh-Hans.md b/docs/README-zh-Hans.md new file mode 100644 index 0000000000000000000000000000000000000000..9d5bcfe3f9747dcfc85123451609915ea53f7cd9 --- /dev/null +++ b/docs/README-zh-Hans.md @@ -0,0 +1,429 @@ +# Colossal-AI +
+ + [![logo](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/colossal-ai_logo_vertical.png)](https://www.colossalai.org/) + + Colossal-AI: 让AI大模型更低成本、方便易用、高效扩展 + +

论文 | + 文档 | + 例程 | + 论坛 | + 博客

+ + [![GitHub Repo stars](https://img.shields.io/github/stars/hpcaitech/ColossalAI?style=social)](https://github.com/hpcaitech/ColossalAI/stargazers) + [![Build](https://github.com/hpcaitech/ColossalAI/actions/workflows/build_on_schedule.yml/badge.svg)](https://github.com/hpcaitech/ColossalAI/actions/workflows/build_on_schedule.yml) + [![Documentation](https://readthedocs.org/projects/colossalai/badge/?version=latest)](https://colossalai.readthedocs.io/en/latest/?badge=latest) + [![CodeFactor](https://www.codefactor.io/repository/github/hpcaitech/colossalai/badge)](https://www.codefactor.io/repository/github/hpcaitech/colossalai) + [![HuggingFace badge](https://img.shields.io/badge/%F0%9F%A4%97HuggingFace-Join-yellow)](https://huggingface.co/hpcai-tech) + [![slack badge](https://img.shields.io/badge/Slack-join-blueviolet?logo=slack&)](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w) + [![WeChat badge](https://img.shields.io/badge/微信-加入-green?logo=wechat&)](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png) + + | [English](README.md) | [中文](README-zh-Hans.md) | + +
+ +## 新闻 +* [2023/03] [ColossalChat: An Open-Source Solution for Cloning ChatGPT With a Complete RLHF Pipeline](https://medium.com/@yangyou_berkeley/colossalchat-an-open-source-solution-for-cloning-chatgpt-with-a-complete-rlhf-pipeline-5edf08fb538b) +* [2023/03] [Intel and Colossal-AI Partner to Deliver Cost-Efficient Open-Source Solution for Protein Folding Structure Prediction](https://www.hpc-ai.tech/blog/intel-habana) +* [2023/03] [AWS and Google Fund Colossal-AI with Startup Cloud Programs](https://www.hpc-ai.tech/blog/aws-and-google-fund-colossal-ai-with-startup-cloud-programs) +* [2023/02] [Open Source Solution Replicates ChatGPT Training Process! Ready to go with only 1.6GB GPU Memory](https://www.hpc-ai.tech/blog/colossal-ai-chatgpt) +* [2023/01] [Hardware Savings Up to 46 Times for AIGC and Automatic Parallelism](https://medium.com/pytorch/latest-colossal-ai-boasts-novel-automatic-parallelism-and-offers-savings-up-to-46x-for-stable-1453b48f3f02) +* [2022/11] [Diffusion Pretraining and Hardware Fine-Tuning Can Be Almost 7X Cheaper](https://www.hpc-ai.tech/blog/diffusion-pretraining-and-hardware-fine-tuning-can-be-almost-7x-cheaper) +* [2022/10] [Use a Laptop to Analyze 90% of Proteins, With a Single-GPU Inference Sequence Exceeding 10,000](https://www.hpc-ai.tech/blog/use-a-laptop-to-analyze-90-of-proteins-with-a-single-gpu-inference-sequence-exceeding) +* [2022/09] [HPC-AI Tech Completes $6 Million Seed and Angel Round Fundraising](https://www.hpc-ai.tech/blog/hpc-ai-tech-completes-6-million-seed-and-angel-round-fundraising-led-by-bluerun-ventures-in-the) + + +## 目录 + + +## 为何选择 Colossal-AI +
+ + + + + James Demmel 教授 (加州大学伯克利分校): Colossal-AI 让分布式训练高效、易用、可扩展。 +
+ +

(返回顶端)

+ +## 特点 + +Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的分布式 AI 模型像构建普通的单 GPU 模型一样简单。我们提供的友好工具可以让您在几行代码内快速开始分布式训练和推理。 + +- 并行化策略 + - 数据并行 + - 流水线并行 + - 1维, [2维](https://arxiv.org/abs/2104.05343), [2.5维](https://arxiv.org/abs/2105.14500), [3维](https://arxiv.org/abs/2105.14450) 张量并行 + - [序列并行](https://arxiv.org/abs/2105.13120) + - [零冗余优化器 (ZeRO)](https://arxiv.org/abs/1910.02054) + - [自动并行](https://arxiv.org/abs/2302.02599) +- 异构内存管理 + - [PatrickStar](https://arxiv.org/abs/2108.05818) +- 使用友好 + - 基于参数文件的并行化 +- 推理 + - [Energon-AI](https://github.com/hpcaitech/EnergonAI) + +

(返回顶端)

+ +## Colossal-AI 成功案例 +### ColossalChat + + + +[ColossalChat](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat): 完整RLHF流程0门槛克隆 [ChatGPT](https://openai.com/blog/chatgpt/) [[代码]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat) [[博客]](https://medium.com/@yangyou_berkeley/colossalchat-an-open-source-solution-for-cloning-chatgpt-with-a-complete-rlhf-pipeline-5edf08fb538b) [[在线样例]](https://chat.colossalai.org) + +

+ +

+ +- 最高可提升单机训练速度7.73倍,单卡推理速度1.42倍 + +

+ +

+ +- 单卡模型容量最多提升10.3倍 +- 最小demo训练流程最低仅需1.62GB显存 (任意消费级GPU) + +

+ +

+ +- 提升单卡的微调模型容量3.7倍 +- 同时保持高速运行 + +

(back to top)

+ +### AIGC +加速AIGC(AI内容生成)模型,如[Stable Diffusion v1](https://github.com/CompVis/stable-diffusion) 和 [Stable Diffusion v2](https://github.com/Stability-AI/stablediffusion) + +

+ +

+ +- [训练](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion): 减少5.6倍显存消耗,硬件成本最高降低46倍(从A100到RTX3060) + +

+ +

+ +- [DreamBooth微调](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/dreambooth): 仅需3-5张目标主题图像个性化微调 + +

+ +

+ +- [推理](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion): GPU推理显存消耗降低2.5倍 + + +

(返回顶端)

+ +### 生物医药 + +加速 [AlphaFold](https://alphafold.ebi.ac.uk/) 蛋白质结构预测 + +

+ +

+ +- [FastFold](https://github.com/hpcaitech/FastFold): 加速AlphaFold训练与推理、数据前处理、推理序列长度超过10000残基 + +

+ +

+ +- [FastFold with Intel](https://github.com/hpcaitech/FastFold): 3倍推理加速和39%成本节省 + +

+ +

+ +- [xTrimoMultimer](https://github.com/biomap-research/xTrimoMultimer): 11倍加速蛋白质单体与复合物结构预测 + +

(返回顶端)

+ +## 并行训练样例展示 + +### GPT-3 +

+ +

+ +- 释放 50% GPU 资源占用, 或 10.7% 加速 + +### GPT-2 + + +- 降低11倍 GPU 显存占用,或超线性扩展(张量并行) + + + +- 用相同的硬件训练24倍大的模型 +- 超3倍的吞吐量 + +### BERT + + +- 2倍训练速度,或1.5倍序列长度 + +### PaLM +- [PaLM-colossalai](https://github.com/hpcaitech/PaLM-colossalai): 可扩展的谷歌 Pathways Language Model ([PaLM](https://ai.googleblog.com/2022/04/pathways-language-model-palm-scaling-to.html)) 实现。 + +### OPT + + +- [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), 由Meta发布的1750亿语言模型,由于完全公开了预训练参数权重,因此促进了下游任务和应用部署的发展。 +- 加速45%,仅用几行代码以低成本微调OPT。[[样例]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/opt) [[在线推理]](https://colossalai.org/docs/advanced_tutorials/opt_service) + +请访问我们的 [文档](https://www.colossalai.org/) 和 [例程](https://github.com/hpcaitech/ColossalAI/tree/main/examples) 以了解详情。 + +### ViT +

+ +

+ +- 14倍批大小和5倍训练速度(张量并行=64) + +### 推荐系统模型 +- [Cached Embedding](https://github.com/hpcaitech/CachedEmbedding), 使用软件Cache实现Embeddings,用更少GPU显存训练更大的模型。 + + +

(返回顶端)

+ +## 单GPU训练样例展示 + +### GPT-2 +

+ +

+ +- 用相同的硬件训练20倍大的模型 + +

+ +

+ +- 用相同的硬件训练120倍大的模型 (RTX 3080) + +### PaLM +

+ +

+ +- 用相同的硬件训练34倍大的模型 + +

(返回顶端)

+ + +## 推理 (Energon-AI) 样例展示 + +

+ +

+ +- [Energon-AI](https://github.com/hpcaitech/EnergonAI) :用相同的硬件推理加速50% + +

+ +

+ +- [OPT推理服务](https://colossalai.org/docs/advanced_tutorials/opt_service): 体验1750亿参数OPT在线推理服务 + +

+ +

+ +- [BLOOM](https://github.com/hpcaitech/EnergonAI/tree/main/examples/bloom): 降低1760亿参数BLOOM模型部署推理成本超10倍 + +

(返回顶端)

+ +## 安装 + +环境要求: + +- PyTorch >= 1.11 (PyTorch 2.x 正在适配中) +- Python >= 3.7 +- CUDA >= 11.0 +- [NVIDIA GPU Compute Capability](https://developer.nvidia.com/cuda-gpus) >= 7.0 (V100/RTX20 and higher) +- Linux OS + +如果你遇到安装问题,可以向本项目 [反馈](https://github.com/hpcaitech/ColossalAI/issues/new/choose)。 + + +### 从PyPI安装 + +您可以用下面的命令直接从PyPI上下载并安装Colossal-AI。我们默认不会安装PyTorch扩展包。 + +```bash +pip install colossalai +``` + +**注:目前只支持Linux。** + +但是,如果你想在安装时就直接构建PyTorch扩展,您可以设置环境变量`CUDA_EXT=1`. + +```bash +CUDA_EXT=1 pip install colossalai +``` + +**否则,PyTorch扩展只会在你实际需要使用他们时在运行时里被构建。** + +与此同时,我们也每周定时发布Nightly版本,这能让你提前体验到新的feature和bug fix。你可以通过以下命令安装Nightly版本。 + +```bash +pip install colossalai-nightly +``` + +### 从源码安装 + +> 此文档将与版本库的主分支保持一致。如果您遇到任何问题,欢迎给我们提 issue :) + +```shell +git clone https://github.com/hpcaitech/ColossalAI.git +cd ColossalAI + +# install dependency +pip install -r requirements/requirements.txt + +# install colossalai +pip install . +``` + +我们默认在`pip install`时不安装PyTorch扩展,而是在运行时临时编译,如果你想要提前安装这些扩展的话(在使用融合优化器时会用到),可以使用一下命令。 + +```shell +CUDA_EXT=1 pip install . +``` + +

(返回顶端)

+ +## 使用 Docker + +### 从DockerHub获取镜像 + +您可以直接从我们的[DockerHub主页](https://hub.docker.com/r/hpcaitech/colossalai)获取最新的镜像,每一次发布我们都会自动上传最新的镜像。 + +### 本地构建镜像 + +运行以下命令从我们提供的 docker 文件中建立 docker 镜像。 + +> 在Dockerfile里编译Colossal-AI需要有GPU支持,您需要将Nvidia Docker Runtime设置为默认的Runtime。更多信息可以点击[这里](https://stackoverflow.com/questions/59691207/docker-build-with-nvidia-runtime)。 +> 我们推荐从[项目主页](https://www.colossalai.org)直接下载Colossal-AI. + +```bash +cd ColossalAI +docker build -t colossalai ./docker +``` + +运行以下命令从以交互式启动 docker 镜像. + +```bash +docker run -ti --gpus all --rm --ipc=host colossalai bash +``` + +

(返回顶端)

+ +## 社区 +欢迎通过[论坛](https://github.com/hpcaitech/ColossalAI/discussions), +[Slack](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w), +或[微信](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png "qrcode")加入 Colossal-AI 社区,与我们分享你的建议和问题。 + + +## 做出贡献 + +参考社区的成功案例,如 [BLOOM](https://bigscience.huggingface.co/) and [Stable Diffusion](https://en.wikipedia.org/wiki/Stable_Diffusion) 等, +无论是个人开发者,还是算力、数据、模型等可能合作方,都欢迎参与参与共建 Colossal-AI 社区,拥抱大模型时代! + +您可通过以下方式联系或参与: +1. [留下Star ⭐](https://github.com/hpcaitech/ColossalAI/stargazers) 展现你的喜爱和支持,非常感谢! +2. 发布 [issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose), 或者在GitHub根据[贡献指南](https://github.com/hpcaitech/ColossalAI/blob/main/CONTRIBUTING.md) 提交一个 PR。 +3. 发送你的正式合作提案到 contact@hpcaitech.com + +真诚感谢所有贡献者! + + + + + +

(返回顶端)

+ + +## CI/CD + +我们使用[GitHub Actions](https://github.com/features/actions)来自动化大部分开发以及部署流程。如果想了解这些工作流是如何运行的,请查看这个[文档](https://github.com/hpcaitech/ColossalAI/blob/main/.github/workflows/README.md). + + +## 引用我们 + +Colossal-AI项目受一些相关的项目启发而成立,一些项目是我们的开发者的科研项目,另一些来自于其他组织的科研工作。我们希望. 我们希望在[参考文献列表](./REFERENCE.md)中列出这些令人称赞的项目,以向开源社区和研究项目致谢。 + +你可以通过以下格式引用这个项目。 + +``` +@article{bian2021colossal, + title={Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training}, + author={Bian, Zhengda and Liu, Hongxin and Wang, Boxiang and Huang, Haichen and Li, Yongbin and Wang, Chuanrui and Cui, Fan and You, Yang}, + journal={arXiv preprint arXiv:2110.14883}, + year={2021} +} +``` + +Colossal-AI 已被 [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), [PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/)等顶级会议录取为官方教程。 + +

(返回顶端)

diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f520608d552cf81e6752e66225dee5c884f38404 --- /dev/null +++ b/docs/README.md @@ -0,0 +1,112 @@ +# 📕 Documentation + +## 🔗 Table of Contents + +- [📕 Documentation](#-documentation) + - [🔗 Table of Contents](#-table-of-contents) + - [📝 Overview](#-overview) + - [🗺 Module Structure](#-module-structure) + - [🧱 Our Documentation System](#-our-documentation-system) + - [🎊 Contribution](#-contribution) + - [🖊 Adding a New Documentation](#-adding-a-new-documentation) + - [🧹 Doc Testing](#-doc-testing) + - [💉 Auto Documentation](#-auto-documentation) + +## 📝 Overview + +We evaluated various existing solutions for documentation in the community and discussed their advantages and disadvantages in the [issue #2651](https://github.com/hpcaitech/ColossalAI/issues/2651). Therefore, we propose to build a more modern and robust documentation system by integrating the Sphinx [autodoc](https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html) function and the [Docusaurus](https://docusaurus.io/) framework. + +## 🗺 Module Structure + +```text +- docs + - source + - en + - zh-Hans + - sidebars.json + - versions.json + - requirements-doc-test.txt +``` + +The documentation module structure is shown above: +1. source: This folder contains multi-language documentation files. +2. `sidebars.json`: The `sidebars.json` defines the table of content for the tutorials. You need to update this file when a new doc is added/deleted. +3. `versions.json`: The `versions.json` in the **main branch** in the **latest commit** will be used to control the versions to be displayed on our website + +## 🧱 Our Documentation System + +We believe that the combination of the existing systems can yield several advantages such as simplicity, usability and maintainability: +1. Support [Markdown](https://www.markdownguide.org/). We believe is a more popular language for writing documentation compared to [RST](https://docutils.sourceforge.io/rst.html). +2. Support Autodoc. It can automatically generate documentation from the docstrings in the source code provided by [Sphinx](https://www.sphinx-doc.org/en/master/). +3. Support elegant and modern UI, which is provided by [Docusaurus](https://docusaurus.io/). +4. Support MDX for more flexible and powerful documentation, which is provided by [Docusaurus](https://docusaurus.io/). +5. Support hosting blogs/project home page/other pages besides the documentation, which is provided by [Docusaurus](https://docusaurus.io/). + +Therefore, we have built the [ColossalAI-Documentation](https://github.com/hpcaitech/ColossalAI-Documentation) repository to integrate the features above. + +## 🎊 Contribution + +You can contribute to the documentation by directly setting up a Pull Request towards the `docs/source` folder. There are several guidelines for documentation contribution. + +1. The documentation is written in Markdown. You can refer to the [Markdown Guide](https://www.markdownguide.org/) for the syntax. +2. You must ensure that the documentation exists for all languages. You can refer to the [Adding a New Documentation](#-adding-a-new-documentation) for more details. +3. You must provide a test command for your documentation, please see [Doc Testing](#-doc-testing) for more details. +4. You can embed your docstring in your markdown, please see [Auto Documentation](#-auto-documentation) for more details. + +### 🖊 Adding a New Documentation + +You can add a Markdown file to the `docs/source` folder`. You need to ensure that multi-language is supported in your PR. +Let's assume that you want to add a file called `your_doc.md`, your file structure will look like this. + +```text +- docs + - source + - en + - your_doc.md # written in English + - zh-Hans + - your_doc.md # written in Chinese + - sidebars.json # add your documentation file name here +``` + +Meanwhile, you need to ensure the `sidebars.json` is updated such that it contains your documentation file. Our CI will check whether documentation exists for all languages and can be used to build the website successfully. + +### 🧹 Doc Testing + +Every documentation is tested to ensure it works well. You need to add the following line to the **bottom of your file** and replace `$command` with the actual command. Do note that the markdown will be converted into a Python file. Assuming you have a `demo.md` file, the test file generated will be `demo.py`. Therefore, you should use `demo.py` in your command, e.g. `python demo.py`. + +```markdown + +``` + +Meanwhile, only code labeled as a Python code block will be considered for testing. + +```markdown + ```python + print("hello world") + ``` +``` + +Lastly, if you want to skip some code, you just need to add the following annotations to tell `docer` to discard the wrapped code for testing. + +```markdown + + + ```python + print("hello world") + ``` + + +``` + +If you have any dependency required, please add it to `requriements-doc-test.txt` for pip and `conda-doc-test-deps.yml` for Conda. + + +### 💉 Auto Documentation + +Lastly, you may want to include the API documentation for a class/function in your documentation for reference. +We support `autodoc` to extract the docstring and transform it into a Web element for an elegant display. +You just need to add `{{ autodoc: }}` in your markdown as a single line. An example is given below and you can see the outcome in [this PR](https://github.com/hpcaitech/ColossalAI-Documentation/pull/175). + +```markdown +{{ autodoc:colossalai.amp.apex_amp.convert_to_apex_amp }} +``` diff --git a/docs/REFERENCE.md b/docs/REFERENCE.md new file mode 100644 index 0000000000000000000000000000000000000000..2681198191cba708d776663bca2cdcdb497eb931 --- /dev/null +++ b/docs/REFERENCE.md @@ -0,0 +1,38 @@ +# References + +The Colossal-AI project aims to provide a wide array of parallelism techniques for the machine learning community in the big-model era. This project is inspired by quite a few reserach works, some are conducted by some of our developers and the others are research projects open-sourced by other organizations. We would like to credit these amazing projects below in the IEEE citation format. + +## By Our Team + +- Q. Xu, S. Li, C. Gong, and Y. You, ‘An Efficient 2D Method for Training Super-Large Deep Learning Models’. arXiv, 2021. + +- Z. Bian, Q. Xu, B. Wang, and Y. You, ‘Maximizing Parallelism in Distributed Training for Huge Neural Networks’. arXiv, 2021. + +- S. Li, F. Xue, C. Baranwal, Y. Li, and Y. You, ‘Sequence Parallelism: Long Sequence Training from System Perspective’. arXiv, 2021. + +- S. Li et al., ‘Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training’. arXiv, 2021. + +- B. Wang, Q. Xu, Z. Bian, and Y. You, ‘Tesseract: Parallelize the Tensor Parallelism Efficiently’, in Proceedings of the 51th International Conference on Parallel Processing, 2022. + +- J. Fang et al., ‘A Frequency-aware Software Cache for Large Recommendation System Embeddings’. arXiv, 2022. + +- J. Fang et al., ‘Parallel Training of Pre-Trained Models via Chunk-Based Dynamic Memory Management’, IEEE Transactions on Parallel and Distributed Systems, vol. 34, no. 1, pp. 304–315, 2023. + +- Y. Liu, S. Li, J. Fang, Y. Shao, B. Yao, and Y. You, ‘Colossal-Auto: Unified Automation of Parallelization and Activation Checkpoint for Large-scale Models’. arXiv, 2023. + + +## By Other Organizations + +- M. Shoeybi, M. Patwary, R. Puri, P. LeGresley, J. Casper, and B. Catanzaro, ‘Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism’. arXiv, 2019. + +- S. Rajbhandari, J. Rasley, O. Ruwase, and Y. He, ‘ZeRO: Memory Optimizations toward Training Trillion Parameter Models’, in Proceedings of the International Conference for High Performance Computing, Networking, Storage and Analysis, 2020. + +- J. Rasley, S. Rajbhandari, O. Ruwase, and Y. He, ‘DeepSpeed: System Optimizations Enable Training Deep Learning Models with Over 100 Billion Parameters’, in Proceedings of the 26th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining, Virtual Event, CA, USA, 2020, pp. 3505–3506. + +- D. Narayanan et al., ‘Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM’, in Proceedings of the International Conference for High Performance Computing, Networking, Storage and Analysis, St. Louis, Missouri, 2021. + +- Jie Ren, Samyam Rajbhandari, Reza Yazdani Aminabadi, Olatunji Ruwase, Shuangyan Yang, Minjia Zhang, Dong Li, Yuxiong He. 2021. ZeRO-Offload: Democratizing Billion-Scale Model Training. arXiv:2101.06840 and USENIX ATC 2021. + +- S. Rajbhandari, O. Ruwase, J. Rasley, S. Smith, and Y. He, ‘ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning’. in Proceedings of the International Conference for High Performance Computing, Networking, Storage and Analysis, St. Louis, Missouri, 2021. + +- L. Zheng et al., ‘Alpa: Automating Inter- and Intra-Operator Parallelism for Distributed Deep Learning’, in 16th USENIX Symposium on Operating Systems Design and Implementation (OSDI 22), 2022, pp. 559–578. diff --git a/docs/conda-doc-test-deps.yml b/docs/conda-doc-test-deps.yml new file mode 100644 index 0000000000000000000000000000000000000000..74a232214adc57d7f24059a0b28cc61070caec0a --- /dev/null +++ b/docs/conda-doc-test-deps.yml @@ -0,0 +1,2 @@ +dependencies: + - cmake diff --git a/docs/requirements-doc-test.txt b/docs/requirements-doc-test.txt new file mode 100644 index 0000000000000000000000000000000000000000..79e04bd5615d7188a8a33cbd03056173cd3cb92c --- /dev/null +++ b/docs/requirements-doc-test.txt @@ -0,0 +1,7 @@ +colossalai +torch +packaging +tensornvme +psutil +transformers +pytest diff --git a/docs/sidebars.json b/docs/sidebars.json new file mode 100644 index 0000000000000000000000000000000000000000..44287c17eadf45d73e96b53a2dc98387141e7244 --- /dev/null +++ b/docs/sidebars.json @@ -0,0 +1,80 @@ +{ + "tutorialSidebar": [ + { + "type": "category", + "label": "Get started", + "collapsed": true, + "items": [ + "get_started/installation", + "get_started/run_demo", + "get_started/reading_roadmap" + ] + }, + { + "type": "category", + "label": "Concepts", + "collapsed": true, + "items": [ + "concepts/distributed_training", + "concepts/paradigms_of_parallelism", + "concepts/colossalai_overview" + ] + }, + { + "type": "category", + "label": "Basics", + "collapsed": true, + "items": [ + "basics/command_line_tool", + "basics/define_your_config", + "basics/launch_colossalai", + "basics/initialize_features", + "basics/engine_trainer", + "basics/configure_parallelization", + "basics/model_checkpoint", + "basics/colotensor_concept" + ] + }, + { + "type": "category", + "label": "Features", + "collapsed": true, + "items": [ + "features/mixed_precision_training", + "features/gradient_accumulation", + "features/gradient_clipping", + "features/gradient_handler", + "features/zero_with_chunk", + { + "type": "category", + "label": "Tensor Parallel", + "collapsed": true, + "items": [ + "features/1D_tensor_parallel", + "features/2D_tensor_parallel", + "features/2p5D_tensor_parallel", + "features/3D_tensor_parallel" + ] + }, + "features/pipeline_parallel", + "features/nvme_offload" + ] + }, + { + "type": "category", + "label": "Advanced Tutorials", + "collapsed": true, + "items": [ + "advanced_tutorials/train_vit_using_pipeline_parallelism", + "advanced_tutorials/train_vit_with_hybrid_parallelism", + "advanced_tutorials/train_gpt_using_hybrid_parallelism", + "advanced_tutorials/define_your_own_parallel_model", + "advanced_tutorials/add_your_parallel", + "advanced_tutorials/meet_gemini", + "advanced_tutorials/parallelize_your_training_like_Megatron", + "advanced_tutorials/integrate_mixture_of_experts_into_your_model", + "advanced_tutorials/opt_service" + ] + } + ] +} diff --git a/docs/source/en/Colossal-Auto/feature/auto_checkpoint.md b/docs/source/en/Colossal-Auto/feature/auto_checkpoint.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/docs/source/en/Colossal-Auto/feature/device_mesh.md b/docs/source/en/Colossal-Auto/feature/device_mesh.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/docs/source/en/Colossal-Auto/feature/layout_converting_management.md b/docs/source/en/Colossal-Auto/feature/layout_converting_management.md new file mode 100644 index 0000000000000000000000000000000000000000..2082a33d8a39a25d0078d9f8c687872a1da475ff --- /dev/null +++ b/docs/source/en/Colossal-Auto/feature/layout_converting_management.md @@ -0,0 +1,13 @@ +When a tensor is required to have different sharding specs in upstream and downstream operators, we need to perform layout conversion processing, which can also be called redistribution. There are currently two mainstream methods, enumeration conversion, and dimension-by-dimension conversion. enumeration conversion is to enumerate all possible situations, and then find the corresponding conversion scheme in the table when conversion is required. However, it has a big problem. That is, as the dimension of the device mesh increases, the scale of this problem is so inflated that it cannot be solved by enumerating tables. Dimension-by-dimension conversion is for a sharding spec of an N-D tensor, X0X1...Xn-1, sharding spec is converted from 0 to n-1 dimension by dimension, so that no matter how many dimensions the device mesh and tensor have, with only one-time Scanning, a feasible conversion operation sequence is generated, the problem is that the conversion efficiency will be very poor. + +Therefore, we propose a novel algorithm, using heuristic search, to solve the conversion problem of sharding spec, which can be described as: +1. Generate all one-step transform sharding specs from source spec +2. In the one-step transform sharding specs, according to the similarity function, select a sharding spec with the "least difference" as the subsequent source sharding spec, and record the sharding spec in the transform path. If a sharding spec of the one-step transforms is the same as the target sharding spec, the algorithm ends. +3. Repeat 1, 2 until the end of the algorithm + + +| Source/target sharding spec pairs |All gather | Shard | All to All | One step transform | Best sharding spec |Transform path| +| :-: | :-: | :-: | :-: | :-: | :-: |:-: | +| $S_{01}RR, RS_{01}R$ | $S_0RR$ | - | $S_0RS_1, S_0S_1R$ | $S_0RR, S_0RS_1, S_0S_1R$ | $S_0RR$ | $S_0RR$ +| $S_0RR, RS_{01}RR$ | $RRR$ | $S_0S_1R, S_0RS_1$ | $RS_0R, RRS_0$ | $RRR$, $S_0S_1R$, $S_0RS_1$, $RS_0R$, $RRS_0$ | $RS_0R$ | $S_0RR$ -> $RS_0R$ +| $RS_0R, RS_{01}RR$ | $RRR$ | $RS_{01}R, S_1S_0R, RS_0S_1$ | $S_0RR, RRS_0$ | $RRR$, $RS_{01}R$, $S_1S_0R$, $RS_0S_1$, $S_0RR$, $RRS_0$ | $RS_{01}R$ | $S_0RR$ -> $RS_0R$ -> $RS_{01}R$ diff --git a/docs/source/en/Colossal-Auto/feature/tracer.md b/docs/source/en/Colossal-Auto/feature/tracer.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/docs/source/en/Colossal-Auto/get_started/installation.md b/docs/source/en/Colossal-Auto/get_started/installation.md new file mode 100644 index 0000000000000000000000000000000000000000..d2a532bfa7b09adea25e3c556066cce53e8e5654 --- /dev/null +++ b/docs/source/en/Colossal-Auto/get_started/installation.md @@ -0,0 +1,27 @@ +# Setup + +## Announcement + +Our auto-parallel feature is a alpha version. It is still under development. We will keep updating it and make it more stable. If you encounter any problem, please feel free to raise an issue. + +## Requirements + +We need some extra dependencies to support auto-parallel. Please install them before using auto-parallel. + +### Install PyTorch + +We only support PyTorch 1.12 now, other versions are not tested. We will support more versions in the future. + +```bash +#conda +conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.3 -c pytorch +#pip +pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113 +``` + +### Install pulp and coin-or-cbc + +```bash +pip install pulp +conda install -c conda-forge coin-or-cbc +``` diff --git a/docs/source/en/Colossal-Auto/get_started/introduction.md b/docs/source/en/Colossal-Auto/get_started/introduction.md new file mode 100644 index 0000000000000000000000000000000000000000..a2606dd2bf9f0c8e98e1f6e11c3033f3d99bbf18 --- /dev/null +++ b/docs/source/en/Colossal-Auto/get_started/introduction.md @@ -0,0 +1,44 @@ +# Introduction + +In recent years, the deployment of large-scale machine learning models has become increasingly important. However, distributed training systems often require **manual parallelization plans**, which can be complex and require expert knowledge in system engineering and configuration. This can be a challenge for most AI developers without the necessary skills. The need for manual parallelization can make deploying large-scale machine learning models difficult and expensive. + +**Colossal-Auto** simplifies the process of deploying large-scale machine learning models for AI developers. Compared to other solutions that require manual configuration of complex parallel policies and model modification, Colossal-Auto only requires one line of code from the user, along with cluster information and model configurations, to enable distributed training. Technically, It seamlessly **integrates with popular AI model frameworks like Hugging Face and Timm.** + + + +## Overview + +
+ +
+ + +## Usage + +```python +# wrap the model using auto_engine +model = autoparallelize(model, meta_input_samples) +# normal training loop +... +``` + + +## Graph Tracing + +Colossal-Auto is **the first auto-parallelism system** that uses static graph analysis based on the PyTorch framework. Obtaining a static execution plan for PyTorch, a dynamic graph framework, has long been an area of research in the field of machine learning systems. Colossal-Auto uses ColoTracer, a forked version of the torch.FX Tracer, to guide the search for an optimal parallelization strategy. The meta-information of each tensor, such as tensor shape, dims, dtype, etc., is computed and recorded during the tracing process. This approach has the advantage of better generalization, as it is not tied to specific models or configurations. + + + +## Fine-grained Parallelism Search +We investigate and research a number of current automatic parallel systems( Tofu , Flexflow , Alpa ) and some auto activation checkpoint algorithms( Rotor , Sublinear ). Inspired from these advanced systems, we build Colossal-Auto which is an automatic parallel system upon PyTorch framework. Colossal-Auto searches for strategies in regard to each operand with the goal of achieving the fastest runtime while meeting memory budget constraints. It ultimately determines the actual training time strategy, including the tensor split strategy for each tensor, the type of communication operators to be inserted between different computing nodes, whether to replace operators, etc. The tensor, data, and hybrid parallelism such as column and row split used by NVIDIA in Megatron-LM and other parallelism systems are all subsets of strategies that can be searched by Colossal-AI. In addition to these parallelisms that can be manually specified, Colossal-AI can specify a unique parallelism method for each operation and, potentially finding a better parallelism strategy than what human experts could provide. + + + +## Distributed Tensor and Shape-Consistency System + +The Colossal-AI system uses a device-mesh, similar to PyTorch's latest DTensor release, to manage its cluster. Colossal-AI uses a sharding-spec to annotate the storage status of each tensor and facilitate their distribution across the cluster. The system also employs a shape-consistency manager to automatically transform tensors between different sharding-specs, allowing for seamless slicing and dicing of tensors, while the shape-consistency manager ensures that the output of upstream operands is consistently stored in the cluster, regardless of how the input of downstream operands is stored. This makes Colossal-AI highly versatile and easy to use without users worrying about the storage status of tensors when performing operations on them. + +Here are some key advantages of Colossal-AI compared to PyTorch DTensor: +Colossal-AI's device-mesh uses cluster performance metrics and profiling results to estimate the time consumption of different communication operators. This helps Colossal-AI optimize communication between nodes and improve overall system efficiency. +Colossal-AI's shape-consistency manager uses a greedy search algorithm to find relatively efficient ways to transform tensors between different sharding-specs, rather than simply transforming dimensions one by one. This can lead to more efficient and effective transformations. +The integration of all-to-all operations in Colossal-AI increases the scalability of the system by enabling more efficient communication between nodes. This is especially useful for large-scale machine learning tasks that require the transfer of large amounts of data between nodes. diff --git a/docs/source/en/Colossal-Auto/get_started/run_demo.md b/docs/source/en/Colossal-Auto/get_started/run_demo.md new file mode 100644 index 0000000000000000000000000000000000000000..34872e399c815cec9bc0cca20c5c36de11389634 --- /dev/null +++ b/docs/source/en/Colossal-Auto/get_started/run_demo.md @@ -0,0 +1,13 @@ +# Quick Demo + +Colossal-Auto simplifies the process of deploying large-scale machine learning models for AI developers. Compared to other solutions that require manual configuration of complex parallel policies and model modification, Colossal-Auto only requires one line of code from the user, along with cluster information and model configurations, to enable distributed training. Quick demos showing how to use Colossal-Auto are given below. + +### 1. Basic usage + +Colossal-Auto can be used to find a hybrid SPMD parallel strategy includes data, tensor(i.e., 1D, 2D, sequential) for each operation. You can follow the [GPT example](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/gpt/experiments/auto_parallel). +Detailed instructions can be found in its `README.md`. + +### 2. Integration with activation checkpoint + +Colossal-Auto's automatic search function for activation checkpointing finds the most efficient checkpoint within a given memory budget, rather than just aiming for maximum memory compression. To avoid a lengthy search process for an optimal activation checkpoint, Colossal-Auto has implemented a two-stage search process. This allows the system to find a feasible distributed training solution in a reasonable amount of time while still benefiting from activation checkpointing for memory management. The integration of activation checkpointing in Colossal-AI improves the efficiency and effectiveness of large model training. You can follow the [Resnet example](https://github.com/hpcaitech/ColossalAI/tree/main/examples/tutorial/auto_parallel). +Detailed instructions can be found in its `README.md`. diff --git a/docs/source/en/advanced_tutorials/add_your_parallel.md b/docs/source/en/advanced_tutorials/add_your_parallel.md new file mode 100644 index 0000000000000000000000000000000000000000..be7284a7ab64824cb49dfb74a42831e14b4afb59 --- /dev/null +++ b/docs/source/en/advanced_tutorials/add_your_parallel.md @@ -0,0 +1,124 @@ +# Add Your Own Parallel Mode + +Author: Shenggui Li, Yongbin Li + +**Prerequisite:** +- [Define Your Configuration](../basics/define_your_config.md) +- [Configure Parallelization](../basics/configure_parallelization.md) + +## Introduction + +To enable researchers and engineers to extend our system to other novel large-scale distributed training algorithm +with less effort, we have decoupled various components in the training lifecycle. You can implement your own +parallelism by simply inheriting from the base class. + +The main components are: + +1. `ProcessGroupInitializer` +2. `GradientHandler` +3. `Schedule` + +**This currently requires some code to the source code, thus we recommend that you install from source with the `-e` flag. +`-e` flag makes the installation editable, thus, your code change will be reflected in your Python runtime. +We will work on this to avoid change to source code in future releases.** + + +## Process Group Initializer + +Parallelism is often managed by process groups where processes involved in the same parallel algorithm are placed in the same +process group. For different parallel algorithms, different process groups need to be created. Colossal-AI provides a +global context for users to easily manage their process groups. If you wish to add new process group, you can easily +define a new class and set it in your configuration file. To define your own way of creating process groups, you can +follow the steps below to create a new distributed initialization. + +1. Add your parallel mode in `colossalai.context.parallel_mode.ParallelMode`. + ```python + class ParallelMode(Enum): + GLOBAL = 'global' + DATA = 'data' + PIPELINE = 'pipe' + ... + + NEW_MODE = 'new_mode' # define your mode here + ``` + +2. Create a `ProcessGroupInitializer`. You can refer to examples given in `colossalai.context.dist_group_initializer`. The + first six arguments are fixed. `ParallelContext` will pass in these arguments for you. If you need to set other + arguments, you can add it behind like the `arg1, arg2` in the example below. Lastly, register your initializer to the + registry by adding the decorator `@DIST_GROUP_INITIALIZER.register_module`. + ```python + # sample initializer class + @DIST_GROUP_INITIALIZER.register_module + class MyParallelInitializer(ProcessGroupInitializer): + + def __init__(self, + rank: int, + world_size: int, + config: Config, + data_parallel_size: int, + pipeline_parlalel_size: int, + tensor_parallel_size: int, + arg1, + arg2): + super().__init__(rank, world_size, config) + self.arg1 = arg1 + self.arg2 = arg2 + # ... your variable init + + def init_parallel_groups(self): + # initialize your process groups + pass + + ``` + + Then, you can insert your new initializer to the current mode-to-initialize mapping + in `colossalai.constants.INITIALIZER_MAPPING`. You can modify the file or insert new key-value pair dynamically. + + ```python + colossalai.constants.INITIALIZER_MAPPING['new_mode'] = 'MyParallelInitializer' + ``` + +3. Set your initializer in your config file. You can pass in your own arguments if there is any. This allows + the `ParallelContext` to create your initializer and initialize your desired process groups. + + ```python + parallel = dict( + pipeline=dict(size=1), + tensor=dict(size=x, mode='new_mode') # this is where you enable your new parallel mode + ) + ``` + +## Gradient Handler + +Gradient handlers are objects which execute the all-reduce operations on parameters' gradients. As different all-reduce +strategies may be executed for different kinds of parallelism, users can +inherit `colossalai.engine.gradient_handler.BaseGradientHandler` to implement their strategies. Currently, the library +uses the normal data parallel gradient handler which all-reduces the gradients across data parallel ranks. The data +parallel gradient handler is added to the engine automatically if data parallel is detected. You can add your own +gradient handler like below: + +```python +from colossalai.registry import GRADIENT_HANDLER +from colossalai.engine import BaseGradientHandler + +@GRADIENT_HANDLER.register_module +class YourGradientHandler(BaseGradientHandler): + + def handle_gradient(self): + do_something() + +``` + +Afterwards, you can specify the gradient handler you want to use in your configuration file. + +```python +gradient_handlers = [ + dict(type='YourGradientHandler'), +] +``` + +## Schedule + +Schedule entails how to execute a forward and backward pass. Currently, Colossal-AI provides pipeline and non-pipeline +schedules. If you want to modify how the forward and backward passes are executed, you can +inherit `colossalai.engine.schedule.BaseSchedule` and implement the `forward_back_step` function. diff --git a/docs/source/en/advanced_tutorials/define_your_own_parallel_model.md b/docs/source/en/advanced_tutorials/define_your_own_parallel_model.md new file mode 100644 index 0000000000000000000000000000000000000000..8e48737d2f6435dd55f1f673647b82d6f64abaca --- /dev/null +++ b/docs/source/en/advanced_tutorials/define_your_own_parallel_model.md @@ -0,0 +1,36 @@ +# Define your own parallel model + +Author: Zhengda Bian, Yongbin Li + +> ⚠️ We are working on this documentation to make it more detailed. We will introduce the mechanism of different parallelism +> and how to use them to write a model. + +Let's say that you have a huge MLP model with billions of parameters and its extremely large hidden layer size makes it +impossible to fit into a single GPU directly. Don't worry, Colossal-AI is here to help you sort things out. With the help of Colossal-AI, +you can write your model in the familiar way in which you used to write models for a single GPU, while Colossal-AI automatically +splits your model weights and fit them perfectly into a set of GPUs. We give a simple example showing how to write a simple +2D parallel model in the Colossal-AI context. + +## Write a simple 2D parallel model + +```python +from colossalai.nn import Linear2D +import torch.nn as nn + +class MLP_2D(nn.Module): + + def __init__(self): + super().__init__() + self.linear_1 = Linear2D(in_features=1024, out_features=16384) + self.linear_2 = Linear2D(in_features=16384, out_features=1024) + + def forward(self, x): + x = self.linear_1(x) + x = self.linear_2(x) + return x +``` + +## Use pre-defined model + +For the sake of your convenience, we kindly provide you in our Model Zoo with some prevalent models such as *BERT*, *ViT*, *MoE*, +and *GPT*. Feel free to customize them into different sizes to fit into your special needs. diff --git a/docs/source/en/advanced_tutorials/integrate_mixture_of_experts_into_your_model.md b/docs/source/en/advanced_tutorials/integrate_mixture_of_experts_into_your_model.md new file mode 100644 index 0000000000000000000000000000000000000000..e01caf76d2b323959a4ce7d7d85521ff919a5385 --- /dev/null +++ b/docs/source/en/advanced_tutorials/integrate_mixture_of_experts_into_your_model.md @@ -0,0 +1,139 @@ +# Integrate Mixture-of-Experts Into Your Model + +Author: Haichen Huang + +**Example Code** +- [ColossalAI-Examples WideNet](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/image/widenet) + +**Related Paper** +- [Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity](https://arxiv.org/abs/2101.03961) +- [Go Wider Instead of Deeper](https://arxiv.org/abs/2107.11817) + + +## Introduction + +Since the advent of Switch Transformer, the AI community has found Mixture of Experts (MoE) a useful technique to enlarge the capacity of deep learning models. + +Colossal-AI provides an early access version of parallelism specifically designed for MoE models. +The most prominent advantage of MoE in Colossal-AI is convenience. +We aim to help our users to easily combine MoE with model parallelism and data parallelism. + +However, the current implementation has two main drawbacks now. +The first drawback is its poor efficiency in large batch size and long sequence length training. +The second drawback is incompatibility with tensor parallelism. +We are working on system optimization to overcome the training efficiency problem. +The compatibility problem with tensor parallelism requires more adaptation, and we will tackle this issue in the future. + +Here, we will introduce how to use MoE with model parallelism and data parallelism. + +## Table of Content +In this tutorial we will cover: +1. Set up MoE running environment +2. Create MoE layer +3. Train your model + +We provided the [example code](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/image/widenet) for this tutorial in [ColossalAI-Examples](https://github.com/hpcaitech/ColossalAI-Examples). +This example uses [WideNet](https://arxiv.org/abs/2107.11817) as an example of MoE-based model. + + +## Set up MoE running environment +In your project folder, create a `config.py`. + +This file is to specify some features you may want to use to train your model. +In order to enable MoE, you need to add a dict called parallel and specify the value of key moe. +You can assign a value for the key size of moe, which represents the model parallel size of experts (i.e. the number of experts in one group to parallelize training). + +For example, if the size is 4, 4 processes will be assigned to 4 consecutive GPUs and these 4 processes form a moe model parallel group. +Each process on the 4 GPUs will only get a portion of experts. Increasing the model parallel size will reduce communication cost, but increase computation cost in each GPU and activation cost in memory. +The total data parallel size is auto-detected and set as the number of GPUs by default. + +```python +MOE_MODEL_PARALLEL_SIZE = ... +parallel = dict( + moe=dict(size=MOE_MODEL_PARALLEL_SIZE) +) +``` + +If `MOE_MODEL_PARALLEL_SIZE = E` and set the number of experts as `E` where `E` is a constant number, the process flow of forward pass of a transformer encoder in a model parallel group is shown below. + +
+ +
MoE Transformer, image source: GShard
+
+ +Since all experts are allocated to all GPUs in a model parallel group and a GPU only owns a portion of experts, +original data parallel groups are no longer correct for the parameters of experts during gradient handling in backward pass anymore. +So we create a new kind of parallel group called moe data parallel group. +The difference among different kinds of parallel group, when the configuration is set as `WORLD_SIZE=4`, +`MOE_MODEL_PARALLEL_SIZE=2`, is shown here. + +
+ +
MoE process group
+
+ + +As for gradient handling, we provide MoeGradientHandler to all-reduce every parameter of the model. +If you use `colossalai.initialize` function to create your training engine, the MoE gradient handler will be added to your engine automatically. +Otherwise, you should take care of gradient by yourself. +All parameters of MoE running environment are stored in colossalai.global_variables.moe_env. +You can access your configuration parameters to check whether your setup is correct. +```python +from colossalai.global_variables import moe_env +``` + +## Create MoE layer +You can create a MoE layer from `colossalai.nn.moe`. +But before doing that, you should set up random seeds for all processes like this. + +```python +from colossalai.context.random import moe_set_seed +from model_zoo.moe.models import Widenet + +moe_set_seed(42) +model = Widenet(num_experts=4, capacity_factor=1.2) +``` + +`moe_set_seed` will set different seed for different processes in a moe model parallel group. +This helps initialize parameters in experts. +Then create an instance of experts and an instance of router. +Here is the example in model zoo. + +```python +from colossalai.nn.layer.moe import Experts, MoeLayer, Top2Router, NormalNoiseGenerator + + +noisy_func = NormalNoiseGenerator(num_experts) +shared_router = Top2Router(capacity_factor, + noisy_func=noisy_func) +shared_experts = Experts(expert=VanillaFFN, + num_experts=num_experts, + **moe_mlp_args( + d_model=d_model, + d_ff=d_ff, + drop_rate=drop_rate + )) +ffn=MoeLayer(dim_model=d_model, num_experts=num_experts, + router=shared_router, experts=shared_experts) +``` + +Inside the initialization of Experts, the local expert number of each GPU will be calculated automatically. You just need to specify the class of each expert and its parameters used in its initialization. As for routers, we have provided top1 router and top2 router. You can find them in colossalai.nn.layer.moe. After creating the instance of experts and router, the only thing initialized in Moelayer is gate module. More definitions of each class can be found in our API document and code. + + +## Train Your Model +Do not to forget to use `colossalai.initialize` function in `colosalai` to add gradient handler for the engine. +We handle the back-propagation of MoE models for you. +In `colossalai.initialize`, we will automatically create a `MoeGradientHandler` object to process gradients. +You can find more information about the handler `MoeGradientHandler` in colossal directory. + +The loss criterion should be wrapped by `Moeloss` to add auxiliary loss of MoE. Example is like this. +```python +criterion = MoeLoss( + aux_weight=0.01, + loss_fn=nn.CrossEntropyLoss, + label_smoothing=0.1 +) +``` + +Finally, just use trainer or engine in `colossalai` to do your training. +Otherwise, you should take care of gradient by yourself. diff --git a/docs/source/en/advanced_tutorials/meet_gemini.md b/docs/source/en/advanced_tutorials/meet_gemini.md new file mode 100644 index 0000000000000000000000000000000000000000..8afb6705b6ae84afc23cc16767cf2581719c1abb --- /dev/null +++ b/docs/source/en/advanced_tutorials/meet_gemini.md @@ -0,0 +1,88 @@ + +# Meet Gemini:The Heterogeneous Memory Manager of Colossal-AI + +Author: [Jiarui Fang](https://github.com/feifeibear), Yang You + +## Brief + +When you only have a few GPUs for large model training tasks, **heterogeneous training** is the most effective approach. By accommodating model data in CPU and GPU and moving the data to the computing device when necessary, it can breakthrough the GPU memory wall by using GPU and CPU memory (composed of CPU DRAM or nvme SSD memory) together at the same time. Moreover, the model scale can be further improved by combining heterogeneous training with the other parallel approaches, such as data parallel, tensor parallel and pipeline parallel . We now describe the design details of **Gemini**, the heterogeneous memory space manager of Colossal-AI. Its idea comes from [PatrickStar](https://arxiv.org/abs/2108.05818), which has been adapted to Colossal-AI. + +## Usage + +At present, Gemini supports compatibility with ZeRO parallel mode, and it is really simple to use Gemini. Set attribute of zero model_config, i.e., tensor_placement_policy='auto'. + +``` +zero = dict( + model_config=dict( + tensor_placement_policy='auto', + shard_strategy=BucketTensorShardStrategy() + ), + optimizer_config=dict( + ...) +) +``` + +Note that Gemini and parallel strategies such as tensor parallelism, data parallelism, pipeline parallelism and zero should be decoupled. However, Colossal-AI requires users to use Gemini with ZeRO. Although they are not necessarily coupled, we will improve it in the near future. + +## Concepts + +**OP**(**OP**erator):operation of a neural network layer, such as linear, LayerNorm, etc. The operator can be a forward propagation calculation or a back-propagation calculation. + +Neural networks must manage two types of training data during training. +**model data**: consists of parameters, gradients and optimizer states, and its scale is related to the definition of model structure. + +**Non-model data**: mainly composed of the intermediate tensor generated by the operator and the temporary variables of the operator. Non-model data changes dynamically according to the configuration of training tasks, such as batch size. Model data and non-model data compete with each other for GPU memory. + +## Design Details + + +In some solutions, the [Zero-offload](https://arxiv.org/abs/2101.06840) adopted by DeepSpeed statically divides model data between CPU and GPU memory, and their memory layout is constant for different training configurations. As shown on the left of the figure below, when the GPU memory is insufficient to meet its corresponding model data requirements, the system will crash even if there is still available memory on the CPU at that time. While Colossal-AI can complete the training by moving part of the model data to the CPU. + +
+ +
Comparison of the memory management of Zero-Offload and Gemini
+
+ + +Colossal-AI designed Gemini, just like two-stars, which manages the memory space of CPU and GPU efficiently. It can make the tensor dynamically distributed in the storage space of CPU-GPU during training, so that the model training can break through the memory wall of GPU. The memory manager consists of two parts: **MemStatsCollector (MSC)** and **StatefulTensorMgr (STM)**. + +We take advantage of the iterative characteristics of the deep learning network training process. We divide iterations into two stages: warmup and non-warmup. One or several iterative steps at the beginning belong to the warmup stage, and the other iterative steps belong to the non-warmup stage. In the warmup stage, we collect information for the MSC, while in the non-warmup stage, STM gets the information collected by the MSC to move the tensor, so as to minimize the CPU-GPU data movement volume. + +
+ +
The workflow of Gemini during warmup and non-warmup phase
+
+ + +### StatefulTensorMgr + +STM manages the information of all model data tensors. In the process of model construction, Colossal-AI registers all model data tensors with STM. The memory manager marks each tensor with state information. The state set includes three types: HOLD, COMPUTE and FREE. The functions of STM are as follows: + +**Query memory usage:**by traversing the locations of all tensors in heterogeneous space, obtain the memory occupation of CPU and GPU by model data. + +**Transition tensor state:** it marks the tensor as COMPUTE state before each model data tensor participates in the operator calculation, and as HOLD state after calculation. The FREE state marked if the tensor is no longer in use. + +**Adjust tensor position:**tensor manager ensures that the tensor in COMPUTE state is placed on the computing device. If the storage space of the computing device is insufficient, it is necessary to move some tensors in HOLD state to other devices for storage. Tensor eviction strategy requires information from MSC, which will be introduced later. + + +### MemStatsCollector +In the warmup stage, the memory information statistician monitors the memory usage of model data and non-model data in CPU and GPU for reference in the non-warmup stage. We can obtain the memory usage of model data at a certain time by querying STM. However, the memory usage of non-model data is difficult to obtain. Owing to the life cycle of non-model data not being managed by users, the existing deep learning framework does not expose the tracking interface of non-model data to users. MSC obtains the usage of CPU and GPU memory by non-model in the warmup stage through sampling. The specific methods are as follows: + +We trigger the memory sampling operation at the beginning and end of the operator. We call this time point **sampling moment**, and the time between the two sampling moments is called **period**. The calculation process is a black box. Due to the possible allocation of temporary buffer, the memory usage is very complex. However, we can accurately obtain the maximum memory usage of the system during the period. The use of non-model data can be obtained by the maximum memory use of the system between two statistical moments-model memory use. + +How do we design the sampling time. Before we choose model data layout adjust of preOp. As shown in the figure below. We sample the system memory used of the previous period and the model data memory used of the next period. The parallel strategy will cause obstacles to the work of MSC. As shown in the figure, for example, for ZeRO or Tensor Parallel, because gathering model data is required before OP calculation, it will bring additional memory requirements. Therefore, we require to sample the system memory before the model data changes, so that the MSC will capture the model change memory of preOp within a period. For example, in period 2-3, we consider the memory changes brought by tensor gather and shard. + +Although the sampling time can be placed in other locations, such as excluding the new information of the change of the gather buffer, it will cause trouble. There are differences in the implementation of Op in different parallel modes. For example, for Linear Op, gather buffer in Tensor Parallel is allocated in Op. For ZeRO, the allocation of gather buffer is in PreOp. Sampling at the beginning of PreOp helps to unify the two situations. + +
+ +
workflow
+
+ +### Tensor Eviction Strategy + +The important duty of MSC is to adjust the tensor layout position. For example, at S2 in the figure above, we reduce the model data on the device, and meet the peak memory requirement calculated in period 2-3. + +In the warmup stage, since we haven't finished a complete iteration yet, we don't know actual memory occupation. At this time, we limit the upper bound of memory usage of the model data. For example, only 30% of the GPU memory can be used. This ensures that we can successfully complete the warmup state. + +In the non-warmup stage, we need to use the memory information of non-model data collected in the warm-up stage to reserve the peak memory required by the computing device for the next Period, which requires us to move some model tensors. In order to avoid frequent replacement of the same tensor in and out of the CPU-GPU, causing a phenomenon similar to [cache thrashing](https://en.wikipedia.org/wiki/Thrashing_(computer_science)). Using the iterative characteristics of DNN training, we design the OPT cache swap out strategy. Specifically, in the warmup stage, we record the sampling time required by each tensor computing device. If we need to expel some HOLD tensors, we will choose the latest tensor needed on this device as the victim. diff --git a/docs/source/en/advanced_tutorials/opt_service.md b/docs/source/en/advanced_tutorials/opt_service.md new file mode 100644 index 0000000000000000000000000000000000000000..a43ec7fdd1fe8736a90e05ba186fe554c1e75384 --- /dev/null +++ b/docs/source/en/advanced_tutorials/opt_service.md @@ -0,0 +1,81 @@ +# Build an online OPT service using Colossal-AI in 5 minutes + +## Introduction + +This tutorial shows how to build your own service with OPT with the help of [Colossal-AI](https://github.com/hpcaitech/ColossalAI). + +## Colossal-AI Inference Overview +Colossal-AI provides an inference subsystem [Energon-AI](https://github.com/hpcaitech/EnergonAI), a serving system built upon Colossal-AI, which has the following characteristics: + +- **Parallelism for Large-scale Models:** With the help of tensor parallel operations, pipeline parallel strategies from Colossal-AI, Colossal-AI inference enables efficient parallel inference for large-scale models. +- **Pre-built large models:** There are pre-built implementations for popular models, such as OPT. It supports a caching technique for the generation task and checkpoints loading. +- **Engine encapsulation:** There has an abstraction layer called an engine. It encapsulates the single instance multiple devices (SIMD) execution with the remote procedure call, making it act as the single instance single device (SISD) execution. +- **An online service system:** Based on FastAPI, users can launch a web service of a distributed inference quickly. The online service makes special optimizations for the generation task. It adopts both left padding and bucket batching techniques to improve efficiency. + +## Basic Usage: + +1. Download OPT model + +To launch the distributed inference service quickly, you can download the OPT-125M from [here](https://huggingface.co/patrickvonplaten/opt_metaseq_125m/blob/main/model/restored.pt). You can get details for loading other sizes of models [here](https://github.com/hpcaitech/EnergonAI/tree/main/examples/opt/script). + +2. Prepare a prebuilt service image + +Pull a docker image from docker hub installed with Colossal-AI inference. + +```bash +docker pull hpcaitech/energon-ai:latest +``` + +3. Launch an HTTP service + +To launch a service, we need to provide python scripts to describe the model type and related configurations, and settings for the HTTP service. +We have provided a set of [examples](https://github.com/hpcaitech/EnergonAI/tree/main/examples]). We will use the [OPT example](https://github.com/hpcaitech/EnergonAI/tree/main/examples/opt) in this tutorial. +The entrance of the service is a bash script server.sh. +The config of the service is at opt_config.py, which defines the model type, the checkpoint file path, the parallel strategy, and http settings. You can adapt it for your own case. +For example, set the model class as opt_125M and set the correct checkpoint path as follows. + +```bash +model_class = opt_125M +checkpoint = 'your_file_path' +``` + +Set the tensor parallelism degree the same as your gpu number. + +```bash +tp_init_size = #gpu +``` + +Now, we can launch a service using docker. You can map the path of the checkpoint and directory containing configs to local disk path `/model_checkpoint` and `/config`. + + +```bash +export CHECKPOINT_DIR="your_opt_checkpoint_path" +# the ${CONFIG_DIR} must contain a server.sh file as the entry of service +export CONFIG_DIR="config_file_path" + +docker run --gpus all --rm -it -p 8020:8020 -v ${CHECKPOINT_DIR}:/model_checkpoint -v ${CONFIG_DIR}:/config --ipc=host energonai:lastest +``` + +Then open `https://[IP-ADDRESS]:8020/docs#` in your browser to try out! + + +## Advance Features Usage: + +1. Batching Optimization + +To use our advanced batching technique to collect multiple queries in batches to serve, you can set the executor_max_batch_size as the max batch size. Note, that only the decoder task with the same top_k, top_p and temperature can be batched together. + +``` +executor_max_batch_size = 16 +``` + +All queries are submitted to a FIFO queue. All consecutive queries whose number of decoding steps is less than or equal to that of the head of the queue can be batched together. Left padding is applied to ensure correctness. executor_max_batch_size should not be too large. This ensures batching won't increase latency. For opt-30b, `executor_max_batch_size=16` may be a good choice, while for opt-175b, `executor_max_batch_size=4` may be better. + +2. Cache Optimization. + +You can cache several recently served query results for each independent serving process. Set the cache_size and cache_list_size in config.py. The cache size is the number of queries cached. The cache_list_size is the number of results stored for each query. And a random cached result will be returned. When the cache is full, LRU is applied to evict cached queries. cache_size=0means no cache is applied. + +``` +cache_size = 50 +cache_list_size = 2 +``` diff --git a/docs/source/en/advanced_tutorials/parallelize_your_training_like_Megatron.md b/docs/source/en/advanced_tutorials/parallelize_your_training_like_Megatron.md new file mode 100644 index 0000000000000000000000000000000000000000..e7698e5e9d1b9d5c99d52cd455281d7e9f358072 --- /dev/null +++ b/docs/source/en/advanced_tutorials/parallelize_your_training_like_Megatron.md @@ -0,0 +1,192 @@ +# Parallelize Your Training like Megatron-LM via ColoTensor + +Author: [Haichen Huang](https://github.com/1SAA) and [Jiarui Fang](https://github.com/feifeibear) + +**Prerequisite:** +- [ColoTensor Concepts](../basics/colotensor_concept.md) + +## Introduction + +Thanks to the convenience given by ColoTensor, users can apply parallelism with the least edition to their serial code. +In this tutorial, we will illustrate how to modify the training model to automatically adapt the code to parallel training like Megatron-LM. +We take the GPT-2 model offered by HuggingFace as an example and provide a way for you to pre-train the GPT-2 model on a single GPU. + +Megatron-LM provided a profound paradigm to parallelize large transformer language models. +However, in order to train large transformer language models at scale, users have to build their models with those modules provided by Megatron. +It imposes several difficult jobs on users, such as loading the weights from the pre-trained models and constructing the parallelized models. +To mitigate users' trouble, we offer ColoTensor to enable the tensor model parallelism automatically. + +## Definitions of the model and the loss function + +First we use the GPTModel and GPTLoss directly from the HuggingFace library. + +```python +import torch +import torch.nn as nn +from transformers import GPT2Config, GPT2LMHeadModel + +class GPTLMModel(nn.Module): + def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=12, max_seq_len=1024, vocab_size=50257, checkpoint=False): + super().__init__() + self.checkpoint = checkpoint + self.model = GPT2LMHeadModel(GPT2Config(n_embd=hidden_size, n_layer=num_layers, + n_head=num_attention_heads, n_positions=max_seq_len, n_ctx=max_seq_len, vocab_size=vocab_size)) + if checkpoint: + self.model.gradient_checkpointing_enable() + + def forward(self, input_ids, attention_mask): + # Only return lm_logits + return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0] + + +class GPTLMLoss(nn.Module): + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, logits, labels): + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) +``` + +## Brief Review of GPT-2 + +Now, we recall the structure of each GPT-2 model. +Every GPT-2 model can be represented as a DAG. +As shown in the below pictures, each circle represents an operator and each square represents a weight. +An arrow indicates the flow of the input data, and the notation alongside the arrow demonstrates the shape of the input data. + +Then, let's take an insight into this GPT-2 model. It consists of three parts. +They are the **embedding module**, **transformer layers**, and the **classification head**. + +The embedding module contains two weights, token embedding weight and position embedding weight. +After the forward operation of the embedding module, each word in all sequences of the raw input data will be embedded into a hidden state. + +
+ +
The embedding module
+
+ +Each transformer layer contains two blocks. The self-attention operation is called in the first block and a two-layer percepton is located in the second block. + +
+ +
The transformer layer
+
+ +In the end, the classification head is just a linear module without bias, which only has a weight inside. + +## Applied with ColoTensor + +Two steps make your serial code adapted to Megatron-LM tensor parallel style. +1. Initialize the model in the context of ColoInitContext. +2. Setting ColoTensorSpec for each parameter. + +### Initialize with ColoInitContext + +We should build the model in the ColoInitContext. +In this context, any parameter initialized would be transformed to ColoParameter and moved to the corresponded device automatically. + +```python +from colossalai.utils.model.colo_init_context import ColoInitContext + +with ColoInitContext(device=torch.device('cpu')): + model = GPTLMModel() +``` + +### Setting ColoTensorSpec for each parameter + +After the creation of the model, we establish the distributed environment through ProcessGroup. +Here, we specify the degree of the tensor parallelism as the same as the number of all GPUs, which means the degree of data parallelism is 1. + +```python +import torch.distributed as dist +from colossalai.tensor import ProcessGroup + +pg = ProcessGroup(tp_degree=dist.get_world_size()) +``` + +Now, some auxiliary functions are necessary for the next step. We define two functions to split a parameter. +Megatron-LM-like tensor parallelism requires splitting a parameter tensor along its first dimension or its last dimension. + +```python +from colossalai.tensor import ShardSpec, ComputeSpec, ComputePattern, ColoParameter, ProcessGroup + +def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup): + spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + if param.process_group.tp_world_size() == 1: + param.set_process_group(pg) + param.set_tensor_spec(*spec) + + +def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup): + split_param_single_dim_tp1d(0, param, pg) + + +def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup): + split_param_single_dim_tp1d(-1, param, pg) +``` + +Then we adapt the model to the tensor parallelism. +According to the tensor parallelism applied in Megatron, it is supposed to shard along the last dimension of tensors, including the weights of token embedding, position embedding, all linear weights and biases in self-attention blocks, the first weight linear and bias in each MLP. +And it shards the second linear weight along its first dimension. + +```python +for mn, module in model.named_modules(): + for pn, param in module.named_parameters(recurse=False): + # set process group for all parameters + param.set_process_group(pg) + + if 'mlp.c_fc' in mn: + if 'weight' in pn or 'bias' in pn: + split_param_col_tp1d(param, pg) # colmn slice + # keep the shape of the output from c_fc + param.compute_spec.set_output_replicate(False) + elif 'mlp.c_proj' in mn: + if 'weight' in pn: + split_param_row_tp1d(param, pg) # row slice + elif 'wte' in mn or 'wpe' in mn: + split_param_col_tp1d(param, pg) # colmn slice + elif 'c_attn' in mn or 'c_proj' in mn: + split_param_col_tp1d(param, pg) # colmn slice +``` + +The modified model is illustrated below. + +The embedding module: + +
+ +
The modified embedding module
+
+ +The transformer layers: + +
+ +
The modified transformer layer
+
+ +Once users have specified the distributed pattern of each parameter, ColoTensor is capable of inferring the computation patterns of all operators, including matrix multiplication, the linear function, other elementwise functions in torch.nn.functional, etc. +In this way, users can train their models as usual. + +In our latest example, a Gemini + ZeRO DDP model is also defined to reduce overhead and improve efficiency.For the details of this part, please refer to [ZeRO](../features/zero_with_chunk.md). You can combine these two parts to understand our entire training process: + +```python +def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"): + from colossalai.nn.parallel import GeminiDDP + model = GeminiDDP(model, + device=get_current_device(), + placement_policy=placememt_policy, + pin_memory=True, + search_range_mb=32) + return model +``` + +## Pretrain GPT-2 On Single GPU + +The above optimization we made allows us to pretrain the GPT-2 model on a single GPU. We only need to set the parameter `GPUNUM`=1 in `run.sh`, and then we can complete the model training on a single GPU when running the file. + +The GPT-2 example is accessible at [Train GPT with Colossal-AI](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/gpt). diff --git a/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md b/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md new file mode 100644 index 0000000000000000000000000000000000000000..715c15eb63003b7c1cde0b28d2533518755db34f --- /dev/null +++ b/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md @@ -0,0 +1,270 @@ +# Train GPT Using Hybrid Parallelism + +Author: Hongxin Liu, Yongbin Li + +**Example Code** +- [ColossalAI-Examples GPT2](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/gpt_2) +- [ColossalAI-Examples GPT3](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/gpt_3) + +**Related Paper** +- [Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training](https://arxiv.org/abs/2110.14883) +- [Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM](https://arxiv.org/abs/2104.04473) + +## Introduction + +In the previous tutorial, we introduce how to train ViT with pipeline. In this tutorial, you will learn a more complex scenario -- train GPT with hybrid parallelism. In this case, GPT-3 is so large that CPU memory cannot fit it as well. Therefore, you must split the model by yourself. + +## Table of content + +In this tutorial we will cover: + +1. The definition of GPT model, based on colossalai/model_zoo +2. Processing the dataset +3. Training GPT using hybrid parallelism + +## Import libraries + +```python +import json +import os +from typing import Callable + +import colossalai +import colossalai.utils as utils +import model_zoo.gpt.gpt as col_gpt +import torch +import torch.nn as nn +from colossalai import nn as col_nn +from colossalai.amp import AMP_TYPE +from colossalai.builder.pipeline import partition_uniform +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.engine.schedule import (InterleavedPipelineSchedule, + PipelineSchedule) +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper +from colossalai.trainer import Trainer, hooks +from colossalai.utils.timer import MultiTimer +from model_zoo.gpt import GPTLMLoss +from torch.nn import functional as F +from torch.utils.data import Dataset +from transformers import GPT2Tokenizer +``` + + + +## Define GPT model + +In the previous tutorial, we introduced 3 ways to build a pipelined model. But for huge models like GPT-3, you can't even build the model in CPU. In this case, you must split the model by yourself. + +GPT dataloader returns `input_ids` and `attention_mask`, so we use two keyword arguments in `forward()` to get them. Note that for stages except the first stage, the first positional argument of `forward()` is the output tensor from the previous stage. So the `hidden_states` is from the previous stage, and for the first stage it's `None`. + +For GPT, the *word embedding layer* shares the weights with the *output head*. We provide `PipelineSharedModuleWrapper` to share parameters among pipeline stages. It takes a `list` of `int` as argument, which means those ranks share the parameters. You can use `register_module()` or `register_parameter()` to register a module or a parameter as the shared module or parameter. If you have multiple sets of shared modules / parameters, you should have multiple `PipelineSharedModuleWrapper` instance. If the parameter is shared within **one** stage, you should not use `PipelineSharedModuleWrapper`, and just use the same module / parameter instance. In this example, the *word embedding layer* is at the first stage, and the *output head* is at the last stage. Thus, they are shared among ranks `[0, pipeline_size - 1]`. + +For the first stage, it maintains the embedding layer and some transformer blocks. For the last stage, it maintains some transformer blocks and the output head layer. For other stages, they just maintain some transformer blocks. `partition_uniform(num_layers, pipeline_size, num_chunks)` returns the parts of all ranks, and the part is a `tuple` of `(start, end)` (exclude end). `start == 0` means that it's the first stage, and `end == num_layers` means it's the last stage. + +```python +class PipelineGPTHybrid(nn.Module): + def __init__(self, + num_layers: int = 12, + hidden_size: int = 768, + num_attention_heads: int = 12, + vocab_size: int = 50304, + embed_drop_rate: float = 0., + act_func: Callable = F.gelu, + mlp_ratio: int = 4, + attn_drop_rate: float = 0., + drop_rate: float = 0., + dtype: torch.dtype = torch.float, + checkpoint: bool = False, + max_position_embeddings: int = 1024, + layer_norm_epsilon: float = 1e-5, + first: bool = False, + last: bool = False): + super().__init__() + self.embedding = None + self.norm = None + self.head = None + if first: + self.embedding = col_gpt.GPTEmbedding( + hidden_size, vocab_size, max_position_embeddings, dropout=embed_drop_rate, dtype=dtype) + self.blocks = nn.ModuleList([ + col_gpt.GPTBlock(hidden_size, num_attention_heads, mlp_ratio=mlp_ratio, attention_dropout=attn_drop_rate, + dropout=drop_rate, dtype=dtype, checkpoint=checkpoint, activation=act_func) + for _ in range(num_layers) + ]) + if last: + self.norm = col_nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) + self.head = col_gpt.GPTLMHead(vocab_size=vocab_size, + dim=hidden_size, + dtype=dtype, + bias=False) + + def forward(self, hidden_states=None, input_ids=None, attention_mask=None): + if self.embedding is not None: + hidden_states = self.embedding(input_ids=input_ids) + batch_size = hidden_states.shape[0] + attention_mask = attention_mask.view(batch_size, -1) + attention_mask = attention_mask[:, None, None, :] + attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * -10000.0 + for block in self.blocks: + hidden_states, attention_mask = block(hidden_states, attention_mask) + if self.norm is not None: + hidden_states = self.head(self.norm(hidden_states)) + return hidden_states + + +def build_gpt_pipeline(num_layers, num_chunks, device=torch.device('cuda'), **kwargs): + logger = get_dist_logger() + pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + rank = gpc.get_global_rank() + wrapper = PipelineSharedModuleWrapper([0, pipeline_size - 1]) + parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank] + models = [] + for start, end in parts: + kwargs['num_layers'] = end - start + kwargs['first'] = start == 0 + kwargs['last'] = end == num_layers + logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers') + chunk = PipelineGPTHybrid(**kwargs).to(device) + if start == 0: + wrapper.register_module(chunk.embedding.word_embeddings) + elif end == num_layers: + wrapper.register_module(chunk.head) + models.append(chunk) + if len(models) == 1: + model = models[0] + else: + model = nn.ModuleList(models) + return model + + +def GPT2_exlarge_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float): + cfg = dict(hidden_size=1600, num_attention_heads=32, checkpoint=checkpoint, dtype=dtype) + return build_gpt_pipeline(48, num_chunks, **cfg) + + +def GPT3_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float): + cfg = dict(hidden_size=12288, num_attention_heads=96, + checkpoint=checkpoint, max_position_embeddings=2048, dtype=dtype) + return build_gpt_pipeline(96, num_chunks, **cfg) +``` + +## Process the dataset + +We provide a small GPT web-text dataset here. The original format is loose JSON, and we will save the processed dataset. + +```python +class WebtextDataset(Dataset): + def __init__(self, path, seq_len=1024) -> None: + super().__init__() + root = os.path.dirname(path) + encoded_data_cache_path = os.path.join(root, f'gpt_webtext_{seq_len}.pt') + if os.path.isfile(encoded_data_cache_path): + seq_len_, data, attention_mask = torch.load( + encoded_data_cache_path) + if seq_len_ == seq_len: + self.data = data + self.attention_mask = attention_mask + return + raw_data = [] + with open(path) as f: + for line in f.readlines(): + raw_data.append(json.loads(line)['text']) + tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + tokenizer.pad_token = tokenizer.unk_token + encoded_data = tokenizer( + raw_data, padding=True, truncation=True, max_length=seq_len, return_tensors='pt') + self.data = encoded_data['input_ids'] + self.attention_mask = encoded_data['attention_mask'] + torch.save((seq_len, self.data, self.attention_mask), + encoded_data_cache_path) + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + return { + 'input_ids': self.data[index], + 'attention_mask': self.attention_mask[index] + }, self.data[index] +``` + +## Training GPT using hybrid parallelism + +In the previous tutorial, we explained the meanings of some pipeline arguments. In this case, we can determine the shape of each output tensor which is exchanged among pipeline stages. For GPT, the shape is `(MICRO BATCH SIZE, SEQUENCE LEN, HIDDEN SIZE)`. By setting this, we can avoid exchanging the tensor shape of each stage. When you are not sure of the tensor shape, you can just leave it `None`, and the shape is inferred automatically. Make sure that the `dtype` of your model is correct. When you use `fp16`, the `dtype` of your model must be `torch.half`. Otherwise, the `dtype` must be `torch.float`. For pipeline parallelism, only `AMP_TYPE.NAIVE` is supported. + +You can easily use tensor parallel by setting `parallel` in `CONFIG`. The data parallelism size is automatically set based on the number of GPUs. + +```python +NUM_EPOCHS = 60 +SEQ_LEN = 1024 +BATCH_SIZE = 192 +NUM_CHUNKS = None +TENSOR_SHAPE = (1, 1024, 1600) +# only pipeline parallel +# CONFIG = dict(parallel=dict(pipeline=2), fp16=dict(mode=AMP_TYPE.NAIVE)) +# pipeline + 1D model parallel +CONFIG = dict(NUM_MICRO_BATCHES = 192, parallel=dict(pipeline=2, tensor=dict(mode='1d', size=2)), fp16=dict(mode=AMP_TYPE.NAIVE)) + + +def train(): + disable_existing_loggers() + parser = colossalai.get_default_parser() + args = parser.parse_args() + colossalai.launch_from_torch(config=CONFIG, backend=args.backend) + logger = get_dist_logger() + + train_ds = WebtextDataset(os.environ['DATA'], seq_len=SEQ_LEN) + train_dataloader = utils.get_dataloader(train_ds, + seed=42, + batch_size=BATCH_SIZE, + pin_memory=True, + shuffle=True, + drop_last=True) + + use_interleaved = NUM_CHUNKS is not None + num_chunks = 1 if not use_interleaved else NUM_CHUNKS + model = GPT2_exlarge_pipeline_hybrid(num_chunks=num_chunks, checkpoint=True, dtype=torch.half) + # model = GPT3_pipeline_hybrid(num_chunks=num_chunks, checkpoint=True, dtype=torch.half) + if use_interleaved and not isinstance(model, nn.ModuleList): + model = nn.ModuleList([model]) + + criterion = GPTLMLoss() + + optimizer = torch.optim.Adam(model.parameters(), lr=0.00015, weight_decay=1e-2,) + + engine, train_dataloader, _, _ = colossalai.initialize(model, + optimizer, + criterion, + train_dataloader=train_dataloader) + global_batch_size = BATCH_SIZE * \ + gpc.get_world_size(ParallelMode.DATA) * getattr(gpc.config, "gradient_accumulation", 1) + logger.info(f'Init done, global batch size = {global_batch_size}', ranks=[0]) + + timer = MultiTimer() + + trainer = Trainer( + engine=engine, + logger=logger, + timer=timer + ) + + hook_list = [ + hooks.LossHook(), + hooks.LogMetricByEpochHook(logger), + hooks.ThroughputHook(), + hooks.LogMetricByStepHook(), + ] + + trainer.fit( + train_dataloader=train_dataloader, + epochs=NUM_EPOCHS, + test_interval=1, + hooks=hook_list, + display_progress=True, + return_output_label=False, + ) +``` diff --git a/docs/source/en/advanced_tutorials/train_vit_using_pipeline_parallelism.md b/docs/source/en/advanced_tutorials/train_vit_using_pipeline_parallelism.md new file mode 100644 index 0000000000000000000000000000000000000000..b26599740c5f573182bf764d969a0d8c376e8660 --- /dev/null +++ b/docs/source/en/advanced_tutorials/train_vit_using_pipeline_parallelism.md @@ -0,0 +1,247 @@ +# Train ViT Using Pipeline Parallelism + +Author: Hongxin Liu, Yongbin Li + +**Example Code** +- [ColossalAI-Examples Pipeline Parallel ViT](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/image/vision_transformer/pipeline_parallel) + +**Related Paper** +- [Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM](https://arxiv.org/abs/2104.04473) + +## Introduction + +In this tutorial, you will learn how to train Vision Transformer for image classification from scratch, using pipeline. +Pipeline parallelism is a kind of model parallelism, which is useful when your GPU memory cannot fit your model. +By using it, we split the original model into multi stages, and each stage maintains a part of the original model. +We assume that your GPU memory cannot fit ViT/L-16, and your memory can fit this model. + +## Table of contents + +In this tutorial we will cover: + +1. The definition of ViT model, based on [TIMM](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) +2. Processing the dataset +3. Training ViT using pipeline + +## Import libraries + +```python +import os +from collections import OrderedDict +from functools import partial + +import colossalai +import colossalai.nn as col_nn +import torch +import torch.nn as nn +from colossalai.builder import build_pipeline_model +from colossalai.engine.schedule import (InterleavedPipelineSchedule, + PipelineSchedule) +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.trainer import Trainer, hooks +from colossalai.utils import MultiTimer, get_dataloader +from timm.models import vision_transformer as vit +from torchvision import transforms +from torchvision.datasets import CIFAR10 +``` + + + +## Define Vision Transformer model + +Generally, we provide 3 ways to build a pipelined model: + +1. `colossalai.builder.build_pipeline_model_from_cfg` +2. `colossalai.builder.build_pipeline_model` +3. Split the model by stages by yourself + +When your memory can fit the model, you can use the first two methods to build your model, otherwise you must split the model by yourself. The first two methods first build the whole model on CPU, then split the model, and finally you can just move the corresponding part of model to GPU. + +`colossalai.builder.build_pipeline_model_from_cfg()` receives a config file of model, and it can split the model uniformly (by layer) or balanced (by parameter size). + +If you are familiar with `PyTorch`, you can use `colossalai.builder.build_pipeline_model()` which receives a `torch.nn.Sequential` model and split it by layer uniformly. + +In this tutorial, we will modify [TIMM/ViT](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) to `torch.nn.Sequential` and then use `colossalai.builder.build_pipeline_model()` to build the pipelined model. + +When the data is **one** `Tensor`, you can use the positional argument in `forward()` of your model to get the data tensor. For the first stage of pipeline, the first positional argument of `forward()` is the data tensor loaded from data loader. For other stages, the first positional argument of `forward()` is the output tensor from the previous stage. Note that if the stage is not the last stage, the return of `forward()` must be a `Tensor`. + +When the data is a `dict` of `Tensor`, you can use named keyword arguments in `forward()` of your model to get the data `dict`. + +```python +class ViTEmbedding(nn.Module): + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, embed_layer=vit.PatchEmbed, drop_rate=0., distilled=False): + super().__init__() + self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 2 if distilled else 1 + self.patch_embed = embed_layer( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + self.init_weights() + + def forward(self, x): + x = self.patch_embed(x) + cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks + if self.dist_token is None: + x = torch.cat((cls_token, x), dim=1) + else: + x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) + x = self.pos_drop(x + self.pos_embed) + return x + + def init_weights(self): + vit.trunc_normal_(self.pos_embed, std=.02) + if self.dist_token is not None: + vit.trunc_normal_(self.dist_token, std=.02) + vit.trunc_normal_(self.cls_token, std=.02) + self.apply(vit._init_vit_weights) + + +class ViTHead(nn.Module): + def __init__(self, embed_dim=768, num_classes=1000, norm_layer=None, distilled=False, representation_size=None): + super().__init__() + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + self.norm = norm_layer(embed_dim) + self.num_classes = num_classes + self.distilled = distilled + self.num_features = embed_dim + # Representation layer + if representation_size and not distilled: + self.num_features = representation_size + self.pre_logits = nn.Sequential(OrderedDict([ + ('fc', nn.Linear(embed_dim, representation_size)), + ('act', nn.Tanh()) + ])) + else: + self.pre_logits = nn.Identity() + # Classifier head(s) + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.head_dist = None + if distilled: + self.head_dist = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + self.init_weights() + + def forward(self, x): + x = self.norm(x) + if self.distilled: + x, x_dist = self.head(x[:, 0]), self.head_dist(x[:, 1]) + if self.training and not torch.jit.is_scripting(): + # during inference, return the average of both classifier predictions + return x, x_dist + else: + return (x + x_dist) / 2 + else: + x = self.pre_logits(x[:, 0]) + x = self.head(x) + return x + + def init_weights(self): + self.apply(vit._init_vit_weights) + + +def sequential_vit(img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=vit.PatchEmbed, norm_layer=None, + act_layer=None): + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU + embedding = ViTEmbedding(img_size=img_size, patch_size=patch_size, in_chans=in_chans, + embed_dim=embed_dim, embed_layer=embed_layer, drop_rate=drop_rate, distilled=distilled) + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + blocks = [vit.Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) + for i in range(depth)] + for block in blocks: + block.apply(vit._init_vit_weights) + head = ViTHead(embed_dim=embed_dim, num_classes=num_classes, norm_layer=norm_layer, + distilled=distilled, representation_size=representation_size) + return nn.Sequential(embedding, *blocks, head) + + +def vit_large_patch16_224(**kwargs): + model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs) + return sequential_vit(**model_kwargs) +``` + +## Process the dataset + +Generally, we train ViT on large dataset like Imagenet. For simplicity, we just use CIFAR-10 here, since this tutorial is just for pipeline training. + +```python +def build_cifar(batch_size): + transform_train = transforms.Compose([ + transforms.RandomCrop(224, pad_if_needed=True), + transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + transform_test = transforms.Compose([ + transforms.Resize(224), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + + train_dataset = CIFAR10(root=os.environ['DATA'], train=True, download=True, transform=transform_train) + test_dataset = CIFAR10(root=os.environ['DATA'], train=False, transform=transform_test) + train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=batch_size, pin_memory=True) + test_dataloader = get_dataloader(dataset=test_dataset, batch_size=batch_size, pin_memory=True) + return train_dataloader, test_dataloader +``` + +## Training ViT using pipeline + +You can set the size of pipeline parallel and number of microbatches in config. `NUM_CHUNKS` is useful when using interleved-pipeline (for more details see [Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM](https://arxiv.org/abs/2104.04473) ). The original batch will be split into `num_microbatches`, and each stage will load a micro batch each time. Then we will generate an approriate schedule for you to execute the pipeline training. If you don't need the output and label of model, you can set `return_output_label` to `False` when calling `trainer.fit()` which can further reduce GPU memory usage. + +You should `export DATA=/path/to/cifar`. + +```python +BATCH_SIZE = 16 +NUM_EPOCHS = 60 +NUM_CHUNKS = 1 +CONFIG = dict(NUM_MICRO_BATCHES=4, parallel=dict(pipeline=2)) + + +def train(): + disable_existing_loggers() + parser = colossalai.get_default_parser() + args = parser.parse_args() + colossalai.launch_from_torch(backend=args.backend, config=CONFIG) + logger = get_dist_logger() + + # build model + model = vit_large_patch16_224() + model = build_pipeline_model(model, num_chunks=NUM_CHUNKS, verbose=True) + + # build criterion + criterion = nn.CrossEntropyLoss() + + # optimizer + optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0) + + # build dataloader + train_dataloader, test_dataloader = build_cifar(BATCH_SIZE) + + engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model, optimizer, criterion, + train_dataloader, test_dataloader) + timer = MultiTimer() + + trainer = Trainer(engine=engine, timer=timer, logger=logger) + + hook_list = [ + hooks.LossHook(), + hooks.AccuracyHook(col_nn.metric.Accuracy()), + hooks.LogMetricByEpochHook(logger), + ] + + trainer.fit(train_dataloader=train_dataloader, + epochs=NUM_EPOCHS, + test_dataloader=test_dataloader, + test_interval=1, + hooks=hook_list, + display_progress=True) +``` diff --git a/docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md b/docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md new file mode 100644 index 0000000000000000000000000000000000000000..b2438a1cf562a5f6bd0d22c96d12856b8afca4d6 --- /dev/null +++ b/docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md @@ -0,0 +1,646 @@ +# Step By Step: Accelerate ViT Training With Colossal-AI (From Data Parallel to Hybrid Parallel) + +Author: Yuxuan Lou + +**Example Code** + +- [Colossal-AI Examples ViT on Cifar10](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/image/vision_transformer) + +**Related Paper** +- [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/pdf/2010.11929.pdf) + + +## Introduction + +In this example for ViT model, Colossal-AI provides three different parallelism techniques which accelerate model training: data parallelism, pipeline parallelism and tensor parallelism. +We will show you how to train ViT on CIFAR-10 dataset with these parallelism techniques. To run this example, you will need 2-4 GPUs. + + +## Tabel of Contents +1. Colossal-AI installation +2. Steps to train ViT with data parallelism +3. Steps to train ViT with pipeline parallelism +4. Steps to train ViT with tensor parallelism or hybrid parallelism + +## Colossal-AI Installation +You can install Colossal-AI pacakage and its dependencies with PyPI. +```bash +pip install colossalai +``` + + + +## Data Parallelism +Data parallism is one basic way to accelerate model training process. You can apply data parallelism to training by only two steps: +1. Define a configuration file +2. Change a few lines of code in train script + +### Define your configuration file (`data_parallel/config.py`) +To use Colossal-AI, the first step is to define a configuration file. And there are two kinds of variables here: + +1. **Colossal-AI feature specification** + +There is an array of features Colossal-AI provides to speed up training (parallel mode, mixed precision, ZeRO, etc.). Each feature is defined by a corresponding field in the config file. If we apply data parallel only, we do not need to specify the parallel mode. In this example, we use mixed precision training natively provided by PyTorch by define the mixed precision configuration `fp16 = dict(mode=AMP_TYPE.TORCH)`. + +2. **Global hyper-parameters** + +Global hyper-parameters include model-specific hyper-parameters, training settings, dataset information, etc. + +```python +from colossalai.amp import AMP_TYPE + +# ViT Base +BATCH_SIZE = 256 +DROP_RATE = 0.1 +NUM_EPOCHS = 300 + +# mix precision +fp16 = dict( + mode=AMP_TYPE.TORCH, +) + +gradient_accumulation = 16 +clip_grad_norm = 1.0 + +dali = dict( + gpu_aug=True, + mixup_alpha=0.2 +) +``` + +### Modify train script (`/data_parallel/train_with_cifar10.py`) + +#### Import modules +- Colossal-AI related modules +```python +import colossalai +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.lr_scheduler import LinearWarmupLR +from colossalai.nn.metric import Accuracy +from colossalai.trainer import Trainer, hooks +``` + +- Other modules +```python +import os + +import torch +from timm.models import vit_base_patch16_224 + + +from torchvision import transforms +from torchvision.datasets import CIFAR10 +``` + +#### Lauch Colossal-AI + +In train script, you need to initialize the distributed environment for Colossal-AI after your config file is prepared. We call this process `launch`. In Colossal-AI, we provided several launch methods to initialize the distributed backend. In most cases, you can use `colossalai.launch` and `colossalai.get_default_parser` to pass the parameters via command line. Besides, Colossal-AI can utilize the existing launch tool provided by PyTorch as many users are familiar with by using `colossalai.launch_from_torch`. For more details, you can view the related [documents](https://www.colossalai.org/docs/basics/launch_colossalai). + +```python +# initialize distributed setting +parser = colossalai.get_default_parser() +args = parser.parse_args() +colossalai.launch_from_torch(config=args.config) + +disable_existing_loggers() +logger = get_dist_logger() +``` + +After initialization, you can access the variables in the config file by using `colossalai.core.global_context`. + +```python +#access parameters +print(gpc.config.BATCH_SIZE) +``` + +#### Build Model + +If only data parallelism is required, you do not need to make any changes to your model. Here, we use `vit_base_patch16_224` from `timm`. +```python +# build model +model = vit_base_patch16_224(drop_rate=0.1, num_classes=gpc.config.NUM_CLASSES) +``` + +#### Build CIFAR-10 Dataloader +`colossalai.utils.get_dataloader` can help you build dataloader easily. + +```python +def build_cifar(batch_size): + transform_train = transforms.Compose([ + transforms.RandomCrop(224, pad_if_needed=True), + transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + transform_test = transforms.Compose([ + transforms.Resize(224), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + + train_dataset = CIFAR10(root=os.environ['DATA'], train=True, download=True, transform=transform_train) + test_dataset = CIFAR10(root=os.environ['DATA'], train=False, transform=transform_test) + train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=batch_size, pin_memory=True) + test_dataloader = get_dataloader(dataset=test_dataset, batch_size=batch_size, pin_memory=True) + return train_dataloader, test_dataloader + + +# build dataloader +train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE) +``` + +#### Define optimizer, loss function and LR scheduler + +Colossal-AI provides its own optimizer, loss function and LR scheduler. Those from PyTorch are also compatible. + +```python +# build optimizer +optimizer = colossalai.nn.Lamb(model.parameters(), lr=1.8e-2, weight_decay=0.1) + +# build loss +criterion = torch.nn.CrossEntropyLoss() + +# lr_scheduler +lr_scheduler = LinearWarmupLR(optimizer, warmup_steps=50, total_steps=gpc.config.NUM_EPOCHS) +``` + +#### Start Colossal-AI engine + +Engine is essentially a wrapper class for model, optimizer and loss function. When we call `colossalai.initialize`, an engine object will be returned, and it has already been equipped with functionalities such as gradient clipping, gradient accumulation and zero optimizer as specified in your configuration file. Further model training is based on Colossal-AI engine. + +```python +engine, train_dataloader, test_dataloader, _ = colossalai.initialize( + model, optimizer, criterion, train_dataloader, test_dataloader + ) +``` + +#### Train: Trainer API +Trainer is a more high-level wrapper for the user to execute training with fewer lines of code. It is easy to create a trainer object by passing the engine object. + +Besides, In trainer, the user can customize some hooks and attach these hooks to the trainer object. A hook object will execute life-cycle methods periodically based on the training scheme. For example, The `LRSchedulerHook` will execute `lr_scheduler.step()` to update the learning rate of the model during either `after_train_iter` or `after_train_epoch` stages. + +```python +# build trainer +trainer = Trainer(engine=engine, logger=logger) + +# build hooks +hook_list = [ + hooks.LossHook(), + hooks.AccuracyHook(accuracy_func=MixupAccuracy()), + hooks.LogMetricByEpochHook(logger), + hooks.LRSchedulerHook(lr_scheduler, by_epoch=True), + + # comment if you do not need to use the hooks below + hooks.SaveCheckpointHook(interval=1, checkpoint_dir='./ckpt'), + hooks.TensorboardHook(log_dir='./tb_logs', ranks=[0]), +] +``` + +Use `trainer.fit` for training: + +```python +# start training +trainer.fit( + train_dataloader=train_dataloader, + test_dataloader=test_dataloader, + epochs=gpc.config.NUM_EPOCHS, + hooks=hook_list, + display_progress=True, + test_interval=1 +) +``` + +### Start training +`DATA` is the filepath where CIFAR-10 dataset will be automatically downloaded and stored. + +`` is the number of GPUs you want to use to train ViT on CIFAR-10 with data parallelism. + +```bash +export DATA= +# If your torch >= 1.10.0 +torchrun --standalone --nproc_per_node train_dp.py --config ./configs/config_data_parallel.py +# If your torch >= 1.9.0 +# python -m torch.distributed.run --standalone --nproc_per_node= train_dp.py --config ./configs/config_data_parallel.py +# Otherwise +# python -m torch.distributed.launch --nproc_per_node --master_addr --master_port 29500 train_dp.py --config ./configs/config.py +``` + + + +## Pipeline Parallelism +Aside from data parallelism, Colossal-AI also support pipeline parallelism. In specific, Colossal-AI uses 1F1B pipeline introduced by NVIDIA. For more details, you can view the related [documents](https://www.colossalai.org/tutorials/features/pipeline_parallel). + +### Define your configuration file(`hybrid_parallel/configs/vit_pipeline.py`) +To apply pipeline parallel on the data parallel basis, you only need to add a **parallel dict** +```python +from colossalai.amp import AMP_TYPE + +parallel = dict( + pipeline=2 +) +# pipeline config +NUM_MICRO_BATCHES = parallel['pipeline'] +TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LENGTH, HIDDEN_SIZE) + +fp16 = dict(mode=AMP_TYPE.NAIVE) +clip_grad_norm = 1.0 +``` + +Other configs: +```python +# hyper parameters +# BATCH_SIZE is as per GPU +# global batch size = BATCH_SIZE x data parallel size +BATCH_SIZE = 256 +LEARNING_RATE = 3e-3 +WEIGHT_DECAY = 0.3 +NUM_EPOCHS = 300 +WARMUP_EPOCHS = 32 + +# model config +IMG_SIZE = 224 +PATCH_SIZE = 16 +HIDDEN_SIZE = 768 +DEPTH = 12 +NUM_HEADS = 12 +MLP_RATIO = 4 +NUM_CLASSES = 10 +CHECKPOINT = True +SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE) ** 2 + 1 # add 1 for cls token +``` + +### Build pipeline model (`/hybrid_parallel/model/vit.py`) +Colossal-AI provides two methods to build a pipeline model from the existing model. +- `colossalai.builder.build_pipeline_model_from_cfg` +- `colossalai.builder.build_pipeline_model` + +Besides, you can also build a pipeline model from scratch with Colossal-AI. +```python +import math +from typing import Callable + +import inspect +import torch +from colossalai import nn as col_nn +from colossalai.registry import LAYERS, MODELS +from colossalai.logging import get_dist_logger +from colossalai.core import global_context as gpc +from colossalai.context import ParallelMode +from colossalai.builder.pipeline import partition_uniform +from torch import dtype, nn +from model_zoo.vit.vit import ViTBlock, ViTEmbedding, ViTHead + + +@MODELS.register_module +class PipelineVisionTransformer(nn.Module): + def __init__(self, + img_size: int = 224, + patch_size: int = 16, + in_chans: int = 3, + num_classes: int = 1000, + depth: int = 12, + num_heads: int = 12, + dim: int = 768, + mlp_ratio: int = 4, + attention_dropout: float = 0., + dropout: float = 0.1, + drop_path: float = 0., + layernorm_epsilon: float = 1e-6, + activation: Callable = nn.functional.gelu, + representation_size: int = None, + dtype: dtype = None, + bias: bool = True, + checkpoint: bool = False, + init_method: str = 'torch', + first_stage=True, + last_stage=True, + start_idx=None, + end_idx=None,): + super().__init__() + + layers = [] + + if first_stage: + embed = ViTEmbedding(img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embedding_dim=dim, + dropout=dropout, + dtype=dtype, + init_method=init_method) + layers.append(embed) + + # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, drop_path, depth)] + + if start_idx is None and end_idx is None: + start_idx = 0 + end_idx = depth + + blocks = [ + ViTBlock( + dim=dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + attention_dropout=attention_dropout, + dropout=dropout, + drop_path=dpr[i], + activation=activation, + dtype=dtype, + bias=bias, + checkpoint=checkpoint, + init_method=init_method, + ) for i in range(start_idx, end_idx) + ] + layers.extend(blocks) + + if last_stage: + norm = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype) + head = ViTHead(dim=dim, + num_classes=num_classes, + representation_size=representation_size, + dtype=dtype, + bias=bias, + init_method=init_method) + layers.extend([norm, head]) + + self.layers = nn.Sequential( + *layers + ) + + def forward(self, x): + x = self.layers(x) + return x + + +def _filter_kwargs(func, kwargs): + sig = inspect.signature(func) + return {k: v for k, v in kwargs.items() if k in sig.parameters} + + +def _build_pipeline_vit(module_cls, num_layers, num_chunks, device=torch.device('cuda'), **kwargs): + logger = get_dist_logger() + if gpc.is_initialized(ParallelMode.PIPELINE): + pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + else: + pipeline_size = 1 + pipeline_rank = 0 + rank = gpc.get_global_rank() + parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank] + models = [] + + for start, end in parts: + kwargs['first_stage'] = start == 0 + kwargs['last_stage'] = end == num_layers + kwargs['start_idx'] = start + kwargs['end_idx'] = end + logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers') + chunk = module_cls(**_filter_kwargs(module_cls.__init__, kwargs)).to(device) + models.append(chunk) + if len(models) == 1: + model = models[0] + else: + model = nn.ModuleList(models) + return model + + +def build_pipeline_vit(num_layers, num_chunks, device=torch.device('cuda'), **kwargs): + return _build_pipeline_vit(PipelineVisionTransformer, num_layers, num_chunks, device, **kwargs) +``` + +### Modify train script (`/hybrid_parallel/train_with_cifar10.py`) + +#### Import modules +```python +from colossalai.engine.schedule import (InterleavedPipelineSchedule, + PipelineSchedule) +from colossalai.utils import MultiTimer +import os + +import colossalai + +import torch +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.nn import CrossEntropyLoss +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.utils import is_using_pp, get_dataloader +from model.vit import build_pipeline_vit +from model_zoo.vit.vit import _create_vit_model +from tqdm import tqdm + +from torchvision import transforms +from torchvision.datasets import CIFAR10 +``` + +#### Launch Colossal-AI +`colossalai.utils.is_using_pp` can help check whether pipeline parallelism is required in config file. + +```python +# initialize distributed setting +parser = colossalai.get_default_parser() +args = parser.parse_args() + +# launch from torch +colossalai.launch_from_torch(config=args.config) + +# get logger +logger = get_dist_logger() +logger.info("initialized distributed environment", ranks=[0]) + +if hasattr(gpc.config, 'LOG_PATH'): + if gpc.get_global_rank() == 0: + log_path = gpc.config.LOG_PATH + if not os.path.exists(log_path): + os.mkdir(log_path) + logger.log_to_file(log_path) + +use_pipeline = is_using_pp() +``` + +#### Define model + +```python +# create model +model_kwargs = dict(img_size=gpc.config.IMG_SIZE, + patch_size=gpc.config.PATCH_SIZE, + dim=gpc.config.HIDDEN_SIZE, + depth=gpc.config.DEPTH, + num_heads=gpc.config.NUM_HEADS, + mlp_ratio=gpc.config.MLP_RATIO, + num_classes=gpc.config.NUM_CLASSES, + init_method='jax', + checkpoint=gpc.config.CHECKPOINT) + +if use_pipeline: + model = build_pipeline_vit(num_layers=model_kwargs['depth'], num_chunks=1, **model_kwargs) +else: + model = _create_vit_model(**model_kwargs) +``` + +#### Count number of parameters + +You can count model parameters on different pipeline stages easily. + +``` +# count number of parameters +total_numel = 0 +for p in model.parameters(): + total_numel += p.numel() +if not gpc.is_initialized(ParallelMode.PIPELINE): + pipeline_stage = 0 +else: + pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE) +logger.info(f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}") +``` + +#### Build dataloader, optimizer, etc. + +```python +def build_cifar(batch_size): + transform_train = transforms.Compose([ + transforms.RandomCrop(224, pad_if_needed=True), + transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + transform_test = transforms.Compose([ + transforms.Resize(224), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + + train_dataset = CIFAR10(root=os.environ['DATA'], train=True, download=True, transform=transform_train) + test_dataset = CIFAR10(root=os.environ['DATA'], train=False, transform=transform_test) + train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=batch_size, pin_memory=True) + test_dataloader = get_dataloader(dataset=test_dataset, batch_size=batch_size, pin_memory=True) + return train_dataloader, test_dataloader + + +# create dataloaders +train_dataloader , test_dataloader = build_cifar() + +# create loss function +criterion = CrossEntropyLoss(label_smoothing=0.1) + +# create optimizer +optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY) + +# create lr scheduler +lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, + total_steps=gpc.config.NUM_EPOCHS, + warmup_steps=gpc.config.WARMUP_EPOCHS) +``` + +#### Start Colossal-AI engine + +```python +# initialize +engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model, + optimizer=optimizer, + criterion=criterion, + train_dataloader=train_dataloader, + test_dataloader=test_dataloader) + +logger.info("Engine is built", ranks=[0]) +``` + +#### Train: based on engine + +In the data parallelism example, we show how to train a model with Trainer API. We can also directly train a model based on engine. In this way, you can customize your training with more features. + +```python +data_iter = iter(train_dataloader) + +for epoch in range(gpc.config.NUM_EPOCHS): + # training + engine.train() + + if gpc.get_global_rank() == 0: + description = 'Epoch {} / {}'.format( + epoch, + gpc.config.NUM_EPOCHS + ) + progress = tqdm(range(len(train_dataloader)), desc=description) + else: + progress = range(len(train_dataloader)) + for _ in progress: + engine.zero_grad() + engine.execute_schedule(data_iter, return_output_label=False) + engine.step() + lr_scheduler.step() +``` + +### Start training +```bash +export DATA= +# If your torch >= 1.10.0 +torchrun --standalone --nproc_per_node train_hybrid.py --config ./configs/config_pipeline_parallel.py +# If your torch >= 1.9.0 +# python -m torch.distributed.run --standalone --nproc_per_node= train_hybrid.py --config ./configs/config_pipeline_parallel.py +``` + + + + +## Tensor Parallelism and Hybrid Parallelism +Tensor parallelism partitions each weight parameter across multiple devices in order to reduce memory load. Colossal-AI support 1D, 2D, 2.5D and 3D tensor parallelism. Besides, you can combine tensor parallelism with pipeline parallelism and data parallelism to reach hybrid parallelism. Colossal-AI also provides an easy way to apply tensor parallelism and hybrid parallelism. On the basis of pipeline parallelism, a few lines of code changing in config file is all you need. + +### Define your configuration file(`/hybrid_parallel/configs/vit_1d_tp2_pp2.py`) +To use tensor parallelism, you only need to add related information to the **parallel dict**. To be specific, `TENSOR_PARALLEL_MODE` can be '1d', '2d', '2.5d', '3d'. And the size of different parallelism should satisfy: `#GPUs = pipeline parallel size x tensor parallel size x data parallel size`. `data parallel size` will automatically computed after you specify the number of GPUs, pipeline parallel size and tensor parallel size. + +```python +from colossalai.amp import AMP_TYPE +# parallel setting +TENSOR_PARALLEL_SIZE = 2 +TENSOR_PARALLEL_MODE = '1d' + +parallel = dict( + pipeline=2, + tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE) +) + +fp16 = dict(mode=AMP_TYPE.NAIVE) +clip_grad_norm = 1.0 + + +# pipeline config +NUM_MICRO_BATCHES = parallel['pipeline'] +TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LENGTH, HIDDEN_SIZE) +``` + +Ohter configs: +```python +# hyper parameters +# BATCH_SIZE is as per GPU +# global batch size = BATCH_SIZE x data parallel size +BATCH_SIZE = 256 +LEARNING_RATE = 3e-3 +WEIGHT_DECAY = 0.3 +NUM_EPOCHS = 300 +WARMUP_EPOCHS = 32 + +# model config +IMG_SIZE = 224 +PATCH_SIZE = 16 +HIDDEN_SIZE = 768 +DEPTH = 12 +NUM_HEADS = 12 +MLP_RATIO = 4 +NUM_CLASSES = 10 +CHECKPOINT = True +SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE) ** 2 + 1 # add 1 for cls token +``` + +### Start training +```bash +export DATA= +# If your torch >= 1.10.0 +torchrun --standalone --nproc_per_node train_hybrid.py --config ./configs/config_hybrid_parallel.py +# If your torch >= 1.9.0 +# python -m torch.distributed.run --standalone --nproc_per_node= train_hybrid.py --config ./configs/config_hybrid_parallel.py +``` diff --git a/docs/source/en/basics/colotensor_concept.md b/docs/source/en/basics/colotensor_concept.md new file mode 100644 index 0000000000000000000000000000000000000000..909c5e4d3c6f13933e80a7dbbe00473c141c4613 --- /dev/null +++ b/docs/source/en/basics/colotensor_concept.md @@ -0,0 +1,96 @@ +# ColoTensor Concepts + +Author: [Jiarui Fang](https://github.com/feifeibear), [Hongxin Liu](https://github.com/ver217) and [Haichen Huang](https://github.com/1SAA) + +**Prerequisite:** +- [Colossal-AI Overview](../concepts/colossalai_overview.md) +- [Distributed Training](../concepts/distributed_training.md) +- [Paradigms of Parallelism](../concepts/paradigms_of_parallelism.md) + +## Introduction + +After ColossalAI version 0.1.8, [ColoTensor](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.html#colossalai.tensor.ColoTensor) becomes the basic data structure for tensors in ColossalAI. It is a subclass of torch.Tensor and can be used as a PyTorch Tensor. Additionally, some unique features make it possible to represent a Global Tensor with a payload distributed across multiple GPU devices. With the help of ColoTensor, the users can write distributed DNN training program similar to a serial one.support the following features. + +ColoTensor contains extra attributes capsuled in a [ColoTensorSpec](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.tensor_spec.html#colossalai.tensor.tensor_spec.ColoTensorSpec) instance to describe the tensor's payload distribution and computing pattern. + +- ProcessGroup: how processes are organized as communication groups. +- Distributed Spec: how tensor is distributed among process groups. +- Compute Spec: how the tensor is used during computation. + +We elaborate on them one by one. + +## ProcessGroup + +An instance of class [ProcessGroup](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.html#colossalai.tensor.ProcessGroup) describes how processes are organized in process groups. Processes in a process group can participate in the same collective communication operations together, such as allgather, allreduce, etc. The way the process group is organized is dominated by the Tensor's parallelism strategy. For example, if the user defines the tensor parallel (TP) and data parallel (DP) modes of a tensor, then the process organization of the process group will be automatically deduced. The process group settings can vary among different tensors. Therefore, it enables us to support more complicated hybrid parallel. The pipeline parallel (PP) definition is not in the ProcessGroup, it needs another set of mechanisms . We will supplement the related content of ColoTensor applied to PP in the future. + +Currently, a process group of ColoTensor is defined by two configurations, i.e. tp_degree and dp_degree. In the case of DP+TP hybrid parallelism, the device can be viewed as a 2D mesh. We place TP communication groups on the leading low dimension of the device mesh and then place the data parallel groups along the high dimension of the device mesh. The reason is that tensor parallelism has a larger communication overhead than data parallelism. Neighboring devices are placed inside a TP process group and are often placed in the same node. + +Considering that 8 processes are configured as tp_degree=4, and dp_degree=2, the layout is shown below. Process group tp0 contains gpu 0,1,2,3. Process dp1 contains gpu 1 and 5. + +
+ +
Process Group using tp_degree=4, dp_degree=2
+
+ +## Distributed Spec + +An instance of [Distributed Spec](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.distspec.html) describes how a ColoTensor is distributed among the ProcessGroup. + +How tensors are distributed among DP process groups is automatically derived and does not need to be manually specified by the user. If this tensor is a model parameter, it is replicated within the DP process group. If it is an activation tensor, it is split along the process with the highest dimension and evenly distributed the tensor payload among processes in the DP process group. + +Therefore, when using Distributed Spec, we only need to describe the way that the tensor is distributed among TP process groups. There are currently two ways to distribute among TP process group, i.e. [ShardSpec](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.distspec.html#colossalai.tensor.distspec.ShardSpec) and [ReplicaSpec](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.distspec.html#colossalai.tensor.distspec.ReplicaSpec). ShardSpec needs to specify the dimension index dim of the partition and the number of partitions num_partitions. Currently, we only support the split on a single dim. Different dist specs on the TP process groups can be converted to each other through the set_dist_spec() interface. The spec conversions are recorded by the autograd mechanism and it will trigger corresponding reverse operations during backward propagation. + +## Compute Spec + +An instance of class [ComputeSpec](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.compute_spec.html#colossalai.tensor.compute_spec.ComputeSpec) describes how a Colotensor be used in DNN training. Currently, we will set the correct Compute Pattern for the ColoTensor as the parameters of the module. The specific application scenarios will be shown in the next document. + +## ColoParameter + +[ColoParameter](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.colo_parameter.html#colossalai.tensor.colo_parameter.ColoParameter) is a subclass of ColoTensor. Used to define a Global Parameter tensor. Its relationship with ColoTensor is consistent with Torch.Tensor and torch.Parameter. The latter allows the tensor to appear in the return values of the module's parameters() and name_parameters() methods. + +## Example + +Let's see an example. A ColoTensor is initialized and sharded on 8 GPUs using tp_degree=4, dp_dgree=2. And then the tensor is sharded along the last dim among the TP process groups. Finally, we reshard it along the first dim (0 dim) among the TP process groups. We encourage users to run the code and observe the shape of each tensor. + + +```python +import torch +import torch.multiprocessing as mp +from colossalai.utils import print_rank_0 +from functools import partial + +import colossalai +from colossalai.tensor import ProcessGroup, ColoTensor, ColoTensorSpec, ShardSpec, ComputeSpec, ComputePattern +from colossalai.testing import spawn + +import torch + +def run_dist_tests(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + pg = ProcessGroup(tp_degree=2, dp_degree=2) + + torch.manual_seed(0) + local_tensor = torch.randn(2, 3, 1).cuda() + print_rank_0(f"shape {local_tensor.shape}, {local_tensor.data}") + + spec = ColoTensorSpec(pg, ShardSpec(dims=[-1], num_partitions=[pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + t1 = ColoTensor.from_torch_tensor(local_tensor, spec) + t1 = t1.to_replicate() + print_rank_0(f"shape {t1.shape}, {t1.data}") + + spec2 = ShardSpec([0], [pg.tp_world_size()]) + t1.set_dist_spec(spec2) + print_rank_0(f"shape {t1.shape}, {t1.data}") + +def test_dist_cases(world_size): + spawn(run_dist_tests, world_size) + +if __name__ == '__main__': + test_dist_cases(4) +``` + +:::caution + +The ColoTensor is an experimental feature and may be updated. + +::: diff --git a/docs/source/en/basics/command_line_tool.md b/docs/source/en/basics/command_line_tool.md new file mode 100644 index 0000000000000000000000000000000000000000..48b199cf78e9b36118443b482ff867c0bfc7996a --- /dev/null +++ b/docs/source/en/basics/command_line_tool.md @@ -0,0 +1,53 @@ +# Command Line Tool + +Author: Shenggui Li + +**Prerequisite:** +- [Distributed Training](../concepts/distributed_training.md) +- [Colossal-AI Overview](../concepts/colossalai_overview.md) + +## Introduction + +Colossal-AI provides command-line utilities for the user. +The current command line tools support the following features. + +- verify Colossal-AI build +- launch distributed jobs +- tensor parallel micro-benchmarking + +## Check Installation + +To verify whether your Colossal-AI is built correctly, you can use the command `colossalai check -i`. +This command will inform you information regarding the version compatibility and cuda extension. + +
+ +
Check Installation Demo
+
+ +## Launcher + +To launch distributed jobs on single or multiple nodes, the command `colossalai run` can be used for process launching. +You may refer to [Launch Colossal-AI](./launch_colossalai.md) for more details. + +## Tensor Parallel Micro-Benchmarking + +As Colossal-AI provides an array of tensor parallelism methods, it is not intuitive to choose one for your hardware and +model. Therefore, we provide a simple benchmarking to evaluate the performance of various tensor parallelisms on your system. +This benchmarking is run on a simple MLP model where the input data is of the shape `(batch_size, seq_length, hidden_size)`. +Based on the number of GPUs, the CLI will look for all possible tensor parallel configurations and display the benchmarking results. +You can customize the benchmarking configurations by checking out `colossalai benchmark --help`. + +```shell +# run on 4 GPUs +colossalai benchmark --gpus 4 + +# run on 8 GPUs +colossalai benchmark --gpus 8 +``` + +:::caution + +Only single-node benchmarking is supported currently. + +::: diff --git a/docs/source/en/basics/configure_parallelization.md b/docs/source/en/basics/configure_parallelization.md new file mode 100644 index 0000000000000000000000000000000000000000..4ac0299eac14252eb55f3f635d5747702d2033d5 --- /dev/null +++ b/docs/source/en/basics/configure_parallelization.md @@ -0,0 +1,156 @@ +# Configure Parallelization + +Author: Shenggui Li, Siqi Mai + +**Prerequisite:** +- [Distributed Training](../concepts/distributed_training.md) +- [Paradigms of Parallelism](../concepts/paradigms_of_parallelism.md) +- [Define Your Configuration](./define_your_config.md) + + +## Introduction + +We support multiple parallelization in Colossal-AI. Hybrid parallelism in our codebase refers to namely the combination +of data parallelism, pipeline parallelism and tensor parallelism (1D, 2D, 2.5D, 3D). + +Each parallelism requires different network topology and thus initialize different process groups. +You can initialize the corresponding process group by setting `parallel` in the config file. +The configuration for `parallel` must obey the following format. Data parallel size will be +inferred automatically based on your inputs to pipeline parallelism and tensor parallelism. +`colossalai.launch` will initialize these distributed process groups automatically based on your configuration. + +Some sample configurations are shown below: + +```python +# sampler format +parallel = dict( + pipeline=dict("size": int), + tensor=dict("size": int, "mode": '1d' or '2d' or '2.5d' or '3d', "kwargs": Any) +) + +# this is ok +parallel = dict( + pipeline=dict(size=2), + tensor=dict(size=4, mode='2d') +) + +# this is ok +parallel = dict( + pipeline=2, + tensor=dict(size=4, mode='2d') +) + +# this is not ok +# as you need to specify the mode for tensor parallelism +parallel = dict( + pipeline=2, + tensor=4 +) + +# this is ok as well as tensor will be default to size 1 +# and mode None +parallel = dict( + pipeline=2 +) + +# this is ok as well as pipeline will default to size 1 +parallel = dict( + tensor=dict(size=4, mode='2d') +) + +``` + +The key name `size` refers to the parallel size of the parallelism dimension. For example, pipeline size 2 means there +will be 2 pipeline stages. The key name `mode` in tensor parallel config means the corresponding tensor parallelism +will be initialized. + +**You can choose to not have 'parallel' in your configuration and both pipeline and tensor will default to size 1.** + +**Total number of GPUs must be equal to `data parallel size * tensor parallel size * pipeline parallel size`** + +## Data Parallel + +Data parallel is the most common way to distribute your training task by splitting data into several shards and train on +a single shard on each device. The configuration for data parallel is detected automatically and set for you. You do not +have to explicitly set them in your configurations. There are two ways to handle the all-reduce in data parallel in Colossal-AI. + +1. If you specify gradient handlers, gradients will be all-reduced according to the gradient handlers +2. Otherwise, PyTorch DistributedDataParallel will be used + +In most cases, you will be using the second mode unless you have complex handling of the gradients. + +## 1D, 2D, 2.5D and 3D Parallel + +To enable hybrid parallelism, we provide an array of tensor parallelism. We provide the list of papers which match each +tensor parallel method. These parallel modes need to work with the distributed layers provided by Colossal-AI. + +- 1D: [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) + +- 2D: [An Efficient 2D Method for Training Super-Large Deep Learning Models](https://arxiv.org/abs/2104.05343) + 2D parallel relies on the SUMMA matrix multiplication algorithm and splits the input data, model weights and layer + outputs along two different dimensions. The tensor chunks are distributed over a 2D mesh of `P = N^2` devices where + `N` is the number of tensor chunks in a single dimension. + +- 2.5D: [2.5-dimensional distributed model training](https://arxiv.org/abs/2105.14500) + Inspired by the 2.5D matrix multiplication algorithm, 2.5D parallel introduces a novel tensor parallelism which + further parallelizes 2D tensor parallelism. An amount of `P = N^2 ∗ d` processors are arranged into `d` layers, where + each layer performs matrix multiplication operations independently with a dimension `N`. + +- 3D: [Maximizing Parallelism in Distributed Training for Huge Neural Networks](https://arxiv.org/abs/2105.14450) + We also introduce a 3D tensor parallelism that parallelizes neural networks on a 3D processor cube. This method + achieves the optimal, `O(P^{1/3})` communication overhead on $P$ processors, while both computation and memory usage + are evenly distributed through optimized load balancing of parameters as well as activations. + +```python +# 1D parallel +parallel = dict( + tensor=dict(size=4, mode='1d') +) + +# 2D parallel +parallel = dict( + tensor=dict(size=4, mode='2d') +) + +# 2.5D parallel +parallel = dict( + tensor=dict(size=8, mode='2.5d', depth=2) +) + +# 3D parallel +parallel = dict( + tensor=dict(size=8, mode='3d') +) +``` + +Once you specify the tensor parallel mode in your configuration, you can proceed to use its corresponding distributed +operator. For example, if you mode is '2d', you can use `colossalai.nn.Linear2D` in you model construction. + + +## Pipeline Parallel + +Pipeline parallelism is to split the model into several partitions by layer. For example, let's assume we have a simple +model which consists of two linear layer. We have two GPUs, and we can allocate the first linear layer to the first GPU +and the second layer to the second GPU. + +You can set the number of pipeline stages in your configuration file. When pipeline size is larger than 1, Colossal-AI +will automatically creates the pipeline schedule which defines the forward and backward step. + +```python +parallel = dict( + pipeline=dict(size=4), # number of pipeline stages +) +``` + +## Sequence Parallel + +Sequence parallel is to support long-sequence modelling such as document-level text understanding and medical imaging. +This method is proposed in [Sequence Parallelism: Making 4D Parallelism Possible](https://arxiv.org/abs/2105.13120). +You can use specify the mode to be `sequence` to initialize its process group. + + +```python +parallel = dict( + tensor=dict(size=4, mode='sequence') +) +``` diff --git a/docs/source/en/basics/define_your_config.md b/docs/source/en/basics/define_your_config.md new file mode 100644 index 0000000000000000000000000000000000000000..d2569691b7dc2f910bd5e676ef3b8b2978b53bc2 --- /dev/null +++ b/docs/source/en/basics/define_your_config.md @@ -0,0 +1,82 @@ +# Define Your Configuration + +Author: Guangyang Lu, Shenggui Li, Siqi Mai + +**Prerequisite:** +- [Distributed Training](../concepts/distributed_training.md) +- [Colossal-AI Overview](../concepts/colossalai_overview.md) + + +## Introduction + +In Colossal-AI, a configuration file is required to specify the features the system will inject into the training process. +In this tutorial, we will introduce you how to construct your configuration file and how this config file will be used. +Using configuration file has several advantages: + +1. You can store your feature configuration and training hyper-parameters in different configuration files +2. New features released in the future can be specified in the configuration without code change in the training script + +In this tutorial, we will cover how to define your configuration file. + +## Configuration Definition + +In a configuration file, there are two types of variables. One serves as feature specification and the other serves +as hyper-parameters. All feature-related variables are reserved keywords. For example, if you want to use mixed precision +training, you need to use the variable name `fp16` in the config file and follow a pre-defined format. + +### Feature Specification + +There is an array of features Colossal-AI provides to speed up training. Each feature is defined by a corresponding field +in the config file. In this tutorial, we are not giving the config details for all the features, but rather we are providing +an illustration of how to specify a feature. **The details of each feature can be found in its respective tutorial.** + +To illustrate the use of config file, we use mixed precision training as an example here. In order to do so, you need to +follow the steps below. + +1. create a configuration file (e.g. `config.py`, the file name can be anything) +2. define the mixed precision configuration in the config file. For example, in order to use mixed precision training +natively provided by PyTorch, you can just write these lines of code below into your config file. + + ```python + from colossalai.amp import AMP_TYPE + + fp16 = dict( + mode=AMP_TYPE.TORCH + ) + ``` + +3. Tell Colossal-AI where your config file is when launch the distributed environment. For example, the config file is in +the current directory. + + ```python + import colossalai + + colossalai.launch(config='./config.py', ...) + ``` + +In this way, Colossal-AI knows what features you want to use and will inject this feature during `colossalai.initialize`. + +### Global Hyper-parameters + +Besides feature specification, the config file can also serve as a place to define your training hyper-parameters. This +comes handy when you want to perform multiple experiments, each experiment details can be put into a single config file +to avoid confusion. These parameters will be stored in the global parallel context and can be accessed in the training script. + +For example, you can specify the batch size in your config file. + +```python +BATCH_SIZE = 32 +``` + +After launch, you are able to access your hyper-parameters through global parallel context. + +```python +import colossalai +from colossalai.core import global_context as gpc + +colossalai.launch(config='./config.py', ...) + +# access your parameter +print(gpc.config.BATCH_SIZE) + +``` diff --git a/docs/source/en/basics/engine_trainer.md b/docs/source/en/basics/engine_trainer.md new file mode 100644 index 0000000000000000000000000000000000000000..bbe32ed5a3b579805d99ead1169b69a196916626 --- /dev/null +++ b/docs/source/en/basics/engine_trainer.md @@ -0,0 +1,387 @@ +# Use Engine and Trainer in Training + +Author: Shenggui Li, Siqi Mai + +**Prerequisite:** +- [Initialize Features](./initialize_features.md) + +## Introduction + +In this tutorial, you will learn how to use the engine and trainer provided in Colossal-AI to train your model. +Before we delve into the details, we would like to first explain the concept of engine and trainer. + +### Engine + +Engine is essentially a wrapper class for model, optimizer and loss function. +When we call `colossalai.initialize`, an engine object will be returned, and it has already been equipped with +functionalities such as gradient clipping, gradient accumulation and zero optimizer as specified in your configuration file. +An engine object will use similar APIs to those of PyTorch training components such that the user has minimum change +to their code. + +Below is a table which shows the commonly used APIs for the engine object. + +| Component | Function | PyTorch | Colossal-AI | +| ------------------------------------- | --------------------------------------------- | ------------------------------- | -------------------------------------- | +| optimizer | Set all gradients to zero before an iteration | optimizer.zero_grad() | engine.zero_grad() | +| optimizer | Update the parameters | optimizer.step() | engine.step() | +| model | Run a forward pass | outputs = model(inputs) | outputs = engine(inputs) | +| criterion | Calculate the loss value | loss = criterion(output, label) | loss = engine.criterion(output, label) | +| criterion | Execute back-propagation on the model | loss.backward() | engine.backward(loss) | + +The reason why we need such an engine class is that we can add more functionalities while hiding the implementations in +the `colossalai.initialize` function. +Imaging we are gonna add a new feature, we can manipulate the model, optimizer, dataloader and loss function in the +`colossalai.initialize` function and only expose an engine object to the user. +The user only needs to modify their code to the minimum extent by adapting the normal PyTorch APIs to the Colossal-AI +engine APIs. In this way, they can enjoy more features for efficient training. + +A normal training iteration using engine can be: + +```python +import colossalai + +# build your model, optimizer, criterion, dataloaders +... + +engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model, + optimizer, + criterion, + train_dataloader, + test_dataloader) +for img, label in train_dataloader: + engine.zero_grad() + output = engine(img) + loss = engine.criterion(output, label) + engine.backward(loss) + engine.step() +``` + +### Trainer + +Trainer is a more high-level wrapper for the user to execute training with fewer lines of code. However, in pursuit of more abstraction, it loses some flexibility compared to engine. The trainer is designed to execute a forward and backward step to perform model weight update. It is easy to create a trainer object by passing the engine object. The trainer has a default value `None` for the argument `schedule`. In most cases, we leave this value to `None` unless we want to use pipeline parallelism. If you wish to explore more about this parameter, you can go to the tutorial on pipeline parallelism. + +```python +from colossalai.logging import get_dist_logger +from colossalai.trainer import Trainer, hooks + +# build components and initialize with colossalai.initialize +... + +# create a logger so that trainer can log on the console +logger = get_dist_logger() + +# create a trainer object +trainer = Trainer( + engine=engine, + logger=logger +) +``` + + + +In trainer, the user can customize some hooks and attach these hooks to the trainer object. A hook object will execute life-cycle methods periodically based on the training scheme. For example, The `LRSchedulerHook` will execute `lr_scheduler.step()` to update the learning rate of the model during either `after_train_iter` or `after_train_epoch` stages depending on whether the user wants to update the learning rate after each training iteration or only after the entire training epoch. You can store the hook objects in a list and pass it to `trainer.fit` method. `trainer.fit` method will execute training and testing based on your parameters. If `display_process` is True, a progress bar will be displayed on your console to show the training process. + +```python +# define the hooks to attach to the trainer +hook_list = [ + hooks.LossHook(), + hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True), + hooks.AccuracyHook(accuracy_func=Accuracy()), + hooks.LogMetricByEpochHook(logger), +] + +# start training +trainer.fit( + train_dataloader=train_dataloader, + epochs=NUM_EPOCHS, + test_dataloader=test_dataloader, + test_interval=1, + hooks=hook_list, + display_progress=True +) +``` + +If you want to customize your own hook class, you can inherit `hooks.BaseHook` and override the life-cycle methods of your interest. A dummy example to demonstrate how to create a simple log message hook is provided below for your reference. + +```python +from colossalai.logging import get_dist_logger +from colossalai.trainer import hooks + +class LogMessageHook(hooks.BaseHook): + + def __init__(self, priority=10): + self._logger = get_dist_logger() + + def before_train(self, trainer): + self._logger.info('training starts') + + def after_train(self, trainer): + self._logger.info('training finished') + + +... + +# then in your training script +hook_list.append(LogMessageHook()) +``` + + + +In the sections below, I will guide you through the steps required to train a ResNet model with both engine and trainer. + + + +## Explain with ResNet + +### Overview + +In this section we will cover: + +1. Use an engine object to train a ResNet34 model on CIFAR10 dataset +2. Use a trainer object to train a ResNet34 model on CIFAR10 dataset + +The project structure will be like: + +```bash +-- config.py +-- run_resnet_cifar10_with_engine.py +-- run_resnet_cifar10_with_trainer.py +``` + +Steps 1-4 below are commonly used regardless of using engine or trainer. Thus, steps 1-4 + step 5 will be your `run_resnet_cifar10_with_engine.py` and steps 1-4 + step 6 will form `run_resnet_cifar10_with_trainer.py`. + +### Hands-on Practice + +#### Step 1. Create a Config File + +In your project folder, create a `config.py`. This file is to specify some features you may want to use to train your model. A sample config file is as below: + +```python +from colossalai.amp import AMP_TYPE + +BATCH_SIZE = 128 +NUM_EPOCHS = 200 + +fp16=dict( + mode=AMP_TYPE.TORCH +) +``` + +In this config file, we specify that we want to use batch size 128 per GPU and run for 200 epochs. These two parameters are exposed by `gpc.config`. For example, you can use `gpc.config.BATCH_SIZE` to access the value you store in your config file. The `fp16` configuration tells `colossalai.initialize` to use mixed precision training provided by PyTorch to train the model with better speed and lower memory consumption. + +#### Step 2. Initialize Distributed Environment + +We need to initialize the distributed training environment. This has been introduced in the tutorial on how to +[launch Colossal-AI](./launch_colossalai.md). For this demonstration, we use `launch_from_torch` and PyTorch launch utility. + +```python +import colossalai + +# ./config.py refers to the config file we just created in step 1 +colossalai.launch_from_torch(config='./config.py') +``` + +#### Step 3. Create all the training components + +In this step, we can create all the components used for training. These components include: + +1. Model +2. Optimizer +3. Criterion/loss function +4. Training/Testing dataloaders +5. Learning rate Scheduler +6. Logger + + + +To build these components, you need to import the following modules: + +```python +from pathlib import Path +from colossalai.logging import get_dist_logger +import torch +import os +from colossalai.core import global_context as gpc +from colossalai.utils import get_dataloader +from torchvision import transforms +from colossalai.nn.lr_scheduler import CosineAnnealingLR +from torchvision.datasets import CIFAR10 +from torchvision.models import resnet34 +``` + + + +Then build your components in the same way as how to normally build them in your PyTorch scripts. In the script below, we set the root path for CIFAR10 dataset as an environment variable `DATA`. You can change it to any path you like, for example, you can change `root=Path(os.environ['DATA'])` to `root='./data'` so that there is no need to set the environment variable. + +```python +# build logger +logger = get_dist_logger() + +# build resnet +model = resnet34(num_classes=10) + +# build datasets +train_dataset = CIFAR10( + root='./data', + download=True, + transform=transforms.Compose( + [ + transforms.RandomCrop(size=32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[ + 0.2023, 0.1994, 0.2010]), + ] + ) +) + +test_dataset = CIFAR10( + root='./data', + train=False, + transform=transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[ + 0.2023, 0.1994, 0.2010]), + ] + ) +) + +# build dataloaders +train_dataloader = get_dataloader(dataset=train_dataset, + shuffle=True, + batch_size=gpc.config.BATCH_SIZE, + num_workers=1, + pin_memory=True, + ) + +test_dataloader = get_dataloader(dataset=test_dataset, + add_sampler=False, + batch_size=gpc.config.BATCH_SIZE, + num_workers=1, + pin_memory=True, + ) + +# build criterion +criterion = torch.nn.CrossEntropyLoss() + +# optimizer +optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) + +# lr_scheduler +lr_scheduler = CosineAnnealingLR(optimizer, total_steps=gpc.config.NUM_EPOCHS) +``` + +#### Step 4. Initialize with Colossal-AI + +Next, the essential step is to obtain the engine class by calling `colossalai.initialize`. As stated in `config.py`, we will be using mixed precision training for training ResNet34 model. `colossalai.initialize` will automatically check your config file and assign relevant features to your training components. In this way, our engine object has already been able to train with mixed precision, but you do not have to explicitly take care of it. + +```python +engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model, + optimizer, + criterion, + train_dataloader, + test_dataloader, + ) +``` + + + +#### Step 5. Train with engine + +With all the training components ready, we can train ResNet34 just like how to normally deal with PyTorch training. + +```python +for epoch in range(gpc.config.NUM_EPOCHS): + # execute a training iteration + engine.train() + for img, label in train_dataloader: + img = img.cuda() + label = label.cuda() + + # set gradients to zero + engine.zero_grad() + + # run forward pass + output = engine(img) + + # compute loss value and run backward pass + train_loss = engine.criterion(output, label) + engine.backward(train_loss) + + # update parameters + engine.step() + + # update learning rate + lr_scheduler.step() + + # execute a testing iteration + engine.eval() + correct = 0 + total = 0 + for img, label in test_dataloader: + img = img.cuda() + label = label.cuda() + + # run prediction without back-propagation + with torch.no_grad(): + output = engine(img) + test_loss = engine.criterion(output, label) + + # compute the number of correct prediction + pred = torch.argmax(output, dim=-1) + correct += torch.sum(pred == label) + total += img.size(0) + + logger.info( + f"Epoch {epoch} - train loss: {train_loss:.5}, test loss: {test_loss:.5}, acc: {correct / total:.5}, lr: {lr_scheduler.get_last_lr()[0]:.5g}", ranks=[0]) +``` + +#### Step 6. Train with trainer + +If you wish to train with a trainer object, you can follow the code snippet below: + +```python +from colossalai.nn.metric import Accuracy +from colossalai.trainer import Trainer, hooks + + +# create a trainer object +trainer = Trainer( + engine=engine, + logger=logger +) + +# define the hooks to attach to the trainer +hook_list = [ + hooks.LossHook(), + hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True), + hooks.AccuracyHook(accuracy_func=Accuracy()), + hooks.LogMetricByEpochHook(logger), + hooks.LogMemoryByEpochHook(logger) +] + +# start training +# run testing every 1 epoch +trainer.fit( + train_dataloader=train_dataloader, + epochs=gpc.config.NUM_EPOCHS, + test_dataloader=test_dataloader, + test_interval=1, + hooks=hook_list, + display_progress=True +) +``` + + + +#### Step 7. Start Distributed Training + +Lastly, we can invoke the scripts using the distributed launcher provided by PyTorch as we used `launch_from_torch` in Step 2. You need to replace `` with the number of GPUs available on your machine. This number can be 1 if you only want to use 1 GPU. If you wish to use other launchers, you can refer to the tutorial on How to Launch Colossal-AI. + +```bash +# with engine +python -m torch.distributed.launch --nproc_per_node --master_addr localhost --master_port 29500 run_resnet_cifar10_with_engine.py +# with trainer +python -m torch.distributed.launch --nproc_per_node --master_addr localhost --master_port 29500 run_resnet_cifar10_with_trainer.py +``` diff --git a/docs/source/en/basics/initialize_features.md b/docs/source/en/basics/initialize_features.md new file mode 100644 index 0000000000000000000000000000000000000000..e768d2022ad8d706980a1dcd1b90ffd7a8be5367 --- /dev/null +++ b/docs/source/en/basics/initialize_features.md @@ -0,0 +1,49 @@ +# Initialize Features + +Author: Shenggui Li, Siqi Mai + +**Prerequisite:** +- [Distributed Training](../concepts/distributed_training.md) +- [Colossal-AI Overview](../concepts/colossalai_overview.md) + +## Introduction + +In this tutorial, we will cover the use of `colossalai.initialize` which injects features into your training components +(e.g. model, optimizer, dataloader) seamlessly. Calling `colossalai.initialize` is the standard procedure before you run +into your training loops. + +In the section below, I will cover how `colossalai.initialize` works and what we should take note of. + +## Usage + +In a typical workflow, we will launch distributed environment at the beginning of our training script. +Afterwards, we will instantiate our objects such as model, optimizer, loss function, dataloader etc. At this moment, `colossalai.initialize` +can come in to inject features into these objects. A pseudo-code example is like below: + +```python +import colossalai +import torch +... + + +# launch distributed environment +colossalai.launch(config='./config.py', ...) + +# create your objects +model = MyModel() +optimizer = torch.optim.Adam(model.parameters(), lr=0.001) +criterion = torch.nn.CrossEntropyLoss() +train_dataloader = MyTrainDataloader() +test_dataloader = MyTrainDataloader() + +# initialize features +engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model, + optimizer, + criterion, + train_dataloader, + test_dataloader) +``` + +The `colossalai.initialize` function will return an `Engine` object. The engine object is a wrapper +for model, optimizer and loss function. **The engine object will run with features specified in the config file.** +More details about the engine can be found in the [Use Engine and Trainer in Training](./engine_trainer.md). diff --git a/docs/source/en/basics/launch_colossalai.md b/docs/source/en/basics/launch_colossalai.md new file mode 100644 index 0000000000000000000000000000000000000000..be487f8539a57ce1995516f4e95a9d8d480b9630 --- /dev/null +++ b/docs/source/en/basics/launch_colossalai.md @@ -0,0 +1,232 @@ +# Launch Colossal-AI + +Author: Chuanrui Wang, Shenggui Li, Siqi Mai + +**Prerequisite:** +- [Distributed Training](../concepts/distributed_training.md) +- [Colossal-AI Overview](../concepts/colossalai_overview.md) + + +## Introduction + +As mentioned in the previous tutorials stated in the prerequisite, you need to initialize the distributed environment +for Colossal-AI after your config file is prepared. +We call this process `launch`. +In this tutorial, you will learn how to launch Colossal-AI on your server, be it a small one or big one. + +In Colossal-AI, we provided several launch methods to initialize the distributed backend. +In most cases, you can use `colossalai.launch` and `colossalai.get_default_parser` to pass the +parameters via command line. +If you happen to use launchers such as SLURM, OpenMPI and PyTorch launch utility, +we also provide several launching helper methods to access the rank and world size from the environment variables +set by these launchers directly for your convenience. + +In this tutorial we will cover how to launch Colossal-AI to initialize the distributed backends: +- Launch with `colossalai.launch` +- Launch with Colossal-AI CLI +- Launch with SLURM +- Launch with OpenMPI + +## Launch Distributed Environment + +In order to launch Colossal-AI, we need two types of arguments: +1. config file +2. distributed settings + +The config file is always required regardless of the launch method but distributed settings can vary. The config file +can be a path to the configuration file or a Python dictionary. The distributed settings can be passed via command line +or multi-process launchers. + +### Command Line Parser + +Before we jump to `launch`, we firstly need to understand what parameters we need for initialization. +As stated in the `Basic Concepts in Distributed Training` section of [Distributed Training](../concepts/distributed_training.md), +the important parameters are: + +1. host +2. port +3. rank +4. world_size +5. backend + +In Colossal-AI, we provided a command line parser which has added these arguments in advance. You can get this parser by calling +`colossalai.get_default_parser()`. This parser is usually used with `colossalai.launch`. + +```python +# add these lines in your train.py +import colossalai + +# get default parser +parser = colossalai.get_default_parser() + +# if you want to add your own arguments +parser.add_argument(...) + +# parse arguments +args = parser.parse_args() +``` + +Then in your terminal, you can pass in these arguments: +```shell + +python train.py --host --rank --world_size --port --backend +``` + +`backend` is optional and the default value is `nccl`. + +### Native Launch + +To initialize the distributed environment, we provided a general `colossalai.launch` API. The `colossalai.launch` function takes in the parameters +listed above and create a default process group in the communication network. This function is often used with the default +parser for convenience. + +```python +import colossalai + +# parse arguments +args = colossalai.get_default_parser().parse_args() + +# launch distributed environment +colossalai.launch(config=, + rank=args.rank, + world_size=args.world_size, + host=args.host, + port=args.port, + backend=args.backend +) + +``` + + +### Launch with Colossal-AI CLI + +To enable easy launching on both single or multi nodes, we have implemented a launcher for Colossal-AI. This launcher is +a wrapper of the torch distributed launch utility but enhanced with the capability of launching multi-node jobs easily. + +First, we need to set the launch method in our code. As this is a wrapper of the torch distributed launch utility, we will +use `colossalai.launch_from_torch`. The arguments required for distributed environment such as rank, world size, host and port are all set by the PyTorch +launcher and can be read from the environment variable directly. + +```python +import colossalai + +colossalai.launch_from_torch( + config=, +) +``` + +Next, we can easily start multiple processes with `colossalai run` in your terminal. Below is an example to run the code +on a single node with 4 GPUs. You can change the number of GPUs by `nproc_per_node` and the default port by `master_port`. + +```shell +# run on the local node with 4 GPUs (default port: 29500) +colossalai run --nproc_per_node 4 train.py + +# run on the local node with 4 GPUs with a different port +colossalai run --nproc_per_node 4 --master_port 29505 test.py +``` + +If you are in a cluster and want to launch multi-node training, the CLI can help you start processes on different nodes +with one simple command. There are two ways you can launch multi-node jobs. + +- Run with `--hosts` + +This is suitable when you only have a few nodes. Let's say I have two nodes, namely `host1` and `host2`, I can start +multi-node training with the following command. Compared to single-node training, you must specify the `master_addr` +option, which is auto-set to localhost if running on a single node only. + +:::caution + +`master_addr` cannot be localhost when running on multiple nodes, it should be the hostname or IP address of a node. + +::: + +```shell +# run on these two nodes +colossalai run --nproc_per_node 4 --host host1,host2 --master_addr host1 test.py +``` +- Run with `--hostfile` + +This method is suitable when you have a lot of nodes. The host file is a simple text file listing the available nodes. +The list of nodes is commonly provided by cluster managers such as SLURM and PBS Pro. For example, you can get the list +of nodes allocated to you via the environment variable `SLURM_NODELIST` in SLURM and `PBS_NODEFILE` in PBS Pro. +Just do `echo $SLURM_NODELIST` or `cat $PBS_NODEFILE` to check it out. If you do not have such cluster managers, you can +manually create one for your own use. + +The host file given to Colossal-AI launcher must be in the following format where each line is the host name of a node. + +```text +host1 +host2 +``` + +With the host file ready, we can launch multi-node jobs with the following commands. Just like using `--host`, you also +need to specify the `master_addr` option. Some extra options are provided for `--hostfile` as listed below: + +- `--include`: specify the hosts to include for multi-node jobs. For example, if your host file has 8 nodes, but you +happen to only want to run on 6 nodes instead, you can add `--include host1,host2,host3,...,host6` so that the job will only +be launcher on the 6 nodes. +- `--exclude`: specify the hosts to exclude for multi-node jobs. This is useful when some nodes are faulty. For example, +if host1 GPU has some problems and you do not wish to run on host1 but all other nodes, you can add `--exclude host1` so that +the job will only be launched on the remaining nodes. + +```shell +# run with a hostfile +colossalai run --nproc_per_node 4 --hostfile ./hostfile --master_addr host1 test.py + +# only include certain hosts to execute commands +# this is used to manually select nodes to run +colossalai run --nproc_per_node 4 --hostfile ./hostfile --master_addr host1 --include host1 test.py + +# exclude certain hosts to execute commands +# this can be used when certain nodes are faulty +colossalai run --nproc_per_node 4 --hostfile ./hostfile --master_addr host1 --exclude host2 test.py +``` + +### Launch with SLURM + +If you are on a system managed by the SLURM scheduler, you can also rely on the `srun` launcher to kickstart your Colossal-AI scripts. +We provided the helper function `launch_from_slurm` for compatibility with the SLURM scheduler. +`launch_from_slurm` will automatically read the rank and world size from the environment variables `SLURM_PROCID` and `SLURM_NPROCS` respectively +and use them to start the distributed backend. +Do this in your training script: + +```python +import colossalai + +colossalai.launch_from_slurm( + config=, + host=args.host, + port=args.port +) +``` + +You can initialize the distributed environment by using this command in terminal. + +```bash +srun python train.py --host --port 29500 +``` + +### Launch with OpenMPI +If you are more familiar with OpenMPI, you can use `launch_from_openmpi` instead. +`launch_from_openmpi` will automatically read the local rank, global rank and world size from the environment variables +`OMPI_COMM_WORLD_LOCAL_RANK`, `MPI_COMM_WORLD_RANK` and `OMPI_COMM_WORLD_SIZE` respectively and +use them to start the distributed backend. + +Do this in your train.py: +```python +colossalai.launch_from_openmpi( + config=, + host=args.host, + port=args.port +) +``` + +A sample command to launch multiple processes with OpenMPI would be: + +```bash +mpirun --hostfile -np python train.py --host --port 29500 +``` + +- --hostfile: use this option to specify a list of hosts on which to run +- --np: set the number of processes (GPUs) to launch in total. For example, if --np 4, 4 python processes will be initialized to run train.py. diff --git a/docs/source/en/basics/model_checkpoint.md b/docs/source/en/basics/model_checkpoint.md new file mode 100644 index 0000000000000000000000000000000000000000..09d44e7c27097a37486ef4678af07aeb1163d44e --- /dev/null +++ b/docs/source/en/basics/model_checkpoint.md @@ -0,0 +1,61 @@ +# Model Checkpoint + +Author : Guangyang Lu + +**Prerequisite:** +- [Launch Colossal-AI](./launch_colossalai.md) +- [Initialize Colossal-AI](./initialize_features.md) + +**Example Code:** +- [ColossalAI-Examples Model Checkpoint](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/utils/checkpoint) + +**This function is experiential.** + +## Introduction + +In this tutorial, you will learn how to save and load model checkpoints. + +To leverage the power of parallel strategies in Colossal-AI, modifications to models and tensors are needed, for which you cannot directly use `torch.save` or `torch.load` to save or load model checkpoints. Therefore, we have provided you with the API to achieve the same thing. + +Moreover, when loading, you are not demanded to use the same parallel strategy as saving. + +## How to use + +### Save + +There are two ways to train a model in Colossal-AI, by engine or by trainer. +**Be aware that we only save the `state_dict`.** Therefore, when loading the checkpoints, you need to define the model first. + +#### Save when using engine + +```python +from colossalai.utils import save_checkpoint +model = ... +engine, _, _, _ = colossalai.initialize(model=model, ...) +for epoch in range(num_epochs): + ... # do some training + save_checkpoint('xxx.pt', epoch, model) +``` + +#### Save when using trainer +```python +from colossalai.trainer import Trainer, hooks +model = ... +engine, _, _, _ = colossalai.initialize(model=model, ...) +trainer = Trainer(engine, ...) +hook_list = [ + hooks.SaveCheckpointHook(1, 'xxx.pt', model) + ...] + +trainer.fit(... + hook=hook_list) +``` + +### Load + +```python +from colossalai.utils import load_checkpoint +model = ... +load_checkpoint('xxx.pt', model) +... # train or test +``` diff --git a/docs/source/en/concepts/colossalai_overview.md b/docs/source/en/concepts/colossalai_overview.md new file mode 100644 index 0000000000000000000000000000000000000000..38b682d49e62dd3ee9b7b03ea34fdd8697bda286 --- /dev/null +++ b/docs/source/en/concepts/colossalai_overview.md @@ -0,0 +1,36 @@ +# Colossal-AI Overview + +Author: Shenggui Li, Siqi Mai + +## About Colossal-AI + +With the development of deep learning model size, it is important to shift to a new training paradigm. The traditional training method with no parallelism and optimization became a thing of the past and new training methods are the key to make training large-scale models efficient and cost-effective. + +Colossal-AI is designed to be a unified system to provide an integrated set of training skills and utilities to the user. You can find the common training utilities such as mixed precision training and gradient accumulation. Besides, we provide an array of parallelism including data, tensor and pipeline parallelism. We optimize tensor parallelism with different multi-dimensional distributed matrix-matrix multiplication algorithm. We also provided different pipeline parallelism methods to allow the user to scale their model across nodes efficiently. More advanced features such as offloading can be found in this tutorial documentation in detail as well. + +## General Usage + +We aim to make Colossal-AI easy to use and non-intrusive to user code. There is a simple general workflow if you want to use Colossal-AI. + +
+ +
Workflow
+
+ +1. Prepare a configuration file where specifies the features you want to use and your parameters. +2. Initialize distributed backend with `colossalai.launch` +3. Inject the training features into your training components (e.g. model, optimizer) with `colossalai.initialize`. +4. Run training and testing + +We will cover the whole workflow in the `basic tutorials` section. + +## Future Development + +The Colossal-AI system will be expanded to include more training skills, these new developments may include but are not limited to: + +1. optimization of distributed operations +2. optimization of training on heterogenous system +3. implementation of training utilities to reduce model size and speed up training while preserving model performance +4. expansion of existing parallelism methods + +We welcome ideas and contribution from the community and you can post your idea for future development in our forum. diff --git a/docs/source/en/concepts/distributed_training.md b/docs/source/en/concepts/distributed_training.md new file mode 100644 index 0000000000000000000000000000000000000000..5038714f754b8a39b859d935ecc3c5cfd4733834 --- /dev/null +++ b/docs/source/en/concepts/distributed_training.md @@ -0,0 +1,120 @@ +# Distributed Training + +Author: Shenggui Li, Siqi Mai + +## What is a distributed system? + +
+ +
Image source: Towards Data Science
+
+ +A distributed system consists of multiple software components which run on multiple machines. For example, the traditional +database runs on a single machine. As the amount of data gets incredibly large, a single machine can no longer deliver desirable +performance to the business, especially in situations such as Black Friday where network traffic can be unexpectedly high. +To handle such pressure, modern high-performance database is designed to run on multiple machines, and they work together to provide +high throughput and low latency to the user. + +One important evaluation metric for distributed system is scalability. For example, when we run an application on 4 machines, +we naturally expect that the application can run 4 times faster. However, due to communication overhead and difference in +hardware performance, it is difficult to achieve linear speedup. Thus, it is important to consider how to make the application +faster when we implement it. Algorithms of good design and system optimization can help to deliver good performance. Sometimes, +it is even possible to achieve linear and super-linear speedup. + + +## Why we need distributed training for machine learning? + +Back in 2012, [AlexNet](https://arxiv.org/abs/1404.5997) won the champion of the ImageNet competition, and it was trained +on two GTX 580 3GB GPUs. +Today, most models that appear in the top AI conferences are trained on multiple GPUs. Distributed training is definitely +a common practice when researchers and engineers develop AI models. There are several reasons behind this trend. + +1. Model size increases rapidly. [ResNet50](https://arxiv.org/abs/1512.03385) has 20 million parameters in 2015, +[BERT-Large](https://arxiv.org/abs/1810.04805) has 345 million parameters in 2018, +[GPT-2](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) +has 1.5 billion parameters in 2018, and [GPT-3](https://arxiv.org/abs/2005.14165) has 175 billion parameters in 2020. +It is obvious that the model size grows exponentially with time. The current largest model has exceeded more than 1000 +billion parameters. Super large models generally deliver more superior performance compared to their smaller counterparts. +
+ +
Image source: HuggingFace
+
+ + +2. Dataset size increases rapidly. For most machine learning developers, MNIST and CIFAR10 datasets are often the first few +datasets on which they train their models. However, these datasets are very small compared to well-known ImageNet datasets. +Google even has its own (unpublished) JFT-300M dataset which has around 300 million images, and this is close to 300 times +larger than the ImageNet-1k dataset. + + +3. Computing power gets stronger. With the advancement in the semiconductor industry, graphics cards become more and more +powerful. Due to its larger number of cores, GPU is the most common compute platform for deep learning. +From K10 GPU in 2012 to A100 GPU in 2020, the computing power has increased several hundred times. This allows us to performance +compute-intensive tasks faster and deep learning is exactly such a task. + +Nowadays, the model can be too large to fit into a single GPU, and the dataset can be large enough to train for a hundred +days on a single GPU. Only by training our models on multiple GPUs with different parallelization techniques, we are able +to speed up the training process and obtain results in a reasonable amount of time. + + +## Basic Concepts in Distributed Training + +Distributed training requires multiple machines/GPUs. During training, there will be communication among these devices. +To understand distributed training better, there are several important terms to be made clear. + +- host: host is the main device in the communication network. It is often required as an argument when initializing the +distributed environment. +- port: port here mainly refers to master port on the host for communication. +- rank: the unique ID given to a device in the network. +- world size: the number of devices in the network. +- process group: a process group is a communication network which include a subset of the devices. There is always a default +process group which contains all the devices. A subset devices can form a process group so that they only communicate among +the devices within the group. + +
+ +
A distributed system example
+
+ +To illustrate these concepts, let's assume we have 2 machines (also called nodes), and each machine has 4 GPUs. When we +initialize distributed environment over these two machines, we essentially launch 8 processes (4 processes on each machine) +and each process is bound to a GPU. + +Before initializing the distributed environment, we need to specify the host (master address) and port (master port). In +this example, we can let host be node 0 and port be a number such as 29500. All the 8 processes will then look for the +address and port and connect to one another. +The default process group will then be created. The default process group has a world size of 8 and details are as follows: + +| process ID | rank | Node index | GPU index | +| ---------- | ---- | ---------- | --------- | +| 0 | 0 | 0 | 0 | +| 1 | 1 | 0 | 1 | +| 2 | 2 | 0 | 2 | +| 3 | 3 | 0 | 3 | +| 4 | 4 | 1 | 0 | +| 5 | 5 | 1 | 1 | +| 6 | 6 | 1 | 2 | +| 7 | 7 | 1 | 3 | + + +We can also create a new process group. This new process group can contain any subset of the processes. +For example, we can create one containing only even-number processes, and the details of this new group will be: + +| process ID | rank | Node index | GPU index | +| ---------- | ---- | ---------- | --------- | +| 0 | 0 | 0 | 0 | +| 2 | 1 | 0 | 2 | +| 4 | 2 | 1 | 0 | +| 6 | 3 | 1 | 2 | + +**Please note that rank is relative to the process group and one process can have a different rank in different process +groups. The max rank is always `world size of the process group - 1`.** + +In the process group, the processes can communicate in two ways: +1. peer-to-peer: one process send data to another process +2. collective: a group of process perform operations such as scatter, gather, all-reduce, broadcast together. + +
+ +
Collective communication, source: PyTorch distributed tutorial
+
diff --git a/docs/source/en/concepts/paradigms_of_parallelism.md b/docs/source/en/concepts/paradigms_of_parallelism.md new file mode 100644 index 0000000000000000000000000000000000000000..1a5dab7a76f79caf28b534ff2f38dcfd4015591c --- /dev/null +++ b/docs/source/en/concepts/paradigms_of_parallelism.md @@ -0,0 +1,124 @@ +# Paradigms of Parallelism + +Author: Shenggui Li, Siqi Mai + +## Introduction + +With the development of deep learning, there is an increasing demand for parallel training. This is because that model +and datasets are getting larger and larger and training time becomes a nightmare if we stick to single-GPU training. In +this section, we will provide a brief overview of existing methods to parallelize training. If you wish to add on to this +post, you may create a discussion in the [GitHub forum](https://github.com/hpcaitech/ColossalAI/discussions). + +## Data Parallel + +Data parallel is the most common form of parallelism due to its simplicity. In data parallel training, the dataset is split +into several shards, each shard is allocated to a device. This is equivalent to parallelize the training process along the +batch dimension. Each device will hold a full copy of the model replica and trains on the dataset shard allocated. After +back-propagation, the gradients of the model will be all-reduced so that the model parameters on different devices can stay +synchronized. + +
+ +
Data parallel illustration
+
+ +## Model Parallel + +In data parallel training, one prominent feature is that each GPU holds a copy of the whole model weights. This brings +redundancy issue. Another paradigm of parallelism is model parallelism, where model is split and distributed over an array +of devices. There are generally two types of parallelism: tensor parallelism and pipeline parallelism. Tensor parallelism is +to parallelize computation within an operation such as matrix-matrix multiplication. Pipeline parallelism is to parallelize +computation between layers. Thus, from another point of view, tensor parallelism can be seen as intra-layer parallelism and +pipeline parallelism can be seen as inter-layer parallelism. + +### Tensor Parallel + +Tensor parallel training is to split a tensor into `N` chunks along a specific dimension and each device only holds `1/N` +of the whole tensor while not affecting the correctness of the computation graph. This requires additional communication +to make sure that the result is correct. + +Taking a general matrix multiplication as an example, let's say we have C = AB. We can split B along the column dimension +into `[B0 B1 B2 ... Bn]` and each device holds a column. We then multiply `A` with each column in `B` on each device, we +will get `[AB0 AB1 AB2 ... ABn]`. At this moment, each device still holds partial results, e.g. device rank 0 holds `AB0`. +To make sure the result is correct, we need to all-gather the partial result and concatenate the tensor along the column +dimension. In this way, we are able to distribute the tensor over devices while making sure the computation flow remains +correct. + +
+ +
Tensor parallel illustration
+
+ +In Colossal-AI, we provide an array of tensor parallelism methods, namely 1D, 2D, 2.5D and 3D tensor parallelism. We will +talk about them in detail in `advanced tutorials`. + + +Related paper: +- [GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding](https://arxiv.org/abs/2006.16668) +- [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) +- [An Efficient 2D Method for Training Super-Large Deep Learning Models](https://arxiv.org/abs/2104.05343) +- [2.5-dimensional distributed model training](https://arxiv.org/abs/2105.14500) +- [Maximizing Parallelism in Distributed Training for Huge Neural Networks](https://arxiv.org/abs/2105.14450) + +### Pipeline Parallel + +Pipeline parallelism is generally easy to understand. If you recall your computer architecture course, this indeed exists +in the CPU design. + +
+ +
Pipeline parallel illustration
+
+ +The core idea of pipeline parallelism is that the model is split by layer into several chunks, each chunk is +given to a device. During the forward pass, each device passes the intermediate activation to the next stage. During the backward pass, +each device passes the gradient of the input tensor back to the previous pipeline stage. This allows devices to compute simultaneously, +and increases the training throughput. One drawback of pipeline parallel training is that there will be some bubble time where +some devices are engaged in computation, leading to waste of computational resources. + +
+ +
Source: GPipe
+
+ +Related paper: +- [PipeDream: Fast and Efficient Pipeline Parallel DNN Training](https://arxiv.org/abs/1806.03377) +- [GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism](https://arxiv.org/abs/1811.06965) +- [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) +- [Chimera: Efficiently Training Large-Scale Neural Networks with Bidirectional Pipelines](https://arxiv.org/abs/2107.06925) + + +## Optimizer-Level Parallel + +Another paradigm works at the optimizer level, and the current most famous method of this paradigm is ZeRO which stands +for [zero redundancy optimizer](https://arxiv.org/abs/1910.02054). ZeRO works at three levels to remove memory redundancy +(fp16 training is required for ZeRO): + +- Level 1: The optimizer states are partitioned across the processes +- Level 2: The reduced 32-bit gradients for updating the model weights are also partitioned such that each process +only stores the gradients corresponding to its partition of the optimizer states. +- Level 3: The 16-bit model parameters are partitioned across the processes + +Related paper: +- [ZeRO: Memory Optimizations Toward Training Trillion Parameter Models](https://arxiv.org/abs/1910.02054) + + +## Parallelism on Heterogeneous System + +The methods mentioned above generally require a large number of GPU to train a large model. However, it is often neglected +that CPU has a much larger memory compared to GPU. On a typical server, CPU can easily have several hundred GB RAM while each GPU +typically only has 16 or 32 GB RAM. This prompts the community to think why CPU memory is not utilized for distributed training. + +Recent advances rely on CPU and even NVMe disk to train large models. The main idea is to offload tensors back to CPU memory +or NVMe disk when they are not used. By using the heterogeneous system architecture, it is possible to accommodate a huge +model on a single machine. + +
+ +
Heterogenous system illustration
+
+ +Related paper: +- [ZeRO-Offload: Democratizing Billion-Scale Model Training](https://arxiv.org/abs/2101.06840) +- [ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning](https://arxiv.org/abs/2104.07857) +- [PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management](https://arxiv.org/abs/2108.05818) diff --git a/docs/source/en/features/1D_tensor_parallel.md b/docs/source/en/features/1D_tensor_parallel.md new file mode 100644 index 0000000000000000000000000000000000000000..7577e50400e91c5e73224590738aa10a82f2ac7e --- /dev/null +++ b/docs/source/en/features/1D_tensor_parallel.md @@ -0,0 +1,118 @@ +# 1D Tensor Parallelism + +Author: Zhengda Bian, Yongbin Li + +**Prerequisite** +- [Define Your Configuration](../basics/define_your_config.md) +- [Configure Parallelization](../basics/configure_parallelization.md) + +**Example Code** +- [ColossalAI-Examples 1D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/tensor_parallel/tensor_parallel_1d.py) + +**Related Paper** +- [Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM](https://deepakn94.github.io/assets/papers/megatron-sc21.pdf) + +## Introduction + +Tensor parallelism partitions model weights across multiple devices in order to reduce memory load. +An efficient 1D tensor parallelism implementation was introduced by [Megatron-LM](https://deepakn94.github.io/assets/papers/megatron-sc21.pdf). + +Let's take a linear layer as an example, which consists of a GEMM $Y = XA$. Given 2 processors, we split the columns of $A$ into $[A_1 ~ A_2]$, and calculate $Y_i = XA_i$ on each processor, which then forms $[Y_1 ~ Y_2] = [XA_1 ~ XA_2]$. This is called a column-parallel fashion. + +When a second linear layer $Z=YB$ follows the column-parallel one, we split $B$ into +```math +\left[\begin{matrix} B_1 \\ B_2 \end{matrix} \right] +``` +which is called a row-parallel fashion. +To calculate +```math +Z = [Y_1 ~ Y_2] \left[\begin{matrix} B_1 \\ B_2 \end{matrix} \right] +``` +we first calculate $Y_iB_i$ on each processor, then use an all-reduce to aggregate the results as $Z=Y_1B_1+Y_2B_2$. + +We also need to note that in the backward pass, the column-parallel linear layer needs to aggregate the gradients of the input tensor $X$, because on each processor $i$ we only have $\dot{X_i}=\dot{Y_i}A_i^T$. +Thus, we apply an all-reduce across the processors to get $\dot{X}=\dot{Y}A^T=\dot{Y_1}A_1^T+\dot{Y_2}A_2^T$. + +## Efficiency +Given $P$ processors, we present the theoretical computation and memory cost, as well as the communication cost based on the ring algorithm in both the forward and backward pass of 1D tensor parallelism. + +| Computation | Memory (parameters) | Memory (activations) | Communication (bandwidth) | Communication (latency) | +| :-: | :-: | :-: | :-: | :-: | +| $O(1/P)$ | $O(1/P)$ | $O(1)$ | $O(2(P-1)/P)$ | $O(2(P-1))$ | + +## Usage + +To enable 1D tensor parallelism for our model, e.g. on 2 GPUs, we need to configure the parallelism setting as below. +```python +CONFIG = dict(parallel=dict( + data=1, + pipeline=1, + tensor=dict(size=2, mode='1d'), +)) +``` +Then Colossal-AI will automatically apply 1D parallelism to all the layers from `colossalai.nn`. + +Let's define a model that consists of a two-layer multi-layer perceptron (MLP) as below. +```python +import colossalai +import colossalai.nn as col_nn +import torch +from colossalai.utils import print_rank_0 + +class MLP(torch.nn.Module): + def __init__(self, dim: int = 256): + super().__init__() + intermediate_dim = dim * 4 + self.dense_1 = col_nn.Linear(dim, intermediate_dim) + print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.transpose(0, 1).shape}') + self.activation = torch.nn.GELU() + self.dense_2 = col_nn.Linear(intermediate_dim, dim) + print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.transpose(0, 1).shape}') + self.dropout = col_nn.Dropout(0.1) + + def forward(self, x): + x = self.dense_1(x) + print_rank_0(f'Output of the first linear layer: {x.shape}') + x = self.activation(x) + x = self.dense_2(x) + print_rank_0(f'Output of the second linear layer: {x.shape}') + x = self.dropout(x) + return x +``` + +Launch Colossal-AI on 2 GPUs and build the model. + +```python +parser = colossalai.get_default_parser() +colossalai.launch(config=CONFIG, + rank=args.rank, + world_size=args.world_size, + local_rank=args.local_rank, + host=args.host, + port=args.port) + +m = MLP() +``` +We will see the shapes of partitioned parameters(e.g. weights) in the MLP model. +```shell +Weight of the first linear layer: torch.Size([256, 512]) +Weight of the second linear layer: torch.Size([512, 256]) +``` +The complete weight of the first linear layer is supposed to have the shape `[256, 1024]`. After the column-parallel partitioning, it becomes `[256, 512]`. +Similarly, the second row-parallel layer partitions the weight `[1024, 256]` into `[512, 256]`. + +We can run the model with some random inputs. +```python +from colossalai.utils import get_current_device + +x = torch.randn((16, 256), device=get_current_device()) +torch.distributed.broadcast(x, src=0) # synchronize input + +x = m(x) +``` +Then we can see the shapes of activation results. +```shell +Output of the first linear layer: torch.Size([16, 512]) +Output of the second linear layer: torch.Size([16, 256]) +``` +The output of the first linear layer is split into 2 partitions (each has the shape `[16, 512]`), while the second layer has identical outputs across the GPUs. diff --git a/docs/source/en/features/2D_tensor_parallel.md b/docs/source/en/features/2D_tensor_parallel.md new file mode 100644 index 0000000000000000000000000000000000000000..7b6c10766099f6fec50b656138fa2a7fd0cdd132 --- /dev/null +++ b/docs/source/en/features/2D_tensor_parallel.md @@ -0,0 +1,142 @@ +# 2D Tensor Parallelism + +Author: Zhengda Bian, Yongbin Li + +**Prerequisite** +- [Define Your Configuration](../basics/define_your_config.md) +- [Configure Parallelization](../basics/configure_parallelization.md) +- [1D Tensor Parallelism](./1D_tensor_parallel.md) + +**Example Code** +- [ColossalAI-Examples - 2D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/tensor_parallel/tensor_parallel_2d.py) + +**Related Paper** +- [An Efficient 2D Method for Training Super-Large Deep Learning Models](https://arxiv.org/pdf/2104.05343.pdf) + +## Introduction + +1D tensor parallelism does not partition activations, which can also consume a great amount of memory in terms of large-scale models. +To evenly distribute the computation and memory load, [an efficient 2D tensor parallelism algorithm](https://arxiv.org/pdf/2104.05343.pdf) was introduced based on SUMMA (Scalable Universal Matrix Multiplication Algorithm). + +Let's still take a linear layer $Y = XA$ as an example. +Given $P=q\times q$ processors (necessary condition), e.g. $q=2$, we split both the input $X$ and weight $A$ into + +$$ +\left[\begin{matrix} X_{10} & X_{11} \\ X_{00} & X_{01} \end{matrix} \right] +\text{~and~} +\left[\begin{matrix} A_{10} & A_{11} \\ A_{00} & A_{01} \end{matrix} \right]. +$$ + +The calculation includes $q$ steps. When $t=1$, $X_{i0}$ is broadcasted in its row, and $A_{0j}$ is broadcasted in its column. So, we have + +$$ +\left[\begin{matrix} X_{10},A_{00} & X_{10},A_{01} \\ X_{00},A_{00} & X_{00},A_{01} \end{matrix} \right]. +$$ + +Then we multiply $X_{i0}$ and $A_{0j}$ on each processor $(i, j)$ as + +$$ +\left[\begin{matrix} X_{10}A_{00} & X_{10}A_{01} \\ X_{00}A_{00} & X_{00}A_{01} \end{matrix} \right] (1). +$$ + +Similarly, when $t=2$, $X_{i1}$ is broadcasted in its row, $A_{1j}$ is broadcasted in its column, and we multiply them as + +$$ +\left[\begin{matrix} X_{11}A_{10} & X_{11}A_{11} \\ X_{01}A_{10} & X_{01}A_{11} \end{matrix} \right] (2). +$$ + +By adding $(1)$ and $(2)$ up, we have + +$$ +Y = XA = \left[\begin{matrix} X_{10}A_{00}+X_{11}A_{10} & X_{10}A_{01}+X_{11}A_{11} \\ X_{00}A_{00}+X_{01}A_{10} & X_{00}A_{01}+X_{01}A_{11} \end{matrix} \right]. +$$ + +## Efficiency +Given $P=q\times q$ processors, we present the theoretical computation and memory cost, as well as the communication cost based on the ring algorithm in both the forward and backward pass of 2D tensor parallelism. + +| Computation | Memory (parameters) | Memory (activations) | Communication (bandwidth) | Communication (latency) | +| :-: | :-: | :-: | :-: | :-: | +| $O(1/q^2)$ | $O(1/q^2)$ | $O(1/q^2)$ | $O(6(q-1)/q)$ | $O(6(q-1))$ | + +## Usage + +To enable 2D tensor parallelism for our model, e.g. on 4 GPUs, we need to configure the parallelism setting as below. +```python +CONFIG = dict(parallel=dict( + data=1, + pipeline=1, + tensor=dict(size=4, mode='2d'), +)) +``` +Then Colossal-AI will automatically apply 2D parallelism to all the layers from `colossalai.nn`. + +Let's define a model that consists of a two-layer multi-layer perceptron (MLP) as below. +```python +import colossalai +import colossalai.nn as col_nn +import torch +from colossalai.utils import print_rank_0 + +class MLP(torch.nn.Module): + def __init__(self, dim: int = 256): + super().__init__() + intermediate_dim = dim * 4 + self.dense_1 = col_nn.Linear(dim, intermediate_dim) + print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.shape}') + self.activation = torch.nn.GELU() + self.dense_2 = col_nn.Linear(intermediate_dim, dim) + print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.shape}') + self.dropout = col_nn.Dropout(0.1) + + def forward(self, x): + x = self.dense_1(x) + print_rank_0(f'Output of the first linear layer: {x.shape}') + x = self.activation(x) + x = self.dense_2(x) + print_rank_0(f'Output of the second linear layer: {x.shape}') + x = self.dropout(x) + return x +``` +Launch Colossal-AI on 4 GPUs and build the model +```python +parser = colossalai.get_default_parser() +colossalai.launch(config=CONFIG, + rank=args.rank, + world_size=args.world_size, + local_rank=args.local_rank, + host=args.host, + port=args.port) + +m = MLP() +``` +We will see the shapes of partitioned parameters(e.g. weights) in the MLP model. +```shell +Weight of the first linear layer: torch.Size([128, 512]) +Weight of the second linear layer: torch.Size([512, 128]) +``` +The complete weight of the first linear layer is supposed to have the shape `[256, 1024]`. After the partitioning of 2D parallelism, it becomes `[128, 512]` on each GPU. +Similarly, the second layer partitions the weight `[1024, 256]` into `[512, 128]`. + +We can run the model with some random inputs. +```python +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.utils import get_current_device + +x = torch.randn((16, 256), device=get_current_device()) +# partition input +torch.distributed.broadcast(x, src=0) +x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)] +x = torch.chunk(x, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)] +print_rank_0(f'Input: {x.shape}') + +x = m(x) +``` +Then we can see the shapes of activation results. +```shell +Input: torch.Size([8, 128]) +Output of the first linear layer: torch.Size([8, 512]) +Output of the second linear layer: torch.Size([8, 128]) +``` +The activation tensors in 2D parallelism are all split in both row and column. +E.g. the output of the first linear layer has the shape `[8, 512]`, while the second layer has the output of `[8, 128]`. diff --git a/docs/source/en/features/2p5D_tensor_parallel.md b/docs/source/en/features/2p5D_tensor_parallel.md new file mode 100644 index 0000000000000000000000000000000000000000..6076562e6dca51a7f41c080d9e115d8b3f2a3415 --- /dev/null +++ b/docs/source/en/features/2p5D_tensor_parallel.md @@ -0,0 +1,142 @@ +# 2.5D Tensor Parallelism + +Author: Zhengda Bian, Yongbin Li + +**Prerequisite** +- [Define Your Configuration](../basics/define_your_config.md) +- [Configure Parallelization](../basics/configure_parallelization.md) +- [1D Tensor Parallelism](./1D_tensor_parallel.md) +- [2D Tensor Parallelism](./2D_tensor_parallel.md) + +**Example Code** +- [ColossalAI-Examples - 2.5D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/tensor_parallel/tensor_parallel_2p5d.py) + +**Related Paper** +- [2.5-dimensional distributed model training](https://arxiv.org/pdf/2105.14500.pdf) + +## Introduction + +Compared with 1D tensor parallelism, 2D parallelism reduces the memory cost, but may introduce more communication. +Therefore, a [2.5D tensor parallelism algorithm](https://arxiv.org/pdf/2105.14500.pdf) was proposed based on 2.5D SUMMA to reduce communication by using more devices. + +Let's still take a linear layer $Y = XA$ as an example. +Given $P=q \times q \times d$ processors (necessary condition), e.g. $q=d=2$, we split the input $X$ into $d\times q$ rows and $q$ columns as + +$$ +\left[\begin{matrix} X_{30} & X_{31} \\ X_{20} & X_{21} \\ X_{10} & X_{11} \\ X_{00} & X_{01}\end{matrix} \right], +$$ +which can be reshaped into $d$ layers as + +$$ +\left[\begin{matrix} X_{10} & X_{11} \\ X_{00} & X_{01} \end{matrix} \right] \text{~and~}\left[\begin{matrix} X_{30} & X_{31} \\ X_{20} & X_{21} \end{matrix} \right]. +$$ + +Also, the weight $A$ is split into + +$$ +\left[\begin{matrix} A_{10} & A_{11} \\ A_{00} & A_{01} \end{matrix} \right]. +$$ + +For each layer of $X$, we use the SUMMA algorithm to multiply $X$ and $A$. +Then, we have the output + +$$ +\left[\begin{matrix} Y_{10}=X_{10}A_{00}+X_{11}A_{10} & Y_{11}=X_{10}A_{01}+X_{11}A_{11} \\ Y_{00}=X_{00}A_{00}+X_{01}A_{10} & Y_{01}=X_{00}A_{01}+X_{01}A_{11} \end{matrix} \right] +\text{~and~} +$$ +$$ +\left[\begin{matrix} Y_{30}=X_{30}A_{00}+X_{31}A_{10} & Y_{31}=X_{30}A_{01}+X_{31}A_{11} \\ Y_{20}=X_{20}A_{00}+X_{21}A_{10} & Y_{21}=X_{20}A_{01}+X_{21}A_{11} \end{matrix} \right]. +$$ + +## Efficiency +Given $P=q \times q \times d$ processors, we present the theoretical computation and memory cost, as well as the communication cost based on the ring algorithm in both the forward and backward pass of 2.5D tensor parallelism. + +| Computation | Memory (parameters) | Memory (activations) | Communication (bandwidth) | Communication (latency) | +| :-: | :-: | :-: | :-: | :-: | +| $O(1/dq^2)$ | $O(1/q^2)$ | $O(1/dq^2)$ | $\small O(3(q-1)(d+1)/dq)$ | $O(6(q-1))$ | + +## Usage + +To enable 2.5D tensor parallelism for our model, e.g. on 8 GPUs, we need to configure the parallelism setting as below. +```python +CONFIG = dict(parallel=dict( + data=1, + pipeline=1, + tensor=dict(size=8, mode='2.5d', depth=2), +)) + +``` +Then Colossal-AI will automatically apply 2.5D parallelism to all the layers from `colossalai.nn`. + +Let's define a model that consists of a two-layer multi-layer perceptron (MLP) as below. +```python +import colossalai +import colossalai.nn as col_nn +import torch +from colossalai.utils import print_rank_0 + +class MLP(torch.nn.Module): + def __init__(self, dim: int = 256): + super().__init__() + intermediate_dim = dim * 4 + self.dense_1 = col_nn.Linear(dim, intermediate_dim) + print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.shape}') + self.activation = torch.nn.GELU() + self.dense_2 = col_nn.Linear(intermediate_dim, dim) + print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.shape}') + self.dropout = col_nn.Dropout(0.1) + + def forward(self, x): + x = self.dense_1(x) + print_rank_0(f'Output of the first linear layer: {x.shape}') + x = self.activation(x) + x = self.dense_2(x) + print_rank_0(f'Output of the second linear layer: {x.shape}') + x = self.dropout(x) + return x +``` +Launch Colossal-AI on 8 GPUs and build the model +```python +parser = colossalai.get_default_parser() +colossalai.launch(config=CONFIG, + rank=args.rank, + world_size=args.world_size, + local_rank=args.local_rank, + host=args.host, + port=args.port) + +m = MLP() +``` +We will see the shapes of partitioned parameters(e.g. weights) in the MLP model. +```shell +Weight of the first linear layer: torch.Size([128, 512]) +Weight of the second linear layer: torch.Size([512, 128]) +``` +The complete weight of the first linear layer is supposed to have the shape `[256, 1024]`. After the partitioning of 2.5D parallelism, it becomes `[128, 512]` on each GPU. +Similarly, the second layer partitions the weight `[1024, 256]` into `[512, 128]`. + +We can run the model with some random inputs. +```python +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.utils import get_current_device + +x = torch.randn((16, 256), device=get_current_device()) +# partition input +torch.distributed.broadcast(x, src=0) +x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)] +x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)] +x = torch.chunk(x, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)] +print_rank_0(f'Input: {x.shape}') + +x = m(x) +``` +Then we can see the shapes of activation results. +```shell +Input: torch.Size([4, 128]) +Output of the first linear layer: torch.Size([4, 512]) +Output of the second linear layer: torch.Size([4, 128]) +``` +The activation tensors in 2.5D parallelism are all split by $d \times q$ in the row and $q$ in the column. +E.g. the output of the first linear layer has the shape `[4, 512]`), while the second layer has the output of `[4, 128]`. +Note, 2.5D parallelism use the same partition method as 2D parallelism for weights, where the difference is the partition of input. diff --git a/docs/source/en/features/3D_tensor_parallel.md b/docs/source/en/features/3D_tensor_parallel.md new file mode 100644 index 0000000000000000000000000000000000000000..1207376335cea4b7865b210f8b9ea09703e09a45 --- /dev/null +++ b/docs/source/en/features/3D_tensor_parallel.md @@ -0,0 +1,151 @@ +# 3D Tensor Parallelism + +Author: Zhengda Bian, Yongbin Li + +**Prerequisite** +- [Define Your Configuration](../basics/define_your_config.md) +- [Configure Parallelization](../basics/configure_parallelization.md) +- [1D Tensor Parallelism](./1D_tensor_parallel.md) +- [2D Tensor Parallelism](./2D_tensor_parallel.md) + +**Example Code** +- [ColossalAI-Examples - 3D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/tensor_parallel/tensor_parallel_3d.py) + +**Related Paper** +- [Maximizing Parallelism in Distributed Training for Huge Neural Networks](https://arxiv.org/pdf/2105.14450.pdf) + +## Introduction + +The [3D tensor parallelism](https://arxiv.org/pdf/2105.14450.pdf) is an approach to parallelize the computation of neural models, hoping to obtain the optimal communication cost. + +Let's still take a linear layer $Y = XA$ as an example. +Given $P=q \times q \times q$ processors (necessary condition), e.g. $q=2$, we split the input $X$ and weight $A$ into + +$$ +\left[\begin{matrix} + X_{000} & X_{001} \\ + X_{010} & X_{011} \\ + X_{100} & X_{101} \\ + X_{110} & X_{111} \end{matrix} +\right] +\text{~and~} +\left[\begin{matrix} + A_{000} & A_{001} & A_{010} & A_{011} \\ + A_{100} & A_{101} & A_{110} & A_{111} \end{matrix} +\right] +\text{~respectively,}$$ +where each $X_{ijl}$ and $A_{lji}$ are stored at processor $(i,j,l)$, as shown in the figure below. + +
+ + + + +
+ +Then we all-gather $X_{ijl}$ across $(i, 0...q,l)$, as well as $A_{lji}$ across $(0...q, j, l)$. +So, we have $X_{il}$ and $A_{lj}$ on each processor $(i,j,l)$ to get $X_{il}A_{lj}$. +Finally, we reduce-scatter the results across $(i, j, 0...q)$ to get $Y_{ijl}$, which forms +$$ +Y= +\left[\begin{matrix} + Y_{000} & Y_{001} \\ + Y_{010} & Y_{011} \\ + Y_{100} & Y_{101} \\ + Y_{110} & Y_{111} \end{matrix} +\right]. +$$ + +We also need to note that in the backward pass, we need to all-gather the gradient $\dot{Y_{ijl}}$, and then reduce-scatter the gradient $\dot{X_{il}}=\dot{Y_{ij}}A_{lj}^T$ and $\dot{A_{lj}}=X_{il}^T\dot{Y_{ij}}$. + +## Efficiency +Given $P=q \times q \times q$ processors, we present the theoretical computation and memory cost, as well as the communication cost based on the ring algorithm in both the forward and backward pass of 3D tensor parallelism. + +| Computation | Memory (parameters) | Memory (activations) | Communication (bandwidth) | Communication (latency) | +| :-: | :-: | :-: | :-: | :-: | +| $O(1/q^3)$ | $O(1/q^3)$ | $O(1/q^3)$ | $O(6(q-1)/q^3)$ | $O(6(q-1))$ | + +## Usage + +To enable 3D tensor parallelism for our model, e.g. on 8 GPUs, we need to configure the parallism setting as below. +```python +CONFIG = dict(parallel=dict( + data=1, + pipeline=1, + tensor=dict(size=8, mode='3d'), +)) +``` +Then Colossal-AI will automatically apply 3D parallelism to all the layers from `colossalai.nn`. + +Let's define a model that consists of a two-layer multi-layer perceptron (MLP) as below. +```python +import colossalai +import colossalai.nn as col_nn +import torch +from colossalai.utils import print_rank_0 + +class MLP(torch.nn.Module): + def __init__(self, dim: int = 256): + super().__init__() + intermediate_dim = dim * 4 + self.dense_1 = col_nn.Linear(dim, intermediate_dim) + print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.shape}') + self.activation = torch.nn.GELU() + self.dense_2 = col_nn.Linear(intermediate_dim, dim) + print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.shape}') + self.dropout = col_nn.Dropout(0.1) + + def forward(self, x): + x = self.dense_1(x) + print_rank_0(f'Output of the first linear layer: {x.shape}') + x = self.activation(x) + x = self.dense_2(x) + print_rank_0(f'Output of the second linear layer: {x.shape}') + x = self.dropout(x) + return x +``` +Launch Colossal-AI on 8 GPUs and build the model +```python +parser = colossalai.get_default_parser() +colossalai.launch(config=CONFIG, + rank=args.rank, + world_size=args.world_size, + local_rank=args.local_rank, + host=args.host, + port=args.port) + +m = MLP() +``` +We will see the shapes of partitioned parameters(e.g. weights) in the MLP model. +```shell +Weight of the first linear layer: torch.Size([128, 256]) +Weight of the second linear layer: torch.Size([512, 64]) +``` +The complete weight of the first linear layer is supposed to have the shape `[256, 1024]`. After the partitioning of 3D parallelism, it becomes `[128, 256]` on each GPU. +Similarly, the second layer partitions the weight `[1024, 256]` into `[512, 64]`. + +We can run the model with some random inputs. +```python +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.utils import get_current_device + +x = torch.randn((16, 256), device=get_current_device()) +# partition input +torch.distributed.broadcast(x, src=0) +x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)] +x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)] +x = torch.chunk(x, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)] +print_rank_0(f'Input: {x.shape}') + +x = m(x) +``` +Then we can see the shapes of activation results. +```shell +Input: torch.Size([4, 128]) +Output of the first linear layer: torch.Size([4, 512]) +Output of the second linear layer: torch.Size([4, 128]) +``` +The activation tensors in 3D parallelism are all split by $q^2$ in the row and $q$ in the column. +E.g. the output of the first linear layer has the shape `[4, 512]`), while the second layer has the output of `[4, 128]`. +Note, although the results of 3D parallelism have the same shape as that of 2.5D parallelism for weights here, the content of each partition is different. diff --git a/docs/source/en/features/gradient_accumulation.md b/docs/source/en/features/gradient_accumulation.md new file mode 100644 index 0000000000000000000000000000000000000000..ecc209fbac8d043d5330c6f971a5cd52efe55b38 --- /dev/null +++ b/docs/source/en/features/gradient_accumulation.md @@ -0,0 +1,45 @@ +# Gradient Accumulation + +Author: Shenggui Li, Yongbin Li + +**Prerequisite** +- [Define Your Configuration](../basics/define_your_config.md) +- [Use Engine and Trainer in Training](../basics/engine_trainer.md) + +**Example Code** +- [ColossalAI-Examples Gradient Accumulation](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/gradient_accumulation) + +## Introduction + +Gradient accumulation is a common way to enlarge your batch size for training. +When training large-scale models, memory can easily become the bottleneck and the batch size can be very small, (e.g. 2), +leading to unsatisfactory convergence. Gradient accumulation works by adding up the gradients calculated in multiple iterations, +and only update the parameters in the preset iteration. + +## Usage + +It is simple to use gradient accumulation in Colossal-AI. Just add this following configuration into your config file. +The integer represents the number of iterations to accumulate gradients. + +```python +gradient_accumulation = +``` + +## Hands-on Practice + +We provide a [runnable example](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/gradient_accumulation) +to demonstrate gradient accumulation. In this example, we set the gradient accumulation size to be 4. You can run the script using this command: + +```shell +python -m torch.distributed.launch --nproc_per_node 1 --master_addr localhost --master_port 29500 run_resnet_cifar10_with_engine.py +``` + +You will see output similar to the text below. This shows gradient is indeed accumulated as the parameter is not updated +in the first 3 steps, but only updated in the last step. + +```text +iteration 0, first 10 elements of param: tensor([-0.0208, 0.0189, 0.0234, 0.0047, 0.0116, -0.0283, 0.0071, -0.0359, -0.0267, -0.0006], device='cuda:0', grad_fn=) +iteration 1, first 10 elements of param: tensor([-0.0208, 0.0189, 0.0234, 0.0047, 0.0116, -0.0283, 0.0071, -0.0359, -0.0267, -0.0006], device='cuda:0', grad_fn=) +iteration 2, first 10 elements of param: tensor([-0.0208, 0.0189, 0.0234, 0.0047, 0.0116, -0.0283, 0.0071, -0.0359, -0.0267, -0.0006], device='cuda:0', grad_fn=) +iteration 3, first 10 elements of param: tensor([-0.0141, 0.0464, 0.0507, 0.0321, 0.0356, -0.0150, 0.0172, -0.0118, 0.0222, 0.0473], device='cuda:0', grad_fn=) +``` diff --git a/docs/source/en/features/gradient_clipping.md b/docs/source/en/features/gradient_clipping.md new file mode 100644 index 0000000000000000000000000000000000000000..f606dde6c393e56269f4ad97c402529dc52569f3 --- /dev/null +++ b/docs/source/en/features/gradient_clipping.md @@ -0,0 +1,62 @@ +# Gradient Clipping + +Author: Boxiang Wang, Haichen Huang, Yongbin Li + +**Prerequisite** +- [Define Your Configuration](../basics/define_your_config.md) +- [Use Engine and Trainer in Training](../basics/engine_trainer.md) + +**Example Code** +- [ColossalAI-Examples Gradient Clipping](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/gradient_clipping) + +**Related Paper** +- [On the difficulty of training Recurrent Neural Networks](https://arxiv.org/abs/1211.5063) + +## Introduction + +In order to speed up training process and seek global optimum for better performance, more and more learning +rate schedulers have been proposed. People turn to control learning rate to adjust descent pace during training, +which makes gradient vector better to be uniformed in every step. In that case, the descent pace can be +controlled as expected. As a result, gradient clipping, a technique which can normalize the gradient vector +to circumscribe it in a uniformed length, becomes indispensable for those who desire their better +performance of their models. + +You do not have to worry about implementing gradient clipping when using Colossal-AI, we support gradient +clipping in a powerful and convenient way. All you need is just an additional command in your configuration +file. + +## Why you should use gradient clipping provided by Colossal-AI + +The reason of why we do not recommend users to write gradient clipping by themselves is that naive gradient clipping +may fail when applying tensor parallelism, pipeline parallelism or MoE. + +According to the illustration below, each GPU only owns a portion of parameters of the weight in a linear layer. +To get correct norm of gradient vector of the weight of the linear layer, the norm of every gradient vector in each GPU +should be summed together. +More complicated thing is that the distribution of bias is different from the distribution of the weight. +The communication group is different in the sum operation. + +(PS: This situation is an old version of 2D parallelism, the implementation in the code is not the same. +But it is a good example about the difficulty to unify all communication in gradient clipping.) + +
+ +
Layout of parameters
+
+ +Do not worry about it, since Colossal-AI have handled it for you. + +### Usage +To use gradient clipping, you can just simply add gradient clipping norm in your configuration file. +```python +clip_grad_norm = 1.0 +``` + +### Hands-On Practice + +We provide a [runnable example](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/gradient_clipping) +to demonstrate gradient clipping. In this example, we set the gradient clipping vector norm to be 1.0. You can run the script using this command: + +```shell +python -m torch.distributed.launch --nproc_per_node 1 --master_addr localhost --master_port 29500 train_with_engine.py +``` diff --git a/docs/source/en/features/gradient_handler.md b/docs/source/en/features/gradient_handler.md new file mode 100644 index 0000000000000000000000000000000000000000..757016fcb53a5c16ac02810d88eee447fb659b2b --- /dev/null +++ b/docs/source/en/features/gradient_handler.md @@ -0,0 +1,63 @@ +# Gradient Handler + +Author: Shenggui Li, Yongbin Li + +**Prerequisite** +- [Define Your Configuration](../basics/define_your_config.md) +- [Use Engine and Trainer in Training](../basics/engine_trainer.md) + +**Example Code** +- [ColossalAI-Examples Gradient Handler](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/gradient_handler) + +## Introduction + +In distributed training, gradient synchronization is required at the end of each iteration. This is important because we +need to make sure the parameters are updated with the same gradients in different machines so that the resulting parameters +are the same. This is often seen in data parallel as the model is replicated across data parallel ranks. + +In Colossal-AI, we provide an interface for users to customize how they want to handle the synchronization. This brings +flexibility in cases such as implementing a new parallelism method. + +When gradient handlers are used, PyTorch `DistributedDataParallel` will not be used as it will synchronize automatically. + +## Customize Your Gradient Handlers + +To implement a customized gradient handler, you need to follow these steps. +1. inherit `BaseGradientHandler` in Colossal-AI. +2. register the gradient handler into the `GRADIENT_HANDLER`. +3. implement `handle_gradient` method. + +```python +from colossalai.registry import GRADIENT_HANDLER +from colossalai.engine.gradient_handler import BaseGradientHandler + + +@GRADIENT_HANDLER.register_module +class MyGradientHandler(BaseGradientHandler): + + def handle_gradient(self): + do_something() + + +``` + + +## Usage + +To use a gradient handler, you need to specify your gradient handler in the config file. The gradient handler +will be automatically built and attached to the engine. + +```python +gradient_handler = [dict(type='MyGradientHandler')] +``` + + +### Hands-On Practice + +We provide a [runnable example](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/gradient_handler) +to demonstrate the use of gradient handler. In this example, we used `DataParallelGradientHandler` instead of PyTorch +`DistributedDataParallel` for data parallel training. + +```shell +python -m torch.distributed.launch --nproc_per_node 4 --master_addr localhost --master_port 29500 train_with_engine.py +``` diff --git a/docs/source/en/features/mixed_precision_training.md b/docs/source/en/features/mixed_precision_training.md new file mode 100644 index 0000000000000000000000000000000000000000..11aa5235301a36f4fc75695e06fe52f8937386da --- /dev/null +++ b/docs/source/en/features/mixed_precision_training.md @@ -0,0 +1,367 @@ +# Auto Mixed Precision Training + +Author: Chuanrui Wang, Shenggui Li, Yongbin Li + +**Prerequisite** +- [Define Your Configuration](../basics/define_your_config.md) +- [Use Engine and Trainer in Training](../basics/engine_trainer.md) + +**Example Code** +- [ColossalAI-Examples AMP](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/amp) + +**Related Paper** +- [Accelerating Scientific Computations with Mixed Precision Algorithms](https://arxiv.org/abs/0808.2794) + + +## Introduction + +AMP stands for automatic mixed precision training. +In Colossal-AI, we have incorporated different implementations of mixed precision training: + +1. torch.cuda.amp +2. apex.amp +3. naive amp + + +| Colossal-AI | support tensor parallel | support pipeline parallel | fp16 extent | +| ----------- | ----------------------- | ------------------------- | ----------- | +| AMP_TYPE.TORCH | ✅ | ❌ | Model parameters, activation, gradients are downcast to fp16 during forward and backward propagation | +| AMP_TYPE.APEX | ❌ | ❌ | More fine-grained, we can choose opt_level O0, O1, O2, O3 | +| AMP_TYPE.NAIVE | ✅ | ✅ | Model parameters, forward and backward operations are all downcast to fp16 | + +The first two rely on the original implementation of PyTorch (version 1.6 and above) and NVIDIA Apex. +The last method is similar to Apex O2 level. +Among these methods, apex AMP is not compatible with tensor parallelism. +This is because that tensors are split across devices in tensor parallelism, thus, it is required to communicate among different processes to check if inf or nan occurs in the whole model weights. +We modified the torch amp implementation so that it is compatible with tensor parallelism now. + +> ❌️ fp16 and zero configuration are not compatible +> +> ⚠️ Pipeline only support naive AMP currently + +We recommend you to use torch AMP as it generally gives better accuracy than naive AMP if no pipeline is used. + +## Table of Contents + +In this tutorial we will cover: + +1. AMP introduction +2. AMP in Colossal-AI +3. Hands-on Practice + +## AMP Introduction + +Automatic Mixed Precision training is a mixture of FP16 and FP32 training. + +Half-precision float point format (FP16) has lower arithmetic complexity and higher compute efficiency. +Besides, fp16 requires half of the storage needed by fp32 and saves memory & network bandwidth, which makes more memory +available for large batch size and model size. + +However, there are other operations, like reductions, which require the dynamic range of fp32 to avoid numeric overflow/underflow. That's the reason why we introduce automatic mixed precision, attempting to match each operation to its appropriate data type, which can reduce the memory footprint and augment training efficiency. + +
+ +
Illustration of an ordinary AMP (figure from PatrickStar paper)
+
+ +## AMP in Colossal-AI + +We supported three AMP training methods and allowed the user to train with AMP with no code. You can just simply add `fp16` +configuration in your configuration file to use AMP. + + +```python +from colossalai.amp import AMP_TYPE + +# use Torch AMP +fp16=dict( + mode = AMP_TYPE.TORCH +) + +# use naive AMP +fp16=dict( + mode = AMP_TYPE.NAIVE +) + +# use NVIDIA Apex AMP +fp16=dict( + mode = AMP_TYPE.APEX +) + +``` + +> These are the minimum configuration, full configuration are stated in the section later + +### AMP Modularity + +AMP module is designed to be completely modular and can be used independently. +If you wish to only use AMP in your code base without `colossalai.initialize`, +you can use `colossalai.amp.convert_to_amp`. + +```python +from colossalai.amp import AMP_TYPE + +# example of using torch amp +model, optimizer, criterion = colossalai.amp.convert_to_amp(model, + optimizer, + criterion, + AMP_TYPE.TORCH) +``` + +### Torch AMP Configuration + +```python +from colossalai.amp import AMP_TYPE + +fp16=dict( + mode=AMP_TYPE.TORCH, + + # below are default values for grad scaler + init_scale=2.**16, + growth_factor=2.0, + backoff_factor=0.5, + growth_interval=2000, + enabled=True +) +``` + +With optional arguments: +- init_scale(float, optional, default=2.**16): Initial scale factor +- growth_factor(float, optional, default=2.0): Factor by which the scale is multiplied during `update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations. +- backoff_factor(float, optional, default=0.5): Factor by which the scale is multiplied during `update` if inf/NaN gradients occur in an iteration. +- growth_interval(int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients that must occur for the scale to be multiplied by ``growth_factor``. +- enabled(bool, optional, default=True): If ``False``, disables gradient scaling. `step` simply invokes the underlying ``optimizer.step()``, and other methods become no-ops. + +### Apex AMP Configuration + +For this mode, we rely on the Apex implementation for mixed precision training. +We support this plugin because it allows for finer control on the granularity of mixed precision. +For example, O2 level (optimization level 2) will keep batch normalization in fp32. + +If you look for more details, please refer to [Apex Documentation](https://nvidia.github.io/apex/). + +```python +from colossalai.amp import AMP_TYPE + +fp16 = dict( + mode=AMP_TYPE.APEX, + + # below are the default values + enabled=True, + opt_level='O1', + cast_model_type=None, + patch_torch_functions=None, + keep_batchnorm_fp32=None, + master_weights=None, + loss_scale=None, + cast_model_outputs=None, + num_losses=1, + verbosity=1, + min_loss_scale=None, + max_loss_scale=16777216.0 +) +``` + +Parameters: +- enabled(bool, optional, default=True): If False, renders all AMP calls no-ops, so your script should run as if Amp were not present. + +- opt_level(str, optional, default="O1" ): Pure or mixed precision optimization level. +Accepted values are “O0”, “O1”, “O2”, and “O3”, explained in detail above Apex AMP Documentation. + +- num_losses(int, optional, default=1): Option to tell AMP in advance how many losses/backward passes you plan to use. +When used in conjunction with the loss_id argument to `amp.scale_loss`, enables Amp to use a different loss scale per +loss/backward pass, which can improve stability. If num_losses is left to 1, Amp will still support multiple +losses/backward passes, but use a single global loss scale for all of them. + +- verbosity(int, default=1): Set to 0 to suppress Amp-related output. + +- min_loss_scale(float, default=None): Sets a floor for the loss scale values that can be chosen by dynamic loss scaling. +The default value of None means that no floor is imposed. If dynamic loss scaling is not used, min_loss_scale is ignored. + +- max_loss_scale(float, default=2.**24 ): Sets a ceiling for the loss scale values that can be chosen by dynamic loss +scaling. If dynamic loss scaling is not used, max_loss_scale is ignored. + +Currently, the under-the-hood properties that govern pure or mixed precision training are the following: +cast_model_type, patch_torch_functions, keep_batchnorm_fp32, master_weights, loss_scale. +They are optional properties override once opt_level is determined + +- cast_model_type: Casts your model’s parameters and buffers to the desired type. +- patch_torch_functions: Patch all Torch functions and Tensor methods to perform Tensor Core-friendly ops like GEMMs and convolutions in FP16, and any ops that benefit from FP32 precision in FP32. +- keep_batchnorm_fp32: To enhance precision and enable cudnn batchnorm (which improves performance), it’s often beneficial to keep batchnorm weights in FP32 even if the rest of the model is FP16. +- master_weights: Maintain FP32 master weights to accompany any FP16 model weights. FP32 master weights are stepped by the optimizer to enhance precision and capture small gradients. +- loss_scale: If loss_scale is a float value, use this value as the static (fixed) loss scale. If loss_scale is the string "dynamic", adaptively adjust the loss scale over time. Dynamic loss scale adjustments are performed by Amp automatically. + + +### Naive AMP Configuration + +In Naive AMP mode, we achieved mixed precision training while maintaining compatibility with complex tensor and pipeline parallelism. +This AMP mode will cast all operations into fp16. +The following code block shows the `config.py` file for this mode. + +```python +from colossalai.amp import AMP_TYPE + +fp16 = dict( + mode=AMP_TYPE.NAIVE, + + # below are the default values + log_num_zeros_in_grad=False, + initial_scale=2 ** 32, + min_scale=1, + growth_factor=2, + backoff_factor=0.5, + growth_interval=1000, + hysteresis=2 +) +``` + +The default parameters of Naive AMP: +- log_num_zeros_in_grad(bool): return number of zeros in the gradients. +- initial_scale(int): initial scale of gradient scaler +- growth_factor(int): the growth rate of loss scale +- backoff_factor(float): the decrease rate of loss scale +- hysteresis(int): delay shift in dynamic loss scaling +- max_scale(int): maximum loss scale allowed +- verbose(bool): if set to `True`, will print debug info + +When using `colossalai.initialize`, you are required to first instantiate a model, an optimizer and a criterion. +The output model is converted to AMP model of smaller memory consumption. +If your input model is already too large to fit in a GPU, please instantiate your model weights in `dtype=torch.float16`. +Otherwise, try smaller models or checkout more parallelization training techniques! + + +## Hands-on Practice + +We provide a [runnable example](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/amp) which demonstrates +the use of AMP with Colossal-AI. In this practice, we will use Torch AMP as an example, but do note that config files are provided for all AMP modes. + +### Step 1. Create a config file + +Create a `config.py` and add the `fp16` configuration. + +```python +# in config.py +from colossalai.amp import AMP_TYPE + +BATCH_SIZE = 128 +DROP_RATE = 0.1 +NUM_EPOCHS = 300 + +fp16 = dict( + mode=AMP_TYPE.TORCH, +) + +clip_grad_norm = 1.0 +``` + +### Step 2. Import libraries in train_with_engine.py + +Create a `train_with_engine.py` and import the necessary dependencies. Remember to install `scipy` and `timm` by running +`pip install timm scipy`. + +```python +import os +import colossalai +import torch +from pathlib import Path +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.utils import get_dataloader +from colossalai.trainer import Trainer, hooks +from colossalai.nn.lr_scheduler import LinearWarmupLR +from timm.models import vit_base_patch16_224 +from torchvision import datasets, transforms + +``` + +### Step 3. Initialize Distributed Environment + +We then need to initialize distributed environment. For demo purpose, we uses `launch_from_torch`. You can refer to [Launch Colossal-AI](../basics/launch_colossalai.md) +for other initialization methods. + +```python +# initialize distributed setting +parser = colossalai.get_default_parser() +args = parser.parse_args() + +# launch from torch +colossalai.launch_from_torch(config=args.config) + +``` + +### Step 4. Create training components + +Build your model, optimizer, loss function, lr scheduler and dataloaders. Note that the root path of the dataset is +obtained from the environment variable `DATA`. You may `export DATA=/path/to/data` or change `Path(os.environ['DATA'])` +to a path on your machine. Data will be automatically downloaded to the root path. + +```python +# build model + model = vit_base_patch16_224(drop_rate=0.1) + + # build dataloader + train_dataset = datasets.Caltech101( + root=Path(os.environ['DATA']), + download=True, + transform=transforms.Compose([ + transforms.Resize(256), + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + Gray2RGB(), + transforms.Normalize([0.5, 0.5, 0.5], + [0.5, 0.5, 0.5]) + ])) + + train_dataloader = get_dataloader(dataset=train_dataset, + shuffle=True, + batch_size=gpc.config.BATCH_SIZE, + num_workers=1, + pin_memory=True, + ) + + # build optimizer + optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, weight_decay=0.1) + + # build loss + criterion = torch.nn.CrossEntropyLoss() + + # lr_scheduler + lr_scheduler = LinearWarmupLR(optimizer, warmup_steps=50, total_steps=gpc.config.NUM_EPOCHS) +``` + +### Step 5. Inject AMP Feature + +Call `colossalai.initialize` to convert the training components to be running with FP16. + +```python +engine, train_dataloader, _, _ = colossalai.initialize( + model, optimizer, criterion, train_dataloader, + ) +``` + +### Step 6. Train with Engine + +Use engine in a normal training loops. + +```python +engine.train() +for epoch in range(gpc.config.NUM_EPOCHS): + for img, label in enumerate(train_dataloader): + img = img.cuda() + label = label.cuda() + engine.zero_grad() + output = engine(img) + loss = engine.criterion(output, label) + engine.backward(loss) + engine.step() + lr_scheduler.step() +``` + +### Step 7. Invoke Training Scripts + +Use the following command to start the training scripts. You can change `--nproc_per_node` to use a different number of GPUs. + +```python +python -m torch.distributed.launch --nproc_per_node 4 --master_addr localhost --master_port 29500 train_with_engine.py --config config/config_AMP_torch.py +``` diff --git a/docs/source/en/features/nvme_offload.md b/docs/source/en/features/nvme_offload.md new file mode 100644 index 0000000000000000000000000000000000000000..4374da3c9c4558f5aca8e7d15603e95ef0477573 --- /dev/null +++ b/docs/source/en/features/nvme_offload.md @@ -0,0 +1,263 @@ +# NVMe offload + +Author: Hongxin Liu + +**Prerequisite:** +- [Zero Redundancy Optimizer with chunk-based memory management](../features/zero_with_chunk.md) + +**Related Paper** + +- [ZeRO-Offload: Democratizing Billion-Scale Model Training](https://arxiv.org/abs/2101.06840) +- [ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning](https://arxiv.org/abs/2104.07857) + +## Introduction + +If a model has `N` parameters, when using Adam, it has `8N` optimizer states. For billion-scale models, optimizer states take at least 32 GB memory. GPU memory limits the model scale we can train, which is called GPU memory wall. If we offload optimizer states to the disk, we can break through GPU memory wall. + +We implement a user-friendly and efficient asynchronous Tensor I/O library: [TensorNVMe](https://github.com/hpcaitech/TensorNVMe). With this library, we can simply implement NVMe offload. + +> This library is compatible with all kinds of disk (HDD, SATA SSD, and NVMe SSD). As I/O bandwidth of HDD or SATA SSD is low, it's recommended to use this lib only on NVMe disk. + +When optimizing a parameter, we can divide the optimization process into three stages: read, compute and offload. We perform the optimization process in a pipelined fashion, which can overlap computation and I/O. + +
+ +
Optimization process
+
+ +## Usage + +First, please make sure you installed [TensorNVMe](https://github.com/hpcaitech/TensorNVMe): + +```shell +pip install packaging +pip install tensornvme +``` + +We implement NVMe offload of optimizer states for Adam ([CPUAdam](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.nn.optimizer.cpu_adam.html) and [HybridAdam](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.nn.optimizer.hybrid_adam.html)). + + + + +```python +from colossalai.nn.optimizer import CPUAdam, HybridAdam + +optimizer = HybridAdam(model.parameters(), lr=1e-3, nvme_offload_fraction=1.0, nvme_offload_dir='./') +``` + + + +`nvme_offload_fraction` is the fraction of optimizer states to be offloaded to NVMe. `nvme_offload_dir` is the directory to save NVMe offload files. If `nvme_offload_dir` is `None`, a random temporary directory will be used. + +It's compatible with all parallel methods in ColossalAI. + +> ⚠ It only offloads optimizer states on CPU. This means it only affects CPU training or Zero/Gemini with offloading. + +## Exampls + +Let's start from two simple examples -- training GPT with different methods. These examples relies on `transformers`. + +We should install dependencies first: + +```shell +pip install psutil transformers +``` + +First, we import essential packages and modules: + +```python +import os +import time +from typing import Dict, Optional + +import psutil +import torch +import torch.nn as nn +from transformers.models.gpt2.configuration_gpt2 import GPT2Config +from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel + +import colossalai +from colossalai.nn.optimizer import HybridAdam +from colossalai.zero import zero_model_wrapper, zero_optim_wrapper +from colossalai.utils.model.colo_init_context import ColoInitContext +``` + +Then we define a loss function: + +```python +class GPTLMLoss(nn.Module): + + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, logits, labels): + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1)) +``` + +And we define some utility functions, which generates random data, computes the number of parameters of a model and get memory usage of current process: + +```python +def get_data(batch_size: int, seq_len: int, + vocab_size: int, device: Optional[str] = None) -> Dict[str, torch.Tensor]: + device = torch.cuda.current_device() if device is None else device + input_ids = torch.randint(vocab_size, (batch_size, seq_len), + device=device) + attn_mask = torch.ones_like(input_ids) + return dict(input_ids=input_ids, attention_mask=attn_mask) + + +def get_model_numel(model: nn.Module) -> int: + return sum(p.numel() for p in model.parameters()) + + +def get_mem_usage() -> int: + proc = psutil.Process(os.getpid()) + return proc.memory_info().rss +``` + +We first try to train GPT model on CPU: + +```python +def train_cpu(nvme_offload_fraction: float = 0.0): + config = GPT2Config() + model = GPT2LMHeadModel(config) + criterion = GPTLMLoss() + optimizer = HybridAdam(model.parameters(), nvme_offload_fraction=nvme_offload_fraction) + print(f'Model numel: {get_model_numel(model) / 1024**3:.3f} B') + + start = time.time() + for step in range(3): + data = get_data(4, 128, config.vocab_size, device='cpu') + outputs = model(**data) + loss = criterion(outputs.logits, data['input_ids']) + loss.backward() + optimizer.step() + optimizer.zero_grad() + print(f'[{step}] loss: {loss.item():.3f}') + + print(f'Time: {time.time() - start:.3f} s') + print(f'Mem usage: {get_mem_usage() / 1024**2:.3f} MB') +``` + +Run without NVME offload: + +```python +train_cpu(0.0) +``` + +We may get below output: + +``` +Model numel: 0.116 B +[0] loss: 10.953 +[1] loss: 10.974 +[2] loss: 10.965 +Time: 7.739 s +Mem usage: 5966.445 MB +``` + +And then run with (full) NVME offload: + +```python +train_cpu(1.0) +``` + +We may get: + +``` +Model numel: 0.116 B +[0] loss: 10.951 +[1] loss: 10.994 +[2] loss: 10.984 +Time: 8.527 s +Mem usage: 4968.016 MB +``` + +For GPT2-S, which has 0.116 billion parameters, its optimizer states take about 0.928 GB memory. And NVME offload saves about 998 MB memory, which meets our expectations. + +Then we can train GPT model with Gemini. The placement policy of Gemini should be `"auto"`, `"cpu"` or `"const"`. + +```python +def train_gemini_cpu(nvme_offload_fraction: float = 0.0): + colossalai.launch_from_torch({}) + config = GPT2Config() + with ColoInitContext(device=torch.cuda.current_device()): + model = GPT2LMHeadModel(config) + criterion = GPTLMLoss() + optimizer = HybridAdam(model.parameters(), nvme_offload_fraction=nvme_offload_fraction) + print(f'Model numel: {get_model_numel(model) / 1024**3:.3f} B') + + gemini_config = dict(strict_ddp_mode=True, device=torch.cuda.current_device(), + placement_policy='cpu', pin_memory=True, hidden_dim=config.n_embd) + model = zero_model_wrapper(model, zero_stage=3, gemini_config=gemini_config) + optimizer = zero_optim_wrapper(model, optimizer, initial_scale=2**5) + + start = time.time() + for step in range(3): + data = get_data(4, 128, config.vocab_size) + outputs = model(**data) + loss = criterion(outputs.logits, data['input_ids']) + optimizer.backward(loss) + optimizer.step() + optimizer.zero_grad() + print(f'[{step}] loss: {loss.item():.3f}') + + print(f'Time: {time.time() - start:.3f} s') + print(f'Mem usage: {get_mem_usage() / 1024**2:.3f} MB') +``` + +Run without NVME offload: + +```python +train_gemini_cpu(0.0) +``` + +We may get: + +``` +Model numel: 0.116 B +searching chunk configuration is completed in 0.27 s. +used number: 118.68 MB, wasted number: 0.75 MB +total wasted percentage is 0.63% +[0] loss: 10.953 +[1] loss: 10.938 +[2] loss: 10.969 +Time: 2.997 s +Mem usage: 5592.227 MB +``` + +And run with (full) NVME offload: + +```python +train_gemini_cpu(1.0) +``` + +We may get: + +``` +Model numel: 0.116 B +searching chunk configuration is completed in 0.27 s. +used number: 118.68 MB, wasted number: 0.75 MB +total wasted percentage is 0.63% +[0] loss: 10.953 +[1] loss: 10.938 +[2] loss: 10.969 +Time: 3.691 s +Mem usage: 5298.344 MB +``` + +NVME offload saves about 294 MB memory. Note that enabling `pin_memory` of Gemini can accelerate training but increase memory usage. So this result also meets our expectation. If we disable `pin_memory`, we can also observe a memory usage drop about 900 MB. + +## API Reference + +{{ autodoc:colossalai.nn.optimizer.HybridAdam }} + +{{ autodoc:colossalai.nn.optimizer.CPUAdam }} + + + diff --git a/docs/source/en/features/pipeline_parallel.md b/docs/source/en/features/pipeline_parallel.md new file mode 100644 index 0000000000000000000000000000000000000000..ac49863b3c719241f8e7cc6349c3730c85519d03 --- /dev/null +++ b/docs/source/en/features/pipeline_parallel.md @@ -0,0 +1,159 @@ +# Pipeline Parallel + +Author: Guangyang Lu, Hongxin Liu, Yongbin Li + +**Prerequisite** +- [Define Your Configuration](../basics/define_your_config.md) +- [Use Engine and Trainer in Training](../basics/engine_trainer.md) +- [Configure Parallelization](../basics/configure_parallelization.md) + +**Example Code** +- [ColossalAI-Examples ResNet with pipeline](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/pipeline_parallel) + +**Related Paper** +- [Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training](https://arxiv.org/abs/2110.14883) +- [Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM](https://arxiv.org/abs/2104.04473) +- [GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism](https://arxiv.org/abs/1811.06965) + +## Quick introduction + +In this tutorial, you will learn how to use pipeline parallel. In Colossal-AI, we use 1F1B pipeline, introduced by Nvidia. In this case, ViT and Imagenet are too large to use. Therefore, here we use ResNet and Cifar as example. + +## Table Of Content + +In this tutorial we will cover: + +1. Introduction of 1F1B pipeline. +2. Usage of non-interleaved and interleaved schedule. +3. Training ResNet with pipeline. + +## Introduction of 1F1B pipeline + +First of all, we will introduce you GPipe for your better understanding. + +
+ +
Figure1: GPipe. This figure is from Megatron-LM paper.
+
+ + +As you can see, for GPipe, only when the forward passes of all microbatches in a batch finish, the backward passes would be executed. + +In general, 1F1B(one forward pass followed by one backward pass) is more efficient than GPipe(in memory or both memory and time). There are two schedules of 1F1B pipeline, the non-interleaved and the interleaved. The figures are shown below. + +
+ +
Figure2: This figure is from Megatron-LM paper. The top part shows the default non-interleaved schedule. And the bottom part shows the interleaved schedule.
+
+ +### Non-interleaved Schedule + +The non-interleaved schedule can be divided into three stages. The first stage is the warm-up stage, where workers perform differing numbers of forward passes. At the following stage, workers perform one forward pass followed by one backward pass. Workers will finish backward passes at the last stage. + +This mode is more memory-efficient than GPipe. However, it would take the same time to finish a turn of passes as GPipe. + +### Interleaved Schedule + +This schedule requires **the number of microbatches to be an integer multiple of the stage of pipeline**. + +In this schedule, each device can perform computation for multiple subsets of layers(called a model chunk) instead of a single contiguous set of layers. i.e. Before device 1 had layer 1-4; device 2 had layer 5-8; and so on. But now device 1 has layer 1,2,9,10; device 2 has layer 3,4,11,12; and so on. With this scheme, each device in the pipeline is assigned multiple pipeline stages and each pipeline stage has less computation. + +This mode is both memory-efficient and time-efficient. + +## Usage of non-interleaved and interleaved schedule + +In Colossal-AI, we provided both non-interleaved(as `PipelineSchedule`) and interleaved schedule(as `InterleavedPipelineSchedule`). + +You just need to set `NUM_MICRO_BATCHES` in config file and set `NUM_CHUNKS` in config file if you want to use Interleaved Pipeline Schedule. If you certainly know the shape of each pipeline stage's output tensor and the shapes are all the same, you can set `TENSOR_SHAPE` in config file to further reduce communication. Otherwise, you can just ignore `tensor_shape`, and the shape will be exchanged over pipeline stages automatically. Then we will generate an appropriate schedule for you. + +## Training ResNet with pipeline + +Let's build the `ResNet` model first with Colossal PipelinableContext: +```python +import os +from typing import Callable, List, Optional, Type, Union +import torch +import torch.nn as nn +import colossalai +import colossalai.nn as col_nn + +from colossalai.core import global_context as gpc +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.trainer import Trainer, hooks +from colossalai.utils import MultiTimer, get_dataloader +from colossalai.context import ParallelMode +from colossalai.pipeline.pipelinable import PipelinableContext + +from titans.dataloader.cifar10 import build_cifar +from torchvision.models import resnet50 +from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1 + +# Define some config +BATCH_SIZE = 64 +NUM_EPOCHS = 2 +NUM_CHUNKS = 1 +CONFIG = dict(NUM_MICRO_BATCHES=4, parallel=dict(pipeline=2)) + +# Train +disable_existing_loggers() +parser = colossalai.get_default_parser() +args = parser.parse_args() +colossalai.launch_from_torch(backend=args.backend, config=CONFIG) +logger = get_dist_logger() +pipelinable = PipelinableContext() + +# build model +with pipelinable: + model = resnet50() +``` + +Define an execution sequence. +```python +exec_seq = [ + 'conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool', + (lambda x: torch.flatten(x, 1), "behind"), 'fc' +] +pipelinable.to_layer_list(exec_seq) +``` + +Partition the model into pipeline. +```python +model = pipelinable.partition(NUM_CHUNKS, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE)) +``` + +In this tutorial, we use `Trainer` to train `ResNet`: +```python +# build criterion +criterion = nn.CrossEntropyLoss() + +# optimizer +optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + +# build dataloader +root = os.environ.get('DATA', './data') +train_dataloader, test_dataloader = build_cifar(BATCH_SIZE, root, padding=4, crop=32, resize=32) + +lr_scheduler = col_nn.lr_scheduler.LinearWarmupLR(optimizer, NUM_EPOCHS, warmup_steps=1) +engine, train_dataloader, test_dataloader, lr_scheduler = colossalai.initialize(model, optimizer, criterion, + train_dataloader, test_dataloader, + lr_scheduler) +timer = MultiTimer() + +trainer = Trainer(engine=engine, timer=timer, logger=logger) + +hook_list = [ + hooks.LossHook(), + hooks.AccuracyHook(col_nn.metric.Accuracy()), + hooks.LogMetricByEpochHook(logger), + hooks.LRSchedulerHook(lr_scheduler, by_epoch=True) +] + +trainer.fit(train_dataloader=train_dataloader, + epochs=NUM_EPOCHS, + test_dataloader=test_dataloader, + test_interval=1, + hooks=hook_list, + display_progress=True) +``` + +We use `2` pipeline stages and the batch will be splitted into `4` micro batches. diff --git a/docs/source/en/features/zero_with_chunk.md b/docs/source/en/features/zero_with_chunk.md new file mode 100644 index 0000000000000000000000000000000000000000..a105831a54099456cc6489cbb9c5599b0b262933 --- /dev/null +++ b/docs/source/en/features/zero_with_chunk.md @@ -0,0 +1,265 @@ +# Zero Redundancy Optimizer with chunk-based memory management + +Author: [Hongxiu Liu](https://github.com/ver217), [Jiarui Fang](https://github.com/feifeibear), [Zijian Ye](https://github.com/ZijianYY) + +**Prerequisite:** +- [Define Your Configuration](../basics/define_your_config.md) + +**Example Code** + +- [Train GPT with Colossal-AI](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/gpt) + +**Related Paper** + +- [ZeRO: Memory Optimizations Toward Training Trillion Parameter Models](https://arxiv.org/abs/1910.02054) +- [ZeRO-Offload: Democratizing Billion-Scale Model Training](https://arxiv.org/abs/2101.06840) +- [ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning](https://arxiv.org/abs/2104.07857) +- [DeepSpeed: System Optimizations Enable Training Deep Learning Models with Over 100 Billion Parameters](https://dl.acm.org/doi/10.1145/3394486.3406703) +- [PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management](https://arxiv.org/abs/2108.05818) + +## Introduction + +The Zero Redundancy Optimizer (ZeRO) removes the memory redundancies across data-parallel processes by partitioning three +model states (optimizer states, gradients, and parameters) instead of replicating them. +By doing so, memory efficiency is boosted drastically compared to classic data parallelism, while the computational granularity +and communication efficiency is retained. + +1. **Shard Optimizer States**: The optimizer states (e.g., for [Adam optimizer](https://arxiv.org/abs/1412.6980), 32-bit weights, +and the first and second momentum estimates) are partitioned across the processes, so that each process updates only its partition. + + +2. **Shard Gradient**: After reduction inside data parallel process group, gradient tensors are also partitioned such that each process only stores the gradients corresponding to its partition of the optimizer states. Note, Colossal converts gradient into fp32 format to participate in parameter updating. + +3. **Shard Parameter**: The 16-bit model parameters are partitioned across the processes of a data parallel group. + +4. **[Gemini](../advanced_tutorials/meet_gemini.md)**: Dynamic heterogeneous memory space manager for parameters, gradients and optimizer states. + +Besides, this article will introduce the Zero Redundancy Optimizer with chunk-based memory management. + +When using ZeRO, we distributed the model by sharding the parameters. The advantage of this method is that the memory of each node is load balanced. But this approach has two significant disadvantages. First, during communication, a temporary memory buffer needs to be allocated and released afterwards, leading to the memory fragmentation problem. Secondly, using tensor as the granularity for communication will cause the network bandwidth underutilized. Generally, the longer the transmitted message length, the higher the bandwidth utilization. + +Using the Chunk mechanism introduced in ColossalAI v0.1.8, we can improve the efficiency of ZeRO. We store a continuous set of parameters in initialization order into a Chunk (a chunk is a continuous memory space), and each Chunk has the same size. Organizing memory in chunks can lead to efficient use of network bandwidth between PCI-e and GPU-GPU, reduce the number of communications, and avoid potential memory fragmentation. + +Before v0.1.8, ZeRO had a high communication cost for parameter communications. If a parameter was used multiple times in several consecutive operators, there will be repeated communications operations, and the efficiency was highly damaged. This situation is very common when using the Gradient Checkpoint technique, and the parameter will recompute the forward propagation during backward propagation. + +Taking GPT as an example, its Checkpoint will be applied to each GPT Block, and each GPT Block contains a Self-Attention layer and an MLP layer. During the backward pass, the forward of the Self-Attention layer and the MLP layer will be computed in turn, and then the backward of the MLP layer and the Self-Attention layer will be computed in turn. + +In addition, due to the communication and memory movement of small Tensors, the bandwidth of NVLINK and PCI-E cannot be fully utilized, and each communication and memory movement has the overhead of kernel launch. After using Chunk, multiple small Tensor communication and memory movement can be changed into one large Tensor communication and memory movement, which not only improves bandwidth utilization but also reduces the overhead of kernel launch. + +We also provide a lightweight chunk search mechanism to help users automatically find the chunk size with the smallest memory fragmentation. + +## Usage + +### GeminiDDP + +We will use `GeminiDDP` to use ZeRO with chunk-based memory management. This is our new torch.Module wrapper which uses ZeRO-DP and Gemini. ZeRO is for parallelism and Gemini is for memory management. + +Also Make sure that your model is initialized under the context of ColoInitContext. + +```python +with ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg): + model = gpt2_medium(checkpoint=True) +``` + +Define the model parameters as follows: + +```python +chunk_manager = init_chunk_manager(model=module, + init_device=device, + hidden_dim=hidden_dim, + search_range_mb=search_range_mb, + min_chunk_size_mb=min_chunk_size_mb) +gemini_manager = GeminiManager(placement_policy, chunk_manager) +``` + +`hidden_dim` is the hidden dimension of DNN. Users can provide this argument to speed up searching. If users do not know this argument before training, it is ok. We will use a default value 1024. `min_chunk_size_mb` is the the minimum chunk size in MegaByte. If the aggregate size of parameters is still samller than the minimum chunk size, all parameters will be compacted into one small chunk. + +Initialization of the optimizer. +```python +optimizer = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=2**5) +``` + +Training +```python +optimizer.zero_grad() +outputs = model(input_ids, attn_mask) +loss = criterion(outputs, input_ids) +optimizer.backward(loss) +optimizer.step() +``` +> ⚠️ Note: Please do not use `loss.backward()`, the standard way of writing is `optimizer.backward(loss)`. + +### Train GPT + +In this example, we use `Hugging Face Transformers`. You have to install `transformers` before running this example. We will take `GPT2 Medium` as an example here. + +For simplicity, we just use randomly generated data here. + +First we only need to import `GPT2LMHeadModel` from `Huggingface transformers` to define our model, which does not require users to define or modify the model, so that users can use it more conveniently. + +```python +class GPTLMModel(nn.Module): + + def __init__(self, + hidden_size=768, + num_layers=12, + num_attention_heads=12, + max_seq_len=1024, + vocab_size=50257, + checkpoint=False): + super().__init__() + self.checkpoint = checkpoint + self.model = GPT2LMHeadModel( + GPT2Config(n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size)) + if checkpoint: + self.model.gradient_checkpointing_enable() + + def forward(self, input_ids, attention_mask): + return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0] + +def gpt2_medium(checkpoint=False): + return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint) +``` + +Define our loss function: + +```python +class GPTLMLoss(nn.Module): + + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, logits, labels): + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) +``` + +Define tensor parallel and parameter sharding strategies for tensor parallelism: + +```python +def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup): + for mn, module in model.named_modules(): + for pn, param in module.named_parameters(recurse=False): + if hasattr(param, 'visited'): + continue + param.set_dist_spec(ReplicaSpec()) + if 'mlp.c_fc' in mn: + if 'weight' in pn or 'bias' in pn: + split_param_col_tp1d(param, pg) + param.compute_spec.set_output_replicate(False) + else: + param.set_dist_spec(ReplicaSpec()) + elif 'mlp.c_proj' in mn: + if 'weight' in pn: + split_param_row_tp1d(param, pg) + else: + param.set_dist_spec(ReplicaSpec()) + elif 'wte' in mn or 'wpe' in mn: + split_param_col_tp1d(param, pg) + elif 'c_attn' in mn or 'c_proj' in mn: + split_param_col_tp1d(param, pg) + else: + param.set_dist_spec(ReplicaSpec()) + + param.visited = True +def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup): + spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + param.set_tensor_spec(*spec) + + +def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup): + split_param_single_dim_tp1d(0, param, pg) + + +def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup): + split_param_single_dim_tp1d(-1, param, pg) +``` + +Define a model which uses Gemini + ZeRO DDP: + +```python +def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"): + cai_version = colossalai.__version__ + if version.parse(cai_version) > version.parse("0.1.10"): + from colossalai.nn.parallel import GeminiDDP + model = GeminiDDP(model, + device=get_current_device(), + placement_policy=placememt_policy, + pin_memory=True, + search_range_mb=32) + elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"): + from colossalai.gemini import ChunkManager, GeminiManager + chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32) + gemini_manager = GeminiManager(placememt_policy, chunk_manager) + chunk_manager = ChunkManager(chunk_size, + pg, + enable_distributed_storage=True, + init_device=GeminiManager.get_default_device(placememt_policy)) + model = ZeroDDP(model, gemini_manager) + else: + raise NotImplemented(f"CAI version {cai_version} is not supported") + return model +``` + +As we pre-train GPT in this example, we just use a simple language model loss. + +Write a function to get random inputs: + +```python +def get_data(batch_size, seq_len, vocab_size): + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device()) + attention_mask = torch.ones_like(input_ids) + return input_ids, attention_mask +``` + +Finally, we can define our training loop: + +```python +def main(): + args = parse_args() + BATCH_SIZE = 8 + SEQ_LEN = 1024 + VOCAB_SIZE = 50257 + NUM_STEPS = 10 + colossalai.launch_from_torch(config={}) + + # build criterion + criterion = GPTLMLoss() + + torch.manual_seed(123) + default_pg = ProcessGroup(tp_degree=args.tp_degree) + default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None + # build GPT model + with ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg): + model = gpt2_medium(checkpoint=True) + pg = default_pg + # Tensor Parallelism (TP) + tensor_parallelize(model, pg) + # Gemini + ZeRO DP, Note it must be used after TP + model = gemini_zero_dpp(model, pg, args.placement) + # build optimizer + optimizer = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=2**5) + numel = sum([p.numel() for p in model.parameters()]) + get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN) + torch.cuda.synchronize() + model.train() + for n in range(NUM_STEPS): + # we just use randomly generated data here + input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE) + optimizer.zero_grad() + outputs = model(input_ids, attn_mask) + loss = criterion(outputs, input_ids) + optimizer.backward(loss) + optimizer.step() + + torch.cuda.synchronize() +``` +> ⚠️ Note: If you want to use the Gemini module, please do not use the [Gradient Accumulation](../features/gradient_accumulation.md) we mentioned before。 +The complete example can be found on [Train GPT with Colossal-AI](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/gpt). diff --git a/docs/source/en/get_started/installation.md b/docs/source/en/get_started/installation.md new file mode 100644 index 0000000000000000000000000000000000000000..290879219074bfc940705bf20f8167d08c941166 --- /dev/null +++ b/docs/source/en/get_started/installation.md @@ -0,0 +1,52 @@ +# Setup + +Requirements: +- PyTorch >= 1.11 (PyTorch 2.x in progress) +- Python >= 3.7 +- CUDA >= 11.0 +- [NVIDIA GPU Compute Capability](https://developer.nvidia.com/cuda-gpus) >= 7.0 (V100/RTX20 and higher) +- Linux OS + +If you encounter any problem about installation, you may want to raise an [issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose) in this repository. + + +## Download From PyPI + +You can install Colossal-AI with + +```shell +pip install colossalai +``` + +**Note: only Linux is supported for now** + +If you want to build PyTorch extensions during installation, you can use the command below. Otherwise, the PyTorch extensions will be built during runtime. + +```shell +CUDA_EXT=1 pip install colossalai +``` + + +## Download From Source + +> The version of Colossal-AI will be in line with the main branch of the repository. Feel free to raise an issue if you encounter any problem. :) + +```shell +git clone https://github.com/hpcaitech/ColossalAI.git +cd ColossalAI + +# install dependency +pip install -r requirements/requirements.txt + +# install colossalai +pip install . +``` + +If you don't want to install and enable CUDA kernel fusion (compulsory installation when using fused optimizer): + +```shell +CUDA_EXT=1 pip install . +``` + + + diff --git a/docs/source/en/get_started/reading_roadmap.md b/docs/source/en/get_started/reading_roadmap.md new file mode 100644 index 0000000000000000000000000000000000000000..476c524ac011b3b85cd3480ae8b22f6d3f2fee59 --- /dev/null +++ b/docs/source/en/get_started/reading_roadmap.md @@ -0,0 +1,19 @@ +# Reading Roadmap + +Colossal-AI provides a collection of parallel training components for you. We aim to support you with your development +of distributed deep learning models just like how you write single-GPU deep learning models. ColossalAI provides easy-to-use +APIs to help you kickstart your training process. To better how ColossalAI works, we recommend you to read this documentation +in the following order. + +- If you are not familiar with distributed system or have never used Colossal-AI, you should first jump into the `Concepts` +section to get a sense of what we are trying to achieve. This section can provide you with some background knowledge on +distributed training as well. +- Next, you can follow the `basics` tutorials. This section will cover the details about how to use Colossal-AI. +- Afterwards, you can try out the features provided in Colossal-AI by reading `features` section. We will provide a codebase for each tutorial. These tutorials will cover the +basic usage of Colossal-AI to realize simple functions such as data parallel and mixed precision training. +- Lastly, if you wish to apply more complicated techniques such as how to run hybrid parallel on GPT-3, the +`advanced tutorials` section is the place to go! + +**We always welcome suggestions and discussions from the community, and we would be more than willing to help you if you +encounter any issue. You can raise an [issue](https://github.com/hpcaitech/ColossalAI/issues) here or create a discussion +topic in the [forum](https://github.com/hpcaitech/ColossalAI/discussions).** diff --git a/docs/source/en/get_started/run_demo.md b/docs/source/en/get_started/run_demo.md new file mode 100644 index 0000000000000000000000000000000000000000..f47bdbbd62fc13d9168212558a678cfa2f5b8127 --- /dev/null +++ b/docs/source/en/get_started/run_demo.md @@ -0,0 +1,43 @@ +# Quick Demo + +Colossal-AI is an integrated large-scale deep learning system with efficient parallelization techniques. The system can +accelerate model training on distributed systems with multiple GPUs by applying parallelization techniques. The system +can also run on systems with only one GPU. Quick demos showing how to use Colossal-AI are given below. + +## Single GPU + +Colossal-AI can be used to train deep learning models on systems with only one GPU and achieve baseline +performances. We provided an example to [train ResNet on CIFAR10 dataset](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/image/resnet) +with only one GPU. You can find the example in [ColossalAI-Examples](https://github.com/hpcaitech/ColossalAI-Examples). +Detailed instructions can be found in its `README.md`. + +## Multiple GPUs + +Colossal-AI can be used to train deep learning models on distributed systems with multiple GPUs and accelerate the +training process drastically by applying efficient parallelization techniques. When we have several parallelism for you +to try out. + +#### 1. data parallel + +You can use the same [ResNet example](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/image/resnet) as the +single-GPU demo above. By setting `--nproc_per_node` to be the number of GPUs you have on your machine, the example +is turned into a data parallel example. + +#### 2. hybrid parallel + +Hybrid parallel includes data, tensor, and pipeline parallelism. In Colossal-AI, we support different types of tensor +parallelism (i.e. 1D, 2D, 2.5D and 3D). You can switch between different tensor parallelism by simply changing the configuration +in the `config.py`. You can follow the [GPT example](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/gpt). +Detailed instructions can be found in its `README.md`. + +#### 3. MoE parallel + +We provided [an example of WideNet](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/image/widenet) to demonstrate +MoE parallelism. WideNet uses mixture of experts (MoE) to achieve better performance. More details can be found in +[Tutorial: Integrate Mixture-of-Experts Into Your Model](../advanced_tutorials/integrate_mixture_of_experts_into_your_model.md) + +#### 4. sequence parallel + +Sequence parallel is designed to tackle memory efficiency and sequence length limit problems in NLP tasks. We provided +[an example of BERT](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/bert/sequene_parallel) in +[ColossalAI-Examples](https://github.com/hpcaitech/ColossalAI-Examples). You can follow the `README.md` to execute the code. diff --git a/docs/source/en/sidebar_category_translation.json b/docs/source/en/sidebar_category_translation.json new file mode 100644 index 0000000000000000000000000000000000000000..9cc320424e4071bf6ee280b137834fd042ce394e --- /dev/null +++ b/docs/source/en/sidebar_category_translation.json @@ -0,0 +1,26 @@ +{ + "sidebar.tutorialSidebar.category.Get started": { + "message": "Get started", + "description": "The label for category Get started in sidebar tutorialSidebar" + }, + "sidebar.tutorialSidebar.category.Concepts": { + "message": "Concepts", + "description": "The label for category Concepts in sidebar tutorialSidebar" + }, + "sidebar.tutorialSidebar.category.Basics": { + "message": "Basics", + "description": "The label for category Basics in sidebar tutorialSidebar" + }, + "sidebar.tutorialSidebar.category.Features": { + "message": "Features", + "description": "The label for category Features in sidebar tutorialSidebar" + }, + "sidebar.tutorialSidebar.category.Tensor Parallel": { + "message": "Tensor Parallel", + "description": "The label for category Tensor Parallel in sidebar tutorialSidebar" + }, + "sidebar.tutorialSidebar.category.Advanced Tutorials": { + "message": "Advanced Tutorials", + "description": "The label for category Advanced Tutorials in sidebar tutorialSidebar" + } +} diff --git a/docs/source/zh-Hans/Colossal-Auto/feature/auto_checkpoint.md b/docs/source/zh-Hans/Colossal-Auto/feature/auto_checkpoint.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/docs/source/zh-Hans/Colossal-Auto/feature/device_mesh.md b/docs/source/zh-Hans/Colossal-Auto/feature/device_mesh.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/docs/source/zh-Hans/Colossal-Auto/feature/layout_converting_management.md b/docs/source/zh-Hans/Colossal-Auto/feature/layout_converting_management.md new file mode 100644 index 0000000000000000000000000000000000000000..71bce57ea91b02fcfbf49ca4ed31c215bd7a6677 --- /dev/null +++ b/docs/source/zh-Hans/Colossal-Auto/feature/layout_converting_management.md @@ -0,0 +1,12 @@ +当一个张量在上下游算子中被要求的sharding spec不同时,我们需要进行分布转换处理(Layout Conversion)。目前主流的方式有两种,打表转换和逐维度转换。打表转换就是将所有可能的情况枚举出来,然后在遇到需要转换的情况下,去表格中找到对应的转换方案。 +为了解决这个问题,我们提出一个新奇的想法,使用启发式的搜索,来解决sharding spec的转换问题。 +然而它有一个很大问题,就是随着设备块(Device Mesh)的维度增加,这个问题的规模极具膨胀,以至于无法通过这种枚举打表的方式来解决。逐维度转换是对于一个N-d tensor的sharding spec,X0X1...Xn-1,我们让i从0到n-1逐维度地进行转换,这样不管设备块和张量的维度多少,我们都只需要一次扫描,就可以得到一个可行的转换操作序列,然而它问题是这样的转换效率会很差。为了解决这个问题,我们提出一个新奇的想法,使用启发式算法,来解决sharding spec的转换问题。,这个算法可以描述为: + 1. 从source spec生成所有的one-step transform sharding specs + 2. 在one-step transform sharding specs中,根据相似度函数,挑选一个”区别最小“的sharding spec作为后续的source sharding spec,并将该sharding spec记录在transform path中,如果one-step transform sharding spec中,有与target sharding spec相同的sharding spec,则算法结束。 + 3. 重复a,b直到算法结束 + +| Source/target sharding spec pairs |All gather | Shard | All to All | One step transform | Best sharding spec |Transform path| +| :-: | :-: | :-: | :-: | :-: | :-: |:-: | +| $S_{01}RR, RS_{01}R$ | $S_0RR$ | - | $S_0RS_1, S_0S_1R$ | $S_0RR, S_0RS_1, S_0S_1R$ | $S_0RR$ | $S_0RR$ +| $S_0RR, RS_{01}RR$ | $RRR$ | $S_0S_1R, S_0RS_1$ | $RS_0R, RRS_0$ | $RRR$, $S_0S_1R$, $S_0RS_1$, $RS_0R$, $RRS_0$ | $RS_0R$ | $S_0RR$ -> $RS_0R$ +| $RS_0R, RS_{01}RR$ | $RRR$ | $RS_{01}R, S_1S_0R, RS_0S_1$ | $S_0RR, RRS_0$ | $RRR$, $RS_{01}R$, $S_1S_0R$, $RS_0S_1$, $S_0RR$, $RRS_0$ | $RS_{01}R$ | $S_0RR$ -> $RS_0R$ -> $RS_{01}R$ diff --git a/docs/source/zh-Hans/Colossal-Auto/feature/tracer.md b/docs/source/zh-Hans/Colossal-Auto/feature/tracer.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/docs/source/zh-Hans/Colossal-Auto/get_started/installation.md b/docs/source/zh-Hans/Colossal-Auto/get_started/installation.md new file mode 100644 index 0000000000000000000000000000000000000000..054b709c92d02969c2aeb43ef1242cc911a489aa --- /dev/null +++ b/docs/source/zh-Hans/Colossal-Auto/get_started/installation.md @@ -0,0 +1,28 @@ +# 安装 + +## 声明 + +我们的自动并行功能处于alpha版本,仍在快速的开发迭代中。我们会在兼容性和稳定性上做持续地改进。如果您遇到任何问题,欢迎随时提issue给我们。 + + +## 要求 + +我们需要一些额外的依赖性来支持自动并行功能。 请在使用自动平行之前安装它们。 + +### 安装PyTorch + +我们仅支持Pytorch 1.12,现在未测试其他版本。 将来我们将支持更多版本。 + +```bash +#conda +conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.3 -c pytorch +#pip +pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113 +``` + +### 安装pulp和coin-or-cbc + +```bash +pip install pulp +conda install -c conda-forge coin-or-cbc +``` diff --git a/docs/source/zh-Hans/Colossal-Auto/get_started/introduction.md b/docs/source/zh-Hans/Colossal-Auto/get_started/introduction.md new file mode 100644 index 0000000000000000000000000000000000000000..bd5326d43220a53f1e569df962b233b5915d8ed4 --- /dev/null +++ b/docs/source/zh-Hans/Colossal-Auto/get_started/introduction.md @@ -0,0 +1,41 @@ +# 介绍 + +近年来,大规模机器学习模型的部署受到越来越多的重视。然而,目前常见的分布式大模型训练方案,都依赖用户**人工反复尝试**和系统专家的经验来进行配置部署。这对绝大多数AI开发者来说十分不友好,因为他们不希望将时间精力花费在研究分布式系统和试错上。 +Colossal-AI的**Colossal-Auto** 帮助AI开发者简化了大规模机器学习模型的部署过程。相比现有其他手动配置复杂并行策略和修改模型的解决方案,Colossal-Auto 仅需增加一行代码,提供 cluster 信息以及单机训练模型即可获得分布式训练能力,并且**原生支持包括 Hugging Face,Timm 等热门 AI 模型库**。 + + + +## 概览 + +
+ +
+ +## 用法 +```python +# wrap the model using auto_engine +model = autoparallelize(model, meta_input_samples) +# normal training loop +... +``` + + +## 图追踪 +Colossal-Auto 是**首个基于 PyTorch 框架使用静态图分析的自动并行系统**。PyTorch 作为一个动态图框架,获取其静态的执行计划是机器学习系统领域被长期研究的问题。Colossal-Auto 使用基于 torch.FX Tracer 的 ColoTracer 来完成对于最优并行策略的搜索。在 tracing 过程中推导并记录了每个 tensor 的元信息,例如 tensor shape,dims,dtype 等。因此 Colossal-AI 具有更好的模型泛化能力,而不是依靠模型名或手动修改来适配并行策略。 + + +## 细粒度分布式训练策略搜索 + +我们调研了很多现有的自动并行系统( Tofu , Flexflow , Alpa ),以及自动激活值检查点算法( Rotor , Sublinear ),在他们的启发下,我们开发一个基于PyTorch框架的自动并行系统Colossal-Auto。Colossal-Auto会在满足内存预算的限制下,以最快运行时间为目标,为每个 op 进行策略搜索,最终得到真实训练时的策略,包括每个 tensor 的切分策略,不同计算节点间需要插入的通信算子类型,是否要进行算子替换等。现有系统中的张量并行,数据并行,NVIDIA 在 Megatron-LM 等并行系统中使用的 column 切分和 row 切分并行等混合并行,都是自动并行可以搜索到的策略的子集。除了这些可以手动指定的并行方式外,Colossal-AI 有能力为每个 op 指定独特的并行方式,因此有可能找到比依赖专家经验和试错配置的手动切分更好的并行策略。 + + + +## 分布式 tensor 与 shape consistency 系统 + +与 PyTorch 最新发布的 DTensor 类似,Colossal-AI 也使用了 device mesh 对集群进行了抽象管理。具体来说,Colossal-AI 使用 sharding spec 对 tensor 的分布式存储状态进行标注,使用 shape consistency manager 自动地对同一 tensor 在不同 sharding spec 间进行转换。这让 Colossal-AI 的通用性和易用性极大地提升,借助 shape consistency manager 可以没有负担地切分 tensor,而不用担心上游 op 的 output 与下游的 input 在集群中的存储方式不同。 + + +相较于 PyTorch DTensor,Colossal-AI 有以下优势: ++ Colossal-AI 的 device mesh 可以 profiling 到集群性能指标,对不同的通信算子进行耗时估算。 ++ Colossal-AI 的 shape consistency 会贪心地搜索 sharding spec 间的转换方式,而不是朴素地逐 dimension 进行转换,这样能找到更高效的转换路径,进而使得 sharding spec 间的转换通信开销更小。 ++ 加入了 all_to_all 操作,使得 Colossal-AI 的扩展性更强,这在大规模集群上进行训练时,可以展现出很大的优势。 diff --git a/docs/source/zh-Hans/Colossal-Auto/get_started/run_demo.md b/docs/source/zh-Hans/Colossal-Auto/get_started/run_demo.md new file mode 100644 index 0000000000000000000000000000000000000000..19316e12b4d5e7015a5fe04cf04bd22fef52f041 --- /dev/null +++ b/docs/source/zh-Hans/Colossal-Auto/get_started/run_demo.md @@ -0,0 +1,12 @@ +# 快速上手 + +Colossal-AI 提供了业界急需的一套高效易用自动并行系统。相比现有其他手动配置复杂并行策略和修改模型的解决方案,Colossal-AI 仅需增加一行代码,提供 cluster 信息以及单机训练模型即可获得分布式训练能力。Colossal-Auto的快速上手示例如下。 + +### 1. 基本用法 +Colossal-Auto 可被用于为每一次操作寻找一个包含数据、张量(如1D、2D、序列化)的混合SPMD并行策略。您可参考[GPT 示例](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/gpt/experiments/auto_parallel)。 +详细的操作指引见其 `README.md`。 + +### 2. 与 activation checkpoint 结合 + +作为大模型训练中必不可少的显存压缩技术,Colossal-AI 也提供了对于 activation checkpoint 的自动搜索功能。相比于大部分将最大显存压缩作为目标的技术方案,Colossal-AI 的搜索目标是在显存预算以内,找到最快的 activation checkpoint 方案。同时,为了避免将 activation checkpoint 的搜索一起建模到 SPMD solver 中导致搜索时间爆炸,Colossal-AI 做了 2-stage search 的设计,因此可以在合理的时间内搜索到有效可行的分布式训练方案。 您可参考 [Resnet 示例](https://github.com/hpcaitech/ColossalAI/tree/main/examples/tutorial/auto_parallel)。 +详细的操作指引见其 `README.md`。 diff --git a/docs/source/zh-Hans/advanced_tutorials/add_your_parallel.md b/docs/source/zh-Hans/advanced_tutorials/add_your_parallel.md new file mode 100644 index 0000000000000000000000000000000000000000..4825a6fa1d6c8b9c7c74fe8f3d769057bad60980 --- /dev/null +++ b/docs/source/zh-Hans/advanced_tutorials/add_your_parallel.md @@ -0,0 +1,112 @@ +# 添加你自己的并行模式 + +作者: Shenggui Li, Yongbin Li + +**前置教程** +- [定义配置文件](../basics/define_your_config.md) +- [并行配置](../basics/configure_parallelization.md) + +## 引言 + +为了使研究人员和工程师能够以更少的努力将我们的系统扩展到其他新颖的大规模分布式训练算法,我们已经将训练生命周期中的各种组件解耦。你可以通过简单地继承基类来实现你自己的并行模式。 + +主要组件有: + +1. `ProcessGroupInitializer` +2. `GradientHandler` +3. `Schedule` + +**目前这需要对源代码进行一些改动,因此我们建议你用`-e`标志从源代码安装。`-e`标志使得安装是可编辑的,因此,你的代码变化将反映在你的Python运行时中。我们将在这方面努力,以避免在未来的版本中改变源代码。** + + +## 进程组初始化器 + +并行通常由进程组来管理,参与相同并行算法的进程被置于同一进程组。对于不同的并行算法,需要创建不同的进程组。 +Colossal-AI 为用户提供了一个全局 context,使他们能够轻松地管理进程组。如果你想添加新的进程组,你可以很容易地定义一个新的类并在你的配置文件中设置它。为了定义你自己的进程组创建方式,你可以按照下面的步骤来创建一个新的分布式初始化。 + +1. 在 `colossalai.context.parallel_mode.ParallelMode` 中添加你自己的并行模式。 + ```python + class ParallelMode(Enum): + GLOBAL = 'global' + DATA = 'data' + PIPELINE = 'pipe' + ... + + NEW_MODE = 'new_mode' # define your mode here + ``` + +2. 创建一个 `ProcessGroupInitializer`。 你可以参考 `colossalai.context.dist_group_initializer` 中给出的例子,前六个参数是固定的。 +`ParallelContext` 将为你传入这些参数。如果你需要设置其他参数,可以像下面的例子中的 `arg1, arg2` 一样,在后面添加它。 +最后,通过添加装饰器 `@DIST_GROUP_INITIALIZER.register_module` 将你的初始化程序注册到注册表。 + ```python + # sample initializer class + @DIST_GROUP_INITIALIZER.register_module + class MyParallelInitializer(ProcessGroupInitializer): + + def __init__(self, + rank: int, + world_size: int, + config: Config, + data_parallel_size: int, + pipeline_parlalel_size: int, + tensor_parallel_size: int, + arg1, + arg2): + super().__init__(rank, world_size, config) + self.arg1 = arg1 + self.arg2 = arg2 + # ... your variable init + + def init_parallel_groups(self): + # initialize your process groups + pass + + ``` + 然后,你可以将你的新初始化器插入到 `colossalai.constants.INITIALIZER_MAPPING` 当前的模式与初始化映射中。你可以修改该文件或动态插入新的键值对。 + + ```python + colossalai.constants.INITIALIZER_MAPPING['new_mode'] = 'MyParallelInitializer' + ``` + +3. 在你的配置文件中设置你的初始化器。你可以传入你的自定义参数。这允许 + `ParallelContext` 创建你的初始化器并初始化你期望的进程组。 + + ```python + parallel = dict( + pipeline=dict(size=1), + tensor=dict(size=x, mode='new_mode') # this is where you enable your new parallel mode + ) + ``` + +## 梯度 Handler + +梯度 handler 是对参数的梯度执行 all-reduce 操作的对象。由于不同的 all-reduce 策略或许在不同的并行中被执行,用户可以继承 +`colossalai.engine.gradient_handler.BaseGradientHandler` 来实现其策略。目前,Colossal-AI 使用普通的数据并行梯度 handler 在数据并行的 rank 间 all-reduce 梯度。 +如果数据并行被检测到,梯度 handler 会被自动添加进 engine。 + +你可以添加你自己的梯度 handler,如下所示: + +```python +from colossalai.registry import GRADIENT_HANDLER +from colossalai.engine import BaseGradientHandler + +@GRADIENT_HANDLER.register_module +class YourGradientHandler(BaseGradientHandler): + + def handle_gradient(self): + do_something() + +``` + +之后,你可以在配置文件中指定你要使用的梯度 handler。 + +```python +gradient_handlers = [ + dict(type='YourGradientHandler'), +] +``` + +## Schedule + +Schedule 包含了如何执行前向和后向计算。目前, Colossal-AI 提供了流水和非流水的 schedule。 +如果你想修改前向和后向计算的执行方式,你可以继承 `colossalai.engine.schedule.BaseSchedule` 并实现 `forward_back_step` 函数。 diff --git a/docs/source/zh-Hans/advanced_tutorials/define_your_own_parallel_model.md b/docs/source/zh-Hans/advanced_tutorials/define_your_own_parallel_model.md new file mode 100644 index 0000000000000000000000000000000000000000..64e8d8bcd14a19c841ccb146257b2db018249f64 --- /dev/null +++ b/docs/source/zh-Hans/advanced_tutorials/define_your_own_parallel_model.md @@ -0,0 +1,31 @@ +# 定义你自己的并行模型 + +作者: Zhengda Bian, Yongbin Li + +> ⚠️ 我们正在编写此文档以使其更加详细。 我们将介绍不同并行的机制以及如何使用它们来编写模型。 + +假设您有一个具有数十亿参数的巨大 MLP 模型,其极大的隐藏层大小使其无法直接被单个 GPU 容纳。别担心,Colossal-AI 可以帮你解决这个问题。 +在 Colossal-AI 的帮助下,您可以用所熟悉的为单个 GPU 编写模型的方式编写大模型,而 Colossal-AI 会自动拆分您的模型权重,并将它们完美地分配到一组 GPU 中。我们给出一个简单的示例,展示如何在 Colossal-AI 中编写简单的 2D 并行模型。 + +## 写一个简单的2D并行模型 + +```python +from colossalai.nn import Linear2D +import torch.nn as nn + +class MLP_2D(nn.Module): + + def __init__(self): + super().__init__() + self.linear_1 = Linear2D(in_features=1024, out_features=16384) + self.linear_2 = Linear2D(in_features=16384, out_features=1024) + + def forward(self, x): + x = self.linear_1(x) + x = self.linear_2(x) + return x +``` + +## 使用预定义的模型 + +为了方便您的使用,我们在 Colossal-AI 的 Model Zoo 中提供一些流行的模型,如*BERT*, *ViT*, *MoE* 和 *GPT*,请自由地将它们定制为不同的尺寸,以满足您的特殊需求。 diff --git a/docs/source/zh-Hans/advanced_tutorials/integrate_mixture_of_experts_into_your_model.md b/docs/source/zh-Hans/advanced_tutorials/integrate_mixture_of_experts_into_your_model.md new file mode 100644 index 0000000000000000000000000000000000000000..456878caa14715ac07549b4d61736d4a84e64a2c --- /dev/null +++ b/docs/source/zh-Hans/advanced_tutorials/integrate_mixture_of_experts_into_your_model.md @@ -0,0 +1,140 @@ +# 将 MoE 整合进你的模型 + +作者: Haichen Huang, Yongbin Li + +**前置教程** +- [ColossalAI-Examples WideNet](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/image/widenet) + +**相关论文** +- [Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity](https://arxiv.org/abs/2101.03961) +- [Go Wider Instead of Deeper](https://arxiv.org/abs/2107.11817) + +(中文版教程将会在近期提供) + +## Introduction + +Since the advent of Switch Transformer, the AI community has found Mixture of Experts (MoE) a useful technique to enlarge the capacity of deep learning models. + +Colossal-AI provides an early access version of parallelism specifically designed for MoE models. +The most prominent advantage of MoE in Colossal-AI is convenience. +We aim to help our users to easily combine MoE with model parallelism and data parallelism. + +However, the current implementation has two main drawbacks now. +The first drawback is its poor efficiency in large batch size and long sequence length training. +The second drawback is incompatibility with tensor parallelism. +We are working on system optimization to overcome the training efficiency problem. +The compatibility problem with tensor parallelism requires more adaptation, and we will tackle this issue in the future. + +Here, we will introduce how to use MoE with model parallelism and data parallelism. + +## Table of Content +In this tutorial we will cover: +1. Set up MoE running environment +2. Create MoE layer +3. Train your model + +We provided the [example code](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/image/widenet) for this tutorial in [ColossalAI-Examples](https://github.com/hpcaitech/ColossalAI-Examples). +This example uses [WideNet](https://arxiv.org/abs/2107.11817) as an example of MoE-based model. + + +## Set up MoE running environment +In your project folder, create a `config.py`. + +This file is to specify some features you may want to use to train your model. +In order to enable MoE, you need to add a dict called parallel and specify the value of key moe. +You can assign a value for the key size of moe, which represents the model parallel size of experts (i.e. the number of experts in one group to parallelize training). + +For example, if the size is 4, 4 processes will be assigned to 4 consecutive GPUs and these 4 processes form a moe model parallel group. +Each process on the 4 GPUs will only get a portion of experts. Increasing the model parallel size will reduce communication cost, but increase computation cost in each GPU and activation cost in memory. +The total data parallel size is auto-detected and set as the number of GPUs by default. + +```python +MOE_MODEL_PARALLEL_SIZE = ... +parallel = dict( + moe=dict(size=MOE_MODEL_PARALLEL_SIZE) +) +``` + +If `MOE_MODEL_PARALLEL_SIZE = E` and set the number of experts as `E` where `E` is a constant number, the process flow of forward pass of a transformer encoder in a model parallel group is shown below. + +
+ +
MoE Transformer, image source: GShard
+
+ +Since all experts are allocated to all GPUs in a model parallel group and a GPU only owns a portion of experts, +original data parallel groups are no longer correct for the parameters of experts during gradient handling in backward pass anymore. +So we create a new kind of parallel group called moe data parallel group. +The difference among different kinds of parallel group, when the configuration is set as `WORLD_SIZE=4`, +`MOE_MODEL_PARALLEL_SIZE=2`, is shown here. + +
+ +
MoE process group
+
+ + +As for gradient handling, we provide MoeGradientHandler to all-reduce every parameter of the model. +If you use `colossalai.initialize` function to create your training engine, the MoE gradient handler will be added to your engine automatically. +Otherwise, you should take care of gradient by yourself. +All parameters of MoE running environment are stored in colossalai.global_variables.moe_env. +You can access your configuration parameters to check whether your setup is correct. +```python +from colossalai.global_variables import moe_env +``` + +## Create MoE layer +You can create a MoE layer from `colossalai.nn.moe`. +But before doing that, you should set up random seeds for all processes like this. + +```python +from colossalai.context.random import moe_set_seed +from model_zoo.moe.models import Widenet + +moe_set_seed(42) +model = Widenet(num_experts=4, capacity_factor=1.2) +``` + +`moe_set_seed` will set different seed for different processes in a moe model parallel group. +This helps initialize parameters in experts. +Then create an instance of experts and an instance of router. +Here is the example in model zoo. + +```python +from colossalai.nn.layer.moe import Experts, MoeLayer, Top2Router, NormalNoiseGenerator + + +noisy_func = NormalNoiseGenerator(num_experts) +shared_router = Top2Router(capacity_factor, + noisy_func=noisy_func) +shared_experts = Experts(expert=VanillaFFN, + num_experts=num_experts, + **moe_mlp_args( + d_model=d_model, + d_ff=d_ff, + drop_rate=drop_rate + )) +ffn=MoeLayer(dim_model=d_model, num_experts=num_experts, + router=shared_router, experts=shared_experts) +``` + +Inside the initialization of Experts, the local expert number of each GPU will be calculated automatically. You just need to specify the class of each expert and its parameters used in its initialization. As for routers, we have provided top1 router and top2 router. You can find them in colossalai.nn.layer.moe. After creating the instance of experts and router, the only thing initialized in Moelayer is gate module. More definitions of each class can be found in our API document and code. + + +## Train Your Model +Do not to forget to use `colossalai.initialize` function in `colosalai` to add gradient handler for the engine. +We handle the back-propagation of MoE models for you. +In `colossalai.initialize`, we will automatically create a `MoeGradientHandler` object to process gradients. +You can find more information about the handler `MoeGradientHandler` in colossal directory. + +The loss criterion should be wrapped by `Moeloss` to add auxiliary loss of MoE. Example is like this. +```python +criterion = MoeLoss( + aux_weight=0.01, + loss_fn=nn.CrossEntropyLoss, + label_smoothing=0.1 +) +``` + +Finally, just use trainer or engine in `colossalai` to do your training. +Otherwise, you should take care of gradient by yourself. diff --git a/docs/source/zh-Hans/advanced_tutorials/meet_gemini.md b/docs/source/zh-Hans/advanced_tutorials/meet_gemini.md new file mode 100644 index 0000000000000000000000000000000000000000..2bf0a9c98c3f9dc92423f1320dd3da02a61a6fc5 --- /dev/null +++ b/docs/source/zh-Hans/advanced_tutorials/meet_gemini.md @@ -0,0 +1,96 @@ +# 认识Gemini:ColossalAI的异构内存空间管理器 + +作者: [Jiarui Fang](https://github.com/feifeibear) + +## 简介 + +在GPU数量不足情况下,想要增加模型规模,异构训练是最有效的手段。它通过在 CPU 和 GPU 中容纳模型数据,并仅在必要时将数据移动到当前设备,可以同时利用 GPU 内存、CPU 内存(由 CPU DRAM 或 NVMe SSD内存组成)来突破单GPU内存墙的限制。并行,在大规模训练下,其他方案如数据并行、模型并行、流水线并行都可以在异构训练基础上进一步扩展GPU规模。这篇文章描述ColossalAI的异构内存空间管理模块Gemini的设计细节,它的思想来源于[PatrickStar](https://arxiv.org/abs/2108.05818),ColossalAI根据自身情况进行了重新实现。 + +## 用法 + +目前Gemini支持和ZeRO并行方式兼容,它的使用方法很简单,在训练策略的配置文件里设置zero的model_config属性tensor_placement_policy='auto' + +``` +zero = dict( + model_config=dict( + reduce_scatter_bucket_size_mb=25, + fp32_reduce_scatter=False, + gradient_predivide_factor=1.0, + tensor_placement_policy="auto", + shard_strategy=TensorShardStrategy(), + ... + ), + optimizer_config=dict( + ... + ) +) +``` + +注意,Gemini和并行策略,如Tensor Parallelism,Data Parallelism,Pipeline Parallelism,ZeRO是解耦合的。对TP,PP的支持还在开发中。 + +## 术语 + +**算子**(**OP**erator):一个神经网络层的计算操作,比如Linear,LayerNorm等。算子可以是正向传播的计算,也可以是反向传播的计算。 + +神经网络在训练期间必须管理的两种类型的训练数据。 + +**模型数据(model data)**: 由参数、梯度和优化器状态组成,其规模与模型结构定义相关 + +**非模型数据(non-model data)**: 主要由算子生成的中间张量和算子的临时变量组成。非模型数据根据训练任务的配置动态变化,例如批量大小。模型数据和非模型数据相互竞争 GPU 内存。 + +## 设计 + +目前的一些解决方案,DeepSpeed采用的[Zero-offload](https://arxiv.org/abs/2101.06840)在CPU和GPU内存之间静态划分模型数据,并且它们的内存布局对于不同的训练配置是恒定的。如下图左边所示,当 GPU 内存不足以满足其相应的模型数据要求时,即使当时CPU上仍有可用内存,系统也会崩溃。而ColossalAI可以通过将一部分模型数据换出到CPU上来完成训练。 + +
+ +
比较Zero-Offload和Gemini的内存管理方案
+
+ + +ColossalAI设计了Gemini,就像双子星一样,它管理CPU和GPU二者内存空间。它可以让张量在训练过程中动态分布在CPU-GPU的存储空间内,从而让模型训练突破GPU的内存墙。内存管理器由两部分组成,分别是MemStatsCollector(MSC)和StatefuleTensorMgr(STM)。 + + +我们利用了深度学习网络训练过程的迭代特性。我们将迭代分为warmup和non-warmup两个阶段,开始时的一个或若干迭代步属于预热阶段,其余的迭代步属于正式阶段。在warmup阶段我们为MSC收集信息,而在non-warmup阶段STM入去MSC收集的信息来移动tensor,以达到最小化CPU-GPU数据移动volume的目的。 + +
+ +
Gemini在不同训练阶段的运行流程
+
+ + +### StatefulTensorMgr + +STM管理所有model data tensor的信息。在模型的构造过程中,ColossalAI把所有model data张量注册给STM。内存管理器给每个张量标记一个状态信息。状态集合包括HOLD,COMPUTE,FREE三种状态。STM的功能如下: + +**查询内存使用:**通过遍历所有tensor的在异构空间的位置,获取模型数据对CPU和GPU的内存占用。 + +**转换张量状态:**它在每个模型数据张量参与算子计算之前,将张量标记为COMPUTE状态,在计算之后标记为HOLD状态。如果张量不再使用则标记的FREE状态。 + +**调整张量位置:**张量管理器保证COMPUTE状态的张量被放置在计算设备上,如果计算设备的存储空间不足,则需要移动出一些HOLD状态的张量到其他设备上存储。Tensor eviction strategy需要MSC的信息,我们将在后面介绍。 + + +### MemStatsCollector +在预热阶段,内存信息统计器监测CPU和GPU中模型数据和非模型数据的内存使用情况,供正式训练阶段参考。我们通过查询STM可以获得模型数据在某个时刻的内存使用。但是非模型的内存使用却难以获取。因为非模型数据的生存周期并不归用户管理,现有的深度学习框架没有暴露非模型数据的追踪接口给用户。MSC通过采样方式在预热阶段获得非模型对CPU和GPU内存的使用情况。具体方法如下: + +我们在算子的开始和结束计算时,触发内存采样操作,我们称这个时间点为**采样时刻(sampling moment)**,两个采样时刻之间的时间我们称为**period**。计算过程是一个黑盒,由于可能分配临时buffer,内存使用情况很复杂。但是,我们可以较准确的获取period的系统最大内存使用。非模型数据的使用可以通过两个统计时刻之间系统最大内存使用-模型内存使用获得。 + +我们如何设计采样时刻呢。我们选择preOp的model data layout adjust之前。如下图所示。我们采样获得上一个period的system memory used,和下一个period的model data memoy used。并行策略会给MSC的工作造成障碍。如图所示,比如对于ZeRO或者Tensor Parallel,由于Op计算前需要gather模型数据,会带来额外的内存需求。因此,我们要求在模型数据变化前进行采样系统内存,这样在一个period内,MSC会把preOp的模型变化内存捕捉。比如在period 2-3内,我们考虑的tensor gather和shard带来的内存变化。 +尽管可以将采样时刻放在其他位置,比如排除gather buffer的变动新信息,但是会给造成麻烦。不同并行方式Op的实现有差异,比如对于Linear Op,Tensor Parallel中gather buffer的分配在Op中。而对于ZeRO,gather buffer的分配是在PreOp中。将放在PreOp开始时采样有利于将两种情况统一。 + + +尽管可以将采样时刻放在其他位置,比如排除gather buffer的变动新信息,但是会给造成麻烦。不同并行方式Op的实现有差异,比如对于Linear Op,Tensor Parallel中gather buffer的分配在Op中。而对于ZeRO,gather buffer的分配是在PreOp中。将放在PreOp开始时采样有利于将两种情况统一。 + +
+ +
Sampling based MemStatsCollector
+
+ +### Tensor Eviction Strategy + +MSC的重要职责是在调整tensor layout位置,比如在上图S2时刻,我们减少设备上model data数据,Period 2-3计算的峰值内存得到满足。 + +在warmup阶段,由于还没执行完毕一个完整的迭代,我们对内存的真实使用情况尚一无所知。我们此时限制模型数据的内存使用上限,比如只使用30%的GPU内存。这样保证我们可以顺利完成预热状态。 + +在non-warmup阶段,我们需要利用预热阶段采集的非模型数据内存信息,预留出下一个Period在计算设备上需要的峰值内存,这需要我们移动出一些模型张量。 +为了避免频繁在CPU-GPU换入换出相同的tensor,引起类似[cache thrashing](https://en.wikipedia.org/wiki/Thrashing_(computer_science))的现象。我们利用DNN训练迭代特性,设计了OPT cache换出策略。具体来说,在warmup阶段,我们记录每个tensor被计算设备需要的采样时刻。如果我们需要驱逐一些HOLD tensor,那么我们选择在本设备上最晚被需要的tensor作为受害者。 diff --git a/docs/source/zh-Hans/advanced_tutorials/opt_service.md b/docs/source/zh-Hans/advanced_tutorials/opt_service.md new file mode 100644 index 0000000000000000000000000000000000000000..a213584fd41d52b6492e309b7a4cef9bd500065c --- /dev/null +++ b/docs/source/zh-Hans/advanced_tutorials/opt_service.md @@ -0,0 +1,79 @@ +# Colossal-AI使用指南:5分钟搭建在线OPT服务 + +## 介绍 + +本指导手册将说明如何利用[Colossal-AI](https://github.com/hpcaitech/ColossalAI)搭建您自己的OPT服务。 + +## Colossal-AI 推理概述 +Colossal-AI 提供了一个推理子系统 [Energon-AI](https://github.com/hpcaitech/EnergonAI), 这是一个基于Colossal-AI的服务系统,拥有以下特性: + +- **大模型并行:** 在Colossal-AI的张量并行和流水线并行策略的帮助下,Colossal-AI的推理可实现大模型的高效并行推理。 +- **预构建大模型:** Colossal-AI提供热门模型的预构建部署,例如OPT。其支持用于生成任务和加载检查点的缓存技术。 +- **引擎封装:** Colossal-AI中有一个抽象层被称作引擎。其将单实例多设备(SIMD) 执行与远程过程调用封装在一起。 +- **在线服务系统:** 基于FastAPI,用户可以快速启动分布式推理的网络服务。 在线服务对生成任务进行了特殊优化。它采用left padding和bucket batching两种技术来提高效率。 + +## 基本用法 + +1. 下载OPT模型 + +想要快速发布分布式推理服务,您从[此处](https://huggingface.co/patrickvonplaten/opt_metaseq_125m/blob/main/model/restored.pt)下载OPT-125M。有关加载其他体量模型的详细方法,您可访问[此处](https://github.com/hpcaitech/EnergonAI/tree/main/examples/opt/script)。 + +2. 准备提前构建的服务镜像 + +从dockerhub拉取一个已经安装Colossal-AI推理的docker镜像。 + +```bash +docker pull hpcaitech/energon-ai:latest +``` + +3. 发布HTTP服务 + +若想发布服务,我们需要准备python脚本来描述模型的类型和相关的部署,以及HTTP服务的设置。 我们为您提供了一组[示例](https://github.com/hpcaitech/EnergonAI/tree/main/examples])。 我们将在本指导手册中使用[OPT 示例](https://github.com/hpcaitech/EnergonAI/tree/main/examples/opt)。 +服务的入口是一个bash脚本 server.sh。 +本服务的配置文件参考 opt_config.py,该文件定义了模型的类型、 检查点文件路径、并行策略和http设置。您能按照您的需求来修改这些设置。 +例如,将模型的大小设置为opt_125M,将正确的检查点路径按照如下设置: + +```bash +model_class = opt_125M +checkpoint = 'your_file_path' +``` + +将张量并行度设置为您的gpu数量。 + +```bash +tp_init_size = #gpu +``` + +现在,我们就能利用docker发布一个服务。您能在`/model_checkpoint` 和 `/config`路径下找到检查点文件和配置文件。 + + +```bash +export CHECKPOINT_DIR="your_opt_checkpoint_path" +# the ${CONFIG_DIR} must contain a server.sh file as the entry of service +export CONFIG_DIR="config_file_path" + +docker run --gpus all --rm -it -p 8020:8020 -v ${CHECKPOINT_DIR}:/model_checkpoint -v ${CONFIG_DIR}:/config --ipc=host energonai:lastest +``` + +接下来,您就可以在您的浏览器中打开 `https://[IP-ADDRESS]:8020/docs#` 进行测试。 + +## 高级特性用法 + +1. 批处理优化 + +若想使用我们的高级批处理技术来批量收集多个查询,您可以将executor_max_batch_size设置为最大批处理大小。 请注意,只有具有相同 top_k、top_p 和温度的解码任务才能一起批处理。 + +``` +executor_max_batch_size = 16 +``` + +所有的查询将进入FIFO队列。解码步数小于或等于队列头部解码步数的所有连续查询可以一起批处理。 应用左填充以确保正确性。 executor_max_batch_size 不应该过大,从而确保批处理不会增加延迟。 以opt-30b为例, `executor_max_batch_size=16` 合适,但对于opt-175b而言, `executor_max_batch_size=4` 更合适。 + +2. 缓存优化 + +对于每一个独立的服务过程,您能将最近的多个查询结果缓存在一起。在config.py中设置 cache_size 和 cache_list_size。缓存的大小应为缓存的查询数目。cache_list_size 应为每次查询存储的结果数。一个随机缓存的结果将会被返回。当缓存已满,LRU策略被用于清理缓存过的查询。cache_size=0意味着不缓存。 + +``` +cache_size = 50 +cache_list_size = 2 +``` diff --git a/docs/source/zh-Hans/advanced_tutorials/parallelize_your_training_like_Megatron.md b/docs/source/zh-Hans/advanced_tutorials/parallelize_your_training_like_Megatron.md new file mode 100644 index 0000000000000000000000000000000000000000..f3c6247c38e452f98ba198195d8cccde390922f4 --- /dev/null +++ b/docs/source/zh-Hans/advanced_tutorials/parallelize_your_training_like_Megatron.md @@ -0,0 +1,176 @@ +# 使用ColoTensor让串行程序像Megatron-LM一样并行 + +Author: [Haichen Huang](https://github.com/1SAA) and [Jiarui Fang](https://github.com/feifeibear) + +**Prerequisite:** +- [ColoTensor Concepts](../basics/colotensor_concept.md) + +## 介绍 + +在新版本中,我们引入了ColoTensor。ColoTensor为用户使用并行训练提供了极大的便利,使得用户可以在原本的串行代码上,通过较小的修改将训练改为并行。在本教程中,我们将说明如何修改训练模型以自动使代码采取像 Megatron-LM 一样的方式并行训练。我们以 HuggingFace 提供的 GPT-2 模型为例,并提供一种方式让你可以在单个GPU上预训练GPT-2模型。 + +Megatron-LM 提供了一个具有影响力的并行化范式,这个范式主要应用于Transformer大模型的训练。然而,为了大规模训练 Transformer 语言大模型,用户必须使用Megatron-LM提供的特殊模块来构建他们的模型。这给用户带来了一些困难的工作,例如从预先训练的模型中加载权重,或是构建自己的并行训练模型。为了减轻用户的麻烦,我们提供 ColoTensor 类,以完成自动启用张量模型并行。 + +## 定义模型和损失函数 + +首先,我们直接调用 HuggingFace 库中的 GPTModel 和 GPTLoss。 + +```python +import torch +import torch.nn as nn +from transformers import GPT2Config, GPT2LMHeadModel + +class GPTLMModel(nn.Module): + def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=12, max_seq_len=1024, vocab_size=50257, checkpoint=False): + super().__init__() + self.checkpoint = checkpoint + self.model = GPT2LMHeadModel(GPT2Config(n_embd=hidden_size, n_layer=num_layers, + n_head=num_attention_heads, n_positions=max_seq_len, n_ctx=max_seq_len, vocab_size=vocab_size)) + if checkpoint: + self.model.gradient_checkpointing_enable() + + def forward(self, input_ids, attention_mask): + # Only return lm_logits + return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0] + + +class GPTLMLoss(nn.Module): + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, logits, labels): + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) +``` + +## 对GPT-2的简短回顾 + +现在,我们回顾一下 GPT-2 模型的结构。每个 GPT-2 模型都可以表示为一个 DAG。如下图所示,每个圆圈代表一个算子,每个方块代表一个权重。每个箭头表示输入数据的流向,而箭头旁边的符号表示输入数据的形状。 + +然后,让我们深入了解一下这个 GPT-2 模型。它由三部分组成,分别是**嵌入模块**、**转换器层**和**分类头**。 + +嵌入模块包含两个权重,符号嵌入权重和位置嵌入权重。在嵌入模块的前向操作之后,原始输入数据的所有序列中的每个单词都会被嵌入到隐藏状态。 + +
+ +
嵌入模块
+
+ +每个转换器层包含两个块。自注意操作在第一个块中调用,同时一个双层感知器位于第二个块中。 + +
+ +
转换器层
+
+ +最后,分类头只是一个不加偏差的线性模块,里面只有一个线性权重。 + +## 应用ColoTensor + +两个步骤使您的串行代码采取 Megatron-LM 张量并行风格。 +1. 在ColoInitContext的上下文中初始化模型。 +2. 为每个参数设置 ColoTensorSpec。 + +### 使用 ColoInitContext 初始化 + +我们应该在 ColoInitContext 中构建模型。在该种上下文中,任何初始化的参数都将转换为 ColoParameter 并自动移动到相应的设备上。 + +```python +from colossalai.utils.model.colo_init_context import ColoInitContext + +with ColoInitContext(device=torch.device('cpu')): + model = GPTLMModel() +``` + +### 为每个参数设置 ColoTensorSpec + +模型创建完成后,我们通过ProcessGroup建立分布式环境。这里,我们将张量并行度指定为所有GPU的数量,即数据并行度为一。 + +```python +import torch.distributed as dist +from colossalai.tensor import ProcessGroup + +pg = ProcessGroup(tp_degree=dist.get_world_size()) +``` + +现在,我们需要一些辅助函数为下一步做准备。我们定义了两个函数来切分参数。Megatron-LM张量并行需要沿参数的第一维或最后一维切分参数张量。 + +```python +from colossalai.tensor import ShardSpec, ComputeSpec, ComputePattern, ColoParameter, ProcessGroup + +def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup): + spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + if param.process_group.tp_world_size() == 1: + param.set_process_group(pg) + param.set_tensor_spec(*spec) + + +def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup): + split_param_single_dim_tp1d(0, param, pg) + + +def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup): + split_param_single_dim_tp1d(-1, param, pg) +``` + +然后我们使模型采用张量并行。根据 Megatron 中使用的张量并行,应该沿着张量的最后一个维度进行切片,包括符号嵌入的权重,位置嵌入的权重,自注意力块中的所有线性权重和偏差,以及每个双层感知器中的第一个线性权重和偏差。且需要沿第一个维度切分双层感知器中的第二个线性权重。 + +```python +for mn, module in model.named_modules(): + for pn, param in module.named_parameters(recurse=False): + # set process group for all parameters + param.set_process_group(pg) + + if 'mlp.c_fc' in mn: + if 'weight' in pn or 'bias' in pn: + split_param_col_tp1d(param, pg) # colmn slice + # keep the shape of the output from c_fc + param.compute_spec.set_output_replicate(False) + elif 'mlp.c_proj' in mn: + if 'weight' in pn: + split_param_row_tp1d(param, pg) # row slice + elif 'wte' in mn or 'wpe' in mn: + split_param_col_tp1d(param, pg) # colmn slice + elif 'c_attn' in mn or 'c_proj' in mn: + split_param_col_tp1d(param, pg) # colmn slice +``` + +修改后的模型如下图所示。 + +嵌入模块: + +
+ +
修改后的嵌入模块
+
+ +转换器层: + +
+ +
修改后的转换器层
+
+ +一旦用户指定了每个参数的在并行中的分布模式,ColoTensor 就能够推断出所有算子的计算模式,包括矩阵乘法、线性函数、torch.nn.functional 中的其他逐元素函数,以及其他的一些常用函数。这样,用户可以像往常一样训练他们的模型。 + +在我们最新示例中还定义了一个Gemini + ZeRO DDP 的模型从而减小开销,提升效率。这一部分的详细内容可以参考[ZeRO](../features/zero_with_chunk.md),你可以将这两部分内容结合起来看从而理解我们整个训练流程: + +```python +def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"): + from colossalai.nn.parallel import GeminiDDP + model = GeminiDDP(model, + device=get_current_device(), + placement_policy=placememt_policy, + pin_memory=True, + search_range_mb=32) + return model +``` + +## 在单个GPU上预训练GPT-2 + +我们做的上述优化让我们可以在单GPU上训练GPT-2模型,只需要将`run.sh`中设置参数`GPUNUM`=1,再运行文件时就可以在单个GPU上完成模型的训练。 + +GPT-2 示例在[Train GPT with Colossal-AI](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/gpt). 获得。 diff --git a/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md b/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md new file mode 100644 index 0000000000000000000000000000000000000000..6c6dcf6e850db886cec080e131065af306187c13 --- /dev/null +++ b/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md @@ -0,0 +1,275 @@ +# 使用混合并行训练 GPT + +作者: Hongxin Liu, Yongbin Li + +**示例代码** +- [ColossalAI-Examples GPT2](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/gpt_2) +- [ColossalAI-Examples GPT3](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/gpt_3) + +**相关论文** +- [Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training](https://arxiv.org/abs/2110.14883) +- [Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM](https://arxiv.org/abs/2104.04473) + +## 引言 + +在上一篇教程中,我们介绍了如何用流水并行训练 ViT。在本教程中,你将学习一个更复杂的场景--用混合并行方式训练GPT。在这种情况下,由于GPT-3过大,即使CPU内存也无法容纳它。因此,你必须自己分割模型。 + +## 目录 + +在本教程中,我们将介绍: + +1. 基于 colossalai/model_zoo 定义 GPT 模型 +2. 处理数据集 +3. 使用混合并行训练 GPT + +## 导入依赖库 + +```python +import json +import os +from typing import Callable + +import colossalai +import colossalai.utils as utils +import model_zoo.gpt.gpt as col_gpt +import torch +import torch.nn as nn +from colossalai import nn as col_nn +from colossalai.amp import AMP_TYPE +from colossalai.builder.pipeline import partition_uniform +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.engine.schedule import (InterleavedPipelineSchedule, + PipelineSchedule) +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper +from colossalai.trainer import Trainer, hooks +from colossalai.utils.timer import MultiTimer +from model_zoo.gpt import GPTLMLoss +from torch.nn import functional as F +from torch.utils.data import Dataset +from transformers import GPT2Tokenizer +``` + + + +## 定义 GPT 模型 + +在前面的教程中,我们介绍了3种建立流水并行模型的方法,但对于像 GPT-3 这样的巨大模型,你甚至不能在 CPU 中建立模型。在这种情况下,你必须自己分割模型。 + +GPT 数据加载器返回 `input_ids` 和 `attention_mask`, 因此我们在 `forward()` 中使用两个关键字参数来获得它们。请注意,对于除第一阶段以外的其他阶段, `forward()` 的第一个位置参数是上一阶段的输出张量。所以 `hidden_states` 来自前一阶段,并且对于第一阶段来说,它是 `None`。 + +对于 GPT, *word embedding layer* 与 *output head* 共享权重。我们提供 `PipelineSharedModuleWrapper` 在流水阶段间共享参数。它需要一个 `int` 型的 `list` 作为参数, 这意味着 rank 们共享这些参数。你可以使用 `register_module()` +或 `register_parameter()` 来注册一个模块或一个参数作为共享模块或参数。如果你有多组共享模块/参数,你应该有多个 `PipelineSharedModuleWrapper` 实例。 如果参数在**一个**阶段内共享, 你不应该使用 +`PipelineSharedModuleWrapper`, 而只是使用同一个模块/参数实例。在这个例子中,*word embedding layer* 在第一阶段, 而 *output head* 在最后一个阶段。因此,他们在 rank `[0, pipeline_size - 1]` 之间共享参数。 + +对于第一阶段,它维护 embedding layer 和一些 transformer blocks。对于最后一个阶段,它维护一些 transformer blocks 和 output head layer。对于其他阶段,他们只维护一些 transformer blocks。 +`partition_uniform(num_layers, pipeline_size, num_chunks)` 返回所有 rank 的 parts, part 是一个 `(start, end)` (不包括end) 的 `tuple`。`start == 0` 表示这是第一阶段, 而 `end == num_layers` 表示这是最后一个阶段。 + +```python +class PipelineGPTHybrid(nn.Module): + def __init__(self, + num_layers: int = 12, + hidden_size: int = 768, + num_attention_heads: int = 12, + vocab_size: int = 50304, + embed_drop_rate: float = 0., + act_func: Callable = F.gelu, + mlp_ratio: int = 4, + attn_drop_rate: float = 0., + drop_rate: float = 0., + dtype: torch.dtype = torch.float, + checkpoint: bool = False, + max_position_embeddings: int = 1024, + layer_norm_epsilon: float = 1e-5, + first: bool = False, + last: bool = False): + super().__init__() + self.embedding = None + self.norm = None + self.head = None + if first: + self.embedding = col_gpt.GPTEmbedding( + hidden_size, vocab_size, max_position_embeddings, dropout=embed_drop_rate, dtype=dtype) + self.blocks = nn.ModuleList([ + col_gpt.GPTBlock(hidden_size, num_attention_heads, mlp_ratio=mlp_ratio, attention_dropout=attn_drop_rate, + dropout=drop_rate, dtype=dtype, checkpoint=checkpoint, activation=act_func) + for _ in range(num_layers) + ]) + if last: + self.norm = col_nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) + self.head = col_gpt.GPTLMHead(vocab_size=vocab_size, + dim=hidden_size, + dtype=dtype, + bias=False) + + def forward(self, hidden_states=None, input_ids=None, attention_mask=None): + if self.embedding is not None: + hidden_states = self.embedding(input_ids=input_ids) + batch_size = hidden_states.shape[0] + attention_mask = attention_mask.view(batch_size, -1) + attention_mask = attention_mask[:, None, None, :] + attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * -10000.0 + for block in self.blocks: + hidden_states, attention_mask = block(hidden_states, attention_mask) + if self.norm is not None: + hidden_states = self.head(self.norm(hidden_states)) + return hidden_states + + +def build_gpt_pipeline(num_layers, num_chunks, device=torch.device('cuda'), **kwargs): + logger = get_dist_logger() + pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + rank = gpc.get_global_rank() + wrapper = PipelineSharedModuleWrapper([0, pipeline_size - 1]) + parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank] + models = [] + for start, end in parts: + kwargs['num_layers'] = end - start + kwargs['first'] = start == 0 + kwargs['last'] = end == num_layers + logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers') + chunk = PipelineGPTHybrid(**kwargs).to(device) + if start == 0: + wrapper.register_module(chunk.embedding.word_embeddings) + elif end == num_layers: + wrapper.register_module(chunk.head) + models.append(chunk) + if len(models) == 1: + model = models[0] + else: + model = nn.ModuleList(models) + return model + + +def GPT2_exlarge_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float): + cfg = dict(hidden_size=1600, num_attention_heads=32, checkpoint=checkpoint, dtype=dtype) + return build_gpt_pipeline(48, num_chunks, **cfg) + + +def GPT3_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float): + cfg = dict(hidden_size=12288, num_attention_heads=96, + checkpoint=checkpoint, max_position_embeddings=2048, dtype=dtype) + return build_gpt_pipeline(96, num_chunks, **cfg) +``` + +## 处理数据集 + +我们在这里提供了一个小型 GPT web-text 数据集。 原始格式是 loose JSON, 我们将保存处理后的数据集。 + +```python +class WebtextDataset(Dataset): + def __init__(self, path, seq_len=1024) -> None: + super().__init__() + root = os.path.dirname(path) + encoded_data_cache_path = os.path.join(root, f'gpt_webtext_{seq_len}.pt') + if os.path.isfile(encoded_data_cache_path): + seq_len_, data, attention_mask = torch.load( + encoded_data_cache_path) + if seq_len_ == seq_len: + self.data = data + self.attention_mask = attention_mask + return + raw_data = [] + with open(path) as f: + for line in f.readlines(): + raw_data.append(json.loads(line)['text']) + tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + tokenizer.pad_token = tokenizer.unk_token + encoded_data = tokenizer( + raw_data, padding=True, truncation=True, max_length=seq_len, return_tensors='pt') + self.data = encoded_data['input_ids'] + self.attention_mask = encoded_data['attention_mask'] + torch.save((seq_len, self.data, self.attention_mask), + encoded_data_cache_path) + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + return { + 'input_ids': self.data[index], + 'attention_mask': self.attention_mask[index] + }, self.data[index] +``` + +## 使用混合并行训练 GPT + +在上一个教程中,我们解释了一些流水并行的参数含义。在本例中,我们可以确定在流水阶段之间交换的每个输出张量的形状。对于 GPT,该形状为 +`(MICRO BATCH SIZE, SEQUENCE LEN, HIDDEN SIZE)`。通过设置该参数,我们可以避免交换每个阶段的张量形状。当你不确定张量的形状时,你可以把它保留为 +`None`, 形状会被自动推测。请确保你的模型的 `dtype` 是正确的:当你使用 `fp16`,模型的 `dtype` 必须是 `torch.half`;否则,`dtype` 必须是 `torch.float`。对于流水并行,仅支持 `AMP_TYPE.NAIVE`。 + +你可以通过在 `CONFIG` 里使用 `parallel` 来轻松使用张量并行。数据并行的大小是根据 GPU 的数量自动设置的。 + +```python +NUM_EPOCHS = 60 +SEQ_LEN = 1024 +BATCH_SIZE = 192 +NUM_CHUNKS = None +TENSOR_SHAPE = (1, 1024, 1600) +# only pipeline parallel +# CONFIG = dict(NUM_MICRO_BATCHES = 192, parallel=dict(pipeline=2), fp16=dict(mode=AMP_TYPE.NAIVE)) +# pipeline + 1D model parallel +CONFIG = dict(NUM_MICRO_BATCHES = 192, parallel=dict(pipeline=2, tensor=dict(mode='1d', size=2)), fp16=dict(mode=AMP_TYPE.NAIVE)) + + +def train(): + disable_existing_loggers() + parser = colossalai.get_default_parser() + args = parser.parse_args() + colossalai.launch_from_torch(config=CONFIG, backend=args.backend) + logger = get_dist_logger() + + train_ds = WebtextDataset(os.environ['DATA'], seq_len=SEQ_LEN) + train_dataloader = utils.get_dataloader(train_ds, + seed=42, + batch_size=BATCH_SIZE, + pin_memory=True, + shuffle=True, + drop_last=True) + + use_interleaved = NUM_CHUNKS is not None + num_chunks = 1 if not use_interleaved else NUM_CHUNKS + model = GPT2_exlarge_pipeline_hybrid(num_chunks=num_chunks, checkpoint=True, dtype=torch.half) + # model = GPT3_pipeline_hybrid(num_chunks=num_chunks, checkpoint=True, dtype=torch.half) + if use_interleaved and not isinstance(model, nn.ModuleList): + model = nn.ModuleList([model]) + + criterion = GPTLMLoss() + + optimizer = torch.optim.Adam(model.parameters(), lr=0.00015, weight_decay=1e-2,) + + engine, train_dataloader, _, _ = colossalai.initialize(model, + optimizer, + criterion, + train_dataloader=train_dataloader) + global_batch_size = BATCH_SIZE * \ + gpc.get_world_size(ParallelMode.DATA) * getattr(gpc.config, "gradient_accumulation", 1) + logger.info(f'Init done, global batch size = {global_batch_size}', ranks=[0]) + + timer = MultiTimer() + + trainer = Trainer( + engine=engine, + logger=logger, + timer=timer + ) + + hook_list = [ + hooks.LossHook(), + hooks.LogMetricByEpochHook(logger), + hooks.ThroughputHook(), + hooks.LogMetricByStepHook(), + ] + + trainer.fit( + train_dataloader=train_dataloader, + epochs=NUM_EPOCHS, + test_interval=1, + hooks=hook_list, + display_progress=True, + return_output_label=False, + ) +``` diff --git a/docs/source/zh-Hans/advanced_tutorials/train_vit_using_pipeline_parallelism.md b/docs/source/zh-Hans/advanced_tutorials/train_vit_using_pipeline_parallelism.md new file mode 100644 index 0000000000000000000000000000000000000000..495c7fa36cc1be2089f9736c700f950b0d3a5a33 --- /dev/null +++ b/docs/source/zh-Hans/advanced_tutorials/train_vit_using_pipeline_parallelism.md @@ -0,0 +1,246 @@ +# 使用流水并行训练 ViT + +作者: Hongxin Liu, Yongbin Li + +**示例代码** +- [ColossalAI-Examples Pipeline Parallel ViT](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/image/vision_transformer/pipeline_parallel) + +**相关论文** +- [Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM](https://arxiv.org/abs/2104.04473) + +## 引言 + +在本教程中,你将学习如何使用流水并行从头开始训练用于图像分类的 Vision Transformer (ViT)。流水并行是一种模型并行,主要针对 GPU 内存不能满足模型容量的情况。 +通过使用流水并行,我们将原始模型分割成多个阶段,每个阶段保留原始模型的一部分。我们假设你的 GPU 内存不能容纳 ViT/L-16,而你的内存可以容纳这个模型。 + +## 目录 + +在本教程中,我们将介绍: + +1. 基于 [TIMM](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) 定义 ViT 模型 +2. 处理数据集 +3. 使用流水并行训练 ViT + +## 导入依赖库 + +```python +import os +from collections import OrderedDict +from functools import partial + +import colossalai +import colossalai.nn as col_nn +import torch +import torch.nn as nn +from colossalai.builder import build_pipeline_model +from colossalai.engine.schedule import (InterleavedPipelineSchedule, + PipelineSchedule) +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.trainer import Trainer, hooks +from colossalai.utils import MultiTimer, get_dataloader +from timm.models import vision_transformer as vit +from torchvision import transforms +from torchvision.datasets import CIFAR10 +``` + + +## 定义 Vision Transformer 模型 + +总的来说, 我们提供3种方法来建立一个流水并行的模型: + +1. `colossalai.builder.build_pipeline_model_from_cfg` +2. `colossalai.builder.build_pipeline_model` +3. 自己按阶段拆分模型 + +当你的内存能够容纳模型时,你可以使用前两种方法来建立你的模型,否则你必须自己分割模型。前两种方法首先在 CPU 上建立整个模型,然后分割模型,最后你可以直接把模型的相应部分移到 GPU 上。 + +`colossalai.builder.build_pipeline_model_from_cfg()` 接收一个模型的配置文件,它可以均匀地(按层)或平衡地(按参数大小)分割模型。 + +如果你熟悉 `PyTorch`, 你可以使用 `colossalai.builder.build_pipeline_model()` 它接收一个 `torch.nn.Sequential` 模型并按层均匀分割。 + +在本教程中,我们将修改 [TIMM/ViT](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) to `torch.nn.Sequential`,然后使用 `colossalai.builder.build_pipeline_model()` 来建立流水线模型。 + +当数据是 **一个** `Tensor`, 你可以使用你的模型 `forward()` 中的位置参数来获得数据张量。对于流水线的第一阶段,`forward()` 的第一个位置参数是从数据加载器加载的数据张量。对于其他阶段,`forward()` 的第一个位置参数是上一阶段的输出张量。注意,如果该阶段不是最后一个阶段,则 `forward()` 的返回必须是一个 `Tensor`。 + +当数据是一个 `Tensor` 的 `dict`, 你可以使用你模型 `forward()` 的命名关键字参数来获得数据的 `dict`。 + +```python +class ViTEmbedding(nn.Module): + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, embed_layer=vit.PatchEmbed, drop_rate=0., distilled=False): + super().__init__() + self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 2 if distilled else 1 + self.patch_embed = embed_layer( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + self.init_weights() + + def forward(self, x): + x = self.patch_embed(x) + cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks + if self.dist_token is None: + x = torch.cat((cls_token, x), dim=1) + else: + x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) + x = self.pos_drop(x + self.pos_embed) + return x + + def init_weights(self): + vit.trunc_normal_(self.pos_embed, std=.02) + if self.dist_token is not None: + vit.trunc_normal_(self.dist_token, std=.02) + vit.trunc_normal_(self.cls_token, std=.02) + self.apply(vit._init_vit_weights) + + +class ViTHead(nn.Module): + def __init__(self, embed_dim=768, num_classes=1000, norm_layer=None, distilled=False, representation_size=None): + super().__init__() + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + self.norm = norm_layer(embed_dim) + self.num_classes = num_classes + self.distilled = distilled + self.num_features = embed_dim + # Representation layer + if representation_size and not distilled: + self.num_features = representation_size + self.pre_logits = nn.Sequential(OrderedDict([ + ('fc', nn.Linear(embed_dim, representation_size)), + ('act', nn.Tanh()) + ])) + else: + self.pre_logits = nn.Identity() + # Classifier head(s) + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.head_dist = None + if distilled: + self.head_dist = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + self.init_weights() + + def forward(self, x): + x = self.norm(x) + if self.distilled: + x, x_dist = self.head(x[:, 0]), self.head_dist(x[:, 1]) + if self.training and not torch.jit.is_scripting(): + # during inference, return the average of both classifier predictions + return x, x_dist + else: + return (x + x_dist) / 2 + else: + x = self.pre_logits(x[:, 0]) + x = self.head(x) + return x + + def init_weights(self): + self.apply(vit._init_vit_weights) + + +def sequential_vit(img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=vit.PatchEmbed, norm_layer=None, + act_layer=None): + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU + embedding = ViTEmbedding(img_size=img_size, patch_size=patch_size, in_chans=in_chans, + embed_dim=embed_dim, embed_layer=embed_layer, drop_rate=drop_rate, distilled=distilled) + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + blocks = [vit.Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) + for i in range(depth)] + for block in blocks: + block.apply(vit._init_vit_weights) + head = ViTHead(embed_dim=embed_dim, num_classes=num_classes, norm_layer=norm_layer, + distilled=distilled, representation_size=representation_size) + return nn.Sequential(embedding, *blocks, head) + + +def vit_large_patch16_224(**kwargs): + model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs) + return sequential_vit(**model_kwargs) +``` + +## 处理数据集 + +一般来说, 我们在大型数据集如 ImageNet 上训练 ViT。为了简单期间,我们在这里只使用 CIFAR-10, 因为本教程只是用于流水并行训练。 + +```python +def build_cifar(batch_size): + transform_train = transforms.Compose([ + transforms.RandomCrop(224, pad_if_needed=True), + transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + transform_test = transforms.Compose([ + transforms.Resize(224), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + + train_dataset = CIFAR10(root=os.environ['DATA'], train=True, download=True, transform=transform_train) + test_dataset = CIFAR10(root=os.environ['DATA'], train=False, transform=transform_test) + train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=batch_size, pin_memory=True) + test_dataloader = get_dataloader(dataset=test_dataset, batch_size=batch_size, pin_memory=True) + return train_dataloader, test_dataloader +``` + +## 使用流水并行训练 ViT + +你可以在配置文件中设置流水并行的大小。`NUM_CHUNKS` 在使用交错流水线时很有用 (更多细节见 [Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM](https://arxiv.org/abs/2104.04473) )。 +原始 batch 将会被分割为 `num_microbatches`, 每个阶段每次将加载一个 micro batch。如果你确定性地知道每个阶段输出张量的形状,你可以在配置文件中设置 `tensor_shape` 来减少通信。 +我们的仓库会自动为用户生成合适的schedule来支持流水并行训练。如果你不需要模型的输出和标签,你可以在调用 `trainer.fit()` 时,将 `return_output_label` 设置为 `False`,这样能进一步减少 GPU 显存使用。 + +你应当使用 `export DATA=/path/to/cifar`。 + +```python +BATCH_SIZE = 16 +NUM_EPOCHS = 60 +NUM_CHUNKS = 1 +CONFIG = dict(NUM_MICRO_BATCHES=4, parallel=dict(pipeline=2)) + + +def train(): + disable_existing_loggers() + parser = colossalai.get_default_parser() + args = parser.parse_args() + colossalai.launch_from_torch(backend=args.backend, config=CONFIG) + logger = get_dist_logger() + + # build model + model = vit_large_patch16_224() + model = build_pipeline_model(model, num_chunks=NUM_CHUNKS, verbose=True) + + # build criterion + criterion = nn.CrossEntropyLoss() + + # optimizer + optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0) + + # build dataloader + train_dataloader, test_dataloader = build_cifar(BATCH_SIZE) + + engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model, optimizer, criterion, + train_dataloader, test_dataloader) + timer = MultiTimer() + + trainer = Trainer(engine=engine, timer=timer, logger=logger) + + hook_list = [ + hooks.LossHook(), + hooks.AccuracyHook(col_nn.metric.Accuracy()), + hooks.LogMetricByEpochHook(logger), + ] + + trainer.fit(train_dataloader=train_dataloader, + epochs=NUM_EPOCHS, + test_dataloader=test_dataloader, + test_interval=1, + hooks=hook_list, + display_progress=True) +``` diff --git a/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md b/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md new file mode 100644 index 0000000000000000000000000000000000000000..6dc5eccf44218ced733de4bd5bb22fecc302c61e --- /dev/null +++ b/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md @@ -0,0 +1,591 @@ +# 使用 Colossal-AI (从数据并行到异构并行)加速 ViT 训练详解 + +作者:Yuxuan Lou + +**示例代码** + +- [Colossal-AI Examples ViT on Cifar10](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/image/vision_transformer) + +**相关文献** +- [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/pdf/2010.11929.pdf) + + +## 引言 + +在这个ViT模型的样例中,Colossal-AI 提供了三种不同的并行技术来加速模型训练:数据并行,流水线并行和张量并行。我们将展示如何使用这三种并行技术在 CIFAR-10 数据集上训练 ViT。为了运行项目,需要2-4个 GPU。 + + +## 目录 +1. Colossal-AI 安装方法 +2. 使用数据并行训练 ViT 步骤 +3. 使用数据流水线并行训练 ViT 步骤 +4. 使用张量并行或异构并行训练 ViT 步骤 + +## Colossal-AI 安装 +可以通过 Python 的官方索引来安装 Colossal-AI 软件包。 +```bash +pip install colossalai +``` + + + +## 数据并行 +数据并行是实现加速模型训练的基本方法。通过两步可以实现训练的数据并行: +1. 构建一个配置文件 +2. 在训练脚本中修改很少的几行代码 + +### 构建配置文件 (`data_parallel/config.py`) +为了使用 Colossal-AI,第一步是构建配置文件。并且,在这里有两种变量: + +1. **Colossal-AI 功能配置** + +Colossal-AI 提供了一系列的功能来加快训练速度(包括模型并行,混合精度,零冗余优化器等)。每个功能都是由配置文件中的相应字段定义的。如果我们只用到数据并行,那么我们只需要具体说明并行模式。在本例中,我们使用 PyTorch 最初提出的混合精度训练,只需要定义混合精度配置 `fp16 = dict(mode=AMP_TYPE.TORCH)` 。 + +2. **全局超参数** + +全局超参数包括特定于模型的超参数、训练设置、数据集信息等。 + +```python +from colossalai.amp import AMP_TYPE +# ViT Base +BATCH_SIZE = 256 +DROP_RATE = 0.1 +NUM_EPOCHS = 300 +# mix precision +fp16 = dict( + mode=AMP_TYPE.TORCH, +) +gradient_accumulation = 16 +clip_grad_norm = 1.0 +dali = dict( + gpu_aug=True, + mixup_alpha=0.2 +) +``` + +### 修改训练脚本 (`/data_parallel/train_with_cifar10.py`) + +#### 导入模块 +- Colossal-AI 相关模块 +```python +import colossalai +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.lr_scheduler import LinearWarmupLR +from colossalai.nn.metric import Accuracy +from colossalai.trainer import Trainer, hooks +``` + +- 其他模块 +```python +import os +import torch +from timm.models import vit_base_patch16_224 +from torchvision import transforms +from torchvision.datasets import CIFAR10 +``` + +#### 启动 Colossal-AI + +在训练脚本中,在构建好配置文件后,需要为 Colossal-AI 初始化分布式环境。我们将此过程称为 `launch` 。在 Colossal-AI 中,我们提供了几种启动方法来初始化分布式后端。在大多数情况下,您可以使用 `colossalai.launch` 和 `colossalai.get_default_parser ` 来实现使用命令行传递参数。此外,Colossal-AI 可以利用 PyTorch 提供的现有启动工具,正如许多用户通过使用熟知的 `colossalai.launch_from_torch` 那样。更多详细信息,您可以查看相关[文档](https://www.colossalai.org/docs/basics/launch_colossalai)。 + + +```python +# initialize distributed setting +parser = colossalai.get_default_parser() +args = parser.parse_args() +colossalai.launch_from_torch(config=args.config) +disable_existing_loggers() +logger = get_dist_logger() +``` + +初始化后,您可以使用 `colossalai.core.global_context` 访问配置文件中的变量。 + +```python +#access parameters +print(gpc.config.BATCH_SIZE) +``` + +#### 构建模型 + +如果只需要数据并行性,则无需对模型代码进行任何更改。这里,我们使用 `timm` 中的 `vit_base_patch16_224`。 + +```python +# build model +model = vit_base_patch16_224(drop_rate=0.1, num_classes=gpc.config.NUM_CLASSES) +``` + +#### 构建 CIFAR-10 数据加载器 +`colossalai.utils.get_dataloader` 可以帮助您轻松构建数据加载器。 + +```python +def build_cifar(batch_size): + transform_train = transforms.Compose([ + transforms.RandomCrop(224, pad_if_needed=True), + transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + transform_test = transforms.Compose([ + transforms.Resize(224), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + train_dataset = CIFAR10(root=os.environ['DATA'], train=True, download=True, transform=transform_train) + test_dataset = CIFAR10(root=os.environ['DATA'], train=False, transform=transform_test) + train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=batch_size, pin_memory=True) + test_dataloader = get_dataloader(dataset=test_dataset, batch_size=batch_size, pin_memory=True) + return train_dataloader, test_dataloader +# build dataloader +train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE) +``` + +#### 定义优化器,损失函数和学习率调度器 + +Colossal-AI 提供了自己的优化器、损失函数和学习率调度器。PyTorch 的这些组件与Colossal-AI也兼容。 + +```python +# build optimizer +optimizer = colossalai.nn.Lamb(model.parameters(), lr=1.8e-2, weight_decay=0.1) +# build loss +criterion = torch.nn.CrossEntropyLoss() +# lr_scheduelr +lr_scheduler = LinearWarmupLR(optimizer, warmup_steps=50, total_steps=gpc.config.NUM_EPOCHS) +``` + +#### 启动用于训练的 Colossal-AI 引擎 + +Engine 本质上是对模型、优化器和损失函数的封装类。当我们使用 `colossalai.initialize` ,将返回一个 engine 对象,并且它已经按照配置文件中的指定内容,配置了梯度剪裁、梯度累积和零冗余优化器等功能。之后,基于 Colossal-AI 的 engine 我们可以进行模型训练。 + +```python +engine, train_dataloader, test_dataloader, _ = colossalai.initialize( + model, optimizer, criterion, train_dataloader, test_dataloader + ) +``` + +#### 训练:Trainer 应用程序编程接口 +Trainer 是一个更高级的封装类,用户可以用更少的代码就可以实现训练。通过传递 engine 对象很容易创建 trainer 对象。 + +此外,在 trainer 中,用户可以自定义一些挂钩,并将这些挂钩连接到 trainer 对象。钩子对象将根据训练方案定期执行生命周期方法。例如,`LRSchedulerHook` 将执行`lr_scheduler.step()` 在 `after_train_iter` 或 `after_train_epoch` 阶段更新模型的学习速率。 + +```python +# build trainer +trainer = Trainer(engine=engine, logger=logger) +# build hooks +hook_list = [ + hooks.LossHook(), + hooks.AccuracyHook(accuracy_func=MixupAccuracy()), + hooks.LogMetricByEpochHook(logger), + hooks.LRSchedulerHook(lr_scheduler, by_epoch=True), + # comment if you do not need to use the hooks below + hooks.SaveCheckpointHook(interval=1, checkpoint_dir='./ckpt'), + hooks.TensorboardHook(log_dir='./tb_logs', ranks=[0]), +] +``` + +使用 `trainer.fit` 进行训练: + +```python +# start training +trainer.fit( + train_dataloader=train_dataloader, + test_dataloader=test_dataloader, + epochs=gpc.config.NUM_EPOCHS, + hooks=hook_list, + display_progress=True, + test_interval=1 +) +``` + +### 开始训练 +`DATA` 是自动下载和存储 CIFAR-10 数据集的文件路径。 + +`` 是要用于使用 CIFAR-10 数据集,以数据并行方式训练 ViT 的 GPU 数。 + +```bash +export DATA= +# If your torch >= 1.10.0 +torchrun --standalone --nproc_per_node train_dp.py --config ./configs/config_data_parallel.py +# If your torch >= 1.9.0 +# python -m torch.distributed.run --standalone --nproc_per_node= train_dp.py --config ./configs/config_data_parallel.py +# Otherwise +# python -m torch.distributed.launch --nproc_per_node --master_addr --master_port 29500 train_dp.py --config ./configs/config.py +``` + + + +## 流水线并行 +除了数据并行性,Colossal-AI 还支持流水线并行。具体而言,Colossal-AI 使用 NVIDIA 引入的 1F1B 流水线。更多详细信息,您可以查看相关[文档](https://www.colossalai.org/tutorials/features/pipeline_parallel)。 + +### 构建配置文件(`hybrid_parallel/configs/vit_pipeline.py`) +要在数据并行的基础上应用流水线并行,只需添加一个 **parallel dict** +```python +from colossalai.amp import AMP_TYPE +parallel = dict( + pipeline=2 +) +# pipeline config +NUM_MICRO_BATCHES = parallel['pipeline'] +TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LENGTH, HIDDEN_SIZE) +fp16 = dict(mode=AMP_TYPE.NAIVE) +clip_grad_norm = 1.0 +``` + +其他配置: +```python +# hyperparameters +# BATCH_SIZE is as per GPU +# global batch size = BATCH_SIZE x data parallel size +BATCH_SIZE = 256 +LEARNING_RATE = 3e-3 +WEIGHT_DECAY = 0.3 +NUM_EPOCHS = 300 +WARMUP_EPOCHS = 32 +# model config +IMG_SIZE = 224 +PATCH_SIZE = 16 +HIDDEN_SIZE = 768 +DEPTH = 12 +NUM_HEADS = 12 +MLP_RATIO = 4 +NUM_CLASSES = 10 +CHECKPOINT = True +SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE) ** 2 + 1 # add 1 for cls token +``` + +### 构建流水线模型 (`/hybrid_parallel/model/vit.py`) +Colossal-AI 提供了两种从现有模型构建流水线模型的方法。 +- `colossalai.builder.build_pipeline_model_from_cfg` +- `colossalai.builder.build_pipeline_model` + +此外,您还可以使用 Colossal-AI 从头开始构建流水线模型。 +```python +import math +from typing import Callable +import inspect +import torch +from colossalai import nn as col_nn +from colossalai.registry import LAYERS, MODELS +from colossalai.logging import get_dist_logger +from colossalai.core import global_context as gpc +from colossalai.context import ParallelMode +from colossalai.builder.pipeline import partition_uniform +from torch import dtype, nn +from model_zoo.vit.vit import ViTBlock, ViTEmbedding, ViTHead +@MODELS.register_module +class PipelineVisionTransformer(nn.Module): + def __init__(self, + img_size: int = 224, + patch_size: int = 16, + in_chans: int = 3, + num_classes: int = 1000, + depth: int = 12, + num_heads: int = 12, + dim: int = 768, + mlp_ratio: int = 4, + attention_dropout: float = 0., + dropout: float = 0.1, + drop_path: float = 0., + layernorm_epsilon: float = 1e-6, + activation: Callable = nn.functional.gelu, + representation_size: int = None, + dtype: dtype = None, + bias: bool = True, + checkpoint: bool = False, + init_method: str = 'torch', + first_stage=True, + last_stage=True, + start_idx=None, + end_idx=None,): + super().__init__() + layers = [] + if first_stage: + embed = ViTEmbedding(img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embedding_dim=dim, + dropout=dropout, + dtype=dtype, + init_method=init_method) + layers.append(embed) + # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, drop_path, depth)] + if start_idx is None and end_idx is None: + start_idx = 0 + end_idx = depth + blocks = [ + ViTBlock( + dim=dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + attention_dropout=attention_dropout, + dropout=dropout, + drop_path=dpr[i], + activation=activation, + dtype=dtype, + bias=bias, + checkpoint=checkpoint, + init_method=init_method, + ) for i in range(start_idx, end_idx) + ] + layers.extend(blocks) + if last_stage: + norm = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype) + head = ViTHead(dim=dim, + num_classes=num_classes, + representation_size=representation_size, + dtype=dtype, + bias=bias, + init_method=init_method) + layers.extend([norm, head]) + self.layers = nn.Sequential( + *layers + ) + def forward(self, x): + x = self.layers(x) + return x +def _filter_kwargs(func, kwargs): + sig = inspect.signature(func) + return {k: v for k, v in kwargs.items() if k in sig.parameters} +def _build_pipeline_vit(module_cls, num_layers, num_chunks, device=torch.device('cuda'), **kwargs): + logger = get_dist_logger() + if gpc.is_initialized(ParallelMode.PIPELINE): + pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + else: + pipeline_size = 1 + pipeline_rank = 0 + rank = gpc.get_global_rank() + parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank] + models = [] + for start, end in parts: + kwargs['first_stage'] = start == 0 + kwargs['last_stage'] = end == num_layers + kwargs['start_idx'] = start + kwargs['end_idx'] = end + logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers') + chunk = module_cls(**_filter_kwargs(module_cls.__init__, kwargs)).to(device) + models.append(chunk) + if len(models) == 1: + model = models[0] + else: + model = nn.ModuleList(models) + return model +def build_pipeline_vit(num_layers, num_chunks, device=torch.device('cuda'), **kwargs): + return _build_pipeline_vit(PipelineVisionTransformer, num_layers, num_chunks, device, **kwargs) +``` + +### 修改训练脚本 (`/hybrid_parallel/train_with_cifar10.py`) + +#### 导入模块 +```python +from colossalai.engine.schedule import (InterleavedPipelineSchedule, + PipelineSchedule) +from colossalai.utils import MultiTimer +import os +import colossalai +import torch +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.nn import CrossEntropyLoss +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.utils import is_using_pp, get_dataloader +from model.vit import build_pipeline_vit +from model_zoo.vit.vit import _create_vit_model +from tqdm import tqdm +from torchvision import transforms +from torchvision.datasets import CIFAR10 +``` + +#### 启动 Colossal-AI +`colossalai.utils.is_using_pp` 可以帮您检查配置文件是否满足流水线并行的要求。 + +```python +# initialize distributed setting +parser = colossalai.get_default_parser() +args = parser.parse_args() +# launch from torch +colossalai.launch_from_torch(config=args.config) +# get logger +logger = get_dist_logger() +logger.info("initialized distributed environment", ranks=[0]) +if hasattr(gpc.config, 'LOG_PATH'): + if gpc.get_global_rank() == 0: + log_path = gpc.config.LOG_PATH + if not os.path.exists(log_path): + os.mkdir(log_path) + logger.log_to_file(log_path) +use_pipeline = is_using_pp() +``` + +#### 定义模型 + +```python +# create model +model_kwargs = dict(img_size=gpc.config.IMG_SIZE, + patch_size=gpc.config.PATCH_SIZE, + dim=gpc.config.HIDDEN_SIZE, + depth=gpc.config.DEPTH, + num_heads=gpc.config.NUM_HEADS, + mlp_ratio=gpc.config.MLP_RATIO, + num_classes=gpc.config.NUM_CLASSES, + init_method='jax', + checkpoint=gpc.config.CHECKPOINT) +if use_pipeline: + model = build_pipeline_vit(num_layers=model_kwargs['depth'], num_chunks=1, **model_kwargs) +else: + model = _create_vit_model(**model_kwargs) +``` + +#### 计算参数个数 + +您可以轻松计算不同流水线阶段上的模型参数个数。 + +``` +# count number of parameters +total_numel = 0 +for p in model.parameters(): + total_numel += p.numel() +if not gpc.is_initialized(ParallelMode.PIPELINE): + pipeline_stage = 0 +else: + pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE) +logger.info(f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}") +``` + +#### 构建数据加载器,优化器等组件 + +```python +def build_cifar(batch_size): + transform_train = transforms.Compose([ + transforms.RandomCrop(224, pad_if_needed=True), + transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + transform_test = transforms.Compose([ + transforms.Resize(224), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + train_dataset = CIFAR10(root=os.environ['DATA'], train=True, download=True, transform=transform_train) + test_dataset = CIFAR10(root=os.environ['DATA'], train=False, transform=transform_test) + train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=batch_size, pin_memory=True) + test_dataloader = get_dataloader(dataset=test_dataset, batch_size=batch_size, pin_memory=True) + return train_dataloader, test_dataloader + + +# craete dataloaders +train_dataloader , test_dataloader = build_cifar() +# create loss function +criterion = CrossEntropyLoss(label_smoothing=0.1) +# create optimizer +optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY) +# create lr scheduler +lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, + total_steps=gpc.config.NUM_EPOCHS, + warmup_steps=gpc.config.WARMUP_EPOCHS) +``` + +#### 启动 Colossal-AI 引擎 + +```python +# intiailize +engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model, + optimizer=optimizer, + criterion=criterion, + train_dataloader=train_dataloader, + test_dataloader=test_dataloader) +logger.info("Engine is built", ranks=[0]) +``` + +#### 训练:基于engine + +在数据并行示例中,我们展示了如何使用 Trainer API 训练模型。我们还可以直接训练基于 engine 的模型。通过这种方式,您可以使用更多功能自定义训练方法。 + +```python +data_iter = iter(train_dataloader) +for epoch in range(gpc.config.NUM_EPOCHS): + # training + engine.train() + if gpc.get_global_rank() == 0: + description = 'Epoch {} / {}'.format( + epoch, + gpc.config.NUM_EPOCHS + ) + progress = tqdm(range(len(train_dataloader)), desc=description) + else: + progress = range(len(train_dataloader)) + for _ in progress: + engine.zero_grad() + engine.execute_schedule(data_iter, return_output_label=False) + engine.step() + lr_scheduler.step() +``` + +### 开始训练 +```bash +export DATA= +# If your torch >= 1.10.0 +torchrun --standalone --nproc_per_node train_hybrid.py --config ./configs/config_pipeline_parallel.py +# If your torch >= 1.9.0 +# python -m torch.distributed.run --standalone --nproc_per_node= train_hybrid.py --config ./configs/config_pipeline_parallel.py +``` + + + + +## 张量并行和异构并行 +张量并行将每个权重参数跨多个设备进行分区,以减少内存负载。Colossal-AI 支持 1D、2D、2.5D 和 3D 张量并行。此外,还可以将张量并行、流水线并行和数据并行结合起来,实现混合并行。Colossal-AI 还提供了一种简单的方法来应用张量并行和混合并行。只需在配置文件中更改几行代码即可实现流水线并行。 + +### 构造您的配置文件 (`/hybrid_parallel/configs/vit_1d_tp2_pp2.py`) +使用张量并行,只需将相关信息添加到 **parallel dict**。具体而言,`TENSOR_PARALLEL_MODE` 可以是“1d”、“2d”、“2.5d”、“3d”。不同并行度的大小应满足:`#GPUs = pipeline parallel size x tensor parallel size x data parallel size`。在指定 GPU 数量、流水线并行大小和张量并行大小后 `data parallel size` 会自动计算。 + +```python +from colossalai.amp import AMP_TYPE +# parallel setting +TENSOR_PARALLEL_SIZE = 2 +TENSOR_PARALLEL_MODE = '1d' +parallel = dict( + pipeline=2, + tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE) +) +fp16 = dict(mode=AMP_TYPE.NAIVE) +clip_grad_norm = 1.0 +# pipeline config +NUM_MICRO_BATCHES = parallel['pipeline'] +TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LENGTH, HIDDEN_SIZE) +``` + +其他配置: +```python +# hyperparameters +# BATCH_SIZE is as per GPU +# global batch size = BATCH_SIZE x data parallel size +BATCH_SIZE = 256 +LEARNING_RATE = 3e-3 +WEIGHT_DECAY = 0.3 +NUM_EPOCHS = 300 +WARMUP_EPOCHS = 32 +# model config +IMG_SIZE = 224 +PATCH_SIZE = 16 +HIDDEN_SIZE = 768 +DEPTH = 12 +NUM_HEADS = 12 +MLP_RATIO = 4 +NUM_CLASSES = 10 +CHECKPOINT = True +SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE) ** 2 + 1 # add 1 for cls token +``` + +### 开始训练 +```bash +export DATA= +# If your torch >= 1.10.0 +torchrun --standalone --nproc_per_node train_hybrid.py --config ./configs/config_hybrid_parallel.py +# If your torch >= 1.9.0 +# python -m torch.distributed.run --standalone --nproc_per_node= train_hybrid.py --config ./configs/config_hybrid_parallel.py +``` diff --git a/docs/source/zh-Hans/basics/colotensor_concept.md b/docs/source/zh-Hans/basics/colotensor_concept.md new file mode 100644 index 0000000000000000000000000000000000000000..d6a332df2e9c4649e8184769e1b5650cf4c3823a --- /dev/null +++ b/docs/source/zh-Hans/basics/colotensor_concept.md @@ -0,0 +1,97 @@ +# ColoTensor Concepts + +Author: [Jiarui Fang](https://github.com/feifeibear), [Hongxin Liu](https://github.com/ver217) and [Haichen Huang](https://github.com/1SAA) + +**Prerequisite:** +- [Colossal-AI Overview](../concepts/colossalai_overview.md) +- [Distributed Training](../concepts/distributed_training.md) +- [Paradigms of Parallelism](../concepts/paradigms_of_parallelism.md) + +## Introduction + +在ColossalAI 0.1.8 版本之后,[ColoTensor](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.html#colossalai.tensor.ColoTensor) 成为 ColossalAI 中张量的基本数据结构。 它是 torch.Tensor 的子类,可以当做 PyTorch Tensor使用。 此外,一些独特的功能使其能够表示一个payload分布在多个 GPU 设备上的Global Tensor,并提供一些列方式操作这个Global Tensor。 在 ColoTensor 的帮助下,用户可以以类似编写串行程序方式,编写的分布式 DNN 训练程序。 + +ColoTensor 包含额外的属性[ColoTensorSpec](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.tensor_spec.html#colossalai.tensor.tensor_spec.ColoTensorSpec) +来描述张量的payload分布和计算模式。 + +- ProcessGroup:如何将进程组织为通信组。 +- Distributed Spec:张量如何在进程组之间分布。 +- Compute Spec:计算过程中如何使用张量。 + +我们一一详述。 + +## ProcessGroup + +[ProcessGroup](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.html#colossalai.tensor.ProcessGroup) 类的一个实例描述了如何在进程组中组织进程。进程组内的进程可以一起参与同一个集合通信,比如allgather, allreduce等。进程组组织方式被张量的并行策略支配。比如,如果用户定义了Tensor的张量并行(TP),数据并行(DP)方式,那么进程组的进程组织方式将被自动推导出来。 进程组设置可能因不同的张量而异。 因此,它使我们能够支持更复杂的混合并行。流水线并行(PP)定义不在ProcessGroup中描述,它需要另一套机制,我们将在未来补充ColoTensor应用于PP的相关内容。 + +目前,ColoTensor 的一个进程组由 tp_degree 和 dp_degree 两种配置定义。 在 DP+TP 混合并行的情况下,可以将设备视为 2D 网格。 我们将 TP 通信组放置在设备网格的前导低维上,然后将数据并行组放置在设备网格的高维上。 原因是张量并行比数据并行具有更大的通信开销。 相邻设备放置在一个 TP 进程组内,并且通常放置在同一个节点中。 + +考虑到8个进程配置为tp_degree=4,dp_degree=2,布局如下图。 进程组 tp0 包含 gpu 0,1,2,3。 进程 dp1 包含 gpu 1 和 5。 + +
+ +
Process Group using tp_degree=4, dp_degree=2
+
+ +## Distributed Spec + +[Distributed Spec](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.distspec.html)描述了 ColoTensor 如何在 ProcessGroup 中分布。 + +张量在 DP 进程组之间的分布方式是自动导出的,不需要用户手动指定。 如果这个张量是一个模型参数,它会在 DP 进程组中被复制。 如果是activation张量,则沿tensor最高维度在DP进程组中进行平均分割。 + +因此,在使用 Distributed Spec 时,我们只需要描述张量在 TP 进程组之间的分布方式即可。 TP 进程组目前有两种分布式规范,即 [ShardSpec](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.distspec.html#colossalai.tensor.distspec.ShardSpec)和[ReplicaSpec](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.distspec.html#colossalai.tensor.distspec.ReplicaSpec)。 ShardSpec 需要指定分区的维度索引 dim 和分区个数 num_partitions。 目前,我们仅支持在单个dim上进行拆分。 TP进程组上不同的dist spec可以通过set_dist_spec()接口相互转换。这些转化操作可以被记录在PyTorch的自动求导机制中,并在反向传播时候触发对应的反向操作。 + +## Compute Spec + +[ComputeSpec](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.compute_spec.html#colossalai.tensor.compute_spec.ComputeSpec)类描述Tensor如何参与计算。目前,我们将作为module parameter的ColoTensor设置正确的Compute Pattern。可以触发正取的计算模式。具体应用方式我们会在接下来的文档中展示。 + +## ColoParameter + +[ColoParameter](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.colo_parameter.html#colossalai.tensor.colo_parameter.ColoParameter)是ColoTensor的子类。用来声明Parameter。他和ColoTensor关系和Torch.Tensor和torch.Parameter一致。后者可以让tensor出现在module的parameters()和name_parameters() 的返回值中。 + +## Example + +让我们看一个例子。 使用 tp_degree=4, dp_dgree=2 在 8 个 GPU 上初始化并Shard一个ColoTensor。 然后tensor被沿着 TP 进程组中的最后一个维度进行分片。 最后,我们沿着 TP 进程组中的第一个维度(dim 0)对其进行重新Shard。 我们鼓励用户运行代码并观察每个张量的形状。 + + +```python +import torch +import torch.multiprocessing as mp +from colossalai.utils import print_rank_0 +from functools import partial + +import colossalai +from colossalai.tensor import ProcessGroup, ColoTensor, ColoTensorSpec, ShardSpec, ComputeSpec, ComputePattern +from colossalai.testing import spawn + +import torch + +def run_dist_tests(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + pg = ProcessGroup(tp_degree=2, dp_degree=2) + + torch.manual_seed(0) + local_tensor = torch.randn(2, 3, 1).cuda() + print_rank_0(f"shape {local_tensor.shape}, {local_tensor.data}") + + spec = ColoTensorSpec(pg, ShardSpec(dims=[-1], num_partitions=[pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + t1 = ColoTensor.from_torch_tensor(local_tensor, spec) + t1 = t1.to_replicate() + print_rank_0(f"shape {t1.shape}, {t1.data}") + + spec2 = ShardSpec([0], [pg.tp_world_size()]) + t1.set_dist_spec(spec2) + print_rank_0(f"shape {t1.shape}, {t1.data}") + +def test_dist_cases(world_size): + spawn(run_dist_tests, world_size) + +if __name__ == '__main__': + test_dist_cases(4) +``` + +:::caution + +The ColoTensor is an experimental feature and may be updated. + +::: diff --git a/docs/source/zh-Hans/basics/command_line_tool.md b/docs/source/zh-Hans/basics/command_line_tool.md new file mode 100644 index 0000000000000000000000000000000000000000..9b0275a6ceddb3ad6da16b068bb8dcf71eea8907 --- /dev/null +++ b/docs/source/zh-Hans/basics/command_line_tool.md @@ -0,0 +1,47 @@ +# 命令行工具 + +作者: Shenggui Li + +**预备知识:** +- [Distributed Training](../concepts/distributed_training.md) +- [Colossal-AI Overview](../concepts/colossalai_overview.md) + +## 简介 + +Colossal-AI给用户提供了命令行工具,目前命令行工具可以用来支持以下功能。 +- 检查Colossal-AI是否安装正确 +- 启动分布式训练 +- 张量并行基准测试 + +## 安装检查 + +用户可以使用`colossalai check -i`这个命令来检查目前环境里的版本兼容性以及CUDA Extension的状态。 + +
+ +
Check Installation Demo
+
+ +## 启动分布式训练 + +在分布式训练时,我们可以使用`colossalai run`来启动单节点或者多节点的多进程,详细的内容可以参考[启动 Colossal-AI](./launch_colossalai.md)。 + +## 张量并行基准测试 + +Colossal-AI提供了多种张量并行,想要充分理解这些方法需要一定的学习成本,对于新手来说很难靠经验选择一个并行方式。 +所以我们提供了一个简单的基准测试,能够让用户在自己的机器上测试不同张量并行的性能。这个基准测试跑一个并行的MLP模型, +输入数据的维度为`(批大小,序列长度,隐藏层维度)`。通过指定GPU的数量,Colossal-AI会搜索所有可行的并行配置。用户可以通过查看`colossalai benchmark --help`来自定义相关的测试参数。 + +```shell +# 使用4个GPU +colossalai benchmark --gpus 4 + +# 使用8个GPU +colossalai benchmark --gpus 8 +``` + +:::caution + +目前仅支持单节点的基准测试。 + +::: diff --git a/docs/source/zh-Hans/basics/configure_parallelization.md b/docs/source/zh-Hans/basics/configure_parallelization.md new file mode 100644 index 0000000000000000000000000000000000000000..eb4b38f48ddb3653735f3dfd49fa263fcdbc9d07 --- /dev/null +++ b/docs/source/zh-Hans/basics/configure_parallelization.md @@ -0,0 +1,136 @@ +# 并行配置 + +作者: Shenggui Li, Siqi Mai + +**预备知识:** +- [分布式训练](../concepts/distributed_training.md) +- [并行技术](../concepts/paradigms_of_parallelism.md) +- [构建配置文件](./define_your_config.md) + + +## 简介 + +我们在 Colossal-AI 中支持多种并行技术。代码库中的混合并行是指您可以轻松地结合数据并行、流水线并行和张量并行(1D、2D、2.5D、3D)的优势共同来进行并行训练。 + +每种并行方式需要不同的网络拓扑结构,因此要初始化不同的进程组。您可以通过在配置文件中设置 `parallel` 来初始化相应的进程组。 `parallel` 的配置必须遵从以下格式。数据并行度的大小将被根据您对流水线并行和张量并行的输入自动推断。`colossalai.launch` 将根据您的配置自动初始化这些分布式进程组。 + +我们为您提供了一些配置的例子以供参考。 + +```python +# sampler format +parallel = dict( + pipeline=dict("size": int), + tensor=dict("size": int, "mode": '1d' or '2d' or '2.5d' or '3d', "kwargs": Any) +) + +# this is ok +parallel = dict( + pipeline=dict(size=2), + tensor=dict(size=4, mode='2d') +) + +# this is ok +parallel = dict( + pipeline=2, + tensor=dict(size=4, mode='2d') +) + +# this is not ok +# as you need to specify the mode for tensor parallelism +parallel = dict( + pipeline=2, + tensor=4 +) + +# this is ok as well as tensor will be default to size 1 +# and mode None +parallel = dict( + pipeline=2 +) + +# this is ok as well as pipeline will default to size 1 +parallel = dict( + tensor=dict(size=4, mode='2d') +) + +``` + +关键字 `size` 指的是并行维度的并行大小。 例如,流水线大小为2意味着有 +将有2个流水线阶段。张量并行配置中的关键字 `mode` 意味着相应的张量并行技术 +将被初始化,如1D、2D、2.5D、3D。 + +**您也可以选择不在您的配置中使用 "并行",此时流水线和张量的并行度都将默认为大小1。** + +**GPU的总数量必须等于` 数据并行大小 x 张量并行大小 x 流水线并行大小` 。** + +## 数据并行 + +数据并行是最常见的分布式训练方式。它将数据分割成几个碎片分别在每个设备上进行训练。数据并行的配置会自动检测并为您设置。您不需要在您的配置中明确地设置它们。在Colossal-AI 中,有两种方法来处理数据并行的 all-reduce。 + +1. 如果您设置了梯度handler,梯度handler将会all-reduce梯度。 +2. 若没有指定相应的配置,Colossal-AI 将会使用 PyTorch 的 DistributedDataParallel。 + +在大多数情况下,若您对梯度没有复杂的处理的需求,您将会使用第二种模式。 + +## 1D, 2D, 2.5D 和 3D 并行 + +为了实现混合并行,我们提供了一系列张量并行方法。您可以阅读相应的学术论文进行深入的了解。这些并行模式需要和 Colossal-AI 提供的分布式层一同工作。 + +- 1D: [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) + +- 2D: [An Efficient 2D Method for Training Super-Large Deep Learning Models](https://arxiv.org/abs/2104.05343) + 2D 并行基于 SUMMA 矩阵乘法,它将输入数据、模型权重和层输出切分成两个不同的维度。 这些张量块分布在 `P = N^2` 设备的二维网格上,其中 `N` 是单一维度上张量块的数量。 + +- 2.5D: [2.5-dimensional distributed model training](https://arxiv.org/abs/2105.14500) + 在 2.5D 矩阵乘法的启发下,2.5D 并行引入了一种新的张量并行,进一步将2D张量并行化。其中,`P = N^2 ∗ d` 个处理器被分配到 `d` 层, 每层独立进行矩阵乘法运算,维度为 `N`。 + +- 3D: [Maximizing Parallelism in Distributed Training for Huge Neural Networks](https://arxiv.org/abs/2105.14450) + 我们还介绍了一种 3D 张量并行方法,在三维处理器立方体上并行化神经网络。这种方法在数量为 `P` 的处理器上实现了最佳的 `O(P^{1/3})` 通信开销,而计算和内存的使用都是通过优化的参数和激活的负载平衡来实现的。同时,通过优化参数和 activations 的负载平衡,计算和内存的使用都是均匀分布的。 + +```python +# 1D parallel +parallel = dict( + tensor=dict(size=4, mode='1d') +) + +# 2D parallel +parallel = dict( + tensor=dict(size=4, mode='2d') +) + +# 2.5D parallel +parallel = dict( + tensor=dict(size=8, mode='2.5d', depth=2) +) + +# 3D parallel +parallel = dict( + tensor=dict(size=8, mode='3d') +) +``` + +当您在配置中指定了张量并行模式,您就可以使用其相应的分布式算子。例如,若您设置模式为 `2d`,那么在模型构建中就能使用 `colossalai.nn.Linear2D` 了。 + + +## 流水线并行 + +流水线并行是将模型按层分成几个部分。例如,假设我们有一个简单的模型,它由两个线性层组成。我们有两个 GPU,我们可以将第一个线性层分配给第一个 GPU 而第二层则分配给第二个 GPU。 + +您可以在您的配置文件中设置流水线并行度的大小。当流水线并行度大于1,Colossal-AI 将会自动地创建流水线并行的 schedule,这将会为您定义好模型训练的 `forward` 和 `backward`。 + +```python +parallel = dict( + pipeline=dict(size=4), # number of pipeline stages +) +``` + +## 序列并行 + +针对处理大图片、视频、长文本、长时间医疗监控等数据的需要,Colossal-AI 还提供了序列并行的方法。该方法是在论文[Sequence Parallelism: Making 4D Parallelism Possible](https://arxiv.org/abs/2105.13120)中提出的。您可以指定模式为 `sequence` 来初始化进程组。 + + +```python +parallel = dict( + tensor=dict(size=4, mode='sequence') +) +``` diff --git a/docs/source/zh-Hans/basics/define_your_config.md b/docs/source/zh-Hans/basics/define_your_config.md new file mode 100644 index 0000000000000000000000000000000000000000..d7e49cbf23dee33fcf497feb3754e507d05c2442 --- /dev/null +++ b/docs/source/zh-Hans/basics/define_your_config.md @@ -0,0 +1,71 @@ +# 构建配置文件 + +作者: Guangyang Lu, Shenggui Li, Siqi Mai + +**预备知识:** +- [分布式训练](../concepts/distributed_training.md) +- [Colossal-AI 总览](../concepts/colossalai_overview.md) + + +## 简介 + +在 Colossal-AI 中,我们需要一个配置文件来指定系统在训练过程中要注入的特征。在本教程中,我们将向您介绍如何构建您的配置文件以及如何使用这个配置文件。使用配置文件有以下一些好处: + +1. 您可以在不同的配置文件中存储您的特征配置和训练超参数。 +2. 对于我们未来发布的新功能,您亦可以在配置中指定,而无需改变训练脚本的代码。 + +在本教程中,我们将向您介绍如何构建您的配置文件。 + +## 配置定义 + +在一个配置文件中,有两种类型的变量。一种是作为特征说明,另一种是作为超参数。所有与特征相关的变量都是保留关键字。例如,如果您想使用混合精度训练,需要在 config 文件中使用变量名`fp16`,并遵循预先定义的格式。 + +### 功能配置 + +Colossal-AI 提供了一系列的功能来加快训练速度。每个功能都是由配置文件中的相应字段定义的。在本教程中,我们不会给出所有功能的配置细节,而是提供一个如何指定一个功能的说明。**每个功能的细节可以在其各自的教程中找到。** + +为了说明配置文件的使用,我们在这里使用混合精度训练作为例子。您需要遵循以下步骤。 + +1. 创建一个配置文件(例如 `config.py`,您可以指定任意的文件名)。 +2. 在配置文件中定义混合精度的配置。例如,为了使用 PyTorch 提供的原始混合精度训练,您只需将下面这几行代码写入您的配置文件中。 + + ```python + from colossalai.amp import AMP_TYPE + + fp16 = dict( + mode=AMP_TYPE.TORCH + ) + ``` + +3. 当启动分布式环境时,向 Colossal-AI 指定您的配置文件的位置。比如下面的例子是配置文件在当前目录下。 + + ```python + import colossalai + + colossalai.launch(config='./config.py', ...) + ``` + +这样,Colossal-AI 便知道您想使用什么功能,并会在 `colossalai.initialize` 期间注入您所需要的功能。 + +### 全局超参数 + +除了功能的配置,您还可以在配置文件中定义训练的超参数。当您想进行多个实验时,这将会变得非常方便。每个实验的细节都可以放在独立的配置文件中,以避免混乱。这些参数将被存储在全局并行环境中,可以在训练脚本中访问。 + +例如,您可以在配置文件中指定批量大小。 + +```python +BATCH_SIZE = 32 +``` + +启动后,您能够通过全局并行上下文访问您的超参数。 + +```python +import colossalai +from colossalai.core import global_context as gpc + +colossalai.launch(config='./config.py', ...) + +# access your parameter +print(gpc.config.BATCH_SIZE) + +``` diff --git a/docs/source/zh-Hans/basics/engine_trainer.md b/docs/source/zh-Hans/basics/engine_trainer.md new file mode 100644 index 0000000000000000000000000000000000000000..a7519bfca14f7dbddff22ef9dec20ae3cdc17e93 --- /dev/null +++ b/docs/source/zh-Hans/basics/engine_trainer.md @@ -0,0 +1,384 @@ +# 如何在训练中使用 Engine 和 Trainer + +作者: Shenggui Li, Siqi Mai + +**预备知识:** +- [初始化功能](./initialize_features.md) + +## 简介 + +在本教程中,您将学习如何使用 Colossal-AI 中提供的 Engine 和 Trainer 来训练您的模型。在深入研究细节之前,我们想先解释一下 Engine 和 Trainer 的概念。 + +### Engine + +Engine 本质上是一个模型、优化器和损失函数的封装类。当我们调用 `colossalai.initialize` 时,一个 Engine 对象将被返回,并且配备了在您的配置文件中指定的梯度剪裁、梯度累计和 ZeRO 优化器等功能。 + +Engine 将使用与 PyTorch 训练组件类似的 API,因此您只需对代码进行微小的修改即可。 + +下表展示了Engine的常用API。 + +| 组件 | 功能 | PyTorch | Colossal-AI | +| ------------------------------------- | --------------------------------------------- | ------------------------------- | -------------------------------------- | +| optimizer | 迭代前将所有梯度设置为零 | optimizer.zero_grad() | engine.zero_grad() | +| optimizer | 更新参数 | optimizer.step() | engine.step() | +| model | 进行一次前向计算 | outputs = model(inputs) | outputs = engine(inputs) | +| criterion | 计算loss值 | loss = criterion(output, label) | loss = engine.criterion(output, label) | +| criterion | 反向计算 | loss.backward() | engine.backward(loss) | + +我们需要这样一个 Engine 类的原因是,我们可以添加更多的功能,同时将实现隐藏在 +`colossalai.initialize` 函数中实现。 +假如我们要添加一个新的功能,我们可以在 `colossalai.initialize` 函数中完成对于模型、优化器、数据加载器和损失函数的功能诠释。不管中间的过程有多复杂,最终我们呈现的以及用户需要使用的只有一个 Engine 类,这将十分便捷。 +用户只需要在最小范围内修改他们的代码,将普通的 PyTorch APIs 调整为 Colossal-AI +Engine 的 API。通过这种方式,他们可以享受更多的功能来进行有效的训练。 + +以下是一个简单的例子: + +```python +import colossalai + +# build your model, optimizer, criterion, dataloaders +... + +engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model, + optimizer, + criterion, + train_dataloader, + test_dataloader) +for img, label in train_dataloader: + engine.zero_grad() + output = engine(img) + loss = engine.criterion(output, label) + engine.backward(loss) + engine.step() +``` + +### Trainer + +Trainer 是一个更高级的封装器,用户可以用更少的代码行来执行训练。 由于 Trainer 的使用会更加简单,相较于 Engine,它会缺少一点灵活性。 Trainer 被设计为进行前向和反向计算来进行模型权重的更新。通过传递 Engine 对象,我们可以很容易地创建一个 Trainer。 +Trainer 的参数 `schedule` 默认值是 `None` 。在大多数情况下,除非我们想使用流水线并行,否则我们把这个值设为 `None`。如果您想探索更多关于这个参数的内容,您可以前往流水线并行的相关教程。 + +```python +from colossalai.logging import get_dist_logger +from colossalai.trainer import Trainer, hooks + +# build components and initialize with colossalai.initialize +... + +# create a logger so that trainer can log on the console +logger = get_dist_logger() + +# create a trainer object +trainer = Trainer( + engine=engine, + logger=logger +) +``` + +在 Trainer 中,用户可以定制一些 hooks,并将这些 hooks 附加到 Trainer 上。hook 将根据训练方案定期地执行生命周期函数。例如,基于用户是想在每次训练迭代后还是只在整个训练周期后更新学习率, +`LRSchedulerHook` 将会在 `after_train_iter` 或 `after_train_epoch` 阶段执行 `lr_scheduler.step()` 去为用户更新学习率。您可以将 hook 存储在一个列表中并将其传递给 `trainer.fit` 方法。`trainer.fit` 方法将根据您的参数执行训练和测试。如果 `display_process` 为 True,将在您的控制台显示一个进度条,以显示训练的过程。 + + +```python +# define the hooks to attach to the trainer +hook_list = [ + hooks.LossHook(), + hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True), + hooks.AccuracyHook(accuracy_func=Accuracy()), + hooks.LogMetricByEpochHook(logger), +] + +# start training +trainer.fit( + train_dataloader=train_dataloader, + epochs=NUM_EPOCHS, + test_dataloader=test_dataloader, + test_interval=1, + hooks=hook_list, + display_progress=True +) +``` + +如果您想定制您的 hook 类,您可以继承 `hooks.BaseHook` 并重写您想要的生命周期方法。下面提供了一个例子来演示如何创建一个简单的关于日志信息的 hook,以供您参考。 + +```python +from colossalai.logging import get_dist_logger +from colossalai.trainer import hooks + +class LogMessageHook(hooks.BaseHook): + + def __init__(self, priority=10): + self._logger = get_dist_logger() + + def before_train(self, trainer): + self._logger.info('training starts') + + def after_train(self, trainer): + self._logger.info('training finished') + + +... + +# then in your training script +hook_list.append(LogMessageHook()) +``` + + + +在下面的章节中,您将会详细地了解到如何用 Engine 和 Trainer 来训练 ResNet 模型。 + + +## ResNet + +### 总览 + +在本节中,我们将介绍: + +1. 使用一个 Engine 在 CIFAR10 数据集上训练 ResNet34 模型 +2. 使用一个 Trainer 在 CIFAR10 数据集上训练 ResNet34 模型 + +项目结构如下: + +```bash +-- config.py +-- run_resnet_cifar10_with_engine.py +-- run_resnet_cifar10_with_trainer.py +``` + +对于使用 Engine 或 Trainer,步骤 1-4 是通用的。 因此,步骤 1-4 + 步骤 5 将会是对应 `run_resnet_cifar10_with_engine.py` 而 步骤 1-4 + 步骤6 则对应 `run_resnet_cifar10_with_trainer.py`。 + +### 牛刀小试 + +#### 步骤 1. 创建配置文件 + +在你的项目文件夹中,创建一个 `config.py`。这个文件是用来指定一些您可能想用来训练您的模型的特征。下面是一个配置文件的例子。 + +```python +from colossalai.amp import AMP_TYPE + +BATCH_SIZE = 128 +NUM_EPOCHS = 200 + +fp16=dict( + mode=AMP_TYPE.TORCH +) +``` + +在这个配置文件中,我们指定要在每个 GPU 上使用批大小为128,并运行200个 epoch。这两个参数是在 `gpc.config` 中体现的。例如,您可以使用 `gpc.config.BATCH_SIZE` 来访问您存储在配置文件中的批大小值。而 `fp16` 配置则会告诉 `colossalai.initialize` 使用 PyTorch 提供的混合精度训练,以更好的速度和更低的内存消耗来训练模型。 + +#### 步骤 2. 初始化分布式环境 + +我们需要初始化分布式训练环境。这在 [启动 Colossal-AI](./launch_colossalai.md) 中有相应的教程。在当前的演示中,我们使用 `launch_from_torch` 和 PyTorch 启用工具。 + +```python +import colossalai + +# ./config.py refers to the config file we just created in step 1 +colossalai.launch_from_torch(config='./config.py') +``` + +#### 步骤 3. 创建所有的训练组件 + +这时,我们可以创建用于训练的所有组件,包括: + +1. 模型 +2. 优化器 +3. 损失函数 +4. 训练/测试数据加载器 +5. 学习率调度器 +6. 日志记录器 + + + +为了构建这些组件,您需要导入以下模块。 + +```python +from pathlib import Path +from colossalai.logging import get_dist_logger +import torch +import os +from colossalai.core import global_context as gpc +from colossalai.utils import get_dataloader +from torchvision import transforms +from colossalai.nn.lr_scheduler import CosineAnnealingLR +from torchvision.datasets import CIFAR10 +from torchvision.models import resnet34 +``` + + + +然后按照通常在PyTorch脚本中构建组件的方式来构建组件。在下面的脚本中,我们将CIFAR10数据集的根路径设置为环境变量 `DATA`。您可以把它改为您想要的任何路径,例如,您可以把 `root=Path(os.environ['DATA'])` 改为 `root='./data'` ,这样就不需要设置环境变量。 + +```python +# build logger +logger = get_dist_logger() + +# build resnet +model = resnet34(num_classes=10) + +# build datasets +train_dataset = CIFAR10( + root='./data', + download=True, + transform=transforms.Compose( + [ + transforms.RandomCrop(size=32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[ + 0.2023, 0.1994, 0.2010]), + ] + ) +) + +test_dataset = CIFAR10( + root='./data', + train=False, + transform=transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[ + 0.2023, 0.1994, 0.2010]), + ] + ) +) + +# build dataloaders +train_dataloader = get_dataloader(dataset=train_dataset, + shuffle=True, + batch_size=gpc.config.BATCH_SIZE, + num_workers=1, + pin_memory=True, + ) + +test_dataloader = get_dataloader(dataset=test_dataset, + add_sampler=False, + batch_size=gpc.config.BATCH_SIZE, + num_workers=1, + pin_memory=True, + ) + +# build criterion +criterion = torch.nn.CrossEntropyLoss() + +# optimizer +optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) + +# lr_scheduler +lr_scheduler = CosineAnnealingLR(optimizer, total_steps=gpc.config.NUM_EPOCHS) +``` + +#### 步骤 4. 用 Colossal-AI 进行初始化 + +接下来,重要的一步是通过调用 `colossalai.initialize` 获得 Engine。正如 `config.py` 中所述,我们将使用混合精度训练来训练 ResNet34 模型。`colossalai.initialize` 将自动检查您的配置文件,并将相关特征分配给您的训练组件。这样一来,我们的 Engine 已经能够进行混合精度训练,而您不需要进行额外的处理。 + +```python +engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model, + optimizer, + criterion, + train_dataloader, + test_dataloader, + ) +``` + + + +#### 步骤 5. 用 Engine 进行训练 + +当所有的训练组件都准备好后,我们就可以像使用 PyTorch 一样训练 ResNet34 了。 + +```python +for epoch in range(gpc.config.NUM_EPOCHS): + # execute a training iteration + engine.train() + for img, label in train_dataloader: + img = img.cuda() + label = label.cuda() + + # set gradients to zero + engine.zero_grad() + + # run forward pass + output = engine(img) + + # compute loss value and run backward pass + train_loss = engine.criterion(output, label) + engine.backward(train_loss) + + # update parameters + engine.step() + + # update learning rate + lr_scheduler.step() + + # execute a testing iteration + engine.eval() + correct = 0 + total = 0 + for img, label in test_dataloader: + img = img.cuda() + label = label.cuda() + + # run prediction without back-propagation + with torch.no_grad(): + output = engine(img) + test_loss = engine.criterion(output, label) + + # compute the number of correct prediction + pred = torch.argmax(output, dim=-1) + correct += torch.sum(pred == label) + total += img.size(0) + + logger.info( + f"Epoch {epoch} - train loss: {train_loss:.5}, test loss: {test_loss:.5}, acc: {correct / total:.5}, lr: {lr_scheduler.get_last_lr()[0]:.5g}", ranks=[0]) +``` + +#### 步骤 6. 用 Trainer 进行训练 + +如果您想用 Trainer 进行训练,您可以参考下面的代码进行您的实验。 + + +```python +from colossalai.nn.metric import Accuracy +from colossalai.trainer import Trainer, hooks + + +# create a trainer object +trainer = Trainer( + engine=engine, + logger=logger +) + +# define the hooks to attach to the trainer +hook_list = [ + hooks.LossHook(), + hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True), + hooks.AccuracyHook(accuracy_func=Accuracy()), + hooks.LogMetricByEpochHook(logger), + hooks.LogMemoryByEpochHook(logger) +] + +# start training +# run testing every 1 epoch +trainer.fit( + train_dataloader=train_dataloader, + epochs=gpc.config.NUM_EPOCHS, + test_dataloader=test_dataloader, + test_interval=1, + hooks=hook_list, + display_progress=True +) +``` + + + +#### 步骤 7. 开始分布式训练 + +最后,我们可以使用 PyTorch 提供的分布式启动器来调用脚本,因为我们在步骤2中使用了 `launch_from_torch`。您需要把`` 替换成您机器上可用的GPU数量。如果您只想使用一个 GPU,您可以把这个数字设为1。如果您想使用其他的启动器,请您参考如何启动 Colossal-AI 的教程。 + + +```bash +# with engine +python -m torch.distributed.launch --nproc_per_node --master_addr localhost --master_port 29500 run_resnet_cifar10_with_engine.py +# with trainer +python -m torch.distributed.launch --nproc_per_node --master_addr localhost --master_port 29500 run_resnet_cifar10_with_trainer.py +``` diff --git a/docs/source/zh-Hans/basics/initialize_features.md b/docs/source/zh-Hans/basics/initialize_features.md new file mode 100644 index 0000000000000000000000000000000000000000..67ea114b42b29e545c628d4cc9ae0fd77c1da2f8 --- /dev/null +++ b/docs/source/zh-Hans/basics/initialize_features.md @@ -0,0 +1,46 @@ +# 初始化功能 + +作者: Shenggui Li, Siqi Mai + +**预备知识:** +- [分布式训练](../concepts/distributed_training.md) +- [Colossal-AI 总览](../concepts/colossalai_overview.md) + +## 简介 + +在本教程中,我们将介绍 `colossalai.initialize` 的使用。 它包含了如何将特征(例如,模型、优化器、数据加载器)无缝注入您的训练组件中。 调用 `colossalai.initialize` 是您进入训练循环前的基本操作。 + +在下面一节中,我们将介绍 `colossalai.initialize` 是如何工作的以及使用中我们要注意的细节。 + +## 使用 + +在一个典型的工作流程中,我们将在训练脚本的开始启动分布式环境。 +之后,我们将实例化我们的对象,如模型、优化器、损失函数、数据加载器等。此时,我们可以使用 `colossalai.initialize` 便捷地为这些对象注入特征。 +具体细节请看以下的伪代码例子。 + +```python +import colossalai +import torch +... + + +# launch distributed environment +colossalai.launch(config='./config.py', ...) + +# create your objects +model = MyModel() +optimizer = torch.optim.Adam(model.parameters(), lr=0.001) +criterion = torch.nn.CrossEntropyLoss() +train_dataloader = MyTrainDataloader() +test_dataloader = MyTrainDataloader() + +# initialize features +engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model, + optimizer, + criterion, + train_dataloader, + test_dataloader) +``` + + `colossalai.initialize` 将返回一个 `Engine` 对象。 该对象把模型、优化器和损失函数封装起来。 **`Engine` 对象会以配置文件中指定的特征运行。** +关于 `Engine` 的更多使用细节可以在 [在训练中使用Engine和Trainer](./engine_trainer.md) 中获取。 diff --git a/docs/source/zh-Hans/basics/launch_colossalai.md b/docs/source/zh-Hans/basics/launch_colossalai.md new file mode 100644 index 0000000000000000000000000000000000000000..ca927de578d5d7db96f6d1717f8aed892da4504a --- /dev/null +++ b/docs/source/zh-Hans/basics/launch_colossalai.md @@ -0,0 +1,212 @@ +# 启动 Colossal-AI + +作者: Chuanrui Wang, Shenggui Li, Siqi Mai + +**预备知识:** +- [分布式训练](../concepts/distributed_training.md) +- [Colossal-AI 总览](../concepts/colossalai_overview.md) + + +## 简介 + +正如我们在前面的教程中所提到的,在您的配置文件准备好后,您需要为 Colossal-AI 初始化分布式环境。我们把这个过程称为 `launch`。在本教程中,您将学习如何在您的服务器上启动 Colossal-AI,不管是小型的还是大型的。 + +在 Colossal-AI 中,我们提供了几种启动方法来初始化分布式后端。 +在大多数情况下,您可以使用 `colossalai.launch` 和 `colossalai.get_default_parser` 来通过命令行传递参数。如果您想使用 SLURM、OpenMPI 和 PyTorch 等启动工具,我们也提供了几个启动的辅助方法以便您的使用。您可以直接从这些启动工具设置的环境变量中访问 rank 和 world size 大小。 + +在本教程中,我们将介绍如何启动 Colossal-AI 来初始化分布式后端: +- 用 colossalai.launch 启动 +- 用 Colossal-AI命令行 启动 +- 用 SLURM 启动 +- 用 OpenMPI 启动 + +## 启动分布式环境 + +为了启动 Colossal-AI,我们需要两类参数: +1. 配置文件 +2. 分布式设置 + +无论我们使用何种启动方式,配置文件是必须要求的,而分布式设置有可能依情况而定。配置文件可以是配置文件的路径或 Python dictionary 的形式。分布式设置可以通过命令行或多进程启动器传递。 + +### 命令行解析器 + +在使用 `launch` 之前, 我们首先需要了解我们需要哪些参数来进行初始化。 +如[分布式训练](../concepts/distributed_training.md) 中 `基本概念` 一节所述 ,涉及的重要参数是: + +1. host +2. port +3. rank +4. world_size +5. backend + +在 Colossal-AI 中,我们提供了一个命令行解析器,它已经提前添加了这些参数。您可以通过调用 `colossalai.get_default_parser()` 来获得这个解析器。这个解析器通常与 `colossalai.launch` 一起使用。 + +```python +# add these lines in your train.py +import colossalai + +# get default parser +parser = colossalai.get_default_parser() + +# if you want to add your own arguments +parser.add_argument(...) + +# parse arguments +args = parser.parse_args() +``` + +您可以在您的终端传入以下这些参数。 +```shell + +python train.py --host --rank --world_size --port --backend +``` + +`backend` 是用户可选的,默认值是 nccl。 + +### 本地启动 + +为了初始化分布式环境,我们提供了一个通用的 `colossalai.launch` API。`colossalai.launch` 函数接收上面列出的参数,并在通信网络中创建一个默认的进程组。方便起见,这个函数通常与默认解析器一起使用。 + +```python +import colossalai + +# parse arguments +args = colossalai.get_default_parser().parse_args() + +# launch distributed environment +colossalai.launch(config=, + rank=args.rank, + world_size=args.world_size, + host=args.host, + port=args.port, + backend=args.backend +) + +``` + + +### 用 Colossal-AI命令行工具 启动 + +为了更好地支持单节点以及多节点的训练,我们通过封装PyTorch的启动器实现了一个更加方便的启动器。 +PyTorch自带的启动器需要在每个节点上都启动命令才能启动多节点训练,而我们的启动器只需要一次调用即可启动训练。 + +首先,我们需要在代码里指定我们的启动方式。由于这个启动器是PyTorch启动器的封装,那么我们自然而然应该使用`colossalai.launch_from_torch`。 +分布式环境所需的参数,如 rank, world size, host 和 port 都是由 PyTorch 启动器设置的,可以直接从环境变量中读取。 + +```python +import colossalai + +colossalai.launch_from_torch( + config=, +) +``` + +接下来,我们可以轻松地在终端使用`colossalai run`来启动训练。下面的命令可以在当前机器上启动一个4卡的训练任务。 +你可以通过设置`nproc_per_node`来调整使用的GPU的数量,也可以改变`master_port`的参数来选择通信的端口。 + +```shell +# 在当前节点上启动4卡训练 (默认使用29500端口) +colossalai run --nproc_per_node 4 train.py + +# 在当前节点上启动4卡训练,并使用一个不同的端口 +colossalai run --nproc_per_node 4 --master_port 29505 test.py +``` + +如果你在使用一个集群,并且想进行多节点的训练,你需要使用Colossal-AI的命令行工具进行一键启动。我们提供了两种方式来启动多节点任务 + +- 通过`--hosts`来启动 + +这个方式适合节点数不多的情况。假设我们有两个节点,分别为`host`和`host2`。我们可以用以下命令进行多节点训练。 +比起单节点训练,多节点训练需要手动设置`--master_addr` (在单节点训练中`master_addr`默认为`127.0.0.1`)。 + +:::caution + +多节点训练时,`master_addr`不能为`localhost`或者`127.0.0.1`,它应该是一个节点的名字或者IP地址。 + +::: + +```shell +# 在两个节点上训练 +colossalai run --nproc_per_node 4 --host host1,host2 --master_addr host1 test.py +``` + + +- 通过`--hostfile`来启动 + +这个方式适用于节点数很大的情况。host file是一个简单的文本文件,这个文件里列出了可以使用的节点的名字。 +在一个集群中,可用节点的列表一般由SLURM或者PBS Pro这样的集群资源管理器来提供。比如,在SLURM中, +你可以从`SLURM_NODELIST`这个环境变量中获取到当前分配列表。在PBS Pro中,这个环境变量为`PBS_NODEFILE`。 +可以通过`echo $SLURM_NODELIST` 或者 `cat $PBS_NODEFILE` 来尝试一下。如果你没有这样的集群管理器, +那么你可以自己手动写一个这样的文本文件即可。 + +提供给Colossal-AI的host file需要遵循以下格式,每一行都是一个节点的名字。 + +```text +host1 +host2 +``` + +如果host file准备好了,那么我们就可以用以下命令开始多节点训练了。和使用`--host`一样,你也需要指定一个`master_addr`。 +当使用host file时,我们可以使用一些额外的参数: +- `--include`: 设置你想要启动训练的节点。比如,你的host file里有8个节点,但是你只想用其中的6个节点进行训练, + 你可以添加`--include host1,host2,host3,...,host6`,这样训练任务只会在这6个节点上启动。 + +- `--exclude`: 设置你想排除在训练之外的节点。当你的某一些节点坏掉时,这个参数会比较有用。比如假如host1的GPU有一些问题,无法正常使用, + 那么你就可以使用`--exclude host1`来将其排除在外,这样你就可以训练任务就只会在剩余的节点上启动。 + +```shell +# 使用hostfile启动 +colossalai run --nproc_per_node 4 --hostfile ./hostfile --master_addr host1 test.py + +# 只使用部分节点进行训练 +colossalai run --nproc_per_node 4 --hostfile ./hostfile --master_addr host1 --include host1 test.py + +# 不使用某些节点进行训练 +colossalai run --nproc_per_node 4 --hostfile ./hostfile --master_addr host1 --exclude host2 test.py +``` + + +### 用 SLURM 启动 + +如果您是在一个由 SLURM 调度器管理的系统上, 您也可以使用 `srun` 启动器来启动您的 Colossal-AI 脚本。我们提供了辅助函数 `launch_from_slurm` 来与 SLURM 调度器兼容。 +`launch_from_slurm` 会自动从环境变量 `SLURM_PROCID` 和 `SLURM_NPROCS` 中分别读取 rank 和 world size ,并使用它们来启动分布式后端。 + +您可以在您的训练脚本中尝试以下操作。 + +```python +import colossalai + +colossalai.launch_from_slurm( + config=, + host=args.host, + port=args.port +) +``` + +您可以通过在终端使用这个命令来初始化分布式环境。 + +```bash +srun python train.py --host --port 29500 +``` + +### 用 OpenMPI 启动 +如果您对OpenMPI比较熟悉,您也可以使用 `launch_from_openmpi` 。 +`launch_from_openmpi` 会自动从环境变量 +`OMPI_COMM_WORLD_LOCAL_RANK`, `MPI_COMM_WORLD_RANK` 和 `OMPI_COMM_WORLD_SIZE` 中分别读取local rank、global rank 和 world size,并利用它们来启动分布式后端。 + +您可以在您的训练脚本中尝试以下操作。 +```python +colossalai.launch_from_openmpi( + config=, + host=args.host, + port=args.port +) +``` + +以下是用 OpenMPI 启动多个进程的示例命令。 +```bash +mpirun --hostfile -np python train.py --host --port 29500 +``` + +- --hostfile: 指定一个要运行的主机列表。 +- --np: 设置总共要启动的进程(GPU)的数量。例如,如果 --np 4,4个 python 进程将被初始化以运行 train.py。 diff --git a/docs/source/zh-Hans/basics/model_checkpoint.md b/docs/source/zh-Hans/basics/model_checkpoint.md new file mode 100644 index 0000000000000000000000000000000000000000..cec12d45198911c526326590bd15d2617a6b4cf4 --- /dev/null +++ b/docs/source/zh-Hans/basics/model_checkpoint.md @@ -0,0 +1,61 @@ +# 模型检查点 + +作者 : Guangyang Lu + +**预备知识:** +- [Launch Colossal-AI](./launch_colossalai.md) +- [Initialize Colossal-AI](./initialize_features.md) + +**示例代码:** +- [ColossalAI-Examples Model Checkpoint](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/utils/checkpoint) + +**函数是经验函数.** + +## 简介 + +本教程将介绍如何保存和加载模型检查点。 + +为了充分利用Colossal-AI的强大并行策略,我们需要修改模型和张量,可以直接使用 `torch.save` 或者 `torch.load` 保存或加载模型检查点。在Colossal-AI中,我们提供了应用程序接口实现上述同样的效果。 + +但是,在加载时,你不需要使用与存储相同的保存策略。 + +## 使用方法 + +### 保存 + +有两种方法可以使用Colossal-AI训练模型,即使用engine或使用trainer。 +**注意我们只保存 `state_dict`.** 因此,在加载检查点时,需要首先定义模型。 + +#### 同 engine 保存 + +```python +from colossalai.utils import save_checkpoint +model = ... +engine, _, _, _ = colossalai.initialize(model=model, ...) +for epoch in range(num_epochs): + ... # do some training + save_checkpoint('xxx.pt', epoch, model) +``` + +#### 用 trainer 保存 +```python +from colossalai.trainer import Trainer, hooks +model = ... +engine, _, _, _ = colossalai.initialize(model=model, ...) +trainer = Trainer(engine, ...) +hook_list = [ + hooks.SaveCheckpointHook(1, 'xxx.pt', model) + ...] + +trainer.fit(... + hook=hook_list) +``` + +### 加载 + +```python +from colossalai.utils import load_checkpoint +model = ... +load_checkpoint('xxx.pt', model) +... # train or test +``` diff --git a/docs/source/zh-Hans/concepts/colossalai_overview.md b/docs/source/zh-Hans/concepts/colossalai_overview.md new file mode 100755 index 0000000000000000000000000000000000000000..cfb35e59e64a99924297e5bbdd83c96930a2f93c --- /dev/null +++ b/docs/source/zh-Hans/concepts/colossalai_overview.md @@ -0,0 +1,36 @@ +# Colossal-AI 总览 + +作者: Shenggui Li, Siqi Mai + +## 关于 Colossal-AI + +随着深度学习模型规模的发展,向新的训练模式转变是非常重要的。没有并行和优化的传统训练方法将成为过去,新的训练方法是使训练大规模模型高效和节省成本的关键。 + +Colossal-AI 是一个集成的系统,为用户提供一套综合的训练方法。您可以找到常见的训练方法,如混合精度训练和梯度累积。此外,我们提供了一系列的并行技术,包括数据并行、张量并行和流水线并行。我们通过不同的多维分布式矩阵乘法算法来优化张量并行。我们还提供了不同的流水线并行方法,使用户能够有效地跨节点扩展他们的模型。更多的高级功能,如卸载,也可以在这个教程文档中找到详细的内容。 + +## Colossal-AI 的使用 + +我们的目标是使 Colossal-AI 易于使用,并且对用户的代码不产生干扰。如果您想使用Colossal-AI,这里有一个简单的一般工作流程。 + +
+ +
Workflow
+
+ +1. 准备一个配置文件,指定您要使用的功能和参数。 +2. 用 `colossalai.launch` 初始化分布式后端。 +3. 用 `colossalai.initialize` 将训练特征注入您的训练组件(如模型、优化器)中。 +4. 进行训练和测试. + +我们将在`基本教程`部分介绍整个工作流程。 + +## 未来计划 + +Colossal-AI 系统将会进一步拓展和优化,包括但不限于: + +1. 分布式操作的优化 +2. 异构系统训练的优化 +3. 从模型大小的维度切入,提升训练速度并维持精度 +4. 拓展现有的并行方法 + +**我们始终欢迎社区的建议和讨论,如果您遇到任何问题,我们将非常愿意帮助您。您可以在GitHub 提 [issue](https://github.com/hpcaitech/ColossalAI/issues) ,或在[论坛](https://github.com/hpcaitech/ColossalAI/discussions)上创建一个讨论主题。** diff --git a/docs/source/zh-Hans/concepts/distributed_training.md b/docs/source/zh-Hans/concepts/distributed_training.md new file mode 100755 index 0000000000000000000000000000000000000000..97b3844daa163d92c480e37fd7b80140854aa3a1 --- /dev/null +++ b/docs/source/zh-Hans/concepts/distributed_training.md @@ -0,0 +1,88 @@ +# 分布式训练 + +作者: Shenggui Li, Siqi Mai + +## 什么是分布式系统? + +
+ +
图片来源: Towards Data Science
+
+ +分布式系统由多个软件组件组成,在多台机器上运行。例如,传统的数据库运行在一台机器上。随着数据量的爆发式增长,单台机器已经不能为企业提供理想的性能。特别是在双十一这样的网络狂欢节,网络流量会出乎意料的大。为了应对这种压力,现代高性能数据库被设计成在多台机器上运行,它们共同为用户提供高吞吐量和低延迟。 + +分布式系统的一个重要评价指标是可扩展性。例如,当我们在4台机器上运行一个应用程序时,我们自然希望该应用程序的运行速度能提高4倍。然而,由于通信开销和硬件性能的差异,很难实现线性提速。因此,当我们实现应用程序时,必须考虑如何使其更快。良好的设计和系统优化的算法可以帮助我们提供良好的性能。有时,甚至有可能实现线性和超线性提速。 + + +## 为什么我们需要机器学习的分布式训练? + +早在2012年,[AlexNet](https://arxiv.org/abs/1404.5997) 就赢得了ImageNet比赛的冠军,而它是在两张 GTX 580 3GB GPU 上训练的。今天,大多数出现在顶级人工智能会议上的模型都是在多个GPU上训练的。当研究人员和工程师开发人工智能模型时,分布式训练无疑是一种常见的做法。这一趋势背后有几个原因。 + +1. 模型规模迅速增加。2015年的 [ResNet50](https://arxiv.org/abs/1512.03385) 有2000万的参数, +2018年的 [BERT-Large](https://arxiv.org/abs/1810.04805)有3.45亿的参数,2018年的 +[GPT-2](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) +有15亿的参数,而2020年的 [GPT-3](https://arxiv.org/abs/2005.14165) 有1750亿个参数。很明显,模型规模随着时间的推移呈指数级增长。目前最大的模型已经超过了1000多亿个参数。而与较小的模型相比,超大型模型通常能提供更优越的性能。 +
+ +
图片来源: HuggingFace
+
+ + +2. 数据集规模迅速增加。对于大多数机器学习开发者来说,MNIST 和 CIFAR10 数据集往往是他们训练模型的前几个数据集。然而,与著名的 ImageNet 数据集相比,这些数据集非常小。谷歌甚至有自己的(未公布的)JFT-300M 数据集,它有大约3亿张图片,这比 ImageNet-1k 数据集大了近300倍。 + + +3. 计算能力越来越强。随着半导体行业的进步,显卡变得越来越强大。由于核的数量增多,GPU是深度学习最常见的算力资源。从2012年的 K10 GPU 到2020年的 A100 GPU,计算能力已经增加了几百倍。这使我们能够更快地执行计算密集型任务,而深度学习正是这样一项任务。 + +如今,我们接触到的模型可能太大,以致于无法装入一个GPU,而数据集也可能大到足以在一个GPU上训练一百天。这时,只有用不同的并行化技术在多个GPU上训练我们的模型,我们才能完成并加快模型训练,以追求在合理的时间内获得想要的结果。 + + +## 分布式训练的基本概念 + +分布式训练需要多台机器/GPU。在训练期间,这些设备之间会有通信。为了更好地理解分布式训练,有几个重要的术语需要我们了解清楚。 + +- host: 主机(host)是通信网络中的主要设备。在初始化分布式环境时,经常需要它作为一个参数。 +- port: 这里的端口(port)主要是指主机上用于通信的主端口。 +- rank: 在网络中赋予设备的唯一ID。 +- world size: 网络中设备的数量。 +- process group: 进程组(process group)是一个通信网络,包括设备的一个子集。总是有一个默认的进程组,它包含所有的设备。一个子集的设备可以形成一个进程组,以便它们只在组内的设备之间进行通信。 + +
+ +
一个分布式系统的例子
+
+ +为了说明这些概念,让我们假设我们有2台机器(也称为节点),每台机器有4个 GPU。当我们在这两台机器上初始化分布式环境时,我们基本上启动了8个进程(每台机器上有4个进程),每个进程被绑定到一个 GPU 上。 + +在初始化分布式环境之前,我们需要指定主机(主地址)和端口(主端口)。在这个例子中,我们可以让主机为节点0,端口为一个数字,如29500。所有的8个进程将寻找地址和端口并相互连接,默认的进程组将被创建。默认进程组的 world size 为8,细节如下。 + +| process ID | rank | Node index | GPU index | +| ---------- | ---- | ---------- | --------- | +| 0 | 0 | 0 | 0 | +| 1 | 1 | 0 | 1 | +| 2 | 2 | 0 | 2 | +| 3 | 3 | 0 | 3 | +| 4 | 4 | 1 | 0 | +| 5 | 5 | 1 | 1 | +| 6 | 6 | 1 | 2 | +| 7 | 7 | 1 | 3 | + + +我们还可以创建一个新的进程组。这个新的进程组可以包含任何进程的子集。例如,我们可以创建一个只包含偶数进程的组: + +| process ID | rank | Node index | GPU index | +| ---------- | ---- | ---------- | --------- | +| 0 | 0 | 0 | 0 | +| 2 | 1 | 0 | 2 | +| 4 | 2 | 1 | 0 | +| 6 | 3 | 1 | 2 | + +**请注意,rank 是相对于进程组而言的,一个进程在不同的进程组中可以有不同的 rank。最大的 rank 始终是 `world size of the process group - 1`。** + +在进程组中,各进程可以通过两种方式进行通信。 +1. peer-to-peer: 一个进程向另一个进程发送数据。 +2. collective: 一组进程一起执行分散、聚集、all-reduce、广播等操作。 + +
+ +
Collective communication, 来源: PyTorch distributed tutorial
+
diff --git a/docs/source/zh-Hans/concepts/paradigms_of_parallelism.md b/docs/source/zh-Hans/concepts/paradigms_of_parallelism.md new file mode 100755 index 0000000000000000000000000000000000000000..8f52d28ecdf47da0b088bc43fc1283d6112720c2 --- /dev/null +++ b/docs/source/zh-Hans/concepts/paradigms_of_parallelism.md @@ -0,0 +1,92 @@ +# 并行技术 + +作者: Shenggui Li, Siqi Mai + +## 简介 + +随着深度学习的发展,对并行训练的需求越来越大。这是因为模型和数据集越来越大,如果我们坚持使用单 GPU 训练,训练过程的等待将会成为一场噩梦。在本节中,我们将对现有的并行训练方法进行简要介绍。如果您想对这篇文章进行补充,欢迎在[GitHub论坛](https://github.com/hpcaitech/ColossalAI/discussions)上进行讨论。 + +## 数据并行 + +数据并行是最常见的并行形式,因为它很简单。在数据并行训练中,数据集被分割成几个碎片,每个碎片被分配到一个设备上。这相当于沿批次维度对训练过程进行并行化。每个设备将持有一个完整的模型副本,并在分配的数据集碎片上进行训练。在反向传播之后,模型的梯度将被全部减少,以便在不同设备上的模型参数能够保持同步。 + +
+ +
数据并行
+
+ +## 模型并行 + +在数据并行训练中,一个明显的特点是每个 GPU 持有整个模型权重的副本。这就带来了冗余问题。另一种并行模式是模型并行,即模型被分割并分布在一个设备阵列上。通常有两种类型的并行:张量并行和流水线并行。张量并行是在一个操作中进行并行计算,如矩阵-矩阵乘法。流水线并行是在各层之间进行并行计算。因此,从另一个角度来看,张量并行可以被看作是层内并行,流水线并行可以被看作是层间并行。 + +### 张量并行 + +张量并行训练是将一个张量沿特定维度分成 `N` 块,每个设备只持有整个张量的 `1/N`,同时不影响计算图的正确性。这需要额外的通信来确保结果的正确性。 + +以一般的矩阵乘法为例,假设我们有 `C = AB`。我们可以将B沿着列分割成 `[B0 B1 B2 ... Bn]`,每个设备持有一列。然后我们将 `A` 与每个设备上 `B` 中的每一列相乘,我们将得到 `[AB0 AB1 AB2 ... ABn]` 。此刻,每个设备仍然持有一部分的结果,例如,设备(rank=0)持有 `AB0`。为了确保结果的正确性,我们需要收集全部的结果,并沿列维串联张量。通过这种方式,我们能够将张量分布在设备上,同时确保计算流程保持正确。 + +
+ +
张量并行
+
+ +在 Colossal-AI 中,我们提供了一系列的张量并行方法,即 1D、2D、2.5D 和 3D 张量并行。我们将在`高级教程`中详细讨论它们。 + + +相关文章: +- [GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding](https://arxiv.org/abs/2006.16668) +- [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) +- [An Efficient 2D Method for Training Super-Large Deep Learning Models](https://arxiv.org/abs/2104.05343) +- [2.5-dimensional distributed model training](https://arxiv.org/abs/2105.14500) +- [Maximizing Parallelism in Distributed Training for Huge Neural Networks](https://arxiv.org/abs/2105.14450) + +### 流水线并行 + +流水线并行一般来说很容易理解。请您回忆一下您的计算机结构课程,这确实存在于 CPU 设计中。 + +
+ +
流水线并行
+
+ +流水线并行的核心思想是,模型按层分割成若干块,每块都交给一个设备。在前向传递过程中,每个设备将中间的激活传递给下一个阶段。在后向传递过程中,每个设备将输入张量的梯度传回给前一个流水线阶段。这允许设备同时进行计算,并增加了训练的吞吐量。流水线并行训练的一个缺点是,会有一些设备参与计算的冒泡时间,导致计算资源的浪费。 + +
+ +
Source: GPipe
+
+ +相关文章: +- [PipeDream: Fast and Efficient Pipeline Parallel DNN Training](https://arxiv.org/abs/1806.03377) +- [GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism](https://arxiv.org/abs/1811.06965) +- [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) +- [Chimera: Efficiently Training Large-Scale Neural Networks with Bidirectional Pipelines](https://arxiv.org/abs/2107.06925) + + +## 优化器相关的并行 + +另一种并行方法和优化器相关,目前这种并行最流行的方法是 `ZeRO`,即[零冗余优化器](https://arxiv.org/abs/1910.02054)。 ZeRO 在三个层面上工作,以消除内存冗余(ZeRO需要进行fp16训练)。 + +- Level 1: 优化器状态在各进程中被划分。 +- Level 2: 用于更新模型权重的32位梯度也被划分,因此每个进程只存储与其优化器状态划分相对应的梯度。 +- Level 3: 16位模型参数在各进程中被划分。 + +相关文章: +- [ZeRO: Memory Optimizations Toward Training Trillion Parameter Models](https://arxiv.org/abs/1910.02054) + + +## 异构系统的并行 + +上述方法通常需要大量的 GPU 来训练一个大型模型。然而,人们常常忽略的是,与 GPU 相比,CPU 的内存要大得多。在一个典型的服务器上,CPU 可以轻松拥有几百GB的内存,而每个 GPU 通常只有16或32GB的内存。这促使人们思考为什么 CPU 内存没有被用于分布式训练。 + +最近的进展是依靠 CPU 甚至是 NVMe 磁盘来训练大型模型。主要的想法是,在不使用张量时,将其卸载回 CPU 内存或 NVMe 磁盘。通过使用异构系统架构,有可能在一台机器上容纳一个巨大的模型。 + +
+ +
异构系统
+
+ +相关文章: +- [ZeRO-Offload: Democratizing Billion-Scale Model Training](https://arxiv.org/abs/2101.06840) +- [ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning](https://arxiv.org/abs/2104.07857) +- [PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management](https://arxiv.org/abs/2108.05818) diff --git a/docs/source/zh-Hans/features/1D_tensor_parallel.md b/docs/source/zh-Hans/features/1D_tensor_parallel.md new file mode 100644 index 0000000000000000000000000000000000000000..2ddc27c7b50f8cef86b444d0b24a1504ebcb8772 --- /dev/null +++ b/docs/source/zh-Hans/features/1D_tensor_parallel.md @@ -0,0 +1,119 @@ +# 1D 张量并行 + +作者: Zhengda Bian, Yongbin Li + +**前置教程** +- [定义配置文件](../basics/define_your_config.md) +- [并行配置](../basics/configure_parallelization.md) + +**示例代码** +- [ColossalAI-Examples 1D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/tensor_parallel/tensor_parallel_1d.py) + +**相关论文** +- [Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM](https://deepakn94.github.io/assets/papers/megatron-sc21.pdf) + +## 引言 + +张量并行将模型参数划分到多个设备上,以减少内存负荷。 +[Megatron-LM](https://deepakn94.github.io/assets/papers/megatron-sc21.pdf) 介绍了一种高效的一维张量并行化实现。 + +让我们以一个线性层为例,它包括一个 GEMM $Y = XA$。 给定2个处理器,我们把列 $A$ 划分为 $[A_1 ~ A_2]$, 并在每个处理器上计算 $Y_i = XA_i$ , 然后形成 $[Y_1 ~ Y_2] = [XA_1 ~ XA_2]$. 这被称为列并行方式。 + +当第二个线性层 $Z=YB$ 跟随上述列并行层的时候, 我们把 $B$ 划分为 +```math +\left[\begin{matrix} B_1 \\ B_2 \end{matrix} \right] +``` +这就是所谓的行并行方式.
+ +为了计算 +```math +Z = [Y_1 ~ Y_2] \left[\begin{matrix} B_1 \\ B_2 \end{matrix} \right] +``` +我们首先在每个处理器上计算 $Y_iB_i$ 然后使用一个all-reduce操作将结果汇总为 $Z=Y_1B_1+Y_2B_2$。 + +我们还需要注意,在后向计算中,列并行线性层需要聚合输入张量 $X$, 因为在每个处理器 $i$ 上,我们只有 $\dot{X_i}=\dot{Y_i}A_i^T$,因此,我们在各处理器之间进行all-reduce,得到 $\dot{X}=\dot{Y}A^T=\dot{Y_1}A_1^T+\dot{Y_2}A_2^T$。 + +## 效率 +给定 $P$ 个处理器, 我们展现理论上的计算和内存成本,以及基于环形算法的1D张量并行的前向和后向的通信成本。 + +| 计算 | 内存 (参数) | 内存 (activations) | 通信 (带宽) | 通信 (时延) | +| :-: | :-: | :-: | :-: | :-: | +| $O(1/P)$ | $O(1/P)$ | $O(1)$ | $O(2(P-1)/P)$ | $O(2(P-1))$ | + +## 使用 + +为了使模型能够实现一维张量并行, 如在2个 GPU 上, 我们需要配置如下的并行设置。 +```python +CONFIG = dict(parallel=dict( + data=1, + pipeline=1, + tensor=dict(size=2, mode='1d'), +)) +``` + +然后 Colossal-AI 会自动对所有来自 `colossalai.nn` 的层应用1D张量并行。 + +让我们定义一个由两层多层感知器 (MLP) 组成的模型,如下所示。 +```python +import colossalai +import colossalai.nn as col_nn +import torch +from colossalai.utils import print_rank_0 + +class MLP(torch.nn.Module): + def __init__(self, dim: int = 256): + super().__init__() + intermediate_dim = dim * 4 + self.dense_1 = col_nn.Linear(dim, intermediate_dim) + print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.transpose(0, 1).shape}') + self.activation = torch.nn.GELU() + self.dense_2 = col_nn.Linear(intermediate_dim, dim) + print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.transpose(0, 1).shape}') + self.dropout = col_nn.Dropout(0.1) + + def forward(self, x): + x = self.dense_1(x) + print_rank_0(f'Output of the first linear layer: {x.shape}') + x = self.activation(x) + x = self.dense_2(x) + print_rank_0(f'Output of the second linear layer: {x.shape}') + x = self.dropout(x) + return x +``` + +在2个 GPU 上启动 Colossal-AI 并建立模型。 + +```python +parser = colossalai.get_default_parser() +colossalai.launch(config=CONFIG, + rank=args.rank, + world_size=args.world_size, + local_rank=args.local_rank, + host=args.host, + port=args.port) + +m = MLP() +``` +我们将会看到 MLP 模型中被划分的参数(如权重)的形状。 +```shell +Weight of the first linear layer: torch.Size([256, 512]) +Weight of the second linear layer: torch.Size([512, 256]) +``` +第一个线性层的完整权重形状应该为 `[256, 1024]`. 经过列-并行分割,它变成了 `[256, 512]`。 +同样地,第二个行并行层将权重 `[1024, 256]` 划分为 `[512, 256]`。 + +我们可以用一些随机输入来运行这个模型。 +```python +from colossalai.utils import get_current_device + +x = torch.randn((16, 256), device=get_current_device()) +torch.distributed.broadcast(x, src=0) # synchronize input + +x = m(x) +``` +然后我们可以看到 activation 结果的形状。 +```shell +Output of the first linear layer: torch.Size([16, 512]) +Output of the second linear layer: torch.Size([16, 256]) +``` +第一个线性层的输出被划分成2块 (每个形状为 `[16, 512]`), 而第二层在整个 GPU 上的输出是相同的。 diff --git a/docs/source/zh-Hans/features/2D_tensor_parallel.md b/docs/source/zh-Hans/features/2D_tensor_parallel.md new file mode 100644 index 0000000000000000000000000000000000000000..c942f82bf9d2592b981dc77ede304a8fe8a674cb --- /dev/null +++ b/docs/source/zh-Hans/features/2D_tensor_parallel.md @@ -0,0 +1,141 @@ +# 2D 张量并行 + +作者: Zhengda Bian, Yongbin Li + +**前置教程** +- [定义配置文件](../basics/define_your_config.md) +- [并行配置](../basics/configure_parallelization.md) +- [1D 张量并行](./1D_tensor_parallel.md) + +**示例代码** +- [ColossalAI-Examples - 2D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/tensor_parallel/tensor_parallel_2d.py) + +**相关论文** +- [An Efficient 2D Method for Training Super-Large Deep Learning Models](https://arxiv.org/pdf/2104.05343.pdf) + +## 引言 + +1D张量并行没有对 activations 进行划分,就大规模模型而言,这也会消耗大量的内存。 +为了平均分配计算和内存负荷,在 SUMMA(可扩展的通用矩阵乘法算法)的基础上, [2D张量并行](https://arxiv.org/pdf/2104.05343.pdf) 被引入。 + +我们还是以线性层 $Y = XA$ 为例。 +给定 $P=q\times q$ 个处理器(必要条件), 如 $q=2$, 我们把输入 $X$ 和权重A $A$ 都划分为 + +$$ +\left[\begin{matrix} X_{10} & X_{11} \\ X_{00} & X_{01} \end{matrix} \right] +\text{~and~} +\left[\begin{matrix} A_{10} & A_{11} \\ A_{00} & A_{01} \end{matrix} \right]。 +$$ + +该计算包括 $q$ 步。 当 $t=1$ 时, $X_{i0}$ 在其行中被广播, 而 $A_{0j}$ 在其列中被广播。因此,我们有 + +$$ +\left[\begin{matrix} X_{10},A_{00} & X_{10},A_{01} \\ X_{00},A_{00} & X_{00},A_{01} \end{matrix} \right]。 +$$ + +然后我们在每个处理器 $(i, j)$ 上将 $X_{i0}$ 和 $A_{0j}$ 相乘为 + +$$ +\left[\begin{matrix} X_{10}A_{00} & X_{10}A_{01} \\ X_{00}A_{00} & X_{00}A_{01} \end{matrix} \right] (1)。 +$$ + +同样,当 $t=2$ 时, $X_{i1}$ 在其行中被广播, $A_{1j}$ 在其列中被广播, 我们将它们相乘为 + +$$ +\left[\begin{matrix} X_{11}A_{10} & X_{11}A_{11} \\ X_{01}A_{10} & X_{01}A_{11} \end{matrix} \right] (2)。 +$$ + +通过将 $(1)$ 和 $(2)$ 相加,我们有 + +$$ +Y = XA = \left[\begin{matrix} X_{10}A_{00}+X_{11}A_{10} & X_{10}A_{01}+X_{11}A_{11} \\ X_{00}A_{00}+X_{01}A_{10} & X_{00}A_{01}+X_{01}A_{11} \end{matrix} \right]。 +$$ + +## 效率 +给定 $P=q\times q$ 个处理器, 我们展现理论上的计算和内存成本,以及基于环形算法的2D张量并行的前向和后向的通信成本。 + +| 计算 | 内存 (参数) | 内存 (activations) | 通信 (带宽) | 通信 (时延) | +| :-: | :-: | :-: | :-: | :-: | +| $O(1/q^2)$ | $O(1/q^2)$ | $O(1/q^2)$ | $O(6(q-1)/q)$ | $O(6(q-1))$ | + +## 使用 + +为了使我们的模型能够实现二维张量并行,例如在4个 GPU 上,我们需要配置如下的并行设置。 +```python +CONFIG = dict(parallel=dict( + data=1, + pipeline=1, + tensor=dict(size=4, mode='2d'), +)) +``` +然后 Colossal-AI 会自动对所有来自 `colossalai.nn` 的层应用2D张量并行。 + +让我们定义一个由两层多层感知器 (MLP) 组成的模型,如下所示。 +```python +import colossalai +import colossalai.nn as col_nn +import torch +from colossalai.utils import print_rank_0 + +class MLP(torch.nn.Module): + def __init__(self, dim: int = 256): + super().__init__() + intermediate_dim = dim * 4 + self.dense_1 = col_nn.Linear(dim, intermediate_dim) + print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.shape}') + self.activation = torch.nn.GELU() + self.dense_2 = col_nn.Linear(intermediate_dim, dim) + print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.shape}') + self.dropout = col_nn.Dropout(0.1) + + def forward(self, x): + x = self.dense_1(x) + print_rank_0(f'Output of the first linear layer: {x.shape}') + x = self.activation(x) + x = self.dense_2(x) + print_rank_0(f'Output of the second linear layer: {x.shape}') + x = self.dropout(x) + return x +``` +在4个 GPU 上启动 Colossal-AI 并建立模型。 +```python +parser = colossalai.get_default_parser() +colossalai.launch(config=CONFIG, + rank=args.rank, + world_size=args.world_size, + local_rank=args.local_rank, + host=args.host, + port=args.port) + +m = MLP() +``` +我们将会看到 MLP 模型中被划分的参数(如权重)的形状。 +```shell +Weight of the first linear layer: torch.Size([128, 512]) +Weight of the second linear layer: torch.Size([512, 128]) +``` +第一个线性层的完整权重形状应该为 `[256, 1024]`. 经过2D并行划分后,它在每个 GPU 上变成了 `[128, 512]` 。 +同样地,第二层将权重 `[1024, 256]` 划分为 `[512, 128]`. + +我们可以用一些随机输入来运行这个模型。 +```python +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.utils import get_current_device + +x = torch.randn((16, 256), device=get_current_device()) +# partition input +torch.distributed.broadcast(x, src=0) +x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)] +x = torch.chunk(x, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)] +print_rank_0(f'Input: {x.shape}') + +x = m(x) +``` +然后我们可以看到 activation 结果的形状。 +```shell +Input: torch.Size([8, 128]) +Output of the first linear layer: torch.Size([8, 512]) +Output of the second linear layer: torch.Size([8, 128]) +``` +2D并行中的 activation 张量都是同时在行和列分割的。例如,第一个线性层的输出是 `[8, 512]`, 而第二层的输出为 `[8, 128]`。 diff --git a/docs/source/zh-Hans/features/2p5D_tensor_parallel.md b/docs/source/zh-Hans/features/2p5D_tensor_parallel.md new file mode 100644 index 0000000000000000000000000000000000000000..59a4be02ce47764889fb38b6fa66503a58322baf --- /dev/null +++ b/docs/source/zh-Hans/features/2p5D_tensor_parallel.md @@ -0,0 +1,145 @@ +# 2.5D 张量并行 + +作者: Zhengda Bian, Yongbin Li + +**前置教程** +- [定义配置文件](../basics/define_your_config.md) +- [并行配置](../basics/configure_parallelization.md) +- [1D 张量并行](./1D_tensor_parallel.md) +- [2D 张量并行](./2D_tensor_parallel.md) + +**示例代码** +- [ColossalAI-Examples - 2.5D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/tensor_parallel/tensor_parallel_2p5d.py) + +**相关论文** +- [2.5-dimensional distributed model training](https://arxiv.org/pdf/2105.14500.pdf) + +## 引言 + +与一维张量并行相比,二维并行降低了内存成本,但可能引入更多的通信。因此,[2.5D张量并行](https://arxiv.org/pdf/2105.14500.pdf) 在 2.5D SUMMA 的基础上被提出,它通过使用更多的设备来减少通信。 + +我们还是以线性层 $Y = XA$ 为例。 +给定 $P=q \times q \times d$ 个处理器(必要条件), 如 $q=d=2$, 我们把输入 $X$ 划分为 $d\times q$ 行和 $q$ 列 + +$$ +\left[\begin{matrix} X_{30} & X_{31} \\ X_{20} & X_{21} \\ X_{10} & X_{11} \\ X_{00} & X_{01}\end{matrix} \right], +$$ +它可以被重塑为 $d$ 层 + +$$ +\left[\begin{matrix} X_{10} & X_{11} \\ X_{00} & X_{01} \end{matrix} \right] \text{~and~}\left[\begin{matrix} X_{30} & X_{31} \\ X_{20} & X_{21} \end{matrix} \right]. +$$ + +另外,权重 $A$ 被分割为 + +$$ +\left[\begin{matrix} A_{10} & A_{11} \\ A_{00} & A_{01} \end{matrix} \right]. +$$ + +对于 $X$ 相关的每一层, 我们使用SUMMA算法将 $X$ 与 $A$ 相乘。 +然后,我们得到输出 + +$$ +\left[\begin{matrix} Y_{10}=X_{10}A_{00}+X_{11}A_{10} & Y_{11}=X_{10}A_{01}+X_{11}A_{11} \\ Y_{00}=X_{00}A_{00}+X_{01}A_{10} & Y_{01}=X_{00}A_{01}+X_{01}A_{11} \end{matrix} \right] +\text{~and~} +$$ +$$ +\left[\begin{matrix} Y_{30}=X_{30}A_{00}+X_{31}A_{10} & Y_{31}=X_{30}A_{01}+X_{31}A_{11} \\ Y_{20}=X_{20}A_{00}+X_{21}A_{10} & Y_{21}=X_{20}A_{01}+X_{21}A_{11} \end{matrix} \right]. +$$ + +## 效率 + +给定 $P=q \times q \times d$ 个处理器, 我们展现理论上的计算和内存成本,以及基于环形算法的2.5D张量并行的前向和后向的通信成本。 + +| 计算 | 内存 (参数) | 内存 (activations) | 通信 (带宽) | 通信 (时延) | +| :-: | :-: | :-: | :-: | :-: | +| $O(1/dq^2)$ | $O(1/q^2)$ | $O(1/dq^2)$ | $\small O(3(q-1)(d+1)/dq)$ | $O(6(q-1))$ | + +## 使用 + +为了使我们的模型能够实现2.5D张量并行,例如在8个 GPU 上,我们需要配置如下的并行设置。 + +```python +CONFIG = dict(parallel=dict( + data=1, + pipeline=1, + tensor=dict(size=8, mode='2.5d', depth=2), +)) + +``` + +然后 Colossal-AI 会自动对所有来自 `colossalai.nn` 的层应用2.5D张量并行。 + +让我们定义一个由两层多层感知器 (MLP) 组成的模型,如下所示。 + +```python +import colossalai +import colossalai.nn as col_nn +import torch +from colossalai.utils import print_rank_0 + +class MLP(torch.nn.Module): + def __init__(self, dim: int = 256): + super().__init__() + intermediate_dim = dim * 4 + self.dense_1 = col_nn.Linear(dim, intermediate_dim) + print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.shape}') + self.activation = torch.nn.GELU() + self.dense_2 = col_nn.Linear(intermediate_dim, dim) + print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.shape}') + self.dropout = col_nn.Dropout(0.1) + + def forward(self, x): + x = self.dense_1(x) + print_rank_0(f'Output of the first linear layer: {x.shape}') + x = self.activation(x) + x = self.dense_2(x) + print_rank_0(f'Output of the second linear layer: {x.shape}') + x = self.dropout(x) + return x +``` +在8个 GPU 上启动 Colossal-AI 并建立模型。 +```python +parser = colossalai.get_default_parser() +colossalai.launch(config=CONFIG, + rank=args.rank, + world_size=args.world_size, + local_rank=args.local_rank, + host=args.host, + port=args.port) + +m = MLP() +``` +我们将会看到 MLP 模型中被划分的参数(如权重)的形状。 +```shell +Weight of the first linear layer: torch.Size([128, 512]) +Weight of the second linear layer: torch.Size([512, 128]) +``` + +第一个线性层的完整权重形状应该为 `[256, 1024]`. 经过2.5D并行划分后,它在每个 GPU 上变成了 `[128, 512]` 。 +同样地,第二层将权重 `[1024, 256]` 划分为 `[512, 128]`. + +我们可以用一些随机输入来运行这个模型。 +```python +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.utils import get_current_device + +x = torch.randn((16, 256), device=get_current_device()) +# partition input +torch.distributed.broadcast(x, src=0) +x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)] +x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)] +x = torch.chunk(x, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)] +print_rank_0(f'Input: {x.shape}') + +x = m(x) +``` +然后我们可以看到 activation 结果的形状。 +```shell +Input: torch.Size([4, 128]) +Output of the first linear layer: torch.Size([4, 512]) +Output of the second linear layer: torch.Size([4, 128]) +``` +2.5D并行中的 activation 张量都是同时在$d \times q$行和$q$列分割的。例如,第一个线性层的输出是 `[4, 512]`, 而第二层的输出为 `[4, 128]`。 +注意,2.5D并行使用与2D并行相同的划分方法来处理权重,区别在于对输入的划分。 diff --git a/docs/source/zh-Hans/features/3D_tensor_parallel.md b/docs/source/zh-Hans/features/3D_tensor_parallel.md new file mode 100644 index 0000000000000000000000000000000000000000..440121c942431b48869487f8199a88fb3fa9133a --- /dev/null +++ b/docs/source/zh-Hans/features/3D_tensor_parallel.md @@ -0,0 +1,154 @@ +# 3D 张量并行 + +作者: Zhengda Bian, Yongbin Li + +**前置教程** +- [定义配置文件](../basics/define_your_config.md) +- [并行配置](../basics/configure_parallelization.md) +- [1D 张量并行](./1D_tensor_parallel.md) +- [2D 张量并行](./2D_tensor_parallel.md) + +**示例代码** +- [ColossalAI-Examples - 3D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/tensor_parallel/tensor_parallel_3d.py) + +**相关论文** +- [Maximizing Parallelism in Distributed Training for Huge Neural Networks](https://arxiv.org/pdf/2105.14450.pdf) + +## 引言 + +[3D 张量并行](https://arxiv.org/pdf/2105.14450.pdf) 是一种将神经网络模型的计算并行化,以期望获得最佳通信成本优化的方法。 + +我们还是以线性层 $Y = XA$ 为例。 +给定 $P=q \times q \times q$ 个处理器(必要条件), 如 $q=2$, 我们把输入 $X$ 和权重 $A$ 划分为 + +$$ +\left[\begin{matrix} + X_{000} & X_{001} \\ + X_{010} & X_{011} \\ + X_{100} & X_{101} \\ + X_{110} & X_{111} \end{matrix} +\right] +\text{~and~} +\left[\begin{matrix} + A_{000} & A_{001} & A_{010} & A_{011} \\ + A_{100} & A_{101} & A_{110} & A_{111} \end{matrix} +\right] +\text{~respectively,}$$ +其中每个 $X_{ijl}$ 和 $A_{lji}$ 都被存储在处理器 $(i,j,l)$ 上, 如下图所示。 + +
+ + + + +
+ +然后我们在 $(i, 0...q,l)$ 上收集 $X_{ijl}$, 以及在$(0...q, j, l)$ 上收集 $A_{lji}$。 +因此,我们在每个处理器 $(i,j,l)$ 上都有 $X_{il}$ 和 $A_{lj}$ 以获得 $X_{il}A_{lj}$。 +最后,我们在 $(i, j, 0...q)$ 对结果进行 reduce-scatter 得到 $Y_{ijl}$, 形成 +$$ +Y= +\left[\begin{matrix} + Y_{000} & Y_{001} \\ + Y_{010} & Y_{011} \\ + Y_{100} & Y_{101} \\ + Y_{110} & Y_{111} \end{matrix} +\right]. +$$ + +我们还需要注意,在后向传播中, 我们需要 all-gather 梯度 $\dot{Y_{ijl}}$, 然后 reduce-scatter 梯度 $\dot{X_{il}}=\dot{Y_{ij}}A_{lj}^T$ and $\dot{A_{lj}}=X_{il}^T\dot{Y_{ij}}$。 + +## 效率 +给定 $P=q \times q \times q$ 个处理器, 我们展现理论上的计算和内存成本,以及基于环形算法的3D张量并行的前向和后向的通信成本。 + +| 计算 | 内存 (参数) | 内存 (activations) | 通信 (带宽) | 通信 (时延) | +| :-: | :-: | :-: | :-: | :-: | +| $O(1/q^3)$ | $O(1/q^3)$ | $O(1/q^3)$ | $O(6(q-1)/q^3)$ | $O(6(q-1))$ | + +## 使用 + +为了使我们的模型能够实现3D张量并行,例如在8个 GPU 上,我们需要配置如下的并行设置。 + +```python +CONFIG = dict(parallel=dict( + data=1, + pipeline=1, + tensor=dict(size=8, mode='3d'), +)) +``` +然后 Colossal-AI 会自动对所有来自 `colossalai.nn` 的层应用3D张量并行。 + +让我们定义一个由两层多层感知器 (MLP) 组成的模型,如下所示。 + +```python +import colossalai +import colossalai.nn as col_nn +import torch +from colossalai.utils import print_rank_0 + +class MLP(torch.nn.Module): + def __init__(self, dim: int = 256): + super().__init__() + intermediate_dim = dim * 4 + self.dense_1 = col_nn.Linear(dim, intermediate_dim) + print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.shape}') + self.activation = torch.nn.GELU() + self.dense_2 = col_nn.Linear(intermediate_dim, dim) + print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.shape}') + self.dropout = col_nn.Dropout(0.1) + + def forward(self, x): + x = self.dense_1(x) + print_rank_0(f'Output of the first linear layer: {x.shape}') + x = self.activation(x) + x = self.dense_2(x) + print_rank_0(f'Output of the second linear layer: {x.shape}') + x = self.dropout(x) + return x +``` +在8个 GPU 上启动 Colossal-AI 并建立模型。 +```python +parser = colossalai.get_default_parser() +colossalai.launch(config=CONFIG, + rank=args.rank, + world_size=args.world_size, + local_rank=args.local_rank, + host=args.host, + port=args.port) + +m = MLP() +``` +我们将会看到 MLP 模型中被划分的参数(如权重)的形状。 +```shell +Weight of the first linear layer: torch.Size([128, 256]) +Weight of the second linear layer: torch.Size([512, 64]) +``` + +第一个线性层的完整权重形状应该为 `[256, 1024]`. 经过3D并行划分后,它在每个 GPU 上变成了 `[128, 256]` 。 +同样地,第二层将权重 `[1024, 256]` 划分为 `[512, 64]`. + +我们可以用一些随机输入来运行这个模型。 + +```python +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.utils import get_current_device + +x = torch.randn((16, 256), device=get_current_device()) +# partition input +torch.distributed.broadcast(x, src=0) +x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)] +x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)] +x = torch.chunk(x, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)] +print_rank_0(f'Input: {x.shape}') + +x = m(x) +``` +然后我们可以看到 activation 结果的形状。 +```shell +Input: torch.Size([4, 128]) +Output of the first linear layer: torch.Size([4, 512]) +Output of the second linear layer: torch.Size([4, 128]) +``` +3D并行中的 activation 张量都是同时在$q^2$行和$q$列分割的。例如,第一个线性层的输出是 `[4, 512]`, 而第二层的输出为 `[4, 128]`。 +注意,虽然这里3D并行的结果与2.5D并行的结果形状相同,但每个划分的内容是不同的。 diff --git a/docs/source/zh-Hans/features/gradient_accumulation.md b/docs/source/zh-Hans/features/gradient_accumulation.md new file mode 100644 index 0000000000000000000000000000000000000000..e21e5fcd43d897761df6a5080313edb85d6a2e34 --- /dev/null +++ b/docs/source/zh-Hans/features/gradient_accumulation.md @@ -0,0 +1,40 @@ +# 梯度累积 + +作者: Shenggui Li, Yongbin Li + +**前置教程** +- [定义配置文件](../basics/define_your_config.md) +- [在训练中使用Engine和Trainer](../basics/engine_trainer.md) + +**示例代码** +- [ColossalAI-Examples Gradient Accumulation](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/gradient_accumulation) + +## 引言 + +梯度累积是一种常见的增大训练 batch size 的方式。 在训练大模型时,内存经常会成为瓶颈,并且 batch size 通常会很小(如2),这导致收敛性无法保证。梯度累积将多次迭代的梯度累加,并仅在达到预设迭代次数时更新参数。 + +## 使用 + +在 Colossal-AI 中使用梯度累积非常简单,仅需将下列配置添加进 config 文件。其中,整数值代表期望梯度累积的次数。 + +```python +gradient_accumulation = +``` + +## 实例 + +我们提供了一个 [运行实例](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/gradient_accumulation) +来展现梯度累积。在这个例子中,梯度累积次数被设置为4,你可以通过一下命令启动脚本 + +```shell +python -m torch.distributed.launch --nproc_per_node 1 --master_addr localhost --master_port 29500 run_resnet_cifar10_with_engine.py +``` + +你将会看到类似下方的文本输出。这展现了梯度虽然在前3个迭代中被计算,但直到最后一次迭代,参数才被更新。 + +```text +iteration 0, first 10 elements of param: tensor([-0.0208, 0.0189, 0.0234, 0.0047, 0.0116, -0.0283, 0.0071, -0.0359, -0.0267, -0.0006], device='cuda:0', grad_fn=) +iteration 1, first 10 elements of param: tensor([-0.0208, 0.0189, 0.0234, 0.0047, 0.0116, -0.0283, 0.0071, -0.0359, -0.0267, -0.0006], device='cuda:0', grad_fn=) +iteration 2, first 10 elements of param: tensor([-0.0208, 0.0189, 0.0234, 0.0047, 0.0116, -0.0283, 0.0071, -0.0359, -0.0267, -0.0006], device='cuda:0', grad_fn=) +iteration 3, first 10 elements of param: tensor([-0.0141, 0.0464, 0.0507, 0.0321, 0.0356, -0.0150, 0.0172, -0.0118, 0.0222, 0.0473], device='cuda:0', grad_fn=) +``` diff --git a/docs/source/zh-Hans/features/gradient_clipping.md b/docs/source/zh-Hans/features/gradient_clipping.md new file mode 100644 index 0000000000000000000000000000000000000000..203f66a3fea247742823e5c24f5940d2bb5bf87a --- /dev/null +++ b/docs/source/zh-Hans/features/gradient_clipping.md @@ -0,0 +1,51 @@ +# 梯度裁剪 + +作者: Boxiang Wang, Haichen Huang, Yongbin Li + +**前置教程** +- [定义配置文件](../basics/define_your_config.md) +- [在训练中使用Engine和Trainer](../basics/engine_trainer.md) + +**示例代码** +- [ColossalAI-Examples Gradient Clipping](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/gradient_clipping) + +**相关论文** +- [On the difficulty of training Recurrent Neural Networks](https://arxiv.org/abs/1211.5063) + +## 引言 + +为了加快训练过程和寻求全局最优以获得更好的性能,越来越多的学习率调度器被提出。人们通过控制学习率来调整训练中的下降速度。这使得梯度向量在每一步都能更好地统一。在这种情况下,下降速度可以按预期被控制。 +因此,梯度裁剪,一种可以将梯度向量归一化,以将其限制在统一长度的技术,对于那些希望模型性能更好的人来说是不可或缺的。 + +在使用 Colossal-AI 时,你不必担心实现梯度剪裁,我们以一种有效而方便的方式支持梯度剪裁。你所需要的只是在你的配置文件中增加一个命令。 + +## 为什么应该使用 Colossal-AI 中的梯度裁剪 + +我们不建议用户自己编写梯度剪裁,因为朴素的梯度剪裁在应用张量并行、流水线并行、MoE 等功能时可能会失败。 + +根据下图,每个 GPU 只拥有线性层中权重的一部分参数。为了得到线性层权重的梯度向量的正确范数,每个 GPU 中的每个梯度向量的范数应该相加。更复杂的是,偏置的分布不同于权重的分布。通信组在求和运算中有所不同。 + +(注: 这种情况是旧版本的 2D 并行,在代码中的实现是不一样的。但这是一个很好的例子,能够说明在梯度剪裁中统一所有通信的困难。) + +
+ +
参数分布
+
+ +不用担心它,因为 Colossal-AI 已经为你处理好。 + +### 使用 +要使用梯度裁剪,只需在配置文件中添加梯度裁剪范数即可。 + +```python +clip_grad_norm = 1.0 +``` + +### 实例 + +我们提供了一个展现梯度裁剪的[运行实例](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/gradient_clipping) +。在本例中,我们将梯度裁剪范数设置为1.0,你可以使用以下命令运行脚本: + +```shell +python -m torch.distributed.launch --nproc_per_node 1 --master_addr localhost --master_port 29500 train_with_engine.py +``` diff --git a/docs/source/zh-Hans/features/gradient_handler.md b/docs/source/zh-Hans/features/gradient_handler.md new file mode 100644 index 0000000000000000000000000000000000000000..701c60fed57f01edfa9a49c9b68287ff9c62a7f9 --- /dev/null +++ b/docs/source/zh-Hans/features/gradient_handler.md @@ -0,0 +1,59 @@ +# 梯度 Handler + +作者: Shenggui Li, Yongbin Li + +**前置教程** +- [定义配置文件](../basics/define_your_config.md) +- [在训练中使用Engine和Trainer](../basics/engine_trainer.md) + +**示例代码** +- [ColossalAI-Examples Gradient Handler](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/gradient_handler) + +## 引言 + +在分布式训练中,每次迭代结束时都需要梯度同步。这很重要,因为我们需要确保在不同的机器中使用相同的梯度更新参数,以便生成的参数都一样。这通常在数据并行中看到,因为在数据并行中的模型是直接复制的。 + +在 Colossal-AI 中,我们为用户提供了一个接口来定制他们想要如何处理同步。这为实现新的并行方法等情况带来了灵活性。 + +当梯度 Handler 被使用时, PyTorch 的 `DistributedDataParallel` 将不再被使用,因为它会自动同步梯度. + +## 定制你的梯度 Handler + +要实现定制的梯度Handler,需要遵循以下步骤。 +1. 继承Colossal-AI中的 `BaseGradientHandler` +2. 将梯度Handler注册进 `GRADIENT_HANDLER` +3. 实现 `handle_gradient` + +```python +from colossalai.registry import GRADIENT_HANDLER +from colossalai.engine.gradient_handler import BaseGradientHandler + + +@GRADIENT_HANDLER.register_module +class MyGradientHandler(BaseGradientHandler): + + def handle_gradient(self): + do_something() + + +``` + + +## 使用 + +要使用梯度 Handler,需要在配置文件中指定梯度 Handler。梯度 Handler 将自动构建并连接到 Engine。 + +```python +gradient_handler = [dict(type='MyGradientHandler')] +``` + + +### 实例 + +我们提供了一个 [运行实例](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/gradient_handler) +展现梯度 Handler 的使用. 在这个例子中,我们使用 `DataParallelGradientHandler` 而不是 PyTorch 的 +`DistributedDataParallel` 实现数据并行. + +```shell +python -m torch.distributed.launch --nproc_per_node 4 --master_addr localhost --master_port 29500 train_with_engine.py +``` diff --git a/docs/source/zh-Hans/features/mixed_precision_training.md b/docs/source/zh-Hans/features/mixed_precision_training.md new file mode 100644 index 0000000000000000000000000000000000000000..c9db3a59c1c3912f256e90bbc57ce5322341e3a9 --- /dev/null +++ b/docs/source/zh-Hans/features/mixed_precision_training.md @@ -0,0 +1,344 @@ +# 自动混合精度训练 (AMP) + +作者: Chuanrui Wang, Shenggui Li, Yongbin Li + +**前置教程** +- [定义配置文件](../basics/define_your_config.md) +- [在训练中使用Engine和Trainer](../basics/engine_trainer.md) + +**示例代码** +- [ColossalAI-Examples AMP](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/amp) + +**相关论文** +- [Accelerating Scientific Computations with Mixed Precision Algorithms](https://arxiv.org/abs/0808.2794) + + +## 引言 + +AMP 代表自动混合精度训练。 +在 Colossal-AI 中, 我们结合了混合精度训练的不同实现: + +1. torch.cuda.amp +2. apex.amp +3. naive amp + + +| Colossal-AI | 支持张量并行 | 支持流水并行 | fp16范围 | +| ----------- | ----------------------- | ------------------------- | ----------- | +| AMP_TYPE.TORCH | ✅ | ❌ | 在前向和反向传播期间,模型参数、激活和梯度向下转换至fp16 | +| AMP_TYPE.APEX | ❌ | ❌ | 更细粒度,我们可以选择 opt_level O0, O1, O2, O3 | +| AMP_TYPE.NAIVE | ✅ | ✅ | 模型参数、前向和反向操作,全都向下转换至fp16 | + +前两个依赖于 PyTorch (1.6及以上) 和 NVIDIA Apex 的原始实现。最后一种方法类似 Apex O2。在这些方法中,Apex-AMP 与张量并行不兼容。这是因为张量是以张量并行的方式在设备之间拆分的,因此,需要在不同的进程之间进行通信,以检查整个模型权重中是否出现inf或nan。我们修改了torch amp实现,使其现在与张量并行兼容。 + +> ❌️ fp16与ZeRO配置不兼容 +> +> ⚠️ 流水并行目前仅支持naive amp + +我们建议使用 torch AMP,因为在不使用流水并行时,它通常比 NVIDIA AMP 提供更好的准确性。 + +## 目录 + +在本教程中,我们将介绍: + +1. AMP 介绍 +2. Colossal-AI 中的 AMP +3. 练习实例 + +## AMP 介绍 + +自动混合精度训练是混合 FP16 和 FP32 训练。 + +半精度浮点格式(FP16)具有较低的算法复杂度和较高的计算效率。此外,FP16 仅需要 FP32 所需的一半存储空间,并节省了内存和网络带宽,从而为大 batch size 和大模型提供了更多内存。 + +然而,还有其他操作,如缩减,需要 FP32 的动态范围,以避免数值溢出/下溢。因此,我们引入自动混合精度,尝试将每个操作与其相应的数据类型相匹配,这可以减少内存占用并提高训练效率。 + +
+ +
AMP 示意图 (图片来自 PatrickStar 论文)
+
+ +## Colossal-AI 中的 AMP + +我们支持三种 AMP 训练方法,并允许用户在没有改变代码的情况下使用 AMP 进行训练。只需在配置文件中添加'fp16'配置即可使用 AMP。 + +```python +from colossalai.amp import AMP_TYPE + +# 使用 Torch AMP +fp16=dict( + mode = AMP_TYPE.TORCH +) + +# 使用 naive AMP +fp16=dict( + mode = AMP_TYPE.NAIVE +) + +# 使用 Nvidia Apex AMP +fp16=dict( + mode = AMP_TYPE.APEX +) + +``` + +> 这些是最低配置,完整配置将在后面的部分中说明 + +### AMP 模块化 + +AMP 模块设计为完全模块化,可以独立使用。如果你想在你的代码库中只使用 AMP 而不使用`colossalai.initialize`,你可以导入`colossalai.amp.convert_to_amp`。 + +```python +from colossalai.amp import AMP_TYPE + +# 使用torch amp的例子 +model, optimizer, criterion = colossalai.amp.convert_to_amp(model, + optimizer, + criterion, + AMP_TYPE.TORCH) +``` + +### Torch AMP 配置 + +```python +from colossalai.amp import AMP_TYPE + +fp16=dict( + mode=AMP_TYPE.TORCH, + + # 下列是grad scaler的默认值 + init_scale=2.**16, + growth_factor=2.0, + backoff_factor=0.5, + growth_interval=2000, + enabled=True +) +``` + +可选参数: +- init_scale(float, optional, default=2.**16): 初始缩放因子; +- growth_factor(float, optional, default=2.0): 如果在``growth_interval``连续迭代过程中没有出现 inf/NaN 梯度,则在`update`中乘以比例系数; +- backoff_factor(float, optional, default=0.5): 如果在迭代中出现 inf/NaN 梯度,则在`update`中乘以比例系数; +- growth_interval(int, optional, default=2000): 在指定次数的连续迭代中,若没有出现 inf/NaN 梯度,则乘以``growth_factor``. +- enabled(bool, optional, default=True): ``False``则使梯度缩放无效,`step` 仅调用底层的 ``optimizer.step()``, 其他方法成为空操作。 + +### Apex AMP 配置 + +对于这种模式,我们依靠 Apex 实现混合精度训练。我们支持这个插件,因为它允许对混合精度的粒度进行更精细的控制。 +例如, O2 水平 (优化器水平2) 将保持 batch normalization 为 FP32。 + +如果你想了解更多细节,请参考 [Apex Documentation](https://nvidia.github.io/apex/)。 + +```python +from colossalai.amp import AMP_TYPE + +fp16 = dict( + mode=AMP_TYPE.APEX, + + # 下列是默认值 + enabled=True, + opt_level='O1', + cast_model_type=None, + patch_torch_functions=None, + keep_batchnorm_fp32=None, + master_weights=None, + loss_scale=None, + cast_model_outputs=None, + num_losses=1, + verbosity=1, + min_loss_scale=None, + max_loss_scale=16777216.0 +) +``` + +参数: +- enabled(bool, optional, default=True): False 会使所有 AMP 调用成为空操作, 程序将会像没有使用 AMP 一样运行。 + +- opt_level(str, optional, default="O1" ): 纯精度或混合精度优化水平。可选值 “O0”, “O1”, “O2”, and “O3”, 详细解释见上方 Apex AMP 文档。 + +- num_losses(int, optional, default=1): 选择提前告知 AMP 您计划使用多少次损失/反向计算。 +当`amp.scale_loss`与 loss_id 参数一起使用时,使 AMP 在每次损失/反向计算时使用不同的损失比例,这可以提高稳定性。如果 num_losses 被设置为1,AMP 仍支持多次损失/反向计算,但对他们都使用同一个全局损失比例。 + +- verbosity(int, default=1): 设置为0抑制 AMP 相关输出。 + +- min_loss_scale(float, default=None): 为可通过动态损耗比例选择的损耗比例值设置下限。 +默认值“None”意味着不设置任何下限。如果不使用动态损耗比例,则忽略 min_loss_scale 。 + +- max_loss_scale(float, default=2.**24 ): 为可通过动态损耗比例选择的损耗比例值设置上限。如果不使用动态损耗比例,则 max_loss_scale 被忽略. + +目前,管理纯精度或混合精度训练的幕后属性有以下几种: +cast_model_type, patch_torch_functions, keep_batchnorm_fp32, master_weights, loss_scale. +一旦 opt_level 被确定,它们是可选的可覆盖属性 + +- cast_model_type: 将模型的参数和缓冲区强制转换为所需的类型。 +- patch_torch_functions: 补全所有的 Torch 函数和张量方法,以便在FP16中执行张量核心友好的操作,如 GEMMs 和卷积,以及在 FP32 中执行任何受益于 FP32 精度的操作。 +- keep_batchnorm_fp32: 为了提高精度并启用 cudnn batchnorm (这会提高性能),在 FP32 中保留 batchnorm 权重通常是有益的,即使模型的其余部分是 FP16。 +- master_weights: 保持 FP32 主权重以配合任何 FP16 模型权重。 FP32 主权重由优化器分级,以提高精度和捕捉小梯度。 +- loss_scale: 如果 loss_scale 是一个浮点数,则使用这个值作为静态(固定)的损失比例。如果 loss_scale 是字符串 "dynamic",则随着时间的推移自适应地调整损失比例。动态损失比例调整由 AMP 自动执行。 + + +### Naive AMP 配置 + +在 Naive AMP 模式中, 我们实现了混合精度训练,同时保持了与复杂张量和流水并行的兼容性。该 AMP 模式将所有操作转为 FP16 。下列代码块展示了该模式的`config.py`。 + +```python +from colossalai.amp import AMP_TYPE + +fp16 = dict( + mode=AMP_TYPE.NAIVE, + + # below are the default values + log_num_zeros_in_grad=False, + initial_scale=2 ** 32, + min_scale=1, + growth_factor=2, + backoff_factor=0.5, + growth_interval=1000, + hysteresis=2 +) +``` + +Naive AMP 的默认参数: +- log_num_zeros_in_grad(bool): 返回0值梯度的个数. +- initial_scale(int): gradient scaler 的初始值 +- growth_factor(int): loss scale 的增长率 +- backoff_factor(float): loss scale 的下降率 +- hysterisis(int): 动态 loss scaling 的延迟偏移 +- max_scale(int): loss scale 的最大允许值 +- verbose(bool): 如果被设为`True`,将打印调试信息 + +当使用`colossalai.initialize`时, 首先需要实例化一个模型、一个优化器和一个标准。将输出模型转换为内存消耗较小的 AMP 模型。如果您的输入模型已经太大,无法放置在 GPU 中,请使用`dtype=torch.float16`实例化你的模型。或者请尝试更小的模型,或尝试更多的并行化训练技术! + +## 实例 + +我们提供了一个 [运行实例](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/amp) +展现如何在 Colossal-AI 使用 AMP。在该例程中,我们使用 Torch AMP, 但提供的配置文件也适用于所有 AMP 模式. + +### 步骤 1. 创建配置文件 + +创建一个`config.py`文件并添加`fp16`配置. + +```python +# in config.py +from colossalai.amp import AMP_TYPE + +BATCH_SIZE = 128 +DROP_RATE = 0.1 +NUM_EPOCHS = 300 + +fp16 = dict( + mode=AMP_TYPE.TORCH, +) + +clip_grad_norm = 1.0 +``` + +### 步骤 2. 在 train_with_engine.py 导入相关库 + +创建`train_with_engine.py`并导入必要依赖. 请记得通过命令`pip install timm scipy`安装`scipy`和`timm`。 + +```python +import os +import colossalai +import torch +from pathlib import Path +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.utils import get_dataloader +from colossalai.trainer import Trainer, hooks +from colossalai.nn.lr_scheduler import LinearWarmupLR +from timm.models import vit_base_patch16_224 +from torchvision import datasets, transforms + +``` + +### 步骤 3. 初始化分布式环境 + +我们需要初始化分布式环境。为了快速演示,我们使用`launch_from_torch`。你可以参考 [Launch Colossal-AI](../basics/launch_colossalai.md) +使用其他初始化方法。 + +```python +# 初始化分布式设置 +parser = colossalai.get_default_parser() +args = parser.parse_args() + +# launch from torch +colossalai.launch_from_torch(config=args.config) + +``` + +### 步骤 4. 创建训练组件 + +构建你的模型、优化器、损失函数、学习率调整器和数据加载器。注意数据集的路径从环境变量`DATA`获得。你可以通过 `export DATA=/path/to/data` 或 `Path(os.environ['DATA'])` +在你的机器上设置路径。数据将会被自动下载到该路径。 + +```python +# build model + model = vit_base_patch16_224(drop_rate=0.1) + + # build dataloader + train_dataset = datasets.Caltech101( + root=Path(os.environ['DATA']), + download=True, + transform=transforms.Compose([ + transforms.Resize(256), + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + Gray2RGB(), + transforms.Normalize([0.5, 0.5, 0.5], + [0.5, 0.5, 0.5]) + ])) + + train_dataloader = get_dataloader(dataset=train_dataset, + shuffle=True, + batch_size=gpc.config.BATCH_SIZE, + num_workers=1, + pin_memory=True, + ) + + # build optimizer + optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, weight_decay=0.1) + + # build loss + criterion = torch.nn.CrossEntropyLoss() + + # lr_scheduelr + lr_scheduler = LinearWarmupLR(optimizer, warmup_steps=50, total_steps=gpc.config.NUM_EPOCHS) +``` + +### 步骤 5. 插入 AMP + +调用 `colossalai.initialize` 将所有训练组件转为为FP16模式. + +```python +engine, train_dataloader, _, _ = colossalai.initialize( + model, optimizer, criterion, train_dataloader, + ) +``` + +### 步骤 6. 使用 Engine 训练 + +使用Engine构建一个普通的训练循环 + +```python +engine.train() +for epoch in range(gpc.config.NUM_EPOCHS): + for img, label in enumerate(train_dataloader): + img = img.cuda() + label = label.cuda() + engine.zero_grad() + output = engine(img) + loss = engine.criterion(output, label) + engine.backward(loss) + engine.step() + lr_scheduler.step() +``` + +### 步骤 7. 启动训练脚本 + +使用下列命令启动训练脚本,你可以改变 `--nproc_per_node` 以使用不同数量的 GPU。 + +```python +python -m torch.distributed.launch --nproc_per_node 4 --master_addr localhost --master_port 29500 train_with_engine.py --config config/config_AMP_torch.py +``` diff --git a/docs/source/zh-Hans/features/nvme_offload.md b/docs/source/zh-Hans/features/nvme_offload.md new file mode 100644 index 0000000000000000000000000000000000000000..fd75ed1f5b3ecb10b51846c73f6afb2d5a1234c2 --- /dev/null +++ b/docs/source/zh-Hans/features/nvme_offload.md @@ -0,0 +1,251 @@ +# NVMe offload + +作者: Hongxin Liu + +**前置教程:** +- [基于Chunk内存管理的零冗余优化器 (ZeRO)](../features/zero_with_chunk.md) + +**相关论文** + +- [ZeRO-Offload: Democratizing Billion-Scale Model Training](https://arxiv.org/abs/2101.06840) +- [ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning](https://arxiv.org/abs/2104.07857) +## 引言 + +如果模型具有`N`个参数,在使用 Adam 时,优化器状态具有`8N`个参数。对于十亿规模的模型,优化器状态至少需要 32 GB 内存。 GPU显存限制了我们可以训练的模型规模,这称为GPU显存墙。如果我们将优化器状态 offload 到磁盘,我们可以突破 GPU 内存墙。 + +我们实现了一个用户友好且高效的异步 Tensor I/O 库:[TensorNVMe](https://github.com/hpcaitech/TensorNVMe)。有了这个库,我们可以简单地实现 NVMe offload。 + +> 该库与各种磁盘(HDD、SATA SSD 和 NVMe SSD)兼容。由于 HDD 或 SATA SSD 的 I/O 带宽较低,建议仅在 NVMe 磁盘上使用此库。 + +在优化参数时,我们可以将优化过程分为三个阶段:读取、计算和 offload。我们以流水线的方式执行优化过程,这可以重叠计算和 I/O。 + +
+ +
优化过程
+
+ + +## 使用 + +首先,请确保您安装了 [TensorNVMe](https://github.com/hpcaitech/TensorNVMe): + +```shell +pip install packaging +pip install tensornvme +``` + +我们为 Adam ([CPUAdam](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.nn.optimizer.cpu_adam.html) 和 [HybridAdam](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.nn.optimizer.hybrid_adam.html)) 实现了优化器状态的 NVMe offload。 + + + +```python +from colossalai.nn.optimizer import CPUAdam, HybridAdam + +optimizer = HybridAdam(model.parameters(), lr=1e-3, nvme_offload_fraction=1.0, nvme_offload_dir='./') +``` + + + +`nvme_offload_fraction` 是要 offload 到 NVMe 的优化器状态的比例。 `nvme_offload_dir` 是保存 NVMe offload 文件的目录。如果 `nvme_offload_dir` 为 `None`,将使用随机临时目录。 + +它与 ColossalAI 中的所有并行方法兼容。 + + +> ⚠ 它只会卸载在 CPU 上的优化器状态。这意味着它只会影响 CPU 训练或者使用卸载的 Zero/Gemini。 + +## Exampls + +Let's start from two simple examples -- training GPT with different methods. These examples relies on `transformers`. +首先让我们从两个简单的例子开始 -- 用不同的方法训练 GPT。这些例子依赖`transformers`。 + +我们首先应该安装依赖: + +```shell +pip install psutil transformers +``` + +首先,我们导入必要的包和模块: + +```python +import os +import time +from typing import Dict, Optional +import psutil +import torch +import torch.nn as nn +from transformers.models.gpt2.configuration_gpt2 import GPT2Config +from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel +import colossalai +from colossalai.nn.optimizer import HybridAdam +from colossalai.zero import zero_model_wrapper, zero_optim_wrapper +from colossalai.utils.model.colo_init_context import ColoInitContext +``` + +然后我们定义一个损失函数: + +```python +class GPTLMLoss(nn.Module): + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + def forward(self, logits, labels): + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1)) +``` + +我们定义一些工具函数,用来生成随机数据、计算模型参数量和获取当前进程内存占用: + +```python +def get_data(batch_size: int, seq_len: int, + vocab_size: int, device: Optional[str] = None) -> Dict[str, torch.Tensor]: + device = torch.cuda.current_device() if device is None else device + input_ids = torch.randint(vocab_size, (batch_size, seq_len), + device=device) + attn_mask = torch.ones_like(input_ids) + return dict(input_ids=input_ids, attention_mask=attn_mask) +def get_model_numel(model: nn.Module) -> int: + return sum(p.numel() for p in model.parameters()) +def get_mem_usage() -> int: + proc = psutil.Process(os.getpid()) + return proc.memory_info().rss +``` + +我们首先尝试在 CPU 上训练 GPT 模型: + +```python +def train_cpu(nvme_offload_fraction: float = 0.0): + config = GPT2Config() + model = GPT2LMHeadModel(config) + criterion = GPTLMLoss() + optimizer = HybridAdam(model.parameters(), nvme_offload_fraction=nvme_offload_fraction) + print(f'Model numel: {get_model_numel(model) / 1024**3:.3f} B') + start = time.time() + for step in range(3): + data = get_data(4, 128, config.vocab_size, device='cpu') + outputs = model(**data) + loss = criterion(outputs.logits, data['input_ids']) + loss.backward() + optimizer.step() + optimizer.zero_grad() + print(f'[{step}] loss: {loss.item():.3f}') + print(f'Time: {time.time() - start:.3f} s') + print(f'Mem usage: {get_mem_usage() / 1024**2:.3f} MB') +``` + +不使用 NVME 卸载: + +```python +train_cpu(0.0) +``` + +我们可能得到如下输出: + +``` +Model numel: 0.116 B +[0] loss: 10.953 +[1] loss: 10.974 +[2] loss: 10.965 +Time: 7.739 s +Mem usage: 5966.445 MB +``` + +然后使用(全量) NVME 卸载: + +```python +train_cpu(1.0) +``` + +我们可能得到: + +``` +Model numel: 0.116 B +[0] loss: 10.951 +[1] loss: 10.994 +[2] loss: 10.984 +Time: 8.527 s +Mem usage: 4968.016 MB +``` + +对于有1.16亿参数的 GPT2-S 来说,它的优化器状态大约需要占用 0.928 GB 内存。NVME 卸载节省了大约 998 MB 内存,符合我们的预期。 + +然后我们可以用 Gemini 来训练 GPT 模型。放置策略应该设置为`"auto"`、 `"cpu"` 或 `"const"`。 + +```python +def train_gemini_cpu(nvme_offload_fraction: float = 0.0): + colossalai.launch_from_torch({}) + config = GPT2Config() + with ColoInitContext(device=torch.cuda.current_device()): + model = GPT2LMHeadModel(config) + criterion = GPTLMLoss() + optimizer = HybridAdam(model.parameters(), nvme_offload_fraction=nvme_offload_fraction) + print(f'Model numel: {get_model_numel(model) / 1024**3:.3f} B') + gemini_config = dict(strict_ddp_mode=True, device=torch.cuda.current_device(), + placement_policy='cpu', pin_memory=True, hidden_dim=config.n_embd) + model = zero_model_wrapper(model, zero_stage=3, gemini_config=gemini_config) + optimizer = zero_optim_wrapper(model, optimizer, initial_scale=2**5) + start = time.time() + for step in range(3): + data = get_data(4, 128, config.vocab_size) + outputs = model(**data) + loss = criterion(outputs.logits, data['input_ids']) + optimizer.backward(loss) + optimizer.step() + optimizer.zero_grad() + print(f'[{step}] loss: {loss.item():.3f}') + print(f'Time: {time.time() - start:.3f} s') + print(f'Mem usage: {get_mem_usage() / 1024**2:.3f} MB') +``` + +不使用 NVME 卸载: + +```python +train_gemini_cpu(0.0) +``` + +我们可能得到: + +``` +Model numel: 0.116 B +searching chunk configuration is completed in 0.27 s. +used number: 118.68 MB, wasted number: 0.75 MB +total wasted percentage is 0.63% +[0] loss: 10.953 +[1] loss: 10.938 +[2] loss: 10.969 +Time: 2.997 s +Mem usage: 5592.227 MB +``` + +然后使用(全量) NVME 卸载: + +```python +train_gemini_cpu(1.0) +``` + +我们可能得到: + +``` +Model numel: 0.116 B +searching chunk configuration is completed in 0.27 s. +used number: 118.68 MB, wasted number: 0.75 MB +total wasted percentage is 0.63% +[0] loss: 10.953 +[1] loss: 10.938 +[2] loss: 10.969 +Time: 3.691 s +Mem usage: 5298.344 MB +``` + +NVME 卸载节省了大约 294 MB 内存。注意使用 Gemini 的 `pin_memory` 功能可以加速训练,但是会增加内存占用。所以这个结果也是符合我们预期的。如果我们关闭 `pin_memory`,我们仍然可以观察到大约 900 MB 的内存占用下降。 + +## API 参考 + +{{ autodoc:colossalai.nn.optimizer.HybridAdam }} + +{{ autodoc:colossalai.nn.optimizer.CPUAdam }} + + + diff --git a/docs/source/zh-Hans/features/pipeline_parallel.md b/docs/source/zh-Hans/features/pipeline_parallel.md new file mode 100644 index 0000000000000000000000000000000000000000..98096b1d7f9378bf178c6da9a5febfdfae67efb3 --- /dev/null +++ b/docs/source/zh-Hans/features/pipeline_parallel.md @@ -0,0 +1,158 @@ +# 流水并行 + +作者: Guangyang Lu, Hongxin Liu, Yongbin Li + +**前置教程** +- [定义配置文件](../basics/define_your_config.md) +- [在训练中使用Engine和Trainer](../basics/engine_trainer.md) +- [并行配置](../basics/configure_parallelization.md) + +**示例代码** +- [ColossalAI-Examples ResNet with pipeline](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/pipeline_parallel) + +**相关论文** +- [Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training](https://arxiv.org/abs/2110.14883) +- [Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM](https://arxiv.org/abs/2104.04473) +- [GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism](https://arxiv.org/abs/1811.06965) + +## 快速预览 + +在本教程中,你将学习如何使用流水并行。在 Colossal-AI 中, 我们使用 NVIDIA 推出的 1F1B 流水线。由于在本例中, 使用 ViT 和 ImageNet 太过庞大,因此我们使用 ResNet 和 CIFAR 为例. + +## 目录 + +在本教程中,我们将介绍: + +1. 介绍 1F1B 流水线; +2. 使用非交错和交错 schedule; +3. 使用流水线训练 ResNet。 + +## 认识 1F1B 流水线 + +首先,我们将向您介绍 GPipe,以便您更好地了解。 + +
+ +
图1: GPipe,来自论文 Megatron-LM
+
+ +正如你所看到的,对于 GPipe,只有当一个批次中所有 microbatches 的前向计算完成后,才会执行后向计算。 + +一般来说,1F1B(一个前向通道和一个后向通道)比 GPipe (在内存或内存和时间方面)更有效率。1F1B 流水线有两个 schedule ,非交错式和交错式,图示如下。 +
+ +
Figure2: 图片来自论文 Megatron-LM 。上面的部分显示了默认的非交错 schedule,底部显示的是交错的 schedule。
+
+ +### 非交错 Schedule + +非交错式 schedule 可分为三个阶段。第一阶段是热身阶段,处理器进行不同数量的前向计算。在接下来的阶段,处理器进行一次前向计算,然后是一次后向计算。处理器将在最后一个阶段完成后向计算。 + +这种模式比 GPipe 更节省内存。然而,它需要和 GPipe 一样的时间来完成一轮计算。 + +### 交错 Schedule + +这个 schedule 要求**microbatches的数量是流水线阶段的整数倍**。 + +在这个 schedule 中,每个设备可以对多个层的子集(称为模型块)进行计算,而不是一个连续层的集合。具体来看,之前设备1拥有层1-4,设备2拥有层5-8,以此类推;但现在设备1有层1,2,9,10,设备2有层3,4,11,12,以此类推。 +在该模式下,流水线上的每个设备都被分配到多个流水线阶段,每个流水线阶段的计算量较少。 + +这种模式既节省内存又节省时间。 + +## 使用schedule + +在 Colossal-AI 中, 我们提供非交错(`PipelineSchedule`) 和交错(`InterleavedPipelineSchedule`)schedule。 + +你只需要在配置文件中,设置 `NUM_MICRO_BATCHES` 并在你想使用交错schedule的时候,设置 `NUM_CHUNKS`。 如果你确定性地知道每个管道阶段的输出张量的形状,而且形状都是一样的,你可以设置 `tensor_shape` 以进一步减少通信。否则,你可以忽略 `tensor_shape` , 形状将在管道阶段之间自动交换。 我们将会根据用户提供的配置文件,生成一个合适schedule来支持用户的流水并行训练。 + +## 使用流水线训练 ResNet + +我们首先用Colossal PipelinableContext方式建立 `ResNet` 模型: +```python +import os +from typing import Callable, List, Optional, Type, Union +import torch +import torch.nn as nn +import colossalai +import colossalai.nn as col_nn + +from colossalai.core import global_context as gpc +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.trainer import Trainer, hooks +from colossalai.utils import MultiTimer, get_dataloader +from colossalai.context import ParallelMode +from colossalai.pipeline.pipelinable import PipelinableContext + +from titans.dataloader.cifar10 import build_cifar +from torchvision.models import resnet50 +from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1 + +# Define some config +BATCH_SIZE = 64 +NUM_EPOCHS = 2 +NUM_CHUNKS = 1 +CONFIG = dict(NUM_MICRO_BATCHES=4, parallel=dict(pipeline=2)) + +# Train +disable_existing_loggers() +parser = colossalai.get_default_parser() +args = parser.parse_args() +colossalai.launch_from_torch(backend=args.backend, config=CONFIG) +logger = get_dist_logger() +pipelinable = PipelinableContext() + +# build model +with pipelinable: + model = resnet50() +``` + +给定切分顺序,module直接给出name,部分函数需要手动添加。 +```python +exec_seq = [ + 'conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool', + (lambda x: torch.flatten(x, 1), "behind"), 'fc' +] +pipelinable.to_layer_list(exec_seq) +``` + +将模型切分成流水线阶段。 +```python +model = pipelinable.partition(NUM_CHUNKS, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE)) +``` + +我们使用`Trainer`训练`ResNet`: +```python +# build criterion +criterion = nn.CrossEntropyLoss() + +# optimizer +optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + +# build dataloader +root = os.environ.get('DATA', './data') +train_dataloader, test_dataloader = build_cifar(BATCH_SIZE, root, padding=4, crop=32, resize=32) + +lr_scheduler = col_nn.lr_scheduler.LinearWarmupLR(optimizer, NUM_EPOCHS, warmup_steps=1) +engine, train_dataloader, test_dataloader, lr_scheduler = colossalai.initialize(model, optimizer, criterion, + train_dataloader, test_dataloader, + lr_scheduler) +timer = MultiTimer() + +trainer = Trainer(engine=engine, timer=timer, logger=logger) + +hook_list = [ + hooks.LossHook(), + hooks.AccuracyHook(col_nn.metric.Accuracy()), + hooks.LogMetricByEpochHook(logger), + hooks.LRSchedulerHook(lr_scheduler, by_epoch=True) +] + +trainer.fit(train_dataloader=train_dataloader, + epochs=NUM_EPOCHS, + test_dataloader=test_dataloader, + test_interval=1, + hooks=hook_list, + display_progress=True) +``` + +我们使用 `2` 个流水段,并且 batch 将被切分为 `4` 个 micro batches。 diff --git a/docs/source/zh-Hans/features/zero_with_chunk.md b/docs/source/zh-Hans/features/zero_with_chunk.md new file mode 100644 index 0000000000000000000000000000000000000000..72403bf610a4f9523f88a5d0791417e0daed3bd8 --- /dev/null +++ b/docs/source/zh-Hans/features/zero_with_chunk.md @@ -0,0 +1,265 @@ +# 基于Chunk内存管理的零冗余优化器 (ZeRO) + +作者: [Hongxiu Liu](https://github.com/ver217), [Jiarui Fang](https://github.com/feifeibear), [Zijian Ye](https://github.com/ZijianYY) + +**前置教程:** + +- [定义配置文件](../basics/define_your_config.md) + +**示例代码** + +- [Train GPT with Colossal-AI](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/gpt) + +**相关论文** + +- [ZeRO: Memory Optimizations Toward Training Trillion Parameter Models](https://arxiv.org/abs/1910.02054) +- [ZeRO-Offload: Democratizing Billion-Scale Model Training](https://arxiv.org/abs/2101.06840) +- [ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning](https://arxiv.org/abs/2104.07857) +- [DeepSpeed: System Optimizations Enable Training Deep Learning Models with Over 100 Billion Parameters](https://dl.acm.org/doi/10.1145/3394486.3406703) +- [PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management](https://arxiv.org/abs/2108.05818) + + +## 引言 + +零冗余优化器 (ZeRO) 通过对三个模型状态(优化器状态、梯度和参数)进行划分而不是复制他们,消除了数据并行进程中的内存冗余。该方法与传统的数据并行相比,内存效率得到了极大的提高,而计算粒度和通信效率得到了保留。 + +1. **分片优化器状态**: 优化器状态 (如 [Adam optimizer](https://arxiv.org/abs/1412.6980), 32位的权重, +以及一二阶动量估计) 被划分到各个进程中, 因此每个进程只更新其分区。 + + +2. **分片梯度**: 在梯度在数据并行进程组内进行 reduction 后, 梯度张量也被划分,这样每个进程只存储与其划分的优化器状态对应的梯度。 注意, Colossal-AI 将梯度转换为 FP32 格式以参与更新参数。 + +3. **分片参数**: 16位的模型参数被划分到一个数据并行组的进程中。 + +4. **[Gemini](../advanced_tutorials/meet_gemini.md)**: 对于参数、梯度、优化器状态的动态异构内存空间管理器。 + +此外,我们还将介绍基于Chunk内存管理的零冗余优化器。 + +在使用零冗余优化器 (ZeRO)时,我们通过切分参数的方式对模型进行分布式存储,这种方法的优点是每个节点的内存负载是完全均衡的。但是这种方式有很多缺点。首先,通信时需要申请一块临时内存用来通信,通信完毕释放,这回导致存在内存碎片化的问题。其次,以Tensor为粒度进行通信,会导致网络带宽无法充分利用。通常来说传输的消息长度越长带宽利用率越高。 + +利用ColossalAI v0.1.8引入了Chunk机制,我们可以提升ZeRO的性能。我们将运算顺序上连续的一组参数存入一个Chunk中(Chunk即一段连续的内存空间),每个Chunk的大小相同。Chunk方式组织内存可以保证PCI-e和GPU-GPU之间网络带宽的高效利用,减小了通信次数,同时避免潜在的内存碎片。 + +在v0.1.8之前,ZeRO在进行参数聚合时通信成本较高,如果一个参数在连续的几次计算中被使用多次,即会发生多次通信,效率较低。这种情况在使用Checkpoint时非常常见,参数在计算backward时会重计算一遍forward。这种情况下,ZeRO的效率便不高。 + +以GPT为例,其Checkpoint会应用在每一个GPT Block上,每一个GPT Block包含一个Self-Attention层和MLP层。在计算Backward时,会依次计算Self-Attention层、MLP层的forward,然后依次计算MLP层、Self-Attention层的backward。如使用Chunk机制,我们将Self-Attention层和MLP层放在同一个Chunk中,在每个GPT Block的backward的中便无需再通信。 + +除此之外,由于小Tensor的通信、内存移动没法完全利用NVLINK、PCIE带宽,而且每次通信、内存移动都有kernel launch的开销。使用了Chunk之后可以把多次小Tensor的通信、内存移动变为一次大Tensor的通信、内存移动,既提高了带宽利用,也减小了kernel launch的开销。 + +我们提供了轻量级的Chunk搜索机制,帮助用户自动找到内存碎片最小的Chunk尺寸。 + +## 使用 + +### GeminiDDP + +我们将运用`GeminiDDP`的方式来使用基于Chunk内存管理的ZeRO。这是我们新包装的torch.Module ,它使用 ZeRO-DP 和 Gemini,其中ZeRO 用于并行,Gemini 用于内存管理。 + +同样需要确保你的模型是在 `ColoInitContext` 的上下文中初始化的。 + +```python +with ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg): + model = gpt2_medium(checkpoint=True) +``` + +定义模型参数如下: + +```python +chunk_manager = init_chunk_manager(model=module, + init_device=device, + hidden_dim=hidden_dim, + search_range_mb=search_range_mb, + min_chunk_size_mb=min_chunk_size_mb) +gemini_manager = GeminiManager(placement_policy, chunk_manager) +model = ZeroDDP(model, gemini_manager) +``` + +`hidden dim`是DNN的隐藏维度。用户可以提供这个参数来加快搜索速度。如果用户在训练前不知道这个参数也可以。 我们将使用默认值 1024。`min_chunk_size_mb`是以兆字节为单位的最小块大小。如果参数的总大小仍然小于最小块大小,则所有参数将被压缩为一个小块。 + +初始化优化器。 +```python +optimizer = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=2**5) +``` + +训练 +```python +optimizer.zero_grad() +outputs = model(input_ids, attn_mask) +loss = criterion(outputs, input_ids) +optimizer.backward(loss) +optimizer.step() +``` +> ⚠️ 注意:请不要使用`loss.backward()`,规范写法是`optimizer.backward(loss)`。 + +### 训练GPT + +在此例程中, 我们使用 `Hugging Face Transformers`,并以 `GPT2 Medium` 为例。你必须在允许该例程前安装 `transformers`。 + +为了简单起见,我们在这里只使用随机生成的数据。 + +首先我们只需要引入`Huggingface transformers` 的 `GPT2LMHeadModel`来定义我们的模型,不需要用户进行模型的定义与修改,方便用户使用。 + +```python +class GPTLMModel(nn.Module): + + def __init__(self, + hidden_size=768, + num_layers=12, + num_attention_heads=12, + max_seq_len=1024, + vocab_size=50257, + checkpoint=False): + super().__init__() + self.checkpoint = checkpoint + self.model = GPT2LMHeadModel( + GPT2Config(n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size)) + if checkpoint: + self.model.gradient_checkpointing_enable() + + def forward(self, input_ids, attention_mask): + return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0] + +def gpt2_medium(checkpoint=False): + return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint) +``` + +定义损失函数: + +```python +class GPTLMLoss(nn.Module): + + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, logits, labels): + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) +``` + +定义张量并行和参数分片策略: + +```python +def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup): + for mn, module in model.named_modules(): + for pn, param in module.named_parameters(recurse=False): + if hasattr(param, 'visited'): + continue + param.set_dist_spec(ReplicaSpec()) + if 'mlp.c_fc' in mn: + if 'weight' in pn or 'bias' in pn: + split_param_col_tp1d(param, pg) + param.compute_spec.set_output_replicate(False) + else: + param.set_dist_spec(ReplicaSpec()) + elif 'mlp.c_proj' in mn: + if 'weight' in pn: + split_param_row_tp1d(param, pg) + else: + param.set_dist_spec(ReplicaSpec()) + elif 'wte' in mn or 'wpe' in mn: + split_param_col_tp1d(param, pg) + elif 'c_attn' in mn or 'c_proj' in mn: + split_param_col_tp1d(param, pg) + else: + param.set_dist_spec(ReplicaSpec()) + + param.visited = True +def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup): + spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + param.set_tensor_spec(*spec) + + +def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup): + split_param_single_dim_tp1d(0, param, pg) + + +def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup): + split_param_single_dim_tp1d(-1, param, pg) +``` + +定义一个使用 Gemini + ZeRO DDP 的模型: + +```python +def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"): + cai_version = colossalai.__version__ + if version.parse(cai_version) > version.parse("0.1.10"): + from colossalai.nn.parallel import GeminiDDP + model = GeminiDDP(model, + device=get_current_device(), + placement_policy=placememt_policy, + pin_memory=True, + search_range_mb=32) + elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"): + from colossalai.gemini import ChunkManager, GeminiManager + chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32) + gemini_manager = GeminiManager(placememt_policy, chunk_manager) + chunk_manager = ChunkManager(chunk_size, + pg, + enable_distributed_storage=True, + init_device=GeminiManager.get_default_device(placememt_policy)) + model = ZeroDDP(model, gemini_manager) + else: + raise NotImplemented(f"CAI version {cai_version} is not supported") + return model +``` + +由于我们在这个例子中对GPT进行预训练,因此只使用了一个简单的语言模型损失函数。 + +写一个获得随机输入的函数: + +```python +def get_data(batch_size, seq_len, vocab_size): + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device()) + attention_mask = torch.ones_like(input_ids) + return input_ids, attention_mask +``` + +最后,我们可以定义我们的训练循环: + +```python +def main(): + args = parse_args() + BATCH_SIZE = 8 + SEQ_LEN = 1024 + VOCAB_SIZE = 50257 + NUM_STEPS = 10 + colossalai.launch_from_torch(config={}) + + # build criterion + criterion = GPTLMLoss() + + torch.manual_seed(123) + default_pg = ProcessGroup(tp_degree=args.tp_degree) + default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None + # build GPT model + with ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg): + model = gpt2_medium(checkpoint=True) + pg = default_pg + # Tensor Parallelism (TP) + tensor_parallelize(model, pg) + # Gemini + ZeRO DP, Note it must be used after TP + model = gemini_zero_dpp(model, pg, args.placement) + # build optimizer + optimizer = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=2**5) + numel = sum([p.numel() for p in model.parameters()]) + get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN) + torch.cuda.synchronize() + model.train() + for n in range(NUM_STEPS): + # we just use randomly generated data here + input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE) + optimizer.zero_grad() + outputs = model(input_ids, attn_mask) + loss = criterion(outputs, input_ids) + optimizer.backward(loss) + optimizer.step() + + torch.cuda.synchronize() +``` +> ⚠️ 注意:如果你使用Gemini模块的话,请不要使用我们之前提到过的[梯度累加](../features/gradient_accumulation.md)。 +完整的例子代码可以在 [Train GPT with Colossal-AI](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/gpt). 获得。 diff --git a/docs/source/zh-Hans/get_started/installation.md b/docs/source/zh-Hans/get_started/installation.md new file mode 100755 index 0000000000000000000000000000000000000000..72f85393814fb97eb5679d26c82d74f102e6cd49 --- /dev/null +++ b/docs/source/zh-Hans/get_started/installation.md @@ -0,0 +1,50 @@ +# 安装 + +环境要求: + +- PyTorch >= 1.11 (PyTorch 2.x 正在适配中) +- Python >= 3.7 +- CUDA >= 11.0 +- [NVIDIA GPU Compute Capability](https://developer.nvidia.com/cuda-gpus) >= 7.0 (V100/RTX20 and higher) +- Linux OS + +如果你遇到安装问题,可以向本项目 [反馈](https://github.com/hpcaitech/ColossalAI/issues/new/choose)。 + +## 从PyPI上安装 + +你可以PyPI上使用以下命令直接安装Colossal-AI。 + +```shell +pip install colossalai +``` + +**注:现在只支持Linux。** + +如果你想同时安装PyTorch扩展的话,可以添加`CUDA_EXT=1`。如果不添加的话,PyTorch扩展会在运行时自动安装。 + +```shell +CUDA_EXT=1 pip install colossalai +``` + +## 从源安装 + +> 此文档将与版本库的主分支保持一致。如果您遇到任何问题,欢迎给我们提 issue :) + +```shell +git clone https://github.com/hpcaitech/ColossalAI.git +cd ColossalAI + +# install dependency +pip install -r requirements/requirements.txt + +# install colossalai +pip install . +``` + +如果您不想安装和启用 CUDA 内核融合(使用融合优化器时强制安装): + +```shell +NO_CUDA_EXT=1 pip install . +``` + + diff --git a/docs/source/zh-Hans/get_started/reading_roadmap.md b/docs/source/zh-Hans/get_started/reading_roadmap.md new file mode 100755 index 0000000000000000000000000000000000000000..8813386dc9c17b29bb092858c1e0909e502b1ba9 --- /dev/null +++ b/docs/source/zh-Hans/get_started/reading_roadmap.md @@ -0,0 +1,10 @@ +# 阅读指引 + +Colossal-AI为您提供了一系列的并行训练组件。我们的目标是支持您开发分布式深度学习模型,就像您编写单GPU深度学习模型一样简单。ColossalAI提供了易于使用的API来帮助您启动您的训练过程。为了更好地了解ColossalAI的工作原理,我们建议您按照以下顺序阅读本文档。 + +- 如果您不熟悉分布式系统,或者没有使用过Colossal-AI,您可以先浏览`概念`部分,了解我们要实现的目标同时掌握一些关于分布式训练的背景知识。 +- 接下来,您可以按照`基础教程`进行学习。该节将介绍关于如何使用Colossal-AI的细节。 +- 这时候,您就可以小试牛刀了!`功能` 部分将帮助您尝试如何使用Colossal-AI为您的模型训练进行加速。我们将为每个教程提供一个代码库。这些教程将涵盖Colossal-AI的基本用法,以实现简单的功能,如数据并行和混合精度训练。 +- 最后,如果您希望应用更高超的技术,比如,如何在GPT-3上运行混合并行,快来`高级教程`部分学习如何搭建您自己的模型吧! + +**我们始终欢迎社区的建议和讨论,如果您遇到任何问题,我们将非常愿意帮助您。您可以在GitHub 提 [issue](https://github.com/hpcaitech/ColossalAI/issues) ,或在[论坛](https://github.com/hpcaitech/ColossalAI/discussions)上创建一个讨论主题。** diff --git a/docs/source/zh-Hans/get_started/run_demo.md b/docs/source/zh-Hans/get_started/run_demo.md new file mode 100755 index 0000000000000000000000000000000000000000..edfc246c22d5672f9d9125042bb4ff4e82b90eba --- /dev/null +++ b/docs/source/zh-Hans/get_started/run_demo.md @@ -0,0 +1,28 @@ +# 快速演示 + +Colossal-AI 是一个集成的大规模深度学习系统,具有高效的并行化技术。该系统可以通过应用并行化技术在具有多个 GPU 的分布式系统上加速模型训练。该系统也可以在只有一个 GPU 的系统上运行。以下是展示如何使用 Colossal-AI 的 Quick demos。 + +## 单 GPU + +Colossal-AI 可以用在只有一个 GPU 的系统上训练深度学习模型,并达到 baseline 的性能。 我们提供了一个 [在CIFAR10数据集上训练ResNet](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/image/resnet) 的例子,该例子只需要一个 GPU。 +您可以在 [ColossalAI-Examples](https://github.com/hpcaitech/ColossalAI-Examples) 中获取该例子。详细说明可以在其 `README.md` 中获取。 + +## 多 GPU + +Colossal-AI 可用于在具有多个 GPU 的分布式系统上训练深度学习模型,并通过应用高效的并行化技术大幅加速训练过程。我们提供了多种并行化技术供您尝试。 + +#### 1. 数据并行 + +您可以使用与上述单 GPU 演示相同的 [ResNet例子](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/image/resnet)。 通过设置 `--nproc_per_node` 为您机器上的 GPU 数量,您就能把数据并行应用在您的例子上了。 + +#### 2. 混合并行 + +混合并行包括数据、张量和流水线并行。在 Colossal-AI 中,我们支持不同类型的张量并行(即 1D、2D、2.5D 和 3D)。您可以通过简单地改变 `config.py` 中的配置在不同的张量并行之间切换。您可以参考 [GPT example](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/gpt), 更多细节能在它的 `README.md` 中被找到。 + +#### 3. MoE并行 + +我们提供了一个 [WideNet例子](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/image/widenet) 来验证 MoE 的并行性。 WideNet 使用 Mixture of Experts(MoE)来实现更好的性能。更多的细节可以在我们的教程中获取:[教会您如何把Mixture of Experts整合到模型中](../advanced_tutorials/integrate_mixture_of_experts_into_your_model.md)。 + +#### 4. 序列并行 + +序列并行是为了解决NLP任务中的内存效率和序列长度限制问题。 我们在 [ColossalAI-Examples](https://github.com/hpcaitech/ColossalAI-Examples) 中提供了一个 [BERT例子](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/bert/sequene_parallel)。您可以按照 `README.md` 来执行代码。 diff --git a/docs/source/zh-Hans/sidebar_category_translation.json b/docs/source/zh-Hans/sidebar_category_translation.json new file mode 100644 index 0000000000000000000000000000000000000000..af4c460c448f92845ed8e38366b88db56b75acb7 --- /dev/null +++ b/docs/source/zh-Hans/sidebar_category_translation.json @@ -0,0 +1,26 @@ +{ + "sidebar.tutorialSidebar.category.Get started": { + "message": "快速开始", + "description": "The label for category Get started in sidebar docs" + }, + "sidebar.tutorialSidebar.category.Concepts": { + "message": "概念", + "description": "The label for category Concepts in sidebar docs" + }, + "sidebar.tutorialSidebar.category.Basics": { + "message": "基础", + "description": "The label for category Basics in sidebar docs" + }, + "sidebar.dotutorialSidebarcs.category.Features": { + "message": "功能", + "description": "The label for category Features in sidebar docs" + }, + "sidebar.dtutorialSidebarocs.category.Tensor Parallel": { + "message": "张量并行", + "description": "The label for category Tensor Parallel in sidebar docs" + }, + "sidebar.tutorialSidebar.category.Advanced Tutorials": { + "message": "高级教程", + "description": "The label for category Advanced Tutorials in sidebar docs" + } +} diff --git a/docs/versions.json b/docs/versions.json new file mode 100644 index 0000000000000000000000000000000000000000..dde32982b7982c2c3c8689f2fbff57556c092545 --- /dev/null +++ b/docs/versions.json @@ -0,0 +1,3 @@ +[ + "current" +] diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000000000000000000000000000000000000..142a735c68192ed819b215e36caf2a64755eeb92 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,64 @@ +# Colossal-AI Examples + +## Table of Contents + +- [Colossal-AI Examples](#colossal-ai-examples) + - [Table of Contents](#table-of-contents) + - [Overview](#overview) + - [Folder Structure](#folder-structure) + - [Integrate Your Example With Testing](#integrate-your-example-with-testing) + +## Overview + +This folder provides several examples accelerated by Colossal-AI. +Folders such as `images` and `language` include a wide range of deep learning tasks and applications. +The `community` folder aim to create a collaborative platform for developers to contribute exotic features built on top of Colossal-AI. +The `tutorial` folder is for everyone to quickly try out the different features in Colossal-AI. + +You can find applications such as Chatbot, AIGC and Biomedicine in the [Applications](https://github.com/hpcaitech/ColossalAI/tree/main/applications) directory. + +## Folder Structure + +```text +└─ examples + └─ images + └─ vit + └─ test_ci.sh + └─ train.py + └─ README.md + └─ ... + └─ ... +``` +## Invitation to open-source contribution +Referring to the successful attempts of [BLOOM](https://bigscience.huggingface.co/) and [Stable Diffusion](https://en.wikipedia.org/wiki/Stable_Diffusion), any and all developers and partners with computing powers, datasets, models are welcome to join and build the Colossal-AI community, making efforts towards the era of big AI models! + +You may contact us or participate in the following ways: +1. [Leaving a Star ⭐](https://github.com/hpcaitech/ColossalAI/stargazers) to show your like and support. Thanks! +2. Posting an [issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose), or submitting a PR on GitHub follow the guideline in [Contributing](https://github.com/hpcaitech/ColossalAI/blob/main/CONTRIBUTING.md). +3. Join the Colossal-AI community on +[Slack](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w), +and [WeChat(微信)](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png "qrcode") to share your ideas. +4. Send your official proposal to email contact@hpcaitech.com + +Thanks so much to all of our amazing contributors! + +## Integrate Your Example With Testing + +Regular checks are important to ensure that all examples run without apparent bugs and stay compatible with the latest API. +Colossal-AI runs workflows to check for examples on a on-pull-request and weekly basis. +When a new example is added or changed, the workflow will run the example to test whether it can run. +Moreover, Colossal-AI will run testing for examples every week. + +Therefore, it is essential for the example contributors to know how to integrate your example with the testing workflow. Simply, you can follow the steps below. + +1. Create a script called `test_ci.sh` in your example folder +2. Configure your testing parameters such as number steps, batch size in `test_ci.sh`, e.t.c. Keep these parameters small such that each example only takes several minutes. +3. Export your dataset path with the prefix `/data` and make sure you have a copy of the dataset in the `/data/scratch/examples-data` directory on the CI machine. Community contributors can contact us via slack to request for downloading the dataset on the CI machine. +4. Implement the logic such as dependency setup and example execution + +## Community Dependency +We are happy to introduce the following nice community dependency repos that are powered by Colossal-AI: +- [lightning-ColossalAI](https://github.com/Lightning-AI/lightning) +- [HCP-Diffusion](https://github.com/7eu7d7/HCP-Diffusion) +- [KoChatGPT](https://github.com/airobotlab/KoChatGPT) +- [minichatgpt](https://github.com/juncongmoo/minichatgpt) diff --git a/examples/community/README.md b/examples/community/README.md new file mode 100644 index 0000000000000000000000000000000000000000..fb2ca37ed98847a346bbe8b40cdb35b08a84d2fc --- /dev/null +++ b/examples/community/README.md @@ -0,0 +1,28 @@ +## Community Examples + +Community-driven Examples is an initiative that allows users to share their own examples to the Colossal-AI community, fostering a sense of community and making it easy for others to access and benefit from shared work. The primary goal with community-driven examples is to have a community-maintained collection of diverse and exotic functionalities built on top of the Colossal-AI package. + +If a community example doesn't work as expected, you can [open an issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose) and @ the author to report it. + + +| Example | Description | Code Example | Colab |Author | +|:------------------|:---------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------|:-----------------------------------------|-----------------------------------------------------:| +| RoBERTa | Adding RoBERTa for SFT and Prompts model training | [RoBERTa](./roberta) | - | [YY Lin](https://github.com/yynil) (Moore Threads) | +| TransformerEngine FP8 | Adding TransformerEngine with FP8 training | [TransformerEngine FP8](./fp8) | - | [Kirthi Shankar Sivamani](https://github.com/ksivaman) (NVIDIA) | +|...|...|...|...|...| + +## Looking for Examples +* [Swin-Transformer](https://github.com/microsoft/Swin-Transformer) +* [T-5](https://github.com/google-research/text-to-text-transfer-transformer) +* [Segment Anything (SAM)](https://github.com/facebookresearch/segment-anything) +* [ControlNet](https://github.com/lllyasviel/ControlNet) +* [Consistency Models](https://github.com/openai/consistency_models) +* [MAE](https://github.com/facebookresearch/mae) +* [CLIP](https://github.com/openai/CLIP) + +Welcome to [open an issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose) to share your insights and needs. + +## How to get involved +To join our community-driven initiative, please visit the [Colossal-AI examples](https://github.com/hpcaitech/ColossalAI/tree/main/examples), review the provided information, and explore the codebase. + +To contribute, create a new issue outlining your proposed feature or enhancement, and our team will review and provide feedback. If you are confident enough you can also submit a PR directly. We look forward to collaborating with you on this exciting project! diff --git a/examples/community/fp8/mnist/README.md b/examples/community/fp8/mnist/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e1128c1054b7bcfbaf0cfaafd83e32f1d803f944 --- /dev/null +++ b/examples/community/fp8/mnist/README.md @@ -0,0 +1,13 @@ +# Basic MNIST Example with optional FP8 of TransformerEngine + +[TransformerEngine](https://github.com/NVIDIA/TransformerEngine) is a library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper GPUs, to provide better performance with lower memory utilization in both training and inference. + +Thanks for the contribution to this tutorial from NVIDIA. + +```bash +python main.py +python main.py --use-te # Linear layers from TransformerEngine +python main.py --use-fp8 # FP8 + TransformerEngine for Linear layers +``` + +> We are working to integrate it with Colossal-AI and will finish it soon. diff --git a/examples/community/fp8/mnist/main.py b/examples/community/fp8/mnist/main.py new file mode 100644 index 0000000000000000000000000000000000000000..a534663d380f4b4e00cd0fe80de1d3525bac3e69 --- /dev/null +++ b/examples/community/fp8/mnist/main.py @@ -0,0 +1,225 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import argparse + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.optim.lr_scheduler import StepLR +from torchvision import datasets, transforms + +try: + from transformer_engine import pytorch as te + HAVE_TE = True +except (ImportError, ModuleNotFoundError): + HAVE_TE = False + + +class Net(nn.Module): + + def __init__(self, use_te=False): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, 1) + self.conv2 = nn.Conv2d(32, 64, 3, 1) + self.dropout1 = nn.Dropout(0.25) + self.dropout2 = nn.Dropout(0.5) + if use_te: + self.fc1 = te.Linear(9216, 128) + self.fc2 = te.Linear(128, 16) + else: + self.fc1 = nn.Linear(9216, 128) + self.fc2 = nn.Linear(128, 16) + self.fc3 = nn.Linear(16, 10) + + def forward(self, x): + """FWD""" + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, 2) + x = self.dropout1(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout2(x) + x = self.fc2(x) + x = self.fc3(x) + output = F.log_softmax(x, dim=1) + return output + + +def train(args, model, device, train_loader, optimizer, epoch, use_fp8): + """Training function.""" + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + with te.fp8_autocast(enabled=use_fp8): + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + if batch_idx % args.log_interval == 0: + print(f"Train Epoch: {epoch} " + f"[{batch_idx * len(data)}/{len(train_loader.dataset)} " + f"({100. * batch_idx / len(train_loader):.0f}%)]\t" + f"Loss: {loss.item():.6f}") + if args.dry_run: + break + + +def calibrate(model, device, test_loader): + """Calibration function.""" + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + with te.fp8_autocast(enabled=False, calibrating=True): + output = model(data) + + +def test(model, device, test_loader, use_fp8): + """Testing function.""" + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + with te.fp8_autocast(enabled=use_fp8): + output = model(data) + test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss + pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability + correct += pred.eq(target.view_as(pred)).sum().item() + + test_loss /= len(test_loader.dataset) + + print(f"\nTest set: Average loss: {test_loss:.4f}, " + f"Accuracy: {correct}/{len(test_loader.dataset)} " + f"({100. * correct / len(test_loader.dataset):.0f}%)\n") + + +def main(): + # Training settings + parser = argparse.ArgumentParser(description="PyTorch MNIST Example") + parser.add_argument( + "--batch-size", + type=int, + default=64, + metavar="N", + help="input batch size for training (default: 64)", + ) + parser.add_argument( + "--test-batch-size", + type=int, + default=1000, + metavar="N", + help="input batch size for testing (default: 1000)", + ) + parser.add_argument( + "--epochs", + type=int, + default=14, + metavar="N", + help="number of epochs to train (default: 14)", + ) + parser.add_argument( + "--lr", + type=float, + default=1.0, + metavar="LR", + help="learning rate (default: 1.0)", + ) + parser.add_argument( + "--gamma", + type=float, + default=0.7, + metavar="M", + help="Learning rate step gamma (default: 0.7)", + ) + parser.add_argument( + "--dry-run", + action="store_true", + default=False, + help="quickly check a single pass", + ) + parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") + parser.add_argument( + "--log-interval", + type=int, + default=10, + metavar="N", + help="how many batches to wait before logging training status", + ) + parser.add_argument( + "--save-model", + action="store_true", + default=False, + help="For Saving the current Model", + ) + parser.add_argument("--use-fp8", + action="store_true", + default=False, + help="Use FP8 for inference and training without recalibration") + parser.add_argument("--use-fp8-infer", action="store_true", default=False, help="Use FP8 inference only") + parser.add_argument("--use-te", action="store_true", default=False, help="Use Transformer Engine") + args = parser.parse_args() + use_cuda = torch.cuda.is_available() + + if args.use_te or args.use_fp8 or args.use_fp8_infer: + assert HAVE_TE, "TransformerEngine not installed." + + if args.use_fp8 or args.use_fp8_infer: + args.use_te = True + + if args.use_te: + assert use_cuda, "CUDA needed for FP8 execution." + + if args.use_fp8_infer: + assert not args.use_fp8, "fp8-infer path currently only supports calibration from a bfloat checkpoint" + + torch.manual_seed(args.seed) + + device = torch.device("cuda" if use_cuda else "cpu") + + train_kwargs = {"batch_size": args.batch_size} + test_kwargs = {"batch_size": args.test_batch_size} + if use_cuda: + cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True} + train_kwargs.update(cuda_kwargs) + test_kwargs.update(cuda_kwargs) + + transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) + dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform) + dataset2 = datasets.MNIST("../data", train=False, transform=transform) + train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs) + test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) + + model = Net(use_te=args.use_te).to(device) + optimizer = optim.Adadelta(model.parameters(), lr=args.lr) + + scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) + for epoch in range(1, args.epochs + 1): + train(args, model, device, train_loader, optimizer, epoch, args.use_fp8) + test(model, device, test_loader, args.use_fp8) + scheduler.step() + + if args.use_fp8_infer: + calibrate(model, device, test_loader) + + if args.save_model or args.use_fp8_infer: + torch.save(model.state_dict(), "mnist_cnn.pt") + print('Eval with reloaded checkpoint : fp8=' + str(args.use_fp8_infer)) + weights = torch.load("mnist_cnn.pt") + model.load_state_dict(weights) + test(model, device, test_loader, args.use_fp8_infer) + + +if __name__ == "__main__": + main() diff --git a/examples/community/roberta/README.md b/examples/community/roberta/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8aefa327a4b4bb7496f722b2713b986a5382b848 --- /dev/null +++ b/examples/community/roberta/README.md @@ -0,0 +1,50 @@ +# Introduction +This example introduce how to pretrain roberta from scratch, including preprocessing, pretraining, finetune. The example can help you quickly train a high-quality roberta. + +## 0. Prerequisite +- Install Colossal-AI +- Editing the port from `/etc/ssh/sshd_config` and `/etc/ssh/ssh_config`, every host expose the same ssh port of server and client. If you are a root user, you also set the **PermitRootLogin** from `/etc/ssh/sshd_config` to "yes" +- Ensure that each host can log in to each other without password. If you have n hosts, need to execute n2 times + +``` +ssh-keygen +ssh-copy-id -i ~/.ssh/id_rsa.pub ip_destination +``` + +- In all hosts, edit /etc/hosts to record all hosts' name and ip.The example is shown below. + +```bash +192.168.2.1 GPU001 +192.168.2.2 GPU002 +192.168.2.3 GPU003 +192.168.2.4 GPU004 +192.168.2.5 GPU005 +192.168.2.6 GPU006 +192.168.2.7 GPU007 +... +``` + +- restart ssh +``` +service ssh restart +``` + +## 1. Corpus Preprocessing +```bash +cd preprocessing +``` +following the `README.md`, preprocess original corpus to h5py plus numpy + +## 2. Pretrain + +```bash +cd pretraining +``` +following the `README.md`, load the h5py generated by preprocess of step 1 to pretrain the model + +## 3. Finetune + +The checkpoint produced by this repo can replace `pytorch_model.bin` from [hfl/chinese-roberta-wwm-ext-large](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/tree/main) directly. Then use transfomers from Hugging Face to finetune downstream application. + +## Contributors +The example is contributed by AI team from [Moore Threads](https://www.mthreads.com/). If you find any problems for pretraining, please file an issue or send an email to yehua.zhang@mthreads.com. At last, welcome any form of contribution! diff --git a/examples/community/roberta/preprocessing/Makefile b/examples/community/roberta/preprocessing/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..82ee4e1c5b31bf53cc859c970268eb5070aa107f --- /dev/null +++ b/examples/community/roberta/preprocessing/Makefile @@ -0,0 +1,9 @@ +CXXFLAGS += -O3 -Wall -shared -std=c++14 -fPIC -fdiagnostics-color +CPPFLAGS += $(shell python3 -m pybind11 --includes) +LIBNAME = mask +LIBEXT = $(shell python3-config --extension-suffix) + +default: $(LIBNAME)$(LIBEXT) + +%$(LIBEXT): %.cpp + $(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@ diff --git a/examples/community/roberta/preprocessing/README.md b/examples/community/roberta/preprocessing/README.md new file mode 100644 index 0000000000000000000000000000000000000000..17cc2f4dc22c36226560a1ad11c0f3f0493c84f1 --- /dev/null +++ b/examples/community/roberta/preprocessing/README.md @@ -0,0 +1,105 @@ +# Data PreProcessing for chinese Whole Word Masked + + + +## Catalogue: +* 1. Introduction +* 2. Quick Start Guide: + * 2.1. Split Sentence + * 2.2.Tokenizer & Whole Word Masked + + + + +## 1. Introduction: [Back to Top] +This folder is used to preprocess chinese corpus with Whole Word Masked. You can obtain corpus from [WuDao](https://resource.wudaoai.cn/home?ind&name=WuDaoCorpora%202.0&id=1394901288847716352). Moreover, data preprocessing is flexible, and you can modify the code based on your needs, hardware or parallel framework(Open MPI, Spark, Dask). + + + +## 2. Quick Start Guide: [Back to Top] + + + +### 2.1. Split Sentence & Split data into multiple shard: +Firstly, each file has multiple documents, and each document contains multiple sentences. Split sentence through punctuation, such as `。!`. **Secondly, split data into multiple shard based on server hardware (cpu, cpu memory, hard disk) and corpus size.** Each shard contains a part of corpus, and the model needs to train all the shards as one epoch. +In this example, split 200G Corpus into 100 shard, and each shard is about 2G. The size of the shard is memory-dependent, taking into account the number of servers, the memory used by the tokenizer, and the memory used by the multi-process training to read the shard (n data parallel requires n\*shard_size memory). **To sum up, data preprocessing and model pretraining requires fighting with hardware, not just GPU.** + +```python +python sentence_split.py --input_path /orginal_corpus --output_path /shard --shard 100 +# This step takes a short time +``` +* `--input_path`: all original corpus, e.g., /orginal_corpus/0.json /orginal_corpus/1.json ... +* `--output_path`: all shard with split sentences, e.g., /shard/0.txt, /shard/1.txt ... +* `--shard`: Number of shard, e.g., 10, 50, or 100 + +Input json: + +``` +[ + { + "id": 0, + "title": "打篮球", + "content": "我今天去打篮球。不回来吃饭。" + } + { + "id": 1, + "title": "旅游", + "content": "我后天去旅游。下周请假。" + } +] +``` + +Output txt: + +``` +我今天去打篮球。 +不回来吃饭。 +]] +我后天去旅游。 +下周请假。 +``` + + + +### 2.2. Tokenizer & Whole Word Masked: + +```python +python tokenize_mask.py --input_path /shard --output_path /h5 --tokenizer_path /roberta --backend python +# This step is time consuming and is mainly spent on mask +``` + +**[optional but recommended]**: the C++ backend with `pybind11` can provide faster speed + +```shell +make +``` + +* `--input_path`: location of all shard with split sentences, e.g., /shard/0.txt, /shard/1.txt ... +* `--output_path`: location of all h5 with token_id, input_mask, segment_ids and masked_lm_positions, e.g., /h5/0.h5, /h5/1.h5 ... +* `--tokenizer_path`: tokenizer path contains huggingface tokenizer.json. Download config.json, special_tokens_map.json, vocab.txt and tokenzier.json from [hfl/chinese-roberta-wwm-ext-large](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/tree/main) +* `--backend`: python or c++, **specifies c++ can obtain faster preprocess speed** +* `--dupe_factor`: specifies how many times the preprocessor repeats to create the input from the same article/document +* `--worker`: number of process + +Input txt: + +``` +我今天去打篮球。 +不回来吃饭。 +]] +我后天去旅游。 +下周请假。 +``` + +Output h5+numpy: + +``` +'input_ids': [[id0,id1,id2,id3,id4,id5,id6,0,0..], + ...] +'input_mask': [[1,1,1,1,1,1,0,0..], + ...] +'segment_ids': [[0,0,0,0,0,...], + ...] +'masked_lm_positions': [[label1,-1,-1,label2,-1...], + ...] +``` diff --git a/examples/community/roberta/preprocessing/get_mask.py b/examples/community/roberta/preprocessing/get_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..74c97a63a9f3994bd4fd6e953222c83fedb229b3 --- /dev/null +++ b/examples/community/roberta/preprocessing/get_mask.py @@ -0,0 +1,260 @@ +import collections +import logging +import os +import random +import time +from enum import IntEnum +from random import choice + +import jieba +import torch + +jieba.setLogLevel(logging.CRITICAL) +import re + +import mask +import numpy as np + +PAD = 0 +MaskedLMInstance = collections.namedtuple("MaskedLMInstance", ["index", "label"]) + + +def map_to_numpy(data): + return np.asarray(data) + + +class PreTrainingDataset(): + + def __init__(self, + tokenizer, + max_seq_length, + backend='python', + max_predictions_per_seq: int = 80, + do_whole_word_mask: bool = True): + self.tokenizer = tokenizer + self.max_seq_length = max_seq_length + self.masked_lm_prob = 0.15 + self.backend = backend + self.do_whole_word_mask = do_whole_word_mask + self.max_predictions_per_seq = max_predictions_per_seq + self.vocab_words = list(tokenizer.vocab.keys()) + self.rec = re.compile('[\u4E00-\u9FA5]') + self.whole_rec = re.compile('##[\u4E00-\u9FA5]') + + self.mlm_p = 0.15 + self.mlm_mask_p = 0.8 + self.mlm_tamper_p = 0.05 + self.mlm_maintain_p = 0.1 + + def tokenize(self, doc): + temp = [] + for d in doc: + temp.append(self.tokenizer.tokenize(d)) + return temp + + def create_training_instance(self, instance): + is_next = 1 + raw_text_list = self.get_new_segment(instance) + tokens_a = raw_text_list + assert len(tokens_a) == len(instance) + # tokens_a, tokens_b, is_next = instance.get_values() + # print(f'is_next label:{is_next}') + # Create mapper + tokens = [] + original_tokens = [] + segment_ids = [] + tokens.append("[CLS]") + original_tokens.append('[CLS]') + segment_ids.append(0) + for index, token in enumerate(tokens_a): + tokens.append(token) + original_tokens.append(instance[index]) + segment_ids.append(0) + + tokens.append("[SEP]") + original_tokens.append('[SEP]') + segment_ids.append(0) + + # for token in tokens_b: + # tokens.append(token) + # segment_ids.append(1) + + # tokens.append("[SEP]") + # segment_ids.append(1) + + # Get Masked LM predictions + if self.backend == 'c++': + output_tokens, masked_lm_output = mask.create_whole_masked_lm_predictions( + tokens, original_tokens, self.vocab_words, self.tokenizer.vocab, self.max_predictions_per_seq, + self.masked_lm_prob) + elif self.backend == 'python': + output_tokens, masked_lm_output = self.create_whole_masked_lm_predictions(tokens) + + # Convert to Ids + input_ids = self.tokenizer.convert_tokens_to_ids(output_tokens) + input_mask = [1] * len(input_ids) + + while len(input_ids) < self.max_seq_length: + input_ids.append(PAD) + segment_ids.append(PAD) + input_mask.append(PAD) + masked_lm_output.append(-1) + return ([ + map_to_numpy(input_ids), + map_to_numpy(input_mask), + map_to_numpy(segment_ids), + map_to_numpy(masked_lm_output), + map_to_numpy([is_next]) + ]) + + def create_masked_lm_predictions(self, tokens): + cand_indexes = [] + for i, token in enumerate(tokens): + if token == "[CLS]" or token == "[SEP]": + continue + if (self.do_whole_word_mask and len(cand_indexes) >= 1 and token.startswith("##")): + cand_indexes[-1].append(i) + else: + cand_indexes.append([i]) + + # cand_indexes.append(i) + + random.shuffle(cand_indexes) + output_tokens = list(tokens) + + num_to_predict = min(self.max_predictions_per_seq, max(1, int(round(len(tokens) * self.masked_lm_prob)))) + + masked_lms = [] + covered_indexes = set() + for index in cand_indexes: + if len(masked_lms) >= num_to_predict: + break + if index in covered_indexes: + continue + covered_indexes.add(index) + + masked_token = None + # 80% mask + if random.random() < 0.8: + masked_token = "[MASK]" + else: + # 10% Keep Original + if random.random() < 0.5: + masked_token = tokens[index] + # 10% replace w/ random word + else: + masked_token = self.vocab_words[random.randint(0, len(self.vocab_words) - 1)] + + output_tokens[index] = masked_token + masked_lms.append(MaskedLMInstance(index=index, label=tokens[index])) + + masked_lms = sorted(masked_lms, key=lambda x: x.index) + masked_lm_output = [-1] * len(output_tokens) + for p in masked_lms: + masked_lm_output[p.index] = self.tokenizer.vocab[p.label] + + return (output_tokens, masked_lm_output) + + def get_new_segment(self, segment): + """ + Input a sentence, return a processed sentence: In order to support the Chinese whole word mask, the words that are separated will be marked with a special mark ("#"), so that the subsequent processing module can know which words belong to the same word. + :param segment: a sentence + """ + seq_cws = jieba.lcut(''.join(segment)) + seq_cws_dict = {x: 1 for x in seq_cws} + new_segment = [] + i = 0 + while i < len(segment): + if len(self.rec.findall(segment[i])) == 0: + new_segment.append(segment[i]) + i += 1 + continue + + has_add = False + for length in range(3, 0, -1): + if i + length > len(segment): + continue + if ''.join(segment[i:i + length]) in seq_cws_dict: + new_segment.append(segment[i]) + for l in range(1, length): + new_segment.append('##' + segment[i + l]) + i += length + has_add = True + break + if not has_add: + new_segment.append(segment[i]) + i += 1 + return new_segment + + def create_whole_masked_lm_predictions(self, tokens): + """Creates the predictions for the masked LM objective.""" + + cand_indexes = [] + for (i, token) in enumerate(tokens): + if token == "[CLS]" or token == "[SEP]": + continue + # Whole Word Masking means that if we mask all of the wordpieces + # corresponding to an original word. When a word has been split into + # WordPieces, the first token does not have any marker and any subsequence + # tokens are prefixed with ##. So whenever we see the ## token, we + # append it to the previous set of word indexes. + # + # Note that Whole Word Masking does *not* change the training code + # at all -- we still predict each WordPiece independently, softmaxed + # over the entire vocabulary. + if (self.do_whole_word_mask and len(cand_indexes) >= 1 and token.startswith("##")): + cand_indexes[-1].append(i) + else: + cand_indexes.append([i]) + + random.shuffle(cand_indexes) + + output_tokens = [t[2:] if len(self.whole_rec.findall(t)) > 0 else t for t in tokens] # 去掉"##" + + num_to_predict = min(self.max_predictions_per_seq, max(1, int(round(len(tokens) * self.masked_lm_prob)))) + + masked_lms = [] + covered_indexes = set() + for index_set in cand_indexes: + if len(masked_lms) >= num_to_predict: + break + # If adding a whole-word mask would exceed the maximum number of + # predictions, then just skip this candidate. + if len(masked_lms) + len(index_set) > num_to_predict: + continue + is_any_index_covered = False + for index in index_set: + if index in covered_indexes: + is_any_index_covered = True + break + if is_any_index_covered: + continue + for index in index_set: + covered_indexes.add(index) + + masked_token = None + # 80% of the time, replace with [MASK] + if random.random() < 0.8: + masked_token = "[MASK]" + else: + # 10% of the time, keep original + if random.random() < 0.5: + masked_token = tokens[index][2:] if len(self.whole_rec.findall( + tokens[index])) > 0 else tokens[index] # 去掉"##" + # 10% of the time, replace with random word + else: + masked_token = self.vocab_words[random.randint(0, len(self.vocab_words) - 1)] + + output_tokens[index] = masked_token + + masked_lms.append( + MaskedLMInstance( + index=index, + label=tokens[index][2:] if len(self.whole_rec.findall(tokens[index])) > 0 else tokens[index])) + assert len(masked_lms) <= num_to_predict + masked_lms = sorted(masked_lms, key=lambda x: x.index) + masked_lm_output = [-1] * len(output_tokens) + for p in masked_lms: + masked_lm_output[p.index] = self.tokenizer.vocab[p.label] + + return (output_tokens, masked_lm_output) diff --git a/examples/community/roberta/preprocessing/mask.cpp b/examples/community/roberta/preprocessing/mask.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d44f58eccfc2eaec068f4dbddfe2442b13429f54 --- /dev/null +++ b/examples/community/roberta/preprocessing/mask.cpp @@ -0,0 +1,190 @@ +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; + +const int32_t LONG_SENTENCE_LEN = 512; + +struct MaskedLMInstance { + int index; + std::string label; + MaskedLMInstance(int index, std::string label) { + this->index = index; + this->label = label; + } +}; + +auto get_new_segment( + std::vector segment, std::vector segment_jieba, + const std::vector chinese_vocab) { // const + // std::unordered_set + // &chinese_vocab + std::unordered_set seq_cws_dict; + for (auto word : segment_jieba) { + seq_cws_dict.insert(word); + } + int i = 0; + std::vector new_segment; + int segment_size = segment.size(); + while (i < segment_size) { + if (!chinese_vocab[i]) { // chinese_vocab.find(segment[i]) == + // chinese_vocab.end() + new_segment.emplace_back(segment[i]); + i += 1; + continue; + } + bool has_add = false; + for (int length = 3; length >= 1; length--) { + if (i + length > segment_size) { + continue; + } + std::string chinese_word = ""; + for (int j = i; j < i + length; j++) { + chinese_word += segment[j]; + } + if (seq_cws_dict.find(chinese_word) != seq_cws_dict.end()) { + new_segment.emplace_back(segment[i]); + for (int j = i + 1; j < i + length; j++) { + new_segment.emplace_back("##" + segment[j]); + } + i += length; + has_add = true; + break; + } + } + if (!has_add) { + new_segment.emplace_back(segment[i]); + i += 1; + } + } + + return new_segment; +} + +bool startsWith(const std::string &s, const std::string &sub) { + return s.find(sub) == 0 ? true : false; +} + +auto create_whole_masked_lm_predictions( + std::vector &tokens, + const std::vector &original_tokens, + const std::vector &vocab_words, + std::map &vocab, const int max_predictions_per_seq, + const double masked_lm_prob) { + // for (auto item : vocab) { + // std::cout << "key=" << std::string(py::str(item.first)) << ", " + // << "value=" << std::string(py::str(item.second)) << + // std::endl; + // } + std::vector > cand_indexes; + std::vector cand_temp; + int tokens_size = tokens.size(); + std::string prefix = "##"; + bool do_whole_masked = true; + + for (int i = 0; i < tokens_size; i++) { + if (tokens[i] == "[CLS]" || tokens[i] == "[SEP]") { + continue; + } + if (do_whole_masked && (cand_indexes.size() > 0) && + (tokens[i].rfind(prefix, 0) == 0)) { + cand_temp.emplace_back(i); + } else { + if (cand_temp.size() > 0) { + cand_indexes.emplace_back(cand_temp); + } + cand_temp.clear(); + cand_temp.emplace_back(i); + } + } + auto seed = std::chrono::system_clock::now().time_since_epoch().count(); + std::shuffle(cand_indexes.begin(), cand_indexes.end(), + std::default_random_engine(seed)); + // for (auto i : cand_indexes) { + // for (auto j : i) { + // std::cout << tokens[j] << " "; + // } + // std::cout << std::endl; + // } + // for (auto i : output_tokens) { + // std::cout << i; + // } + // std::cout << std::endl; + + int num_to_predict = std::min(max_predictions_per_seq, + std::max(1, int(tokens_size * masked_lm_prob))); + // std::cout << num_to_predict << std::endl; + + std::set covered_indexes; + std::vector masked_lm_output(tokens_size, -1); + int vocab_words_len = vocab_words.size(); + std::default_random_engine e(seed); + std::uniform_real_distribution u1(0.0, 1.0); + std::uniform_int_distribution u2(0, vocab_words_len - 1); + int mask_cnt = 0; + std::vector output_tokens; + output_tokens = original_tokens; + + for (auto index_set : cand_indexes) { + if (mask_cnt > num_to_predict) { + break; + } + int index_set_size = index_set.size(); + if (mask_cnt + index_set_size > num_to_predict) { + continue; + } + bool is_any_index_covered = false; + for (auto index : index_set) { + if (covered_indexes.find(index) != covered_indexes.end()) { + is_any_index_covered = true; + break; + } + } + if (is_any_index_covered) { + continue; + } + for (auto index : index_set) { + covered_indexes.insert(index); + std::string masked_token; + if (u1(e) < 0.8) { + masked_token = "[MASK]"; + } else { + if (u1(e) < 0.5) { + masked_token = output_tokens[index]; + } else { + int random_index = u2(e); + masked_token = vocab_words[random_index]; + } + } + // masked_lms.emplace_back(MaskedLMInstance(index, output_tokens[index])); + masked_lm_output[index] = vocab[output_tokens[index]]; + output_tokens[index] = masked_token; + mask_cnt++; + } + } + + // for (auto p : masked_lms) { + // masked_lm_output[p.index] = vocab[p.label]; + // } + return std::make_tuple(output_tokens, masked_lm_output); +} + +PYBIND11_MODULE(mask, m) { + m.def("create_whole_masked_lm_predictions", + &create_whole_masked_lm_predictions); + m.def("get_new_segment", &get_new_segment); +} diff --git a/examples/community/roberta/preprocessing/sentence_split.py b/examples/community/roberta/preprocessing/sentence_split.py new file mode 100644 index 0000000000000000000000000000000000000000..76e8bd428723d6de230a2a9df4a9f835c0a8ecf3 --- /dev/null +++ b/examples/community/roberta/preprocessing/sentence_split.py @@ -0,0 +1,152 @@ +import argparse +import functools +import json +import multiprocessing +import os +import re +import time +from typing import List + +from tqdm import tqdm + + +def split_sentence(document: str, flag: str = "all", limit: int = 510) -> List[str]: + sent_list = [] + try: + if flag == "zh": + document = re.sub('(?P([。?!…](?![”’"\'])))', r'\g\n', document) + document = re.sub('(?P([。?!]|…{1,2})[”’"\'])', r'\g\n', document) + elif flag == "en": + document = re.sub('(?P([.?!](?![”’"\'])))', r'\g\n', document) + document = re.sub('(?P([?!.]["\']))', r'\g\n', + document) # Special quotation marks + else: + document = re.sub('(?P([。?!….?!](?![”’"\'])))', r'\g\n', document) + + document = re.sub('(?P(([。?!.!?]|…{1,2})[”’"\']))', r'\g\n', + document) # Special quotation marks + + sent_list_ori = document.splitlines() + for sent in sent_list_ori: + sent = sent.strip() + if not sent: + continue + elif len(sent) <= 2: + continue + else: + while len(sent) > limit: + temp = sent[0:limit] + sent_list.append(temp) + sent = sent[limit:] + sent_list.append(sent) + except: + sent_list.clear() + sent_list.append(document) + return sent_list + + +def get_sent(output_path, input_path, fin_list=[], host=-1, seq_len=512) -> None: + + workers = 32 + + if input_path[-1] == '/': + input_path = input_path[:-1] + + cur_path = os.path.join(output_path, str(host) + '.txt') + new_split_sentence = functools.partial(split_sentence, limit=seq_len - 2) + with open(cur_path, 'w', encoding='utf-8') as f: + for fi, fin_path in enumerate(fin_list): + if not os.path.exists(os.path.join(input_path, fin_path[0])): + continue + if '.json' not in fin_path[0]: + continue + + print("Processing ", fin_path[0], " ", fi) + + with open(os.path.join(input_path, fin_path[0]), 'r') as fin: + f_data = [l['content'] for l in json.load(fin)] + + pool = multiprocessing.Pool(workers) + all_sent = pool.imap_unordered(new_split_sentence, f_data, 32) + pool.close() + print('finished..') + + cnt = 0 + for d in tqdm(all_sent): + for i in d: + f.write(i.strip() + '\n') + f.write(']]' + '\n') + cnt += 1 + # if cnt >= 2: + # exit() + + +def getFileSize(filepath, shard): + all_data = [] + for i in os.listdir(filepath): + all_data.append(os.path.join(filepath, i)) + all_size = sum([os.path.getsize(os.path.join(filepath, f)) for f in all_data]) + ans = [[f.split('/')[-1], os.path.getsize(os.path.join(filepath, f))] for f in all_data] + ans = sorted(ans, key=lambda x: x[1], reverse=True) + per_size = all_size / shard + real_shard = [] + temp = [] + accu_size = 0 + for i in ans: + accu_size += i[1] + temp.append(i) + if accu_size > per_size: + real_shard.append(temp) + accu_size = 0 + temp = [] + + if len(temp) > 0: + real_shard.append(temp) + + return real_shard + + +def get_start_end(real_shard, base=0, server_num=10, server_name='GPU'): + import socket + host = int(socket.gethostname().split(server_name)[-1]) + + fin_list = real_shard[server_num * base + host - 1] + print(fin_list) + print(f'I am server {host}, process {server_num * base + host - 1}, len {len(fin_list)}') + return fin_list, host + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--server_num', type=int, default=10, help='number of servers') + parser.add_argument('--seq_len', type=int, default=512, help='sequence length') + parser.add_argument('--shard', type=int, default=100, help='number of shards, e.g., 10, 50, or 100') + parser.add_argument('--input_path', type=str, required=True, help='input path of original corpus') + parser.add_argument('--output_path', type=str, required=True, help='output path of shard which has split sentence') + args = parser.parse_args() + + server_num = args.server_num + seq_len = args.seq_len + shard = args.shard + input_path = args.input_path + output_path = args.output_path + + real_shard = getFileSize(input_path, shard) + + start = time.time() + for index, shard in enumerate(real_shard): + get_sent(output_path, input_path, fin_list=shard, host=index, seq_len=seq_len) + print(f'cost {str(time.time() - start)}') + + # if you have multiple server, you can use code below or modify code to openmpi + + # for i in range(len(real_shard) // server_num + 1): + # fin_list, host = get_start_end(real_shard, i) + + # start = time.time() + # get_sent(output_path, + # input_path, + # fin_list=fin_list, host= 10 * i + host - 1) + + # print(f'cost {str(time.time() - start)}') diff --git a/examples/community/roberta/preprocessing/tokenize_mask.py b/examples/community/roberta/preprocessing/tokenize_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..f3d49c3d965fc680daf0532a5714122b2bf03678 --- /dev/null +++ b/examples/community/roberta/preprocessing/tokenize_mask.py @@ -0,0 +1,267 @@ +import argparse +import multiprocessing +import os +import socket +import time +from random import shuffle + +import h5py +import numpy as np +import psutil +from get_mask import PreTrainingDataset +from tqdm import tqdm +from transformers import AutoTokenizer + + +def get_raw_instance(document, max_sequence_length=512): + """ + Get the initial training instances, split the whole segment into multiple parts according to the max_sequence_length, and return as multiple processed instances. + :param document: document + :param max_sequence_length: + :return: a list. each element is a sequence of text + """ + # document = self.documents[index] + max_sequence_length_allowed = max_sequence_length - 2 + # document = [seq for seq in document if len(seq)= max_sequence_length_allowed: + if len(curr_seq) > 0: + result_list.append(curr_seq) + curr_seq = [] + result_list.append(document[sz_idx][:max_sequence_length_allowed]) + sz_idx += 1 + else: + result_list.append(curr_seq) + curr_seq = [] + + if len(curr_seq) > max_sequence_length_allowed / 2: # /2 + result_list.append(curr_seq) + + # num_instance=int(len(big_list)/max_sequence_length_allowed)+1 + # print("num_instance:",num_instance) + + # result_list=[] + # for j in range(num_instance): + # index=j*max_sequence_length_allowed + # end_index=index+max_sequence_length_allowed if j!=num_instance-1 else -1 + # result_list.append(big_list[index:end_index]) + return result_list + + +def split_numpy_chunk(path, tokenizer, pretrain_data, host): + + documents = [] + instances = [] + + s = time.time() + with open(path, encoding='utf-8') as fd: + document = [] + for i, line in enumerate(tqdm(fd)): + line = line.strip() + # document = line + # if len(document.split("")) <= 3: + # continue + if len(line) > 0 and line[:2] == "]]": # This is end of document + documents.append(document) + document = [] + elif len(line) >= 2: + document.append(line) + if len(document) > 0: + documents.append(document) + print('read_file ', time.time() - s) + + # documents = [x for x in documents if x] + # print(len(documents)) + # print(len(documents[0])) + # print(documents[0][0:10]) + import multiprocessing + from typing import List + + ans = [] + for docs in tqdm(documents): + ans.append(pretrain_data.tokenize(docs)) + print(time.time() - s) + del documents + + instances = [] + for a in tqdm(ans): + raw_ins = get_raw_instance(a) + instances.extend(raw_ins) + del ans + + print('len instance', len(instances)) + + sen_num = len(instances) + seq_len = 512 + input_ids = np.zeros([sen_num, seq_len], dtype=np.int32) + input_mask = np.zeros([sen_num, seq_len], dtype=np.int32) + segment_ids = np.zeros([sen_num, seq_len], dtype=np.int32) + masked_lm_output = np.zeros([sen_num, seq_len], dtype=np.int32) + + for index, ins in tqdm(enumerate(instances)): + mask_dict = pretrain_data.create_training_instance(ins) + input_ids[index] = mask_dict[0] + input_mask[index] = mask_dict[1] + segment_ids[index] = mask_dict[2] + masked_lm_output[index] = mask_dict[3] + + with h5py.File(f'/output/{host}.h5', 'w') as hf: + hf.create_dataset("input_ids", data=input_ids) + hf.create_dataset("input_mask", data=input_ids) + hf.create_dataset("segment_ids", data=segment_ids) + hf.create_dataset("masked_lm_positions", data=masked_lm_output) + + del instances + + +def split_numpy_chunk_pool(input_path, output_path, pretrain_data, worker, dupe_factor, seq_len, file_name): + + if os.path.exists(os.path.join(output_path, f'{file_name}.h5')): + print(f'{file_name}.h5 exists') + return + + documents = [] + instances = [] + + s = time.time() + with open(input_path, 'r', encoding='utf-8') as fd: + document = [] + for i, line in enumerate(tqdm(fd)): + line = line.strip() + if len(line) > 0 and line[:2] == "]]": # This is end of document + documents.append(document) + document = [] + elif len(line) >= 2: + document.append(line) + if len(document) > 0: + documents.append(document) + print(f'read_file cost {time.time() - s}, length is {len(documents)}') + + ans = [] + s = time.time() + pool = multiprocessing.Pool(worker) + encoded_doc = pool.imap_unordered(pretrain_data.tokenize, documents, 100) + for index, res in tqdm(enumerate(encoded_doc, start=1), total=len(documents), colour='cyan'): + ans.append(res) + pool.close() + print((time.time() - s) / 60) + del documents + + instances = [] + for a in tqdm(ans, colour='MAGENTA'): + raw_ins = get_raw_instance(a, max_sequence_length=seq_len) + instances.extend(raw_ins) + del ans + + print('len instance', len(instances)) + + new_instances = [] + for _ in range(dupe_factor): + for ins in instances: + new_instances.append(ins) + + shuffle(new_instances) + instances = new_instances + print('after dupe_factor, len instance', len(instances)) + + sentence_num = len(instances) + input_ids = np.zeros([sentence_num, seq_len], dtype=np.int32) + input_mask = np.zeros([sentence_num, seq_len], dtype=np.int32) + segment_ids = np.zeros([sentence_num, seq_len], dtype=np.int32) + masked_lm_output = np.zeros([sentence_num, seq_len], dtype=np.int32) + + s = time.time() + pool = multiprocessing.Pool(worker) + encoded_docs = pool.imap_unordered(pretrain_data.create_training_instance, instances, 32) + for index, mask_dict in tqdm(enumerate(encoded_docs), total=len(instances), colour='blue'): + input_ids[index] = mask_dict[0] + input_mask[index] = mask_dict[1] + segment_ids[index] = mask_dict[2] + masked_lm_output[index] = mask_dict[3] + pool.close() + print((time.time() - s) / 60) + + with h5py.File(os.path.join(output_path, f'{file_name}.h5'), 'w') as hf: + hf.create_dataset("input_ids", data=input_ids) + hf.create_dataset("input_mask", data=input_mask) + hf.create_dataset("segment_ids", data=segment_ids) + hf.create_dataset("masked_lm_positions", data=masked_lm_output) + + del instances + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--tokenizer_path', type=str, required=True, default=10, help='path of tokenizer') + parser.add_argument('--seq_len', type=int, default=512, help='sequence length') + parser.add_argument('--max_predictions_per_seq', + type=int, + default=80, + help='number of shards, e.g., 10, 50, or 100') + parser.add_argument('--input_path', type=str, required=True, help='input path of shard which has split sentence') + parser.add_argument('--output_path', type=str, required=True, help='output path of h5 contains token id') + parser.add_argument('--backend', + type=str, + default='python', + help='backend of mask token, python, c++, numpy respectively') + parser.add_argument( + '--dupe_factor', + type=int, + default=1, + help='specifies how many times the preprocessor repeats to create the input from the same article/document') + parser.add_argument('--worker', type=int, default=32, help='number of process') + parser.add_argument('--server_num', type=int, default=10, help='number of servers') + args = parser.parse_args() + + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) + pretrain_data = PreTrainingDataset(tokenizer, + args.seq_len, + args.backend, + max_predictions_per_seq=args.max_predictions_per_seq) + + data_len = len(os.listdir(args.input_path)) + + for i in range(data_len): + input_path = os.path.join(args.input_path, f'{i}.txt') + if os.path.exists(input_path): + start = time.time() + print(f'process {input_path}') + split_numpy_chunk_pool(input_path, args.output_path, pretrain_data, args.worker, args.dupe_factor, + args.seq_len, i) + end_ = time.time() + print(u'memory:%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024)) + print(f'has cost {(end_ - start) / 60}') + print('-' * 100) + print('') + + # if you have multiple server, you can use code below or modify code to openmpi + + # host = int(socket.gethostname().split('GPU')[-1]) + # for i in range(data_len // args.server_num + 1): + # h = args.server_num * i + host - 1 + # input_path = os.path.join(args.input_path, f'{h}.txt') + # if os.path.exists(input_path): + # start = time.time() + # print(f'I am server {host}, process {input_path}') + # split_numpy_chunk_pool(input_path, + # args.output_path, + # pretrain_data, + # args.worker, + # args.dupe_factor, + # args.seq_len, + # h) + # end_ = time.time() + # print(u'memory:%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) ) + # print(f'has cost {(end_ - start) / 60}') + # print('-' * 100) + # print('') diff --git a/examples/community/roberta/pretraining/README.md b/examples/community/roberta/pretraining/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c248fc1f570831070b3b7d07e56ee64572405224 --- /dev/null +++ b/examples/community/roberta/pretraining/README.md @@ -0,0 +1,23 @@ +# Pretraining +1. Pretraining roberta through running the script below. Detailed parameter descriptions can be found in the arguments.py. `data_path_prefix` is absolute path specifies output of preprocessing. **You have to modify the *hostfile* according to your cluster.** + +```bash +bash run_pretrain.sh +``` +* `--hostfile`: servers' host name from /etc/hosts +* `--include`: servers which will be used +* `--nproc_per_node`: number of process(GPU) from each server +* `--data_path_prefix`: absolute location of train data, e.g., /h5/0.h5 +* `--eval_data_path_prefix`: absolute location of eval data +* `--tokenizer_path`: tokenizer path contains huggingface tokenizer.json, e.g./tokenizer/tokenizer.json +* `--bert_config`: config.json which represent model +* `--mlm`: model type of backbone, bert or deberta_v2 + +2. if resume training from earylier checkpoint, run the script below. + +```shell +bash run_pretrain_resume.sh +``` +* `--resume_train`: whether to resume training +* `--load_pretrain_model`: absolute path which contains model checkpoint +* `--load_optimizer_lr`: absolute path which contains optimizer checkpoint diff --git a/examples/community/roberta/pretraining/arguments.py b/examples/community/roberta/pretraining/arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..40210c4b1be779e7cd029e7deb70b7ed7acad2d3 --- /dev/null +++ b/examples/community/roberta/pretraining/arguments.py @@ -0,0 +1,87 @@ +from numpy import require + +import colossalai + +__all__ = ['parse_args'] + + +def parse_args(): + parser = colossalai.get_default_parser() + + parser.add_argument( + "--distplan", + type=str, + default='CAI_Gemini', + help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].", + ) + parser.add_argument( + "--tp_degree", + type=int, + default=1, + help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.", + ) + parser.add_argument( + "--placement", + type=str, + default='cpu', + help="Placement Policy for Gemini. Valid when using colossalai as dist plan.", + ) + parser.add_argument( + "--shardinit", + action='store_true', + help= + "Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.", + ) + + parser.add_argument('--lr', type=float, required=True, help='initial learning rate') + parser.add_argument('--epoch', type=int, required=True, help='number of epoch') + parser.add_argument('--data_path_prefix', type=str, required=True, help="location of the train data corpus") + parser.add_argument('--eval_data_path_prefix', + type=str, + required=True, + help='location of the evaluation data corpus') + parser.add_argument('--tokenizer_path', type=str, required=True, help='location of the tokenizer') + parser.add_argument('--max_seq_length', type=int, default=512, help='sequence length') + parser.add_argument('--refresh_bucket_size', + type=int, + default=1, + help="This param makes sure that a certain task is repeated for this time steps to \ + optimise on the back propogation speed with APEX's DistributedDataParallel") + parser.add_argument("--max_predictions_per_seq", + "--max_pred", + default=80, + type=int, + help="The maximum number of masked tokens in a sequence to be predicted.") + parser.add_argument("--gradient_accumulation_steps", default=1, type=int, help="accumulation_steps") + parser.add_argument("--train_micro_batch_size_per_gpu", default=2, type=int, required=True, help="train batch size") + parser.add_argument("--eval_micro_batch_size_per_gpu", default=2, type=int, required=True, help="eval batch size") + parser.add_argument("--num_workers", default=8, type=int, help="") + parser.add_argument("--async_worker", action='store_true', help="") + parser.add_argument("--bert_config", required=True, type=str, help="location of config.json") + parser.add_argument("--wandb", action='store_true', help="use wandb to watch model") + parser.add_argument("--wandb_project_name", default='roberta', help="wandb project name") + parser.add_argument("--log_interval", default=100, type=int, help="report interval") + parser.add_argument("--log_path", type=str, required=True, help="log file which records train step") + parser.add_argument("--tensorboard_path", type=str, required=True, help="location of tensorboard file") + parser.add_argument("--colossal_config", + type=str, + required=True, + help="colossal config, which contains zero config and so on") + parser.add_argument("--ckpt_path", + type=str, + required=True, + help="location of saving checkpoint, which contains model and optimizer") + parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") + parser.add_argument('--vscode_debug', action='store_true', help="use vscode to debug") + parser.add_argument('--load_pretrain_model', default='', type=str, help="location of model's checkpoin") + parser.add_argument( + '--load_optimizer_lr', + default='', + type=str, + help="location of checkpoint, which contains optimerzier, learning rate, epoch, shard and global_step") + parser.add_argument('--resume_train', action='store_true', help="whether resume training from a early checkpoint") + parser.add_argument('--mlm', default='bert', type=str, help="model type, bert or deberta") + parser.add_argument('--checkpoint_activations', action='store_true', help="whether to use gradient checkpointing") + + args = parser.parse_args() + return args diff --git a/examples/community/roberta/pretraining/bert_dataset_provider.py b/examples/community/roberta/pretraining/bert_dataset_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..eaf165ed18f4022a218671f4846d9e92c287f911 --- /dev/null +++ b/examples/community/roberta/pretraining/bert_dataset_provider.py @@ -0,0 +1,16 @@ +class BertDatasetProviderInterface: + + def get_shard(self, index, shuffle=True): + raise NotImplementedError + + def release_shard(self, index): + raise NotImplementedError + + def prefetch_shard(self, index): + raise NotImplementedError + + def get_batch(self, batch_iter): + raise NotImplementedError + + def prefetch_batch(self): + raise NotImplementedError diff --git a/examples/community/roberta/pretraining/evaluation.py b/examples/community/roberta/pretraining/evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..009242cd1cf5fcc48d0224394979de743e4f9f0b --- /dev/null +++ b/examples/community/roberta/pretraining/evaluation.py @@ -0,0 +1,76 @@ +import math +import os + +import torch +from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider +from tqdm import tqdm +from utils.global_vars import get_tensorboard_writer, get_timers + + +def evaluate(model, args, logger, global_step, criterion): + evaluate_dataset_provider = NvidiaBertDatasetProvider(args, evaluate=True) + start_shard = 0 + + model.eval() + timers = get_timers() + eval_step = 0 + eval_loss = 0 + cur_loss = 0 + world_size = torch.distributed.get_world_size() + + with torch.no_grad(): + + for shard in range(start_shard, len(os.listdir(args.eval_data_path_prefix))): + + timers('eval_shard_time').start() + + dataset_iterator, total_length = evaluate_dataset_provider.get_shard(shard) + # evaluate_dataset_provider.prefetch_shard(shard + 1) + if torch.distributed.get_rank() == 0: + iterator_data = tqdm(enumerate(dataset_iterator), + total=(total_length // args.eval_micro_batch_size_per_gpu // world_size), + colour='MAGENTA', + smoothing=1) + else: + iterator_data = enumerate(dataset_iterator) + + for step, batch_data in iterator_data: #tqdm(enumerate(dataset_iterator), total=(total_length // args.train_micro_batch_size_per_gpu // world_size), colour='cyan', smoothing=1): + + # batch_data = pretrain_dataset_provider.get_batch(batch_index) + eval_step += 1 + input_ids = batch_data[0].cuda() + attention_mask = batch_data[1].cuda() + token_type_ids = batch_data[2].cuda() + mlm_label = batch_data[3].cuda() + # nsp_label = batch_data[5].cuda() + + output = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + + loss = criterion(output.logits, mlm_label) #prediction_scores + evaluate_dataset_provider.prefetch_batch() + + eval_loss += loss.float().item() + + cur_loss = eval_loss / eval_step + elapsed_time = timers("eval_shard_time").elapsed() + elapsed_time_per_iteration = elapsed_time / eval_step + ppl = math.exp(cur_loss) + + if args.wandb and torch.distributed.get_rank() == 0: + tensorboard_log = get_tensorboard_writer() + tensorboard_log.log_eval({ + 'loss': cur_loss, + 'ppl': ppl, + 'mins_batch': elapsed_time_per_iteration + }, global_step) + + eval_log_str = f'evaluation shard: {shard} | step: {eval_step} | elapsed_time: {elapsed_time / 60 :.3f} minutes ' + \ + f'| mins/batch: {elapsed_time_per_iteration :.3f} seconds | loss: {cur_loss:.7f} | ppl: {ppl:.7f}' + + logger.info(eval_log_str) + logger.info('-' * 100) + logger.info('') + + evaluate_dataset_provider.release_shard() + model.train() + return cur_loss diff --git a/examples/community/roberta/pretraining/hostfile b/examples/community/roberta/pretraining/hostfile new file mode 100644 index 0000000000000000000000000000000000000000..f4e047f01fdd6e2826e5cefd2bf84c8178d561ad --- /dev/null +++ b/examples/community/roberta/pretraining/hostfile @@ -0,0 +1,10 @@ +GPU001 +GPU002 +GPU003 +GPU004 +GPU005 +GPU006 +GPU007 +GPU008 +GPU009 +GPU010 diff --git a/examples/community/roberta/pretraining/loss.py b/examples/community/roberta/pretraining/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..989c2bd5c450462e130fe947a28472db32ea6f49 --- /dev/null +++ b/examples/community/roberta/pretraining/loss.py @@ -0,0 +1,17 @@ +import torch + +__all__ = ['LossForPretraining'] + + +class LossForPretraining(torch.nn.Module): + + def __init__(self, vocab_size): + super(LossForPretraining, self).__init__() + self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-1) + self.vocab_size = vocab_size + + def forward(self, prediction_scores, masked_lm_labels, next_sentence_labels=None): + masked_lm_loss = self.loss_fn(prediction_scores.view(-1, self.vocab_size), masked_lm_labels.view(-1)) + # next_sentence_loss = self.loss_fn(seq_relationship_score.view(-1, 2), next_sentence_labels.view(-1)) + total_loss = masked_lm_loss #+ next_sentence_loss + return total_loss diff --git a/examples/community/roberta/pretraining/model/bert.py b/examples/community/roberta/pretraining/model/bert.py new file mode 100644 index 0000000000000000000000000000000000000000..a5da1bea6f655b3e2e36168a8339f848f48fc2ff --- /dev/null +++ b/examples/community/roberta/pretraining/model/bert.py @@ -0,0 +1,1881 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BERT model.""" + +import math +import os +import warnings +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from packaging import version +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.models.bert.configuration_bert import BertConfig +from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from transformers.utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "bert-base-uncased" +_CONFIG_FOR_DOC = "BertConfig" +_TOKENIZER_FOR_DOC = "BertTokenizer" + +# TokenClassification docstring +_CHECKPOINT_FOR_TOKEN_CLASSIFICATION = "dbmdz/bert-large-cased-finetuned-conll03-english" +_TOKEN_CLASS_EXPECTED_OUTPUT = ( + "['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC'] ") +_TOKEN_CLASS_EXPECTED_LOSS = 0.01 + +# QuestionAnswering docstring +_CHECKPOINT_FOR_QA = "deepset/bert-base-cased-squad2" +_QA_EXPECTED_OUTPUT = "'a nice puppet'" +_QA_EXPECTED_LOSS = 7.41 +_QA_TARGET_START_INDEX = 14 +_QA_TARGET_END_INDEX = 15 + +# SequenceClassification docstring +_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "textattack/bert-base-uncased-yelp-polarity" +_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'" +_SEQ_CLASS_EXPECTED_LOSS = 0.01 + +BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "bert-base-uncased", + "bert-large-uncased", + "bert-base-cased", + "bert-large-cased", + "bert-base-multilingual-uncased", + "bert-base-multilingual-cased", + "bert-base-chinese", + "bert-base-german-cased", + "bert-large-uncased-whole-word-masking", + "bert-large-cased-whole-word-masking", + "bert-large-uncased-whole-word-masking-finetuned-squad", + "bert-large-cased-whole-word-masking-finetuned-squad", + "bert-base-cased-finetuned-mrpc", + "bert-base-german-dbmdz-cased", + "bert-base-german-dbmdz-uncased", + "cl-tohoku/bert-base-japanese", + "cl-tohoku/bert-base-japanese-whole-word-masking", + "cl-tohoku/bert-base-japanese-char", + "cl-tohoku/bert-base-japanese-char-whole-word-masking", + "TurkuNLP/bert-base-finnish-cased-v1", + "TurkuNLP/bert-base-finnish-uncased-v1", + "wietsedv/bert-base-dutch-cased", + # See all BERT models at https://huggingface.co/models?filter=bert +] + + +def load_tf_weights_in_bert(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error("Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions.") + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any(n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name): + logger.info(f"Skipping {'/'.join(name)}") + continue + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + if version.parse(torch.__version__) > version.parse("1.6.0"): + self.register_buffer( + "token_type_ids", + torch.zeros(self.position_ids.size(), dtype=torch.long), + persistent=False, + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values_length: int = 0, + ) -> torch.Tensor: + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length:seq_length + past_key_values_length] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + + def __init__(self, config, position_embedding_type=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError(f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})") + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr(config, "position_embedding_type", "absolute") + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + + self.is_decoder = config.is_decoder + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + +class BertSelfOutput(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + + def __init__(self, config, position_embedding_type=None): + super().__init__() + self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices(heads, self.self.num_attention_heads, + self.self.attention_head_size, self.pruned_heads) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = BertAttention(config, position_embedding_type="absolute") + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`") + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + cross_attn_past_key_value, + output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + layer_output = apply_chunking_to_forward(self.feed_forward_chunk, self.chunk_size_feed_forward, + self.seq_len_dim, attention_output) + outputs = (layer_output,) + outputs + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] if v is not None) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output: torch.Tensor) -> torch.Tensor: + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertOnlyNSPHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, pooled_output): + seq_relationship_score = self.seq_relationship(pooled_output) + return seq_relationship_score + + +class BertPreTrainingHeads(nn.Module): + + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + load_tf_weights = load_tf_weights_in_bert + base_model_prefix = "bert" + supports_gradient_checkpointing = True + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, BertEncoder): + module.gradient_checkpointing = value + + +@dataclass +class BertForPreTrainingOutput(ModelOutput): + """ + Output type of [`BertForPreTraining`]. + + Args: + loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): + Total loss as the sum of the masked language modeling loss and the next sequence prediction + (classification) loss. + prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`): + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation + before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + prediction_logits: torch.FloatTensor = None + seq_relationship_logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +BERT_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`BertConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +BERT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`BertTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.", + BERT_START_DOCSTRING, +) +class BertModel(BertPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in [Attention is + all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set + to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and + `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPoolingAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +@add_start_docstrings( + """ + Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next + sentence prediction (classification)` head. + """, + BERT_START_DOCSTRING, +) +class BertForPreTraining(BertPreTrainedModel): + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config) + self.cls = BertPreTrainingHeads(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + next_sentence_label: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), + the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence + pair (see `input_ids` docstring) Indices should be in `[0, 1]`: + + - 0 indicates sequence B is a continuation of sequence A, + - 1 indicates sequence B is a random sequence. + kwargs (`Dict[str, any]`, optional, defaults to *{}*): + Used to hide legacy arguments that have been deprecated. + + Returns: + + Example: + + ```python + >>> from transformers import BertTokenizer, BertForPreTraining + >>> import torch + + >>> tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") + >>> model = BertForPreTraining.from_pretrained("bert-base-uncased") + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.prediction_logits + >>> seq_relationship_logits = outputs.seq_relationship_logits + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output, pooled_output = outputs[:2] + prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) + + total_loss = None + if labels is not None and next_sentence_label is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) + total_loss = masked_lm_loss + next_sentence_loss + + if not return_dict: + output = (prediction_scores, seq_relationship_score) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return BertForPreTrainingOutput( + loss=total_loss, + prediction_logits=prediction_scores, + seq_relationship_logits=seq_relationship_score, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings("""Bert Model with a `language modeling` head on top for CLM fine-tuning.""", + BERT_START_DOCSTRING) +class BertLMHeadModel(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + if not config.is_decoder: + logger.warning("If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`") + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.Tensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]` + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past} + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past + + +@add_start_docstrings("""Bert Model with a `language modeling` head on top.""", BERT_START_DOCSTRING) +class BertForMaskedLM(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + if config.is_decoder: + logger.warning("If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention.") + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + expected_output="'paris'", + expected_loss=0.88, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + effective_batch_size = input_shape[0] + + # add a dummy token + if self.config.pad_token_id is None: + raise ValueError("The PAD token should be defined for generation") + + attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1) + dummy_token = torch.full((effective_batch_size, 1), + self.config.pad_token_id, + dtype=torch.long, + device=input_ids.device) + input_ids = torch.cat([input_ids, dummy_token], dim=1) + + return {"input_ids": input_ids, "attention_mask": attention_mask} + + +@add_start_docstrings( + """Bert Model with a `next sentence prediction (classification)` head on top.""", + BERT_START_DOCSTRING, +) +class BertForNextSentencePrediction(BertPreTrainedModel): + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config) + self.cls = BertOnlyNSPHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair + (see `input_ids` docstring). Indices should be in `[0, 1]`: + + - 0 indicates sequence B is a continuation of sequence A, + - 1 indicates sequence B is a random sequence. + + Returns: + + Example: + + ```python + >>> from transformers import BertTokenizer, BertForNextSentencePrediction + >>> import torch + + >>> tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") + >>> model = BertForNextSentencePrediction.from_pretrained("bert-base-uncased") + + >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light." + >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt") + + >>> outputs = model(**encoding, labels=torch.LongTensor([1])) + >>> logits = outputs.logits + >>> assert logits[0, 0] < logits[0, 1] # next sentence was random + ``` + """ + + if "next_sentence_label" in kwargs: + warnings.warn( + "The `next_sentence_label` argument is deprecated and will be removed in a future version, use" + " `labels` instead.", + FutureWarning, + ) + labels = kwargs.pop("next_sentence_label") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + seq_relationship_scores = self.cls(pooled_output) + + next_sentence_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1)) + + if not return_dict: + output = (seq_relationship_scores,) + outputs[2:] + return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output + + return NextSentencePredictorOutput( + loss=next_sentence_loss, + logits=seq_relationship_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled + output) e.g. for GLUE tasks. + """, + BERT_START_DOCSTRING, +) +class BertForSequenceClassification(BertPreTrainedModel): + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.bert = BertModel(config) + classifier_dropout = (config.classifier_dropout + if config.classifier_dropout is not None else config.hidden_dropout_prob) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_SEQ_CLASS_EXPECTED_OUTPUT, + expected_loss=_SEQ_CLASS_EXPECTED_LOSS, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + BERT_START_DOCSTRING, +) +class BertForMultipleChoice(BertPreTrainedModel): + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config) + classifier_dropout = (config.classifier_dropout + if config.classifier_dropout is not None else config.hidden_dropout_prob) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = (inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None else None) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + BERT_START_DOCSTRING, +) +class BertForTokenClassification(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.bert = BertModel(config, add_pooling_layer=False) + classifier_dropout = (config.classifier_dropout + if config.classifier_dropout is not None else config.hidden_dropout_prob) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_TOKEN_CLASSIFICATION, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_TOKEN_CLASS_EXPECTED_OUTPUT, + expected_loss=_TOKEN_CLASS_EXPECTED_LOSS, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + BERT_START_DOCSTRING, +) +class BertForQuestionAnswering(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.bert = BertModel(config, add_pooling_layer=False) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_QA, + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + qa_target_start_index=_QA_TARGET_START_INDEX, + qa_target_end_index=_QA_TARGET_END_INDEX, + expected_output=_QA_EXPECTED_OUTPUT, + expected_loss=_QA_EXPECTED_LOSS, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + start_positions: Optional[torch.Tensor] = None, + end_positions: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/examples/community/roberta/pretraining/model/deberta_v2.py b/examples/community/roberta/pretraining/model/deberta_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..5fc284911e38723ea9eb6ae2036521096a4323dc --- /dev/null +++ b/examples/community/roberta/pretraining/model/deberta_v2.py @@ -0,0 +1,1643 @@ +# coding=utf-8 +# Copyright 2020 Microsoft and the Hugging Face Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch DeBERTa-v2 model.""" + +import math +from collections.abc import Sequence +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss +from transformers import FillMaskPipeline, T5ForConditionalGeneration, T5Tokenizer +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutput, + MaskedLMOutput, + MultipleChoiceModelOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.models.deberta_v2.configuration_deberta_v2 import DebertaV2Config +from transformers.pytorch_utils import softmax_backward_data +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "DebertaV2Config" +_TOKENIZER_FOR_DOC = "DebertaV2Tokenizer" +_CHECKPOINT_FOR_DOC = "microsoft/deberta-v2-xlarge" + +DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "microsoft/deberta-v2-xlarge", + "microsoft/deberta-v2-xxlarge", + "microsoft/deberta-v2-xlarge-mnli", + "microsoft/deberta-v2-xxlarge-mnli", +] + + +# Copied from transformers.models.deberta.modeling_deberta.ContextPooler +class ContextPooler(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size) + self.dropout = StableDropout(config.pooler_dropout) + self.config = config + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + + context_token = hidden_states[:, 0] + context_token = self.dropout(context_token) + pooled_output = self.dense(context_token) + pooled_output = ACT2FN[self.config.pooler_hidden_act](pooled_output) + return pooled_output + + @property + def output_dim(self): + return self.config.hidden_size + + +# Copied from transformers.models.deberta.modeling_deberta.XSoftmax with deberta->deberta_v2 +class XSoftmax(torch.autograd.Function): + """ + Masked Softmax which is optimized for saving memory + + Args: + input (`torch.tensor`): The input tensor that will apply softmax. + mask (`torch.IntTensor`): + The mask matrix where 0 indicate that element will be ignored in the softmax calculation. + dim (int): The dimension that will apply softmax + + Example: + + ```python + >>> import torch + >>> from transformers.models.deberta_v2.modeling_deberta_v2 import XSoftmax + + >>> # Make a tensor + >>> x = torch.randn([4, 20, 100]) + + >>> # Create a mask + >>> mask = (x > 0).int() + + >>> # Specify the dimension to apply softmax + >>> dim = -1 + + >>> y = XSoftmax.apply(x, mask, dim) + ```""" + + @staticmethod + def forward(self, input, mask, dim): + self.dim = dim + rmask = ~(mask.to(torch.bool)) + + output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min)) + output = torch.softmax(output, self.dim) + output.masked_fill_(rmask, 0) + self.save_for_backward(output) + return output + + @staticmethod + def backward(self, grad_output): + (output,) = self.saved_tensors + inputGrad = softmax_backward_data(self, grad_output, output, self.dim, output) + return inputGrad, None, None + + @staticmethod + def symbolic(g, self, mask, dim): + import torch.onnx.symbolic_helper as sym_help + from torch.onnx.symbolic_opset9 import masked_fill, softmax + + mask_cast_value = g.op("Cast", mask, to_i=sym_help.cast_pytorch_to_onnx["Long"]) + r_mask = g.op( + "Cast", + g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value), + to_i=sym_help.cast_pytorch_to_onnx["Byte"], + ) + output = masked_fill(g, self, r_mask, + g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min))) + output = softmax(g, output, dim) + return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8))) + + +# Copied from transformers.models.deberta.modeling_deberta.DropoutContext +class DropoutContext(object): + + def __init__(self): + self.dropout = 0 + self.mask = None + self.scale = 1 + self.reuse_mask = True + + +# Copied from transformers.models.deberta.modeling_deberta.get_mask +def get_mask(input, local_context): + if not isinstance(local_context, DropoutContext): + dropout = local_context + mask = None + else: + dropout = local_context.dropout + dropout *= local_context.scale + mask = local_context.mask if local_context.reuse_mask else None + + if dropout > 0 and mask is None: + mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).to(torch.bool) + + if isinstance(local_context, DropoutContext): + if local_context.mask is None: + local_context.mask = mask + + return mask, dropout + + +# Copied from transformers.models.deberta.modeling_deberta.XDropout +class XDropout(torch.autograd.Function): + """Optimized dropout function to save computation and memory by using mask operation instead of multiplication.""" + + @staticmethod + def forward(ctx, input, local_ctx): + mask, dropout = get_mask(input, local_ctx) + ctx.scale = 1.0 / (1 - dropout) + if dropout > 0: + ctx.save_for_backward(mask) + return input.masked_fill(mask, 0) * ctx.scale + else: + return input + + @staticmethod + def backward(ctx, grad_output): + if ctx.scale > 1: + (mask,) = ctx.saved_tensors + return grad_output.masked_fill(mask, 0) * ctx.scale, None + else: + return grad_output, None + + +# Copied from transformers.models.deberta.modeling_deberta.StableDropout +class StableDropout(nn.Module): + """ + Optimized dropout module for stabilizing the training + + Args: + drop_prob (float): the dropout probabilities + """ + + def __init__(self, drop_prob): + super().__init__() + self.drop_prob = drop_prob + self.count = 0 + self.context_stack = None + + def forward(self, x): + """ + Call the module + + Args: + x (`torch.tensor`): The input tensor to apply dropout + """ + if self.training and self.drop_prob > 0: + return XDropout.apply(x, self.get_context()) + return x + + def clear_context(self): + self.count = 0 + self.context_stack = None + + def init_context(self, reuse_mask=True, scale=1): + if self.context_stack is None: + self.context_stack = [] + self.count = 0 + for c in self.context_stack: + c.reuse_mask = reuse_mask + c.scale = scale + + def get_context(self): + if self.context_stack is not None: + if self.count >= len(self.context_stack): + self.context_stack.append(DropoutContext()) + ctx = self.context_stack[self.count] + ctx.dropout = self.drop_prob + self.count += 1 + return ctx + else: + return self.drop_prob + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaSelfOutput with DebertaLayerNorm->LayerNorm +class DebertaV2SelfOutput(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = StableDropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaAttention with Deberta->DebertaV2 +class DebertaV2Attention(nn.Module): + + def __init__(self, config): + super().__init__() + self.self = DisentangledSelfAttention(config) + self.output = DebertaV2SelfOutput(config) + self.config = config + + def forward( + self, + hidden_states, + attention_mask, + output_attentions=False, + query_states=None, + relative_pos=None, + rel_embeddings=None, + ): + self_output = self.self( + hidden_states, + attention_mask, + output_attentions, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + ) + if output_attentions: + self_output, att_matrix = self_output + if query_states is None: + query_states = hidden_states + attention_output = self.output(self_output, query_states) + + if output_attentions: + return (attention_output, att_matrix) + else: + return attention_output + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->DebertaV2 +class DebertaV2Intermediate(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaOutput with DebertaLayerNorm->LayerNorm +class DebertaV2Output(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = StableDropout(config.hidden_dropout_prob) + self.config = config + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaLayer with Deberta->DebertaV2 +class DebertaV2Layer(nn.Module): + + def __init__(self, config): + super().__init__() + self.attention = DebertaV2Attention(config) + self.intermediate = DebertaV2Intermediate(config) + self.output = DebertaV2Output(config) + + def forward( + self, + hidden_states, + attention_mask, + query_states=None, + relative_pos=None, + rel_embeddings=None, + output_attentions=False, + ): + attention_output = self.attention( + hidden_states, + attention_mask, + output_attentions=output_attentions, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + ) + if output_attentions: + attention_output, att_matrix = attention_output + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + if output_attentions: + return (layer_output, att_matrix) + else: + return layer_output + + +class ConvLayer(nn.Module): + + def __init__(self, config): + super().__init__() + kernel_size = getattr(config, "conv_kernel_size", 3) + groups = getattr(config, "conv_groups", 1) + self.conv_act = getattr(config, "conv_act", "tanh") + self.conv = nn.Conv1d(config.hidden_size, + config.hidden_size, + kernel_size, + padding=(kernel_size - 1) // 2, + groups=groups) + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = StableDropout(config.hidden_dropout_prob) + self.config = config + + def forward(self, hidden_states, residual_states, input_mask): + out = self.conv(hidden_states.permute(0, 2, 1).contiguous()).permute(0, 2, 1).contiguous() + rmask = (1 - input_mask).bool() + out.masked_fill_(rmask.unsqueeze(-1).expand(out.size()), 0) + out = ACT2FN[self.conv_act](self.dropout(out)) + + layer_norm_input = residual_states + out + output = self.LayerNorm(layer_norm_input).to(layer_norm_input) + + if input_mask is None: + output_states = output + else: + if input_mask.dim() != layer_norm_input.dim(): + if input_mask.dim() == 4: + input_mask = input_mask.squeeze(1).squeeze(1) + input_mask = input_mask.unsqueeze(2) + + input_mask = input_mask.to(output.dtype) + output_states = output * input_mask + + return output_states + + +class DebertaV2Encoder(nn.Module): + """Modified BertEncoder with relative position bias support""" + + def __init__(self, config): + super().__init__() + + self.layer = nn.ModuleList([DebertaV2Layer(config) for _ in range(config.num_hidden_layers)]) + self.relative_attention = getattr(config, "relative_attention", False) + + if self.relative_attention: + self.max_relative_positions = getattr(config, "max_relative_positions", -1) + if self.max_relative_positions < 1: + self.max_relative_positions = config.max_position_embeddings + + self.position_buckets = getattr(config, "position_buckets", -1) + pos_ebd_size = self.max_relative_positions * 2 + + if self.position_buckets > 0: + pos_ebd_size = self.position_buckets * 2 + + # rel = nn.Parameter(torch.empty((pos_ebd_size, config.hidden_size))) + # self.rel_embeddings = nn.init.normal_(rel, mean=0.0, std=config.initializer_range) + self.rel_embeddings = nn.Embedding(pos_ebd_size, config.hidden_size) + + self.norm_rel_ebd = [x.strip() for x in getattr(config, "norm_rel_ebd", "none").lower().split("|")] + + if "layer_norm" in self.norm_rel_ebd: + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True) + + self.conv = ConvLayer(config) if getattr(config, "conv_kernel_size", 0) > 0 else None + self.gradient_checkpointing = False + + def get_rel_embedding(self): + att_span = self.position_buckets + rel_index = torch.arange(0, att_span * 2).long().to(self.rel_embeddings.weight.device) + rel_embeddings = self.rel_embeddings(rel_index) + # rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None + # rel_embeddings = self.rel_embeddings if self.relative_attention else None + if rel_embeddings is not None and ("layer_norm" in self.norm_rel_ebd): + rel_embeddings = self.LayerNorm(rel_embeddings) + return rel_embeddings + + def get_attention_mask(self, attention_mask): + if attention_mask.dim() <= 2: + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1) + attention_mask = attention_mask.byte() + elif attention_mask.dim() == 3: + attention_mask = attention_mask.unsqueeze(1) + + return attention_mask + + def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None): + if self.relative_attention and relative_pos is None: + q = query_states.size(-2) if query_states is not None else hidden_states.size(-2) + relative_pos = build_relative_position(q, + hidden_states.size(-2), + bucket_size=self.position_buckets, + max_position=self.max_relative_positions) + return relative_pos + + def forward( + self, + hidden_states, + attention_mask, + output_hidden_states=True, + output_attentions=False, + query_states=None, + relative_pos=None, + return_dict=True, + ): + if attention_mask.dim() <= 2: + input_mask = attention_mask + else: + input_mask = (attention_mask.sum(-2) > 0).byte() + attention_mask = self.get_attention_mask(attention_mask) + relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos) + + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + if isinstance(hidden_states, Sequence): + next_kv = hidden_states[0] + else: + next_kv = hidden_states + rel_embeddings = self.get_rel_embedding() + output_states = next_kv + for i, layer_module in enumerate(self.layer): + + if output_hidden_states: + all_hidden_states = all_hidden_states + (output_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + output_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + next_kv, + attention_mask, + query_states, + relative_pos, + rel_embeddings, + ) + else: + output_states = layer_module( + next_kv, + attention_mask, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + output_attentions=output_attentions, + ) + + if output_attentions: + output_states, att_m = output_states + + if i == 0 and self.conv is not None: + output_states = self.conv(hidden_states, output_states, input_mask) + + if query_states is not None: + query_states = output_states + if isinstance(hidden_states, Sequence): + next_kv = hidden_states[i + 1] if i + 1 < len(self.layer) else None + else: + next_kv = output_states + + if output_attentions: + all_attentions = all_attentions + (att_m,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (output_states,) + + if not return_dict: + return tuple(v for v in [output_states, all_hidden_states, all_attentions] if v is not None) + return BaseModelOutput(last_hidden_state=output_states, + hidden_states=all_hidden_states, + attentions=all_attentions) + + +def make_log_bucket_position(relative_pos, bucket_size, max_position): + sign = np.sign(relative_pos) + mid = bucket_size // 2 + abs_pos = np.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, np.abs(relative_pos)) + log_pos = np.ceil(np.log(abs_pos / mid) / np.log((max_position - 1) / mid) * (mid - 1)) + mid + bucket_pos = np.where(abs_pos <= mid, relative_pos, log_pos * sign).astype(np.int) + return bucket_pos + + +def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-1): + """ + Build relative position according to the query and key + + We assume the absolute position of query \\(P_q\\) is range from (0, query_size) and the absolute position of key + \\(P_k\\) is range from (0, key_size), The relative positions from query to key is \\(R_{q \\rightarrow k} = P_q - + P_k\\) + + Args: + query_size (int): the length of query + key_size (int): the length of key + bucket_size (int): the size of position bucket + max_position (int): the maximum allowed absolute position + + Return: + `torch.LongTensor`: A tensor with shape [1, query_size, key_size] + + """ + q_ids = np.arange(0, query_size) + k_ids = np.arange(0, key_size) + rel_pos_ids = q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0], 1)) + if bucket_size > 0 and max_position > 0: + rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position) + rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long) + rel_pos_ids = rel_pos_ids[:query_size, :] + rel_pos_ids = rel_pos_ids.unsqueeze(0) + return rel_pos_ids + + +@torch.jit.script +# Copied from transformers.models.deberta.modeling_deberta.c2p_dynamic_expand +def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos): + return c2p_pos.expand([query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)]) + + +@torch.jit.script +# Copied from transformers.models.deberta.modeling_deberta.p2c_dynamic_expand +def p2c_dynamic_expand(c2p_pos, query_layer, key_layer): + return c2p_pos.expand([query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)]) + + +@torch.jit.script +# Copied from transformers.models.deberta.modeling_deberta.pos_dynamic_expand +def pos_dynamic_expand(pos_index, p2c_att, key_layer): + return pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2))) + + +class DisentangledSelfAttention(nn.Module): + """ + Disentangled self-attention module + + Parameters: + config (`DebertaV2Config`): + A model config class instance with the configuration to build a new model. The schema is similar to + *BertConfig*, for more details, please refer [`DebertaV2Config`] + + """ + + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError(f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})") + self.num_attention_heads = config.num_attention_heads + _attention_head_size = config.hidden_size // config.num_attention_heads + self.attention_head_size = getattr(config, "attention_head_size", _attention_head_size) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.query_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + self.key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + self.value_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + + self.share_att_key = getattr(config, "share_att_key", False) + self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else [] + self.relative_attention = getattr(config, "relative_attention", False) + + if self.relative_attention: + self.position_buckets = getattr(config, "position_buckets", -1) + self.max_relative_positions = getattr(config, "max_relative_positions", -1) + if self.max_relative_positions < 1: + self.max_relative_positions = config.max_position_embeddings + self.pos_ebd_size = self.max_relative_positions + if self.position_buckets > 0: + self.pos_ebd_size = self.position_buckets + + self.pos_dropout = StableDropout(config.hidden_dropout_prob) + + if not self.share_att_key: + if "c2p" in self.pos_att_type: + self.pos_key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + if "p2c" in self.pos_att_type: + self.pos_query_proj = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = StableDropout(config.attention_probs_dropout_prob) + # self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True) + + def transpose_for_scores(self, x, attention_heads): + new_x_shape = x.size()[:-1] + (attention_heads, -1) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1)) + + def forward( + self, + hidden_states, + attention_mask, + output_attentions=False, + query_states=None, + relative_pos=None, + rel_embeddings=None, + ): + """ + Call the module + + Args: + hidden_states (`torch.FloatTensor`): + Input states to the module usually the output from previous layer, it will be the Q,K and V in + *Attention(Q,K,V)* + + attention_mask (`torch.ByteTensor`): + An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum + sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j* + th token. + + output_attentions (`bool`, optional): + Whether return the attention matrix. + + query_states (`torch.FloatTensor`, optional): + The *Q* state in *Attention(Q,K,V)*. + + relative_pos (`torch.LongTensor`): + The relative position encoding between the tokens in the sequence. It's of shape [*B*, *N*, *N*] with + values ranging in [*-max_relative_positions*, *max_relative_positions*]. + + rel_embeddings (`torch.FloatTensor`): + The embedding of relative distances. It's a tensor of shape [\\(2 \\times + \\text{max_relative_positions}\\), *hidden_size*]. + + + """ + if query_states is None: + query_states = hidden_states + query_layer = self.transpose_for_scores(self.query_proj(query_states), self.num_attention_heads) + key_layer = self.transpose_for_scores(self.key_proj(hidden_states), self.num_attention_heads) + value_layer = self.transpose_for_scores(self.value_proj(hidden_states), self.num_attention_heads) + + rel_att = None + # Take the dot product between "query" and "key" to get the raw attention scores. + scale_factor = 1 + if "c2p" in self.pos_att_type: + scale_factor += 1 + if "p2c" in self.pos_att_type: + scale_factor += 1 + scale = math.sqrt(query_layer.size(-1) * scale_factor) + attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale + if self.relative_attention: + rel_embeddings = self.pos_dropout(rel_embeddings) + rel_att = self.disentangled_attention_bias(query_layer, key_layer, relative_pos, rel_embeddings, + scale_factor) + + if rel_att is not None: + attention_scores = attention_scores + rel_att + attention_scores = attention_scores + attention_scores = attention_scores.view(-1, self.num_attention_heads, attention_scores.size(-2), + attention_scores.size(-1)) + + # bsz x height x length x dimension + attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1) + attention_probs = self.dropout(attention_probs) + context_layer = torch.bmm(attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), + value_layer) + context_layer = (context_layer.view(-1, self.num_attention_heads, context_layer.size(-2), + context_layer.size(-1)).permute(0, 2, 1, 3).contiguous()) + new_context_layer_shape = context_layer.size()[:-2] + (-1,) + context_layer = context_layer.view(new_context_layer_shape) + if output_attentions: + return (context_layer, attention_probs) + else: + return context_layer + + def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor): + if relative_pos is None: + q = query_layer.size(-2) + relative_pos = build_relative_position(q, + key_layer.size(-2), + bucket_size=self.position_buckets, + max_position=self.max_relative_positions) + if relative_pos.dim() == 2: + relative_pos = relative_pos.unsqueeze(0).unsqueeze(0) + elif relative_pos.dim() == 3: + relative_pos = relative_pos.unsqueeze(1) + # bsz x height x query x key + elif relative_pos.dim() != 4: + raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}") + + att_span = self.pos_ebd_size + relative_pos = relative_pos.long().to(query_layer.device) + + # rel_index = torch.arange(0, att_span * 2).long().to(query_layer.device) + # rel_embeddings = rel_embeddings(rel_index).unsqueeze(0) + rel_embeddings = rel_embeddings.unsqueeze(0) + # rel_embeddings = rel_embeddings.unsqueeze(0) + # rel_embeddings = rel_embeddings[0 : att_span * 2, :].unsqueeze(0) + if self.share_att_key: + pos_query_layer = self.transpose_for_scores(self.query_proj(rel_embeddings), + self.num_attention_heads).repeat( + query_layer.size(0) // self.num_attention_heads, 1, 1) + pos_key_layer = self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads).repeat( + query_layer.size(0) // self.num_attention_heads, 1, 1) + else: + if "c2p" in self.pos_att_type: + pos_key_layer = self.transpose_for_scores(self.pos_key_proj(rel_embeddings), + self.num_attention_heads).repeat( + query_layer.size(0) // self.num_attention_heads, 1, + 1) # .split(self.all_head_size, dim=-1) + if "p2c" in self.pos_att_type: + pos_query_layer = self.transpose_for_scores(self.pos_query_proj(rel_embeddings), + self.num_attention_heads).repeat( + query_layer.size(0) // self.num_attention_heads, 1, + 1) # .split(self.all_head_size, dim=-1) + + score = 0 + # content->position + if "c2p" in self.pos_att_type: + scale = math.sqrt(pos_key_layer.size(-1) * scale_factor) + c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2)) + c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1) + c2p_att = torch.gather( + c2p_att, + dim=-1, + index=c2p_pos.squeeze(0).expand([query_layer.size(0), + query_layer.size(1), + relative_pos.size(-1)]), + ) + score += c2p_att / scale + + # position->content + if "p2c" in self.pos_att_type: + scale = math.sqrt(pos_query_layer.size(-1) * scale_factor) + if key_layer.size(-2) != query_layer.size(-2): + r_pos = build_relative_position( + key_layer.size(-2), + key_layer.size(-2), + bucket_size=self.position_buckets, + max_position=self.max_relative_positions, + ).to(query_layer.device) + r_pos = r_pos.unsqueeze(0) + else: + r_pos = relative_pos + + p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1) + p2c_att = torch.bmm(key_layer, pos_query_layer.transpose(-1, -2)) + p2c_att = torch.gather( + p2c_att, + dim=-1, + index=p2c_pos.squeeze(0).expand([query_layer.size(0), + key_layer.size(-2), + key_layer.size(-2)]), + ).transpose(-1, -2) + score += p2c_att / scale + + return score + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaEmbeddings with DebertaLayerNorm->LayerNorm +class DebertaV2Embeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + pad_token_id = getattr(config, "pad_token_id", 0) + self.embedding_size = getattr(config, "embedding_size", config.hidden_size) + self.word_embeddings = nn.Embedding(config.vocab_size, self.embedding_size, padding_idx=pad_token_id) + + self.position_biased_input = getattr(config, "position_biased_input", True) + if not self.position_biased_input: + self.position_embeddings = None + else: + self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.embedding_size) + + if config.type_vocab_size > 0: + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, self.embedding_size) + + if self.embedding_size != config.hidden_size: + self.embed_proj = nn.Linear(self.embedding_size, config.hidden_size, bias=False) + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = StableDropout(config.hidden_dropout_prob) + self.config = config + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + + def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=None, inputs_embeds=None): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + if self.position_embeddings is not None: + position_embeddings = self.position_embeddings(position_ids.long()) + else: + position_embeddings = torch.zeros_like(inputs_embeds) + + embeddings = inputs_embeds + if self.position_biased_input: + embeddings += position_embeddings + if self.config.type_vocab_size > 0: + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings += token_type_embeddings + + if self.embedding_size != self.config.hidden_size: + embeddings = self.embed_proj(embeddings) + + embeddings = self.LayerNorm(embeddings) + + if mask is not None: + if mask.dim() != embeddings.dim(): + if mask.dim() == 4: + mask = mask.squeeze(1).squeeze(1) + mask = mask.unsqueeze(2) + mask = mask.to(embeddings.dtype) + + embeddings = embeddings * mask + + embeddings = self.dropout(embeddings) + return embeddings + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaPreTrainedModel with Deberta->DebertaV2 +class DebertaV2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = DebertaV2Config + base_model_prefix = "deberta" + _keys_to_ignore_on_load_missing = ["position_ids"] + _keys_to_ignore_on_load_unexpected = ["position_embeddings"] + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, DebertaV2Encoder): + module.gradient_checkpointing = value + + +DEBERTA_START_DOCSTRING = r""" + The DeBERTa model was proposed in [DeBERTa: Decoding-enhanced BERT with Disentangled + Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. It's build + on top of BERT/RoBERTa with two improvements, i.e. disentangled attention and enhanced mask decoder. With those two + improvements, it out perform BERT/RoBERTa on a majority of tasks with 80GB pretraining data. + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior.``` + + + Parameters: + config ([`DebertaV2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +DEBERTA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`DebertaV2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert *input_ids* indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare DeBERTa Model transformer outputting raw hidden-states without any specific head on top.", + DEBERTA_START_DOCSTRING, +) +# Copied from transformers.models.deberta.modeling_deberta.DebertaModel with Deberta->DebertaV2 +class DebertaV2Model(DebertaV2PreTrainedModel): + + def __init__(self, config): + super().__init__(config) + + self.embeddings = DebertaV2Embeddings(config) + self.encoder = DebertaV2Encoder(config) + self.z_steps = 0 + self.config = config + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, new_embeddings): + self.embeddings.word_embeddings = new_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + raise NotImplementedError("The prune function is not implemented in DeBERTa model.") + + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=device) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + embedding_output = self.embeddings( + input_ids=input_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + mask=attention_mask, + inputs_embeds=inputs_embeds, + ) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask, + output_hidden_states=True, + output_attentions=output_attentions, + return_dict=return_dict, + ) + encoded_layers = encoder_outputs[1] + + if self.z_steps > 1: + hidden_states = encoded_layers[-2] + layers = [self.encoder.layer[-1] for _ in range(self.z_steps)] + query_states = encoded_layers[-1] + rel_embeddings = self.encoder.get_rel_embedding() + attention_mask = self.encoder.get_attention_mask(attention_mask) + rel_pos = self.encoder.get_rel_pos(embedding_output) + for layer in layers[1:]: + query_states = layer( + hidden_states, + attention_mask, + output_attentions=False, + query_states=query_states, + relative_pos=rel_pos, + rel_embeddings=rel_embeddings, + ) + encoded_layers.append(query_states) + + sequence_output = encoded_layers[-1] + + if not return_dict: + return (sequence_output,) + encoder_outputs[(1 if output_hidden_states else 2):] + + return BaseModelOutput( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states if output_hidden_states else None, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings("""DeBERTa Model with a `language modeling` head on top.""", DEBERTA_START_DOCSTRING) +# Copied from transformers.models.deberta.modeling_deberta.DebertaForMaskedLM with Deberta->DebertaV2 +class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.deberta = DebertaV2Model(config) + self.cls = DebertaV2OnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, MaskedLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +# copied from transformers.models.bert.BertPredictionHeadTransform with bert -> deberta +class DebertaV2PredictionHeadTransform(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +# copied from transformers.models.bert.BertLMPredictionHead with bert -> deberta +class DebertaV2LMPredictionHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.transform = DebertaV2PredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +# copied from transformers.models.bert.BertOnlyMLMHead with bert -> deberta +class DebertaV2OnlyMLMHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.predictions = DebertaV2LMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +@add_start_docstrings( + """ + DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """, + DEBERTA_START_DOCSTRING, +) +# Copied from transformers.models.deberta.modeling_deberta.DebertaForSequenceClassification with Deberta->DebertaV2 +class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel): + + def __init__(self, config): + super().__init__(config) + + num_labels = getattr(config, "num_labels", 2) + self.num_labels = num_labels + + self.deberta = DebertaV2Model(config) + self.pooler = ContextPooler(config) + output_dim = self.pooler.output_dim + + self.classifier = nn.Linear(output_dim, num_labels) + drop_out = getattr(config, "cls_dropout", None) + drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out + self.dropout = StableDropout(drop_out) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.deberta.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + self.deberta.set_input_embeddings(new_embeddings) + + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + encoder_layer = outputs[0] + pooled_output = self.pooler(encoder_layer) + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + # regression task + loss_fn = nn.MSELoss() + logits = logits.view(-1).to(labels.dtype) + loss = loss_fn(logits, labels.view(-1)) + elif labels.dim() == 1 or labels.size(-1) == 1: + label_index = (labels >= 0).nonzero() + labels = labels.long() + if label_index.size(0) > 0: + labeled_logits = torch.gather(logits, 0, label_index.expand(label_index.size(0), + logits.size(1))) + labels = torch.gather(labels, 0, label_index.view(-1)) + loss_fct = CrossEntropyLoss() + loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1)) + else: + loss = torch.tensor(0).to(logits) + else: + log_softmax = nn.LogSoftmax(-1) + loss = -((log_softmax(logits) * labels).sum(-1)).mean() + elif self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput(loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions) + + +@add_start_docstrings( + """ + DeBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + DEBERTA_START_DOCSTRING, +) +# Copied from transformers.models.deberta.modeling_deberta.DebertaForTokenClassification with Deberta->DebertaV2 +class DebertaV2ForTokenClassification(DebertaV2PreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.deberta = DebertaV2Model(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput(loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions) + + +@add_start_docstrings( + """ + DeBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + DEBERTA_START_DOCSTRING, +) +# Copied from transformers.models.deberta.modeling_deberta.DebertaForQuestionAnswering with Deberta->DebertaV2 +class DebertaV2ForQuestionAnswering(DebertaV2PreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.deberta = DebertaV2Model(config) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + start_positions: Optional[torch.Tensor] = None, + end_positions: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[1:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + DeBERTa Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + DEBERTA_START_DOCSTRING, +) +class DebertaV2ForMultipleChoice(DebertaV2PreTrainedModel): + + def __init__(self, config): + super().__init__(config) + + num_labels = getattr(config, "num_labels", 2) + self.num_labels = num_labels + + self.deberta = DebertaV2Model(config) + self.pooler = ContextPooler(config) + output_dim = self.pooler.output_dim + + self.classifier = nn.Linear(output_dim, 1) + drop_out = getattr(config, "cls_dropout", None) + drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out + self.dropout = StableDropout(drop_out) + + self.init_weights() + + def get_input_embeddings(self): + return self.deberta.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + self.deberta.set_input_embeddings(new_embeddings) + + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + flat_inputs_embeds = (inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None else None) + + outputs = self.deberta( + flat_input_ids, + position_ids=flat_position_ids, + token_type_ids=flat_token_type_ids, + attention_mask=flat_attention_mask, + inputs_embeds=flat_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + encoder_layer = outputs[0] + pooled_output = self.pooler(encoder_layer) + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/examples/community/roberta/pretraining/nvidia_bert_dataset_provider.py b/examples/community/roberta/pretraining/nvidia_bert_dataset_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..72c7bd852a401daaf41a6ef9d5b123c7faf4085f --- /dev/null +++ b/examples/community/roberta/pretraining/nvidia_bert_dataset_provider.py @@ -0,0 +1,170 @@ +import json +import logging +import os +import random +import time +from concurrent.futures import ProcessPoolExecutor + +import h5py +import numpy as np +import torch +import torch.distributed as dist +from bert_dataset_provider import BertDatasetProviderInterface +from torch.utils.data import DataLoader, Dataset +from torch.utils.data.distributed import DistributedSampler +from torch.utils.data.sampler import RandomSampler + +import colossalai.utils as utils + + +# Workaround because python functions are not picklable +class WorkerInitObj(object): + + def __init__(self, seed): + self.seed = seed + + def __call__(self, id): + np.random.seed(seed=self.seed + id) + random.seed(self.seed + id) + + +def create_pretraining_dataset(input_file, max_predictions_per_seq, num_workers, train_batch_size, worker_init, + data_sampler): + train_data = pretraining_dataset(input_file=input_file, max_predictions_per_seq=max_predictions_per_seq) + train_dataloader = DataLoader(train_data, + sampler=data_sampler(train_data), + batch_size=train_batch_size, + num_workers=num_workers, + worker_init_fn=worker_init, + pin_memory=True) + return train_dataloader, len(train_data) + + +class pretraining_dataset(Dataset): + + def __init__(self, input_file, max_predictions_per_seq): + self.input_file = input_file + self.max_predictions_per_seq = max_predictions_per_seq + f = h5py.File(input_file, "r") + keys = ['input_ids', 'input_mask', 'segment_ids', 'masked_lm_positions'] + self.inputs = [np.asarray(f[key][:]) for key in keys] + f.close() + + def __len__(self): + 'Denotes the total number of samples' + return len(self.inputs[0]) + + def __getitem__(self, index): + + [input_ids, input_mask, segment_ids, masked_lm_labels] = [ + torch.from_numpy(input[index].astype(np.int64)) if indice < 5 else torch.from_numpy( + np.asarray(input[index].astype(np.int64))) for indice, input in enumerate(self.inputs) + ] + + return [input_ids, input_mask, segment_ids, masked_lm_labels] + + +class NvidiaBertDatasetProvider(BertDatasetProviderInterface): + + def __init__(self, args, evaluate=False): + self.num_workers = args.num_workers + self.max_seq_length = args.max_seq_length + self.max_predictions_per_seq = args.max_predictions_per_seq + + self.gradient_accumulation_steps = args.gradient_accumulation_steps + if not evaluate: + self.train_micro_batch_size_per_gpu = args.train_micro_batch_size_per_gpu + else: + self.train_micro_batch_size_per_gpu = args.eval_micro_batch_size_per_gpu + self.logger = args.logger + + self.global_rank = dist.get_rank() + self.world_size = dist.get_world_size() + + # Initialize dataset files + if not evaluate: + self.dataset_files = [ + os.path.join(args.data_path_prefix, f) + for f in os.listdir(args.data_path_prefix) + if os.path.isfile(os.path.join(args.data_path_prefix, f)) and 'h5' in f + ] + else: + self.dataset_files = [ + os.path.join(args.eval_data_path_prefix, f) + for f in os.listdir(args.eval_data_path_prefix) + if os.path.isfile(os.path.join(args.eval_data_path_prefix, f)) and 'h5' in f + ] + + self.dataset_files.sort() + # random.shuffle(self.dataset_files) + self.num_files = len(self.dataset_files) + # self.data_sampler = RandomSampler + self.data_sampler = DistributedSampler + + self.worker_init = WorkerInitObj(args.seed + args.local_rank) + self.dataset_future = None + self.pool = ProcessPoolExecutor(1) + self.data_file = None + self.shuffle = True + + if self.global_rank == 0: + self.logger.info(f"NvidiaBertDatasetProvider - Initialization: num_files = {self.num_files}") + + def get_shard(self, index): + start = time.time() + if self.dataset_future is None: + self.data_file = self._get_shard_file(index) + self.train_dataloader, sample_count = create_pretraining_dataset( + input_file=self.data_file, + max_predictions_per_seq=self.max_predictions_per_seq, + num_workers=self.num_workers, + train_batch_size=self.train_micro_batch_size_per_gpu, + worker_init=self.worker_init, + data_sampler=self.data_sampler) + else: + self.train_dataloader, sample_count = self.dataset_future.result(timeout=None) + + self.logger.info( + f"Data Loading Completed for Pretraining Data from {self.data_file} with {sample_count} samples took {time.time()-start:.2f}s." + ) + + return self.train_dataloader, sample_count + + def release_shard(self): + del self.train_dataloader + self.pool.shutdown() + + def prefetch_shard(self, index): + self.data_file = self._get_shard_file(index) + self.dataset_future = self.pool.submit(create_pretraining_dataset, self.data_file, self.max_predictions_per_seq, + self.num_workers, self.train_micro_batch_size_per_gpu, self.worker_init, + self.data_sampler) + + def get_batch(self, batch_iter): + return batch_iter + + def prefetch_batch(self): + pass + + def _get_shard_file(self, shard_index): + file_index = self._get_shard_file_index(shard_index, self.global_rank) + return self.dataset_files[file_index] + + def _get_shard_file_index(self, shard_index, global_rank): + # if dist.is_initialized() and self.world_size > self.num_files: + # remainder = self.world_size % self.num_files + # file_index = (shard_index * self.world_size) + global_rank + ( + # remainder * shard_index) + # else: + # file_index = shard_index * self.world_size + global_rank + + return shard_index % self.num_files + + def shuffle_dataset(self, epoch): + if self.shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.epoch) + indices = torch.randperm(self.num_files, generator=g).tolist() + new_dataset = [self.dataset_files[i] for i in indices] + self.dataset_files = new_dataset diff --git a/examples/community/roberta/pretraining/pretrain_utils.py b/examples/community/roberta/pretraining/pretrain_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cea6ac2c36e5f0225ce4b50a6af4a551986d0951 --- /dev/null +++ b/examples/community/roberta/pretraining/pretrain_utils.py @@ -0,0 +1,124 @@ +import logging +import os +import sys + +import torch +import transformers +from torch.optim import AdamW +from transformers import ( + AutoModelForMaskedLM, + AutoTokenizer, + BertForPreTraining, + GPT2Config, + GPT2LMHeadModel, + RobertaConfig, + RobertaForMaskedLM, + get_linear_schedule_with_warmup, +) + +from colossalai.core import global_context as gpc +from colossalai.nn.lr_scheduler import LinearWarmupLR +from colossalai.nn.optimizer import FusedAdam, HybridAdam + +sys.path.append(os.getcwd()) +from collections import OrderedDict + +import torch.nn as nn +from model.bert import BertForMaskedLM +from model.deberta_v2 import DebertaV2ForMaskedLM + +__all__ = ['get_model', 'get_optimizer', 'get_lr_scheduler', 'get_dataloader_for_pretraining'] + + +def get_new_state_dict(state_dict, start_index=13): + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k[start_index:] + new_state_dict[name] = v + return new_state_dict + + +class LMModel(nn.Module): + + def __init__(self, model, config, args): + super().__init__() + + self.checkpoint = args.checkpoint_activations + self.config = config + self.model = model + if self.checkpoint: + self.model.gradient_checkpointing_enable() + + def forward(self, input_ids, token_type_ids=None, attention_mask=None): + # Only return lm_logits + return self.model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + + +def get_model(args, logger): + + if args.mlm == 'bert': + config = transformers.BertConfig.from_json_file(args.bert_config) + model = BertForMaskedLM(config) + elif args.mlm == 'deberta_v2': + config = transformers.DebertaV2Config.from_json_file(args.bert_config) + model = DebertaV2ForMaskedLM(config) + else: + raise Exception("Invalid mlm!") + + if len(args.load_pretrain_model) > 0: + assert os.path.exists(args.load_pretrain_model) + # load_checkpoint(args.load_pretrain_model, model, strict=False) + m_state_dict = torch.load(args.load_pretrain_model, + map_location=torch.device(f"cuda:{torch.cuda.current_device()}")) + # new_state_dict = get_new_state_dict(m_state_dict) + model.load_state_dict(m_state_dict, + strict=True) # must insure that every process have identical parameters !!!!!!! + logger.info("load model success") + + numel = sum([p.numel() for p in model.parameters()]) + if args.checkpoint_activations: + model.gradient_checkpointing_enable() + # model = LMModel(model, config, args) + + return config, model, numel + + +def get_optimizer(model, lr): + param_optimizer = list(model.named_parameters()) + no_decay = ['bias', 'gamma', 'beta', 'LayerNorm'] + + # configure the weight decay for bert models + optimizer_grouped_parameters = [{ + 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], + 'weight_decay': 0.1 + }, { + 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], + 'weight_decay': 0.0 + }] + optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, betas=[0.9, 0.95]) + return optimizer + + +def get_lr_scheduler(optimizer, total_steps, warmup_steps=2000, last_epoch=-1): + # warmup_steps = int(total_steps * warmup_ratio) + lr_scheduler = get_linear_schedule_with_warmup(optimizer, + num_warmup_steps=warmup_steps, + num_training_steps=total_steps, + last_epoch=last_epoch) + # lr_scheduler = LinearWarmupLR(optimizer, total_steps=total_steps, warmup_steps=warmup_steps) + return lr_scheduler + + +def save_ckpt(model, optimizer, lr_scheduler, path, epoch, shard, global_step): + model_path = path + '_pytorch_model.bin' + optimizer_lr_path = path + '.op_lrs' + checkpoint = {} + checkpoint['optimizer'] = optimizer.state_dict() + checkpoint['lr_scheduler'] = lr_scheduler.state_dict() + checkpoint['epoch'] = epoch + checkpoint['shard'] = shard + checkpoint['global_step'] = global_step + model_state = model.state_dict() #each process must run model.state_dict() + if gpc.get_global_rank() == 0: + torch.save(checkpoint, optimizer_lr_path) + torch.save(model_state, model_path) diff --git a/examples/community/roberta/pretraining/run_pretrain.sh b/examples/community/roberta/pretraining/run_pretrain.sh new file mode 100644 index 0000000000000000000000000000000000000000..280dba714de5d03db20edc132814bb73d15e273f --- /dev/null +++ b/examples/community/roberta/pretraining/run_pretrain.sh @@ -0,0 +1,37 @@ +#!/usr/bin/env sh + +root_path=$PWD +PY_FILE_PATH="$root_path/run_pretraining.py" + +tensorboard_path="$root_path/tensorboard" +log_path="$root_path/exp_log" +ckpt_path="$root_path/ckpt" + + +mkdir -p $tensorboard_path +mkdir -p $log_path +mkdir -p $ckpt_path + +export PYTHONPATH=$PWD + +env OMP_NUM_THREADS=40 colossalai run --hostfile ./hostfile \ + --include GPU002,GPU003,GPU004,GPU007 \ + --nproc_per_node=8 \ + $PY_FILE_PATH \ + --master_addr GPU007 \ + --master_port 20024 \ + --lr 2.0e-4 \ + --train_micro_batch_size_per_gpu 190 \ + --eval_micro_batch_size_per_gpu 20 \ + --epoch 15 \ + --data_path_prefix /h5 \ + --eval_data_path_prefix /eval_h5 \ + --tokenizer_path /roberta \ + --bert_config /roberta/config.json \ + --tensorboard_path $tensorboard_path \ + --log_path $log_path \ + --ckpt_path $ckpt_path \ + --log_interval 50 \ + --mlm bert \ + --wandb \ + --checkpoint_activations \ diff --git a/examples/community/roberta/pretraining/run_pretrain_resume.sh b/examples/community/roberta/pretraining/run_pretrain_resume.sh new file mode 100644 index 0000000000000000000000000000000000000000..8f443b454d7d7ed7f207a9ebca82f08ef5932577 --- /dev/null +++ b/examples/community/roberta/pretraining/run_pretrain_resume.sh @@ -0,0 +1,40 @@ +#!/usr/bin/env sh + +root_path=$PWD +PY_FILE_PATH="$root_path/run_pretraining.py" + +tensorboard_path="$root_path/tensorboard" +log_path="$root_path/exp_log" +ckpt_path="$root_path/ckpt" + + +mkdir -p $tensorboard_path +mkdir -p $log_path +mkdir -p $ckpt_path + +export PYTHONPATH=$PWD + +env OMP_NUM_THREADS=40 colossalai run --hostfile ./hostfile \ + --include GPU002,GPU003,GPU004,GPU007 \ + --nproc_per_node=8 \ + $PY_FILE_PATH \ + --master_addr GPU007 \ + --master_port 20024 \ + --lr 2.0e-4 \ + --train_micro_batch_size_per_gpu 190 \ + --eval_micro_batch_size_per_gpu 20 \ + --epoch 15 \ + --data_path_prefix /h5 \ + --eval_data_path_prefix /eval_h5 \ + --tokenizer_path /roberta \ + --bert_config /roberta/config.json \ + --tensorboard_path $tensorboard_path \ + --log_path $log_path \ + --ckpt_path $ckpt_path \ + --log_interval 50 \ + --mlm bert \ + --wandb \ + --checkpoint_activations \ + --resume_train \ + --load_pretrain_model /ckpt/1.pt \ + --load_optimizer_lr /ckpt/1.op_lrs \ diff --git a/examples/community/roberta/pretraining/run_pretraining.py b/examples/community/roberta/pretraining/run_pretraining.py new file mode 100644 index 0000000000000000000000000000000000000000..9a6ffc1c566165ac63365864740e7dce5b2c09fa --- /dev/null +++ b/examples/community/roberta/pretraining/run_pretraining.py @@ -0,0 +1,263 @@ +import math +import os +import time +from functools import partial + +import torch +from arguments import parse_args +from evaluation import evaluate +from loss import LossForPretraining +from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider +from pretrain_utils import get_lr_scheduler, get_model, get_optimizer, save_ckpt +from tqdm import tqdm +from transformers import AutoTokenizer +from utils.exp_util import get_mem_info, get_tflops, log_args, throughput_calculator +from utils.global_vars import get_tensorboard_writer, get_timers, set_global_variables +from utils.logger import Logger + +import colossalai +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.nn.parallel import GeminiDDP, zero_model_wrapper, zero_optim_wrapper +from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec +from colossalai.utils import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ZeroOptimizer + + +def main(): + + args = parse_args() + launch_time = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime()) + + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) + + # os.environ['CUDA_LAUNCH_BLOCKING'] = '1' + + logger = Logger(os.path.join(args.log_path, launch_time), cuda=torch.cuda.is_available(), debug=args.vscode_debug) + + if args.vscode_debug: + colossalai.launch(config={}, + rank=args.rank, + world_size=args.world_size, + host=args.host, + port=args.port, + backend=args.backend) + args.local_rank = -1 + args.log_interval = 1 + else: + colossalai.launch_from_torch(config={}) #args.colossal_config + args.local_rank = int(os.environ["LOCAL_RANK"]) + logger.info( + f'launch_from_torch, world size: {torch.distributed.get_world_size()} | ' + + f'ParallelMode.MODEL: {ParallelMode.MODEL} | ParallelMode.DATA: {ParallelMode.DATA} | ParallelMode.TENSOR: {ParallelMode.TENSOR}' + ) + + log_args(logger, args) + args.tokenizer = tokenizer + args.logger = logger + set_global_variables(launch_time, args.tensorboard_path) + + world_size = torch.distributed.get_world_size() + init_dev = get_current_device() + + # build model, optimizer and criterion + if args.distplan.startswith("CAI"): + # all param must use the same process group. + world_size = torch.distributed.get_world_size() + shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None + default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None + + if args.shardinit and args.distplan != "CAI_Gemini": + raise RuntimeError("You can only use shardinit with CAI_Gemini") + + # build GPT model + with ColoInitContext(device=get_current_device(), + dtype=torch.half, + default_dist_spec=default_dist_spec, + default_pg=shard_pg): + config, model, numel = get_model(args, logger) + + # asign running configurations + gemini_config = None + if args.distplan.startswith("CAI_ZeRO"): + optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True) + elif args.distplan == "CAI_Gemini": + gemini_config = dict(strict_ddp_mode=args.tp_degree == 1, + device=get_current_device(), + placement_policy=args.placement, + pin_memory=True, + hidden_dim=model.config.hidden_size, + search_range_mb=128) + optim_config = dict(gpu_margin_mem_ratio=0.) + else: + raise RuntimeError + + # build a highly optimized gpu/cpu optimizer + optimizer = get_optimizer(model, lr=args.lr) + + if args.distplan == "CAI_ZeRO1": + zero_stage = 1 + elif args.distplan == "CAI_ZeRO2": + zero_stage = 2 + elif args.distplan == "CAI_Gemini": + zero_stage = 3 + else: + raise RuntimeError + + # wrap your model and optimizer + model = zero_model_wrapper(model, zero_stage, gemini_config) + optimizer = zero_optim_wrapper(model, optimizer, optim_config=optim_config) + + logger.info(get_mem_info(prefix='After init optim, ')) + + else: + config, model, numel = get_model(args, logger) + logger.info("no_zero") + + if torch.distributed.get_rank() == 0: + os.mkdir(os.path.join(args.ckpt_path, launch_time)) + + logger.info(f'Model numel: {numel}') + + get_tflops_func = partial(get_tflops, numel, args.train_micro_batch_size_per_gpu, args.max_seq_length) + + # 144003367 is is the length of the entire dataset + steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size #len(dataloader) + total_steps = steps_per_epoch * args.epoch + + lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=-1) + + start_epoch = 0 + start_shard = 0 + global_step = 0 + if args.resume_train: + assert os.path.exists(args.load_optimizer_lr) + o_l_state_dict = torch.load(args.load_optimizer_lr, map_location='cpu') + o_l_state_dict['lr_scheduler']['last_epoch'] = o_l_state_dict['lr_scheduler']['last_epoch'] - 1 + optimizer.load_state_dict(o_l_state_dict['optimizer']) + # o_l_state_dict['lr_scheduler']['last_epoch'] + lr_scheduler = get_lr_scheduler(optimizer, + total_steps=total_steps, + last_epoch=o_l_state_dict['lr_scheduler']['last_epoch']) + for state in optimizer.state.values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.cuda(f"cuda:{torch.cuda.current_device()}") + # if you want delete the above three code, must move the model to gpu. Because in optimizer.step() + lr_scheduler.load_state_dict(o_l_state_dict['lr_scheduler']) + + start_epoch = o_l_state_dict['epoch'] + start_shard = o_l_state_dict['shard'] + 1 + # global_step = o_l_state_dict['global_step'] + 1 + logger.info( + f'resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}' + ) + + criterion = LossForPretraining(config.vocab_size) + + # build dataloader + pretrain_dataset_provider = NvidiaBertDatasetProvider(args) + + logger.info(get_mem_info(prefix='After init model, ')) + + best_loss = None + eval_loss = 0 + train_loss = 0 + timers = get_timers() + timers('interval_time').start() + timers('epoch_time').start() + timers('shard_time').start() + + for epoch in range(start_epoch, args.epoch): + + for shard in range(start_shard, len(os.listdir(args.data_path_prefix))): + + dataset_iterator, total_length = pretrain_dataset_provider.get_shard(shard) + # pretrain_dataset_provider.prefetch_shard(shard + 1) # may cause cpu memory overload + if torch.distributed.get_rank() == 0: + iterator_data = tqdm(enumerate(dataset_iterator), + total=(total_length // args.train_micro_batch_size_per_gpu // world_size), + colour='cyan', + smoothing=1) + else: + iterator_data = enumerate(dataset_iterator) + + model.train() + + for step, batch_data in iterator_data: + + # batch_data = pretrain_dataset_provider.get_batch(batch_index) + input_ids = batch_data[0].cuda(f"cuda:{torch.cuda.current_device()}") + attention_mask = batch_data[1].cuda(f"cuda:{torch.cuda.current_device()}") + token_type_ids = batch_data[2].cuda(f"cuda:{torch.cuda.current_device()}") + mlm_label = batch_data[3].cuda(f"cuda:{torch.cuda.current_device()}") + # nsp_label = batch_data[5].cuda() + + output = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + + loss = criterion(output.logits, mlm_label) + pretrain_dataset_provider.prefetch_batch() + + optimizer.backward(loss) + train_loss += loss.float().item() + # if (step + 1) % args.accumulation_step == 0: + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + global_step += 1 + + if global_step % args.log_interval == 0 and global_step != 0 \ + and torch.distributed.get_rank() == 0: + elapsed_time = timers('interval_time').elapsed(reset=False) + elapsed_time_per_iteration = elapsed_time / global_step + samples_per_sec, tflops, approx_parameters_in_billions = throughput_calculator( + numel, args, config, elapsed_time, global_step, world_size) + + cur_loss = train_loss / args.log_interval + current_lr = lr_scheduler.get_last_lr()[0] + log_str = f'| epoch: {epoch} | shard: {shard} | step: {global_step} | lr {current_lr:.7f} | elapsed_time: {elapsed_time / 60 :.3f} minutes ' + \ + f'| mins/batch: {elapsed_time_per_iteration :.3f} seconds | loss: {cur_loss:.7f} | ppl: {math.exp(cur_loss):.3f} | TFLOPS: {get_tflops_func(elapsed_time_per_iteration):.3f} or {tflops:.3f}' + logger.info(log_str, print_=False) + + if args.wandb: + tensorboard_log = get_tensorboard_writer() + tensorboard_log.log_train( + { + 'lr': current_lr, + 'loss': cur_loss, + 'ppl': math.exp(cur_loss), + 'mins_batch': elapsed_time_per_iteration + }, global_step) + + train_loss = 0 + + logger.info(f'epoch {epoch} shard {shard} has cost {timers("shard_time").elapsed() / 60 :.3f} mins') + logger.info('*' * 100) + + eval_loss += evaluate(model, args, logger, global_step, criterion) + save_ckpt(model, optimizer, lr_scheduler, + os.path.join(args.ckpt_path, launch_time, f'epoch-{epoch}_shard-{shard}_' + launch_time), epoch, + shard, global_step) + + eval_loss /= len(os.listdir(args.data_path_prefix)) + logger.info( + f'epoch {epoch} | shard_length {len(os.listdir(args.data_path_prefix))} | elapsed_time: {timers("epoch_time").elapsed() / 60 :.3f} mins' + + f'eval_loss: {eval_loss} | ppl: {math.exp(eval_loss)}') + logger.info('-' * 100) + if args.wandb and torch.distributed.get_rank() == 0: + tensorboard_log = get_tensorboard_writer() + tensorboard_log.log_eval({ + 'all_eval_shard_loss': eval_loss, + }, epoch) + start_shard = 0 + eval_loss = 0 + + pretrain_dataset_provider.release_shard() + + logger.info('Congratulation, training has finished!!!') + + +if __name__ == '__main__': + main() diff --git a/examples/community/roberta/pretraining/utils/WandbLog.py b/examples/community/roberta/pretraining/utils/WandbLog.py new file mode 100644 index 0000000000000000000000000000000000000000..b68ba8387dcdb8534d957268dd5606337ba1f74d --- /dev/null +++ b/examples/community/roberta/pretraining/utils/WandbLog.py @@ -0,0 +1,42 @@ +import os +import time + +import wandb +from torch.utils.tensorboard import SummaryWriter + + +class WandbLog: + + @classmethod + def init_wandb(cls, project, notes=None, name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), config=None): + wandb.init(project=project, notes=notes, name=name, config=config) + + @classmethod + def log(cls, result, model=None, gradient=None): + wandb.log(result) + + if model: + wandb.watch(model) + + if gradient: + wandb.watch(gradient) + + +class TensorboardLog: + + def __init__(self, location, name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), config=None): + if not os.path.exists(location): + os.mkdir(location) + self.writer = SummaryWriter(location, comment=name) + + def log_train(self, result, step): + for k, v in result.items(): + self.writer.add_scalar(f'{k}/train', v, step) + + def log_eval(self, result, step): + for k, v in result.items(): + self.writer.add_scalar(f'{k}/eval', v, step) + + def log_zeroshot(self, result, step): + for k, v in result.items(): + self.writer.add_scalar(f'{k}_acc/eval', v, step) diff --git a/examples/community/roberta/pretraining/utils/exp_util.py b/examples/community/roberta/pretraining/utils/exp_util.py new file mode 100644 index 0000000000000000000000000000000000000000..0cdb56bad03117ddfd181a4bbb313b1d6783dada --- /dev/null +++ b/examples/community/roberta/pretraining/utils/exp_util.py @@ -0,0 +1,114 @@ +import functools +import os +import shutil + +import psutil +import torch + +from colossalai.core import global_context as gpc + + +def logging(s, log_path, print_=True, log_=True): + if print_: + print(s) + if log_: + with open(log_path, 'a+') as f_log: + f_log.write(s + '\n') + + +def get_logger(log_path, **kwargs): + return functools.partial(logging, log_path=log_path, **kwargs) + + +def create_exp_dir(dir_path, scripts_to_save=None, debug=False): + if debug: + print('Debug Mode : no experiment dir created') + return functools.partial(logging, log_path=None, log_=False) + + if not os.path.exists(dir_path): + os.makedirs(dir_path) + + print('Experiment dir : {}'.format(dir_path)) + if scripts_to_save is not None: + script_path = os.path.join(dir_path, 'scripts') + if not os.path.exists(script_path): + os.makedirs(script_path) + for script in scripts_to_save: + dst_file = os.path.join(dir_path, 'scripts', os.path.basename(script)) + shutil.copyfile(script, dst_file) + + return get_logger(log_path=os.path.join(dir_path, 'log.txt')) + + +def get_cpu_mem(): + return psutil.Process().memory_info().rss / 1024**2 + + +def get_gpu_mem(): + return torch.cuda.memory_allocated() / 1024**2 + + +def get_mem_info(prefix=''): + return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB' + + +def get_tflops(model_numel, batch_size, seq_len, step_time): + return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) + + +def get_parameters_in_billions(model, world_size=1): + gpus_per_model = world_size + + approx_parameters_in_billions = sum([ + sum([p.ds_numel if hasattr(p, 'ds_id') else p.nelement() + for p in model_module.parameters()]) + for model_module in model + ]) + + return approx_parameters_in_billions * gpus_per_model / (1e9) + + +def throughput_calculator(numel, args, config, iteration_time, total_iterations, world_size=1): + gpus_per_model = 1 + batch_size = args.train_micro_batch_size_per_gpu + samples_per_model = batch_size * args.max_seq_length + model_replica_count = world_size / gpus_per_model + approx_parameters_in_billions = numel + elapsed_time_per_iter = iteration_time / total_iterations + samples_per_second = batch_size / elapsed_time_per_iter + + #flops calculator + hidden_size = config.hidden_size + num_layers = config.num_hidden_layers + vocab_size = config.vocab_size + + # General TFLOPs formula (borrowed from Equation 3 in Section 5.1 of + # https://arxiv.org/pdf/2104.04473.pdf). + # The factor of 4 is when used with activation check-pointing, + # otherwise it will be 3. + checkpoint_activations_factor = 4 if args.checkpoint_activations else 3 + flops_per_iteration = (24 * checkpoint_activations_factor * batch_size * args.max_seq_length * num_layers * + (hidden_size**2)) * (1. + (args.max_seq_length / (6. * hidden_size)) + + (vocab_size / (16. * num_layers * hidden_size))) + tflops = flops_per_iteration / (elapsed_time_per_iter * (10**12)) + return samples_per_second, tflops, approx_parameters_in_billions + + +def synchronize(): + if not torch.distributed.is_available(): + return + if not torch.distributed.is_intialized(): + return + world_size = torch.distributed.get_world_size() + if world_size == 1: + return + torch.distributed.barrier() + + +def log_args(logger, args): + logger.info('--------args----------') + message = '\n'.join([f'{k:<30}: {v}' for k, v in vars(args).items()]) + message += '\n' + message += '\n'.join([f'{k:<30}: {v}' for k, v in gpc.config.items()]) + logger.info(message) + logger.info('--------args----------\n') diff --git a/examples/community/roberta/pretraining/utils/global_vars.py b/examples/community/roberta/pretraining/utils/global_vars.py new file mode 100644 index 0000000000000000000000000000000000000000..7b0c5a2be73d914bc6bccdc9d383432e2bfac1f9 --- /dev/null +++ b/examples/community/roberta/pretraining/utils/global_vars.py @@ -0,0 +1,130 @@ +import time + +import torch + +from .WandbLog import TensorboardLog + +_GLOBAL_TIMERS = None +_GLOBAL_TENSORBOARD_WRITER = None + + +def set_global_variables(launch_time, tensorboard_path): + _set_timers() + _set_tensorboard_writer(launch_time, tensorboard_path) + + +def _set_timers(): + """Initialize timers.""" + global _GLOBAL_TIMERS + _ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers') + _GLOBAL_TIMERS = Timers() + + +def _set_tensorboard_writer(launch_time, tensorboard_path): + """Set tensorboard writer.""" + global _GLOBAL_TENSORBOARD_WRITER + _ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER, 'tensorboard writer') + if torch.distributed.get_rank() == 0: + _GLOBAL_TENSORBOARD_WRITER = TensorboardLog(tensorboard_path + f'/{launch_time}', launch_time) + + +def get_timers(): + """Return timers.""" + _ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers') + return _GLOBAL_TIMERS + + +def get_tensorboard_writer(): + """Return tensorboard writer. It can be None so no need + to check if it is initialized.""" + return _GLOBAL_TENSORBOARD_WRITER + + +def _ensure_var_is_initialized(var, name): + """Make sure the input variable is not None.""" + assert var is not None, '{} is not initialized.'.format(name) + + +def _ensure_var_is_not_initialized(var, name): + """Make sure the input variable is not None.""" + assert var is None, '{} is already initialized.'.format(name) + + +class _Timer: + """Timer.""" + + def __init__(self, name): + self.name_ = name + self.elapsed_ = 0.0 + self.started_ = False + self.start_time = time.time() + + def start(self): + """Start the timer.""" + # assert not self.started_, 'timer has already been started' + torch.cuda.synchronize() + self.start_time = time.time() + self.started_ = True + + def stop(self): + """Stop the timer.""" + assert self.started_, 'timer is not started' + torch.cuda.synchronize() + self.elapsed_ += (time.time() - self.start_time) + self.started_ = False + + def reset(self): + """Reset timer.""" + self.elapsed_ = 0.0 + self.started_ = False + + def elapsed(self, reset=True): + """Calculate the elapsed time.""" + started_ = self.started_ + # If the timing in progress, end it first. + if self.started_: + self.stop() + # Get the elapsed time. + elapsed_ = self.elapsed_ + # Reset the elapsed time + if reset: + self.reset() + # If timing was in progress, set it back. + if started_: + self.start() + return elapsed_ + + +class Timers: + """Group of timers.""" + + def __init__(self): + self.timers = {} + + def __call__(self, name): + if name not in self.timers: + self.timers[name] = _Timer(name) + return self.timers[name] + + def write(self, names, writer, iteration, normalizer=1.0, reset=False): + """Write timers to a tensorboard writer""" + # currently when using add_scalars, + # torch.utils.add_scalars makes each timer its own run, which + # polutes the runs list, so we just add each as a scalar + assert normalizer > 0.0 + for name in names: + value = self.timers[name].elapsed(reset=reset) / normalizer + writer.add_scalar(name + '-time', value, iteration) + + def log(self, names, normalizer=1.0, reset=True): + """Log a group of timers.""" + assert normalizer > 0.0 + string = 'time (ms)' + for name in names: + elapsed_time = self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer + string += ' | {}: {:.2f}'.format(name, elapsed_time) + if torch.distributed.is_initialized(): + if torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1): + print(string, flush=True) + else: + print(string, flush=True) diff --git a/examples/community/roberta/pretraining/utils/logger.py b/examples/community/roberta/pretraining/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..75c9bf4bef251f4b5914ed55f87d1ddc55cb2504 --- /dev/null +++ b/examples/community/roberta/pretraining/utils/logger.py @@ -0,0 +1,30 @@ +import logging +import os + +import torch.distributed as dist + +logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt='%m/%d/%Y %H:%M:%S', + level=logging.INFO) +logger = logging.getLogger(__name__) + + +class Logger(): + + def __init__(self, log_path, cuda=False, debug=False): + self.logger = logging.getLogger(__name__) + self.cuda = cuda + self.log_path = log_path + self.debug = debug + + def info(self, message, log_=True, print_=True, *args, **kwargs): + if (self.cuda and dist.get_rank() == 0) or not self.cuda: + if print_: + self.logger.info(message, *args, **kwargs) + + if log_: + with open(self.log_path, 'a+') as f_log: + f_log.write(message + '\n') + + def error(self, message, *args, **kwargs): + self.logger.error(message, *args, **kwargs) diff --git a/examples/community/roberta/requirements.txt b/examples/community/roberta/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..de082defb14a0a278377803667b54428460d061a --- /dev/null +++ b/examples/community/roberta/requirements.txt @@ -0,0 +1,7 @@ +colossalai >= 0.1.12 +torch >= 1.8.1 +tqdm +tensorboard +numpy +h5py +wandb diff --git a/examples/images/diffusion/LICENSE b/examples/images/diffusion/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..0e609df0d8cd3b5d11a1ea962a56b604b70846a5 --- /dev/null +++ b/examples/images/diffusion/LICENSE @@ -0,0 +1,82 @@ +Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors + +CreativeML Open RAIL-M +dated August 22, 2022 + +Section I: PREAMBLE + +Multimodal generative models are being widely adopted and used, and have the potential to transform the way artists, among other individuals, conceive and benefit from AI or ML technologies as a tool for content creation. + +Notwithstanding the current and potential benefits that these artifacts can bring to society at large, there are also concerns about potential misuses of them, either due to their technical limitations or ethical considerations. + +In short, this license strives for both the open and responsible downstream use of the accompanying model. When it comes to the open character, we took inspiration from open source permissive licenses regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based restrictions not permitting the use of the Model in very specific scenarios, in order for the licensor to be able to enforce the license in case potential misuses of the Model may occur. At the same time, we strive to promote open and responsible research on generative models for art and content generation. + +Even though downstream derivative versions of the model could be released under different licensing terms, the latter will always have to include - at minimum - the same use-based restrictions as the ones in the original license (this license). We believe in the intersection between open and responsible AI development; thus, this License aims to strike a balance between both in order to enable responsible open-science in the field of AI. + +This License governs the use of the model (and its derivatives) and is informed by the model card associated with the model. + +NOW THEREFORE, You and Licensor agree as follows: + +1. Definitions + +- "License" means the terms and conditions for use, reproduction, and Distribution as defined in this document. +- "Data" means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License. +- "Output" means the results of operating a Model as embodied in informational content resulting therefrom. +- "Model" means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material. +- "Derivatives of the Model" means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model. +- "Complementary Material" means the accompanying source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, etc, if any. +- "Distribution" means any transmission, reproduction, publication or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means - e.g. API-based or web access. +- "Licensor" means the copyright owner or entity authorized by the copyright owner that is granting the License, including the persons or entities that may have rights in the Model and/or distributing the Model. +- "You" (or "Your") means an individual or Legal Entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application - e.g. chatbot, translator, image generator. +- "Third Parties" means individuals or legal entities that are not under common control with Licensor or You. +- "Contribution" means any work of authorship, including the original version of the Model and any modifications or additions to that Model or Derivatives of the Model thereof, that is intentionally submitted to Licensor for inclusion in the Model by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Model, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." +- "Contributor" means Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Model. + +Section II: INTELLECTUAL PROPERTY RIGHTS + +Both copyright and patent grants apply to the Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of the Model are subject to additional terms as described in Section III. + +2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the Complementary Material, the Model, and Derivatives of the Model. +3. Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Model to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution incorporated within the Model and/or Complementary Material constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or Work shall terminate as of the date such litigation is asserted or filed. + +Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION + +4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the following conditions: +Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply to the use of Complementary Material. +You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License; +You must cause any modified files to carry prominent notices stating that You changed the files; +You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model, Derivatives of the Model. +You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions - respecting paragraph 4.a. - for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License. +5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph (paragraph 5). +6. The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output You generate using the Model. You are accountable for the Output you generate and its subsequent uses. No use of the output can contravene any provision as stated in the License. + +Section IV: OTHER PROVISIONS + +7. Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License, update the Model through electronic means, or modify the Output of the Model based on updates. You shall undertake reasonable efforts to use the latest version of the Model. +8. Trademarks and related. Nothing in this License permits You to make use of Licensors’ trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by the Licensors. +9. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Model and the Complementary Material (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the Complementary Material and assume any risks associated with Your exercise of permissions under this License. +10. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model and the Complementary Material (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. +11. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. +12. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein. + +END OF TERMS AND CONDITIONS + + + + +Attachment A + +Use Restrictions + +You agree not to use the Model or Derivatives of the Model: +- In any way that violates any applicable national, federal, state, local or international law or regulation; +- For the purpose of exploiting, harming or attempting to exploit or harm minors in any way; +- To generate or disseminate verifiably false information and/or content with the purpose of harming others; +- To generate or disseminate personal identifiable information that can be used to harm an individual; +- To defame, disparage or otherwise harass others; +- For fully automated decision making that adversely impacts an individual’s legal rights or otherwise creates or modifies a binding, enforceable obligation; +- For any use intended to or which has the effect of discriminating against or harming individuals or groups based on online or offline social behavior or known or predicted personal or personality characteristics; +- To exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm; +- For any use intended to or which has the effect of discriminating against individuals or groups based on legally protected characteristics or categories; +- To provide medical advice and medical results interpretation; +- To generate or disseminate information for the purpose to be used for administration of justice, law enforcement, immigration or asylum processes, such as predicting an individual will commit fraud/crime commitment (e.g. by text profiling, drawing causal relationships between assertions made in documents, indiscriminate and arbitrarily-targeted use). diff --git a/examples/images/diffusion/README.md b/examples/images/diffusion/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0c7f42ded318774a05ee78345285fb81f8fba4a8 --- /dev/null +++ b/examples/images/diffusion/README.md @@ -0,0 +1,297 @@ +# ColoDiffusion: Stable Diffusion with Colossal-AI + +Acceleration of AIGC (AI-Generated Content) models such as [Stable Diffusion v1](https://github.com/CompVis/stable-diffusion) and [Stable Diffusion v2](https://github.com/Stability-AI/stablediffusion). + +

+ +

+ +- [Training](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion): Reduce Stable Diffusion memory consumption by up to 5.6x and hardware cost by up to 46x (from A100 to RTX3060). + +

+ +

+ + +- [DreamBooth Fine-tuning](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/dreambooth): Personalize your model using just 3-5 images of the desired subject. + +

+ +

+ + +- [Inference](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion): Reduce inference GPU memory consumption by 2.5x. + + +More details can be found in our [blog of Stable Diffusion v1](https://www.hpc-ai.tech/blog/diffusion-pretraining-and-hardware-fine-tuning-can-be-almost-7x-cheaper) and [blog of Stable Diffusion v2](https://www.hpc-ai.tech/blog/colossal-ai-0-2-0). + + +## Roadmap +This project is in rapid development. + +- [X] Train a stable diffusion model v1/v2 from scatch +- [X] Finetune a pretrained Stable diffusion v1 model +- [X] Inference a pretrained model using PyTorch +- [ ] Finetune a pretrained Stable diffusion v2 model +- [ ] Inference a pretrained model using TensoRT + +## Installation + +### Option #1: Install from source +#### Step 1: Requirements + +To begin with, make sure your operating system has the cuda version suitable for this exciting training session, which is cuda11.6/11.8. For your convience, we have set up the rest of packages here. You can create and activate a suitable [conda](https://conda.io/) environment named `ldm` : + +``` +conda env create -f environment.yaml +conda activate ldm +``` + +You can also update an existing [latent diffusion](https://github.com/CompVis/latent-diffusion) environment by running: + +``` +conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch +pip install transformers diffusers invisible-watermark +``` + +#### Step 2: Install [Colossal-AI](https://colossalai.org/download/) From Our Official Website + +You can install the latest version (0.2.7) from our official website or from source. Notice that the suitable version for this training is colossalai(0.2.5), which stands for torch(1.12.1). + +##### Download suggested version for this training + +``` +pip install colossalai==0.2.5 +``` + +##### Download the latest version from pip for latest torch version + +``` +pip install colossalai +``` + +##### From source: + +``` +git clone https://github.com/hpcaitech/ColossalAI.git +cd ColossalAI + +# install colossalai +CUDA_EXT=1 pip install . +``` + +#### Step 3: Accelerate with flash attention by xformers (Optional) + +Notice that xformers will accelerate the training process at the cost of extra disk space. The suitable version of xformers for this training process is 0.0.12, which can be downloaded directly via pip. For more release versions, feel free to check its official website: [XFormers](https://pypi.org/project/xformers/) + +``` +pip install xformers==0.0.12 +``` + +### Option #2: Use Docker + +To use the stable diffusion Docker image, you can either build using the provided the [Dockerfile](./docker/Dockerfile) or pull a Docker image from our Docker hub. + +``` +# 1. build from dockerfile +cd ColossalAI/examples/images/diffusion/docker +docker build -t hpcaitech/diffusion:0.2.0 . + +# 2. pull from our docker hub +docker pull hpcaitech/diffusion:0.2.0 +``` + +Once you have the image ready, you can launch the image with the following command + +```bash +######################## +# On Your Host Machine # +######################## +# make sure you start your image in the repository root directory +cd ColossalAI + +# run the docker container +docker run --rm \ + -it --gpus all \ + -v $PWD:/workspace \ + -v :/data/scratch \ + -v :/root/.cache/huggingface \ + hpcaitech/diffusion:0.2.0 \ + /bin/bash + +######################## +# Inside a Container # +######################## +# Once you have entered the docker container, go to the stable diffusion directory for training +cd examples/images/diffusion/ + +# Download the model checkpoint from pretrained (See the following steps) +# Set up your configuration the "train_colossalai.sh" (See the following steps) +# start training with colossalai +bash train_colossalai.sh +``` + +It is important for you to configure your volume mapping in order to get the best training experience. +1. **Mandatory**, mount your prepared data to `/data/scratch` via `-v :/data/scratch`, where you need to replace `` with the actual data path on your machine. Notice that within docker we need to transform the Windows path to a Linux one, e.g. `C:\User\Desktop` into `/mnt/c/User/Desktop`. +2. **Recommended**, store the downloaded model weights to your host machine instead of the container directory via `-v :/root/.cache/huggingface`, where you need to replace the `` with the actual path. In this way, you don't have to repeatedly download the pretrained weights for every `docker run`. +3. **Optional**, if you encounter any problem stating that shared memory is insufficient inside container, please add `-v /dev/shm:/dev/shm` to your `docker run` command. + + +## Download the model checkpoint from pretrained + +### stable-diffusion-v2-base (Recommended) + +``` +wget https://huggingface.co/stabilityai/stable-diffusion-2-base/resolve/main/512-base-ema.ckpt +``` + +### stable-diffusion-v1-4 + +``` +git lfs install +git clone https://huggingface.co/CompVis/stable-diffusion-v1-4 +``` + +### stable-diffusion-v1-5 from runway + +``` +git lfs install +git clone https://huggingface.co/runwayml/stable-diffusion-v1-5 +``` + +## Dataset + +The dataSet is from [LAION-5B](https://laion.ai/blog/laion-5b/), the subset of [LAION](https://laion.ai/), +you should the change the `data.file_path` in the `config/train_colossalai.yaml` + +## Training + +We provide the script `train_colossalai.sh` to run the training task with colossalai. Meanwhile, we have enlightened other training process such as DDP model in PyTorch. You can also use `train_ddp.sh` to run the training task with ddp to compare the corresponding performance. + +In `train_colossalai.sh` the main command is + +``` +python main.py --logdir /tmp/ --train --base configs/train_colossalai.yaml --ckpt 512-base-ema.ckpt +``` + +- You can change the `--logdir` to decide where to save the log information and the last checkpoint. + - You will find your ckpt in `logdir/checkpoints` or `logdir/diff_tb/version_0/checkpoints` + - You will find your train config yaml in `logdir/configs` +- You can add the `--ckpt` if you want to load the pretrained model, for example `512-base-ema.ckpt` +- You can change the `--base` to specify the path of config yaml + +### Training config + +You can change the training config in the yaml file + +- devices: device number used for training, default = 8 +- max_epochs: max training epochs, default = 2 +- precision: the precision type used in training, default = 16 (fp16), you must use fp16 if you want to apply colossalai +- placement_policy: the training strategy supported by Colossal AI, default = 'cuda', which refers to loading all the parameters into cuda memory. On the other hand, 'cpu' refers to 'cpu offload' strategy while 'auto' enables 'Gemini', both featured by Colossal AI. +- more information about the configuration of ColossalAIStrategy can be found [here](https://pytorch-lightning.readthedocs.io/en/latest/advanced/model_parallel.html#colossal-ai) + + +## Finetune Example +### Training on Teyvat Datasets + +We provide the finetuning example on [Teyvat](https://huggingface.co/datasets/Fazzie/Teyvat) dataset, which is create by BLIP generated captions. + +You can run by config `configs/Teyvat/train_colossalai_teyvat.yaml` +``` +python main.py --logdir /tmp/ -t -b configs/Teyvat/train_colossalai_teyvat.yaml +``` + +## Inference + +You can get your training last.ckpt and train config.yaml in your `--logdir`, and run by +``` +python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" --plms + --outdir ./output \ + --ckpt path/to/logdir/checkpoints/last.ckpt \ + --config /path/to/logdir/configs/project.yaml \ +``` + +```commandline +usage: txt2img.py [-h] [--prompt [PROMPT]] [--outdir [OUTDIR]] [--skip_grid] [--skip_save] [--ddim_steps DDIM_STEPS] [--plms] [--laion400m] [--fixed_code] [--ddim_eta DDIM_ETA] + [--n_iter N_ITER] [--H H] [--W W] [--C C] [--f F] [--n_samples N_SAMPLES] [--n_rows N_ROWS] [--scale SCALE] [--from-file FROM_FILE] [--config CONFIG] [--ckpt CKPT] + [--seed SEED] [--precision {full,autocast}] + +optional arguments: + -h, --help show this help message and exit + --prompt [PROMPT] the prompt to render + --outdir [OUTDIR] dir to write results to + --skip_grid do not save a grid, only individual samples. Helpful when evaluating lots of samples + --skip_save do not save individual samples. For speed measurements. + --ddim_steps DDIM_STEPS + number of ddim sampling steps + --plms use plms sampling + --laion400m uses the LAION400M model + --fixed_code if enabled, uses the same starting code across samples + --ddim_eta DDIM_ETA ddim eta (eta=0.0 corresponds to deterministic sampling + --n_iter N_ITER sample this often + --H H image height, in pixel space + --W W image width, in pixel space + --C C latent channels + --f F downsampling factor + --n_samples N_SAMPLES + how many samples to produce for each given prompt. A.k.a. batch size + --n_rows N_ROWS rows in the grid (default: n_samples) + --scale SCALE unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty)) + --from-file FROM_FILE + if specified, load prompts from this file + --config CONFIG path to config which constructs model + --ckpt CKPT path to checkpoint of model + --seed SEED the seed (for reproducible sampling) + --use_int8 whether to use quantization method + --precision {full,autocast} + evaluate at this precision +``` + +## Invitation to open-source contribution +Referring to the successful attempts of [BLOOM](https://bigscience.huggingface.co/) and [Stable Diffusion](https://en.wikipedia.org/wiki/Stable_Diffusion), any and all developers and partners with computing powers, datasets, models are welcome to join and build the Colossal-AI community, making efforts towards the era of big AI models! + +You may contact us or participate in the following ways: +1. [Leaving a Star ⭐](https://github.com/hpcaitech/ColossalAI/stargazers) to show your like and support. Thanks! +2. Posting an [issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose), or submitting a PR on GitHub follow the guideline in [Contributing](https://github.com/hpcaitech/ColossalAI/blob/main/CONTRIBUTING.md). +3. Join the Colossal-AI community on +[Slack](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w), +and [WeChat(微信)](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png "qrcode") to share your ideas. +4. Send your official proposal to email contact@hpcaitech.com + +Thanks so much to all of our amazing contributors! + +## Comments + +- Our codebase for the diffusion models builds heavily on [OpenAI's ADM codebase](https://github.com/openai/guided-diffusion) +, [lucidrains](https://github.com/lucidrains/denoising-diffusion-pytorch), +[Stable Diffusion](https://github.com/CompVis/stable-diffusion), [Lightning](https://github.com/Lightning-AI/lightning) and [Hugging Face](https://huggingface.co/CompVis/stable-diffusion). +Thanks for open-sourcing! + +- The implementation of the transformer encoder is from [x-transformers](https://github.com/lucidrains/x-transformers) by [lucidrains](https://github.com/lucidrains?tab=repositories). + +- The implementation of [flash attention](https://github.com/HazyResearch/flash-attention) is from [HazyResearch](https://github.com/HazyResearch). + +## BibTeX + +``` +@article{bian2021colossal, + title={Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training}, + author={Bian, Zhengda and Liu, Hongxin and Wang, Boxiang and Huang, Haichen and Li, Yongbin and Wang, Chuanrui and Cui, Fan and You, Yang}, + journal={arXiv preprint arXiv:2110.14883}, + year={2021} +} +@misc{rombach2021highresolution, + title={High-Resolution Image Synthesis with Latent Diffusion Models}, + author={Robin Rombach and Andreas Blattmann and Dominik Lorenz and Patrick Esser and Björn Ommer}, + year={2021}, + eprint={2112.10752}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} +@article{dao2022flashattention, + title={FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness}, + author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher}, + journal={arXiv preprint arXiv:2205.14135}, + year={2022} +} +``` diff --git a/examples/images/diffusion/configs/Inference/v2-inference-v.yaml b/examples/images/diffusion/configs/Inference/v2-inference-v.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b05955d3faf7a1612d396d119140a52791187313 --- /dev/null +++ b/examples/images/diffusion/configs/Inference/v2-inference-v.yaml @@ -0,0 +1,59 @@ +model: + base_learning_rate: 1.0e-4 + params: + parameterization: "v" + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False # we set this to false because this is an inference only config + + unet_config: + use_checkpoint: True + use_fp16: True + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + + cond_stage_config: + freeze: True + layer: "penultimate" diff --git a/examples/images/diffusion/configs/Inference/v2-inference.yaml b/examples/images/diffusion/configs/Inference/v2-inference.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5d8d583d06d10268c0e5e36f36b79c86ccdd0572 --- /dev/null +++ b/examples/images/diffusion/configs/Inference/v2-inference.yaml @@ -0,0 +1,58 @@ +model: + base_learning_rate: 1.0e-4 + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False # we set this to false because this is an inference only config + + unet_config: + use_checkpoint: True + use_fp16: True + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + + cond_stage_config: + freeze: True + layer: "penultimate" diff --git a/examples/images/diffusion/configs/Inference/v2-inpainting-inference.yaml b/examples/images/diffusion/configs/Inference/v2-inpainting-inference.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ffaa5e8da2ad09a105ef13fd050b31636477aded --- /dev/null +++ b/examples/images/diffusion/configs/Inference/v2-inpainting-inference.yaml @@ -0,0 +1,147 @@ +model: + base_learning_rate: 5.0e-05 + target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: hybrid + scale_factor: 0.18215 + monitor: val/loss_simple_ema + finetune_keys: null + use_ema: False + + unet_config: + use_checkpoint: True + image_size: 32 # unused + in_channels: 9 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + lossconfig: + + cond_stage_config: + freeze: True + layer: "penultimate" + + +data: + tar_base: null # for concat as in LAION-A + p_unsafe_threshold: 0.1 + filter_word_list: "data/filters.yaml" + max_pwatermark: 0.45 + batch_size: 8 + num_workers: 6 + multinode: True + min_size: 512 + train: + shards: + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-0/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-1/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-2/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-3/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-4/{00000..18699}.tar -" #{00000-94333}.tar" + shuffle: 10000 + image_key: jpg + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 512 + interpolation: 3 + - target: torchvision.transforms.RandomCrop + params: + size: 512 + postprocess: + target: ldm.data.laion.AddMask + params: + mode: "512train-large" + p_drop: 0.25 + # NOTE use enough shards to avoid empty validation loops in workers + validation: + shards: + - "pipe:aws s3 cp s3://deep-floyd-s3/datasets/laion_cleaned-part5/{93001..94333}.tar - " + shuffle: 0 + image_key: jpg + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 512 + interpolation: 3 + - target: torchvision.transforms.CenterCrop + params: + size: 512 + postprocess: + target: ldm.data.laion.AddMask + params: + mode: "512train-large" + p_drop: 0.25 + +lightning: + find_unused_parameters: True + modelcheckpoint: + params: + every_n_train_steps: 5000 + + callbacks: + metrics_over_trainsteps_checkpoint: + params: + every_n_train_steps: 10000 + + image_logger: + enable_autocast: False + disabled: False + batch_frequency: 1000 + max_images: 4 + increase_log_steps: False + log_first_step: False + log_images_kwargs: + use_ema_scope: False + inpaint: False + plot_progressive_rows: False + plot_diffusion_rows: False + N: 4 + unconditional_guidance_scale: 5.0 + unconditional_guidance_label: [""] + ddim_steps: 50 # todo check these out for depth2img, + ddim_eta: 0.0 # todo check these out for depth2img, + + trainer: + benchmark: True + val_check_interval: 5000000 + num_sanity_val_steps: 0 + accumulate_grad_batches: 1 diff --git a/examples/images/diffusion/configs/Inference/v2-midas-inference.yaml b/examples/images/diffusion/configs/Inference/v2-midas-inference.yaml new file mode 100644 index 0000000000000000000000000000000000000000..01d3729f1590caf66bea903935ec1df7dc223c02 --- /dev/null +++ b/examples/images/diffusion/configs/Inference/v2-midas-inference.yaml @@ -0,0 +1,63 @@ +model: + base_learning_rate: 5.0e-07 + target: ldm.models.diffusion.ddpm.LatentDepth2ImageDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: hybrid + scale_factor: 0.18215 + monitor: val/loss_simple_ema + finetune_keys: null + use_ema: False + + depth_stage_config: + model_type: "dpt_hybrid" + + unet_config: + use_checkpoint: True + image_size: 32 # unused + in_channels: 5 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + lossconfig: + + cond_stage_config: + freeze: True + layer: "penultimate" diff --git a/examples/images/diffusion/configs/Inference/x4-upscaling.yaml b/examples/images/diffusion/configs/Inference/x4-upscaling.yaml new file mode 100644 index 0000000000000000000000000000000000000000..426d387ca611c730202eea71f431bac2afb8e6b7 --- /dev/null +++ b/examples/images/diffusion/configs/Inference/x4-upscaling.yaml @@ -0,0 +1,66 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentUpscaleDiffusion + params: + parameterization: "v" + low_scale_key: "lr" + linear_start: 0.0001 + linear_end: 0.02 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 128 + channels: 4 + cond_stage_trainable: false + conditioning_key: "hybrid-adm" + monitor: val/loss_simple_ema + scale_factor: 0.08333 + use_ema: False + + low_scale_config: + noise_schedule_config: # image space + linear_start: 0.0001 + linear_end: 0.02 + max_noise_level: 350 + + unet_config: + use_checkpoint: True + num_classes: 1000 # timesteps for noise conditioning (here constant, just need one) + image_size: 128 + in_channels: 7 + out_channels: 4 + model_channels: 256 + attention_resolutions: [ 2,4,8] + num_res_blocks: 2 + channel_mult: [ 1, 2, 2, 4] + disable_self_attentions: [True, True, True, False] + disable_middle_self_attn: False + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + use_linear_in_transformer: True + + first_stage_config: + embed_dim: 4 + ddconfig: + # attn_type: "vanilla-xformers" this model needs efficient attention to be feasible on HR data, also the decoder seems to break in half precision (UNet is fine though) + double_z: True + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + lossconfig: + + + cond_stage_config: + freeze: True + layer: "penultimate" diff --git a/examples/images/diffusion/configs/Teyvat/README.md b/examples/images/diffusion/configs/Teyvat/README.md new file mode 100644 index 0000000000000000000000000000000000000000..65ba3fb800408704d858ba2ea4bdce95ef1178fc --- /dev/null +++ b/examples/images/diffusion/configs/Teyvat/README.md @@ -0,0 +1,8 @@ +# Dataset Card for Teyvat BLIP captions +Dataset used to train [Teyvat characters text to image model](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion). + +BLIP generated captions for characters images from [genshin-impact fandom wiki](https://genshin-impact.fandom.com/wiki/Character#Playable_Characters)and [biligame wiki for genshin impact](https://wiki.biligame.com/ys/%E8%A7%92%E8%89%B2). + +For each row the dataset contains `image` and `text` keys. `image` is a varying size PIL png, and `text` is the accompanying text caption. Only a train split is provided. + +The `text` include the tag `Teyvat`, `Name`,`Element`, `Weapon`, `Region`, `Model type`, and `Description`, the `Description` is captioned with the [pre-trained BLIP model](https://github.com/salesforce/BLIP). diff --git a/examples/images/diffusion/configs/Teyvat/train_colossalai_teyvat.yaml b/examples/images/diffusion/configs/Teyvat/train_colossalai_teyvat.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9e760124c7a4daf73c128de2a16443ed46d6f779 --- /dev/null +++ b/examples/images/diffusion/configs/Teyvat/train_colossalai_teyvat.yaml @@ -0,0 +1,112 @@ +model: + base_learning_rate: 1.0e-4 + params: + parameterization: "v" + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + ckpt: None # use ckpt path + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: txt + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + + scheduler_config: # 10000 warmup steps + warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1.e-4 ] + f_min: [ 1.e-10 ] + + + unet_config: + use_checkpoint: True + use_fp16: True + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + + cond_stage_config: + freeze: True + layer: "penultimate" + +data: + batch_size: 16 + num_workers: 4 + train: + target: ldm.data.teyvat.hf_dataset + params: + path: Fazzie/Teyvat + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 512 + - target: torchvision.transforms.RandomCrop + params: + size: 512 + - target: torchvision.transforms.RandomHorizontalFlip + +lightning: + trainer: + accelerator: 'gpu' + devices: 2 + log_gpu_memory: all + max_epochs: 2 + precision: 16 + auto_select_gpus: False + strategy: + use_chunk: True + enable_distributed_storage: True + placement_policy: cuda + force_outputs_fp32: true + min_chunk_size: 64 + + log_every_n_steps: 2 + logger: True + default_root_dir: "/tmp/diff_log/" + # profiler: pytorch + + logger_config: + wandb: + name: nowname + save_dir: "/tmp/diff_log/" + offline: opt.debug + id: nowname diff --git a/examples/images/diffusion/configs/train_colossalai.yaml b/examples/images/diffusion/configs/train_colossalai.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5f745286a7197260d3dac31d2420c6f090d23651 --- /dev/null +++ b/examples/images/diffusion/configs/train_colossalai.yaml @@ -0,0 +1,109 @@ +model: + base_learning_rate: 1.0e-4 + params: + parameterization: "v" + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: txt + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False # we set this to false because this is an inference only config + + scheduler_config: # 10000 warmup steps + warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1.e-4 ] + f_min: [ 1.e-10 ] + + + unet_config: + use_checkpoint: True + use_fp16: True + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + + + cond_stage_config: + freeze: True + layer: "penultimate" + +data: + batch_size: 128 + wrap: False + # num_workwers should be 2 * batch_size, and total num less than 1024 + # e.g. if use 8 devices, no more than 128 + num_workers: 128 + train: + target: ldm.data.base.Txt2ImgIterableBaseDataset + params: + file_path: # YOUR DATASET_PATH + world_size: 1 + rank: 0 + +lightning: + trainer: + accelerator: 'gpu' + devices: 2 + log_gpu_memory: all + max_epochs: 2 + precision: 16 + auto_select_gpus: False + strategy: + use_chunk: True + enable_distributed_storage: True + placement_policy: cuda + force_outputs_fp32: true + min_chunk_size: 64 + + log_every_n_steps: 2 + logger: True + default_root_dir: "/tmp/diff_log/" + # profiler: pytorch + + logger_config: + wandb: + name: nowname + save_dir: "/tmp/diff_log/" + offline: opt.debug + id: nowname diff --git a/examples/images/diffusion/configs/train_colossalai_cifar10.yaml b/examples/images/diffusion/configs/train_colossalai_cifar10.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0d0f185426c26d90e7ce186566c19960ff433ab6 --- /dev/null +++ b/examples/images/diffusion/configs/train_colossalai_cifar10.yaml @@ -0,0 +1,112 @@ +model: + base_learning_rate: 1.0e-4 + params: + parameterization: "v" + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: txt + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False # we set this to false because this is an inference only config + + scheduler_config: # 10000 warmup steps + warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1.e-4 ] + f_min: [ 1.e-10 ] + + + unet_config: + use_checkpoint: True + use_fp16: True + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + + cond_stage_config: + freeze: True + layer: "penultimate" + +data: + batch_size: 4 + num_workers: 4 + train: + target: ldm.data.cifar10.hf_dataset + params: + name: cifar10 + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 512 + interpolation: 3 + - target: torchvision.transforms.RandomCrop + params: + size: 512 + - target: torchvision.transforms.RandomHorizontalFlip + +lightning: + trainer: + accelerator: 'gpu' + devices: 1 + log_gpu_memory: all + max_epochs: 2 + precision: 16 + auto_select_gpus: False + strategy: + use_chunk: True + enable_distributed_storage: True + placement_policy: cuda + force_outputs_fp32: true + min_chunk_size: 64 + + log_every_n_steps: 2 + logger: True + default_root_dir: "/tmp/diff_log/" + # profiler: pytorch + + logger_config: + wandb: + name: nowname + save_dir: "/tmp/diff_log/" + offline: opt.debug + id: nowname diff --git a/examples/images/diffusion/configs/train_ddp.yaml b/examples/images/diffusion/configs/train_ddp.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f3ae3ddb5ff6d5847795f9d7a43afa243c0a85ed --- /dev/null +++ b/examples/images/diffusion/configs/train_ddp.yaml @@ -0,0 +1,102 @@ +model: + base_learning_rate: 1.0e-4 + params: + parameterization: "v" + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: txt + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False # we set this to false because this is an inference only config + + scheduler_config: # 10000 warmup steps + warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1.e-4 ] + f_min: [ 1.e-10 ] + + + unet_config: + use_checkpoint: True + use_fp16: True + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + + cond_stage_config: + freeze: True + layer: "penultimate" + +data: + batch_size: 128 + # num_workwers should be 2 * batch_size, and the total num less than 1024 + # e.g. if use 8 devices, no more than 128 + num_workers: 128 + train: + target: ldm.data.base.Txt2ImgIterableBaseDataset + params: + file_path: # YOUR DATAPATH + world_size: 1 + rank: 0 + +lightning: + trainer: + accelerator: 'gpu' + devices: 8 + log_gpu_memory: all + max_epochs: 2 + precision: 16 + auto_select_gpus: False + strategy: + find_unused_parameters: False + log_every_n_steps: 2 +# max_steps: 6o + logger: True + default_root_dir: "/tmp/diff_log/" + # profiler: pytorch + + logger_config: + wandb: + name: nowname + save_dir: "/data2/tmp/diff_log/" + offline: opt.debug + id: nowname diff --git a/examples/images/diffusion/docker/Dockerfile b/examples/images/diffusion/docker/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..3b5301b898532572832ad79d4fc5eba7c3025dcf --- /dev/null +++ b/examples/images/diffusion/docker/Dockerfile @@ -0,0 +1,34 @@ +FROM hpcaitech/pytorch-cuda:1.12.0-11.3.0 + +# install torch +# RUN conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch +RUN apt-get update +RUN apt-get install ffmpeg libsm6 libxext6 -y + +# install apex +RUN git clone https://github.com/NVIDIA/apex && \ + cd apex && \ + pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_layer_norm" ./ + +# install colossalai +# RUN git clone https://github.com/hpcaitech/ColossalAI.git \ +# && cd ./ColossalAI \ +# && pip install -v --no-cache-dir . + +RUN pip install colossalai + + +# install titans +RUN pip install --no-cache-dir titans + +RUN git clone https://github.com/hpcaitech/ColossalAI.git && \ + cd ./ColossalAI/examples/images/diffusion && \ + pip install -r requirements.txt && \ + pip install --no-cache-dir transformers==4.19.2 diffusers invisible-watermark + +# install tensornvme +# RUN conda install cmake && \ +# git clone https://github.com/hpcaitech/TensorNVMe.git && \ +# cd TensorNVMe && \ +# pip install -r requirements.txt && \ +# pip install -v --no-cache-dir . diff --git a/examples/images/diffusion/environment.yaml b/examples/images/diffusion/environment.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d1ec69c1a5854cf97f8e8e25be11709ac47aee51 --- /dev/null +++ b/examples/images/diffusion/environment.yaml @@ -0,0 +1,32 @@ +name: ldm +channels: + - pytorch + - defaults +dependencies: + - python=3.9.12 + - pip=20.3 + - cudatoolkit=11.3 + - pytorch=1.12.1 + - torchvision=0.13.1 + - numpy=1.23.1 + - pip: + - albumentations==1.3.0 + - opencv-python==4.6.0.66 + - imageio==2.9.0 + - imageio-ffmpeg==0.4.2 + - omegaconf==2.1.1 + - test-tube>=0.7.5 + - streamlit==1.12.1 + - einops==0.3.0 + - transformers + - webdataset==0.2.5 + - kornia==0.6 + - open_clip_torch==2.0.2 + - invisible-watermark>=0.1.5 + - streamlit-drawable-canvas==0.8.0 + - torchmetrics==0.7.0 + - prefetch_generator + - datasets + - colossalai==0.2.5 + - lightning==1.9.0 + - -e . diff --git a/examples/images/diffusion/ldm/.DS_Store b/examples/images/diffusion/ldm/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..647199f9f93c6d786cda169822854ac38ca55144 Binary files /dev/null and b/examples/images/diffusion/ldm/.DS_Store differ diff --git a/examples/images/diffusion/ldm/data/__init__.py b/examples/images/diffusion/ldm/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/images/diffusion/ldm/data/base.py b/examples/images/diffusion/ldm/data/base.py new file mode 100644 index 0000000000000000000000000000000000000000..a12492c95a162e53ba17903d9bfeb99ccd4a623f --- /dev/null +++ b/examples/images/diffusion/ldm/data/base.py @@ -0,0 +1,76 @@ +import math +import os +from abc import abstractmethod + +import cv2 +import numpy as np +import torch +from torch.utils.data import ChainDataset, ConcatDataset, Dataset, IterableDataset + + +class Txt2ImgIterableBaseDataset(IterableDataset): + ''' + Define an interface to make the IterableDatasets for text2img data chainable + ''' + + def __init__(self, file_path: str, rank, world_size): + super().__init__() + self.file_path = file_path + self.folder_list = [] + self.file_list = [] + self.txt_list = [] + self.info = self._get_file_info(file_path) + self.start = self.info['start'] + self.end = self.info['end'] + self.rank = rank + + self.world_size = world_size + # self.per_worker = int(math.floor((self.end - self.start) / float(self.world_size))) + # self.iter_start = self.start + self.rank * self.per_worker + # self.iter_end = min(self.iter_start + self.per_worker, self.end) + # self.num_records = self.iter_end - self.iter_start + # self.valid_ids = [i for i in range(self.iter_end)] + self.num_records = self.end - self.start + self.valid_ids = [i for i in range(self.end)] + + print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') + + def __len__(self): + # return self.iter_end - self.iter_start + return self.end - self.start + + def __iter__(self): + sample_iterator = self._sample_generator(self.start, self.end) + # sample_iterator = self._sample_generator(self.iter_start, self.iter_end) + return sample_iterator + + def _sample_generator(self, start, end): + for idx in range(start, end): + file_name = self.file_list[idx] + txt_name = self.txt_list[idx] + f_ = open(txt_name, 'r') + txt_ = f_.read() + f_.close() + image = cv2.imdecode(np.fromfile(file_name, dtype=np.uint8), 1) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = torch.from_numpy(image) / 255 + yield {"txt": txt_, "image": image} + + def _get_file_info(self, file_path): + info = \ + { + "start": 1, + "end": 0, + } + self.folder_list = [file_path + i for i in os.listdir(file_path) if '.' not in i] + for folder in self.folder_list: + files = [folder + '/' + i for i in os.listdir(folder) if 'jpg' in i] + txts = [k.replace('jpg', 'txt') for k in files] + self.file_list.extend(files) + self.txt_list.extend(txts) + info['end'] = len(self.file_list) + # with open(file_path, 'r') as fin: + # for _ in enumerate(fin): + # info['end'] += 1 + # self.txt_list = [k.replace('jpg', 'txt') for k in self.file_list] + return info diff --git a/examples/images/diffusion/ldm/data/cifar10.py b/examples/images/diffusion/ldm/data/cifar10.py new file mode 100644 index 0000000000000000000000000000000000000000..53cd61263b472d37c6b2b896cfd8ba2f89477b9a --- /dev/null +++ b/examples/images/diffusion/ldm/data/cifar10.py @@ -0,0 +1,184 @@ +from typing import Dict +import numpy as np +from omegaconf import DictConfig, ListConfig +import torch +from torch.utils.data import Dataset +from pathlib import Path +import json +from PIL import Image +from torchvision import transforms +from einops import rearrange +from ldm.util import instantiate_from_config +from datasets import load_dataset + +def make_multi_folder_data(paths, caption_files=None, **kwargs): + """Make a concat dataset from multiple folders + Don't suport captions yet + If paths is a list, that's ok, if it's a Dict interpret it as: + k=folder v=n_times to repeat that + """ + list_of_paths = [] + if isinstance(paths, (Dict, DictConfig)): + assert caption_files is None, \ + "Caption files not yet supported for repeats" + for folder_path, repeats in paths.items(): + list_of_paths.extend([folder_path]*repeats) + paths = list_of_paths + + if caption_files is not None: + datasets = [FolderData(p, caption_file=c, **kwargs) for (p, c) in zip(paths, caption_files)] + else: + datasets = [FolderData(p, **kwargs) for p in paths] + return torch.utils.data.ConcatDataset(datasets) + +class FolderData(Dataset): + def __init__(self, + root_dir, + caption_file=None, + image_transforms=[], + ext="jpg", + default_caption="", + postprocess=None, + return_paths=False, + ) -> None: + """Create a dataset from a folder of images. + If you pass in a root directory it will be searched for images + ending in ext (ext can be a list) + """ + self.root_dir = Path(root_dir) + self.default_caption = default_caption + self.return_paths = return_paths + if isinstance(postprocess, DictConfig): + postprocess = instantiate_from_config(postprocess) + self.postprocess = postprocess + if caption_file is not None: + with open(caption_file, "rt") as f: + ext = Path(caption_file).suffix.lower() + if ext == ".json": + captions = json.load(f) + elif ext == ".jsonl": + lines = f.readlines() + lines = [json.loads(x) for x in lines] + captions = {x["file_name"]: x["text"].strip("\n") for x in lines} + else: + raise ValueError(f"Unrecognised format: {ext}") + self.captions = captions + else: + self.captions = None + + if not isinstance(ext, (tuple, list, ListConfig)): + ext = [ext] + + # Only used if there is no caption file + self.paths = [] + for e in ext: + self.paths.extend(list(self.root_dir.rglob(f"*.{e}"))) + if isinstance(image_transforms, ListConfig): + image_transforms = [instantiate_from_config(tt) for tt in image_transforms] + image_transforms.extend([transforms.ToTensor(), + transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) + image_transforms = transforms.Compose(image_transforms) + self.tform = image_transforms + + + def __len__(self): + if self.captions is not None: + return len(self.captions.keys()) + else: + return len(self.paths) + + def __getitem__(self, index): + data = {} + if self.captions is not None: + chosen = list(self.captions.keys())[index] + caption = self.captions.get(chosen, None) + if caption is None: + caption = self.default_caption + filename = self.root_dir/chosen + else: + filename = self.paths[index] + + if self.return_paths: + data["path"] = str(filename) + + im = Image.open(filename) + im = self.process_im(im) + data["image"] = im + + if self.captions is not None: + data["txt"] = caption + else: + data["txt"] = self.default_caption + + if self.postprocess is not None: + data = self.postprocess(data) + + return data + + def process_im(self, im): + im = im.convert("RGB") + return self.tform(im) + +def hf_dataset( + name, + image_transforms=[], + image_column="img", + label_column="label", + text_column="txt", + split='train', + image_key='image', + caption_key='txt', + ): + """Make huggingface dataset with appropriate list of transforms applied + """ + ds = load_dataset(name, split=split) + image_transforms = [instantiate_from_config(tt) for tt in image_transforms] + image_transforms.extend([transforms.ToTensor(), + transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) + tform = transforms.Compose(image_transforms) + + assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}" + assert label_column in ds.column_names, f"Didn't find column {label_column} in {ds.column_names}" + + def pre_process(examples): + processed = {} + processed[image_key] = [tform(im) for im in examples[image_column]] + + label_to_text_dict = {0: "airplane", 1: "automobile", 2: "bird", 3: "cat", 4: "deer", 5: "dog", 6: "frog", 7: "horse", 8: "ship", 9: "truck"} + + processed[caption_key] = [label_to_text_dict[label] for label in examples[label_column]] + + return processed + + ds.set_transform(pre_process) + return ds + +class TextOnly(Dataset): + def __init__(self, captions, output_size, image_key="image", caption_key="txt", n_gpus=1): + """Returns only captions with dummy images""" + self.output_size = output_size + self.image_key = image_key + self.caption_key = caption_key + if isinstance(captions, Path): + self.captions = self._load_caption_file(captions) + else: + self.captions = captions + + if n_gpus > 1: + # hack to make sure that all the captions appear on each gpu + repeated = [n_gpus*[x] for x in self.captions] + self.captions = [] + [self.captions.extend(x) for x in repeated] + + def __len__(self): + return len(self.captions) + + def __getitem__(self, index): + dummy_im = torch.zeros(3, self.output_size, self.output_size) + dummy_im = rearrange(dummy_im * 2. - 1., 'c h w -> h w c') + return {self.image_key: dummy_im, self.caption_key: self.captions[index]} + + def _load_caption_file(self, filename): + with open(filename, 'rt') as f: + captions = f.readlines() + return [x.strip('\n') for x in captions] \ No newline at end of file diff --git a/examples/images/diffusion/ldm/data/imagenet.py b/examples/images/diffusion/ldm/data/imagenet.py new file mode 100644 index 0000000000000000000000000000000000000000..1c473f9c6965b22315dbb289eff8247c71bdc790 --- /dev/null +++ b/examples/images/diffusion/ldm/data/imagenet.py @@ -0,0 +1,394 @@ +import os, yaml, pickle, shutil, tarfile, glob +import cv2 +import albumentations +import PIL +import numpy as np +import torchvision.transforms.functional as TF +from omegaconf import OmegaConf +from functools import partial +from PIL import Image +from tqdm import tqdm +from torch.utils.data import Dataset, Subset + +import taming.data.utils as tdu +from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve +from taming.data.imagenet import ImagePaths + +from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light + + +def synset2idx(path_to_yaml="data/index_synset.yaml"): + with open(path_to_yaml) as f: + di2s = yaml.load(f) + return dict((v,k) for k,v in di2s.items()) + + +class ImageNetBase(Dataset): + def __init__(self, config=None): + self.config = config or OmegaConf.create() + if not type(self.config)==dict: + self.config = OmegaConf.to_container(self.config) + self.keep_orig_class_label = self.config.get("keep_orig_class_label", False) + self.process_images = True # if False we skip loading & processing images and self.data contains filepaths + self._prepare() + self._prepare_synset_to_human() + self._prepare_idx_to_synset() + self._prepare_human_to_integer_label() + self._load() + + def __len__(self): + return len(self.data) + + def __getitem__(self, i): + return self.data[i] + + def _prepare(self): + raise NotImplementedError() + + def _filter_relpaths(self, relpaths): + ignore = set([ + "n06596364_9591.JPEG", + ]) + relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore] + if "sub_indices" in self.config: + indices = str_to_indices(self.config["sub_indices"]) + synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings + self.synset2idx = synset2idx(path_to_yaml=self.idx2syn) + files = [] + for rpath in relpaths: + syn = rpath.split("/")[0] + if syn in synsets: + files.append(rpath) + return files + else: + return relpaths + + def _prepare_synset_to_human(self): + SIZE = 2655750 + URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1" + self.human_dict = os.path.join(self.root, "synset_human.txt") + if (not os.path.exists(self.human_dict) or + not os.path.getsize(self.human_dict)==SIZE): + download(URL, self.human_dict) + + def _prepare_idx_to_synset(self): + URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1" + self.idx2syn = os.path.join(self.root, "index_synset.yaml") + if (not os.path.exists(self.idx2syn)): + download(URL, self.idx2syn) + + def _prepare_human_to_integer_label(self): + URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1" + self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt") + if (not os.path.exists(self.human2integer)): + download(URL, self.human2integer) + with open(self.human2integer, "r") as f: + lines = f.read().splitlines() + assert len(lines) == 1000 + self.human2integer_dict = dict() + for line in lines: + value, key = line.split(":") + self.human2integer_dict[key] = int(value) + + def _load(self): + with open(self.txt_filelist, "r") as f: + self.relpaths = f.read().splitlines() + l1 = len(self.relpaths) + self.relpaths = self._filter_relpaths(self.relpaths) + print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths))) + + self.synsets = [p.split("/")[0] for p in self.relpaths] + self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths] + + unique_synsets = np.unique(self.synsets) + class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets)) + if not self.keep_orig_class_label: + self.class_labels = [class_dict[s] for s in self.synsets] + else: + self.class_labels = [self.synset2idx[s] for s in self.synsets] + + with open(self.human_dict, "r") as f: + human_dict = f.read().splitlines() + human_dict = dict(line.split(maxsplit=1) for line in human_dict) + + self.human_labels = [human_dict[s] for s in self.synsets] + + labels = { + "relpath": np.array(self.relpaths), + "synsets": np.array(self.synsets), + "class_label": np.array(self.class_labels), + "human_label": np.array(self.human_labels), + } + + if self.process_images: + self.size = retrieve(self.config, "size", default=256) + self.data = ImagePaths(self.abspaths, + labels=labels, + size=self.size, + random_crop=self.random_crop, + ) + else: + self.data = self.abspaths + + +class ImageNetTrain(ImageNetBase): + NAME = "ILSVRC2012_train" + URL = "http://www.image-net.org/challenges/LSVRC/2012/" + AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2" + FILES = [ + "ILSVRC2012_img_train.tar", + ] + SIZES = [ + 147897477120, + ] + + def __init__(self, process_images=True, data_root=None, **kwargs): + self.process_images = process_images + self.data_root = data_root + super().__init__(**kwargs) + + def _prepare(self): + if self.data_root: + self.root = os.path.join(self.data_root, self.NAME) + else: + cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) + self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) + + self.datadir = os.path.join(self.root, "data") + self.txt_filelist = os.path.join(self.root, "filelist.txt") + self.expected_length = 1281167 + self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", + default=True) + if not tdu.is_prepared(self.root): + # prep + print("Preparing dataset {} in {}".format(self.NAME, self.root)) + + datadir = self.datadir + if not os.path.exists(datadir): + path = os.path.join(self.root, self.FILES[0]) + if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) + assert atpath == path + + print("Extracting {} to {}".format(path, datadir)) + os.makedirs(datadir, exist_ok=True) + with tarfile.open(path, "r:") as tar: + tar.extractall(path=datadir) + + print("Extracting sub-tars.") + subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar"))) + for subpath in tqdm(subpaths): + subdir = subpath[:-len(".tar")] + os.makedirs(subdir, exist_ok=True) + with tarfile.open(subpath, "r:") as tar: + tar.extractall(path=subdir) + + filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) + filelist = [os.path.relpath(p, start=datadir) for p in filelist] + filelist = sorted(filelist) + filelist = "\n".join(filelist)+"\n" + with open(self.txt_filelist, "w") as f: + f.write(filelist) + + tdu.mark_prepared(self.root) + + +class ImageNetValidation(ImageNetBase): + NAME = "ILSVRC2012_validation" + URL = "http://www.image-net.org/challenges/LSVRC/2012/" + AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5" + VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1" + FILES = [ + "ILSVRC2012_img_val.tar", + "validation_synset.txt", + ] + SIZES = [ + 6744924160, + 1950000, + ] + + def __init__(self, process_images=True, data_root=None, **kwargs): + self.data_root = data_root + self.process_images = process_images + super().__init__(**kwargs) + + def _prepare(self): + if self.data_root: + self.root = os.path.join(self.data_root, self.NAME) + else: + cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) + self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) + self.datadir = os.path.join(self.root, "data") + self.txt_filelist = os.path.join(self.root, "filelist.txt") + self.expected_length = 50000 + self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", + default=False) + if not tdu.is_prepared(self.root): + # prep + print("Preparing dataset {} in {}".format(self.NAME, self.root)) + + datadir = self.datadir + if not os.path.exists(datadir): + path = os.path.join(self.root, self.FILES[0]) + if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) + assert atpath == path + + print("Extracting {} to {}".format(path, datadir)) + os.makedirs(datadir, exist_ok=True) + with tarfile.open(path, "r:") as tar: + tar.extractall(path=datadir) + + vspath = os.path.join(self.root, self.FILES[1]) + if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]: + download(self.VS_URL, vspath) + + with open(vspath, "r") as f: + synset_dict = f.read().splitlines() + synset_dict = dict(line.split() for line in synset_dict) + + print("Reorganizing into synset folders") + synsets = np.unique(list(synset_dict.values())) + for s in synsets: + os.makedirs(os.path.join(datadir, s), exist_ok=True) + for k, v in synset_dict.items(): + src = os.path.join(datadir, k) + dst = os.path.join(datadir, v) + shutil.move(src, dst) + + filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) + filelist = [os.path.relpath(p, start=datadir) for p in filelist] + filelist = sorted(filelist) + filelist = "\n".join(filelist)+"\n" + with open(self.txt_filelist, "w") as f: + f.write(filelist) + + tdu.mark_prepared(self.root) + + + +class ImageNetSR(Dataset): + def __init__(self, size=None, + degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1., + random_crop=True): + """ + Imagenet Superresolution Dataloader + Performs following ops in order: + 1. crops a crop of size s from image either as random or center crop + 2. resizes crop to size with cv2.area_interpolation + 3. degrades resized crop with degradation_fn + + :param size: resizing to size after cropping + :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light + :param downscale_f: Low Resolution Downsample factor + :param min_crop_f: determines crop size s, + where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f) + :param max_crop_f: "" + :param data_root: + :param random_crop: + """ + self.base = self.get_base() + assert size + assert (size / downscale_f).is_integer() + self.size = size + self.LR_size = int(size / downscale_f) + self.min_crop_f = min_crop_f + self.max_crop_f = max_crop_f + assert(max_crop_f <= 1.) + self.center_crop = not random_crop + + self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA) + + self.pil_interpolation = False # gets reset later if incase interp_op is from pillow + + if degradation == "bsrgan": + self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f) + + elif degradation == "bsrgan_light": + self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f) + + else: + interpolation_fn = { + "cv_nearest": cv2.INTER_NEAREST, + "cv_bilinear": cv2.INTER_LINEAR, + "cv_bicubic": cv2.INTER_CUBIC, + "cv_area": cv2.INTER_AREA, + "cv_lanczos": cv2.INTER_LANCZOS4, + "pil_nearest": PIL.Image.NEAREST, + "pil_bilinear": PIL.Image.BILINEAR, + "pil_bicubic": PIL.Image.BICUBIC, + "pil_box": PIL.Image.BOX, + "pil_hamming": PIL.Image.HAMMING, + "pil_lanczos": PIL.Image.LANCZOS, + }[degradation] + + self.pil_interpolation = degradation.startswith("pil_") + + if self.pil_interpolation: + self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn) + + else: + self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size, + interpolation=interpolation_fn) + + def __len__(self): + return len(self.base) + + def __getitem__(self, i): + example = self.base[i] + image = Image.open(example["file_path_"]) + + if not image.mode == "RGB": + image = image.convert("RGB") + + image = np.array(image).astype(np.uint8) + + min_side_len = min(image.shape[:2]) + crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None) + crop_side_len = int(crop_side_len) + + if self.center_crop: + self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len) + + else: + self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len) + + image = self.cropper(image=image)["image"] + image = self.image_rescaler(image=image)["image"] + + if self.pil_interpolation: + image_pil = PIL.Image.fromarray(image) + LR_image = self.degradation_process(image_pil) + LR_image = np.array(LR_image).astype(np.uint8) + + else: + LR_image = self.degradation_process(image=image)["image"] + + example["image"] = (image/127.5 - 1.0).astype(np.float32) + example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32) + + return example + + +class ImageNetSRTrain(ImageNetSR): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def get_base(self): + with open("data/imagenet_train_hr_indices.p", "rb") as f: + indices = pickle.load(f) + dset = ImageNetTrain(process_images=False,) + return Subset(dset, indices) + + +class ImageNetSRValidation(ImageNetSR): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def get_base(self): + with open("data/imagenet_val_hr_indices.p", "rb") as f: + indices = pickle.load(f) + dset = ImageNetValidation(process_images=False,) + return Subset(dset, indices) diff --git a/examples/images/diffusion/ldm/data/lsun.py b/examples/images/diffusion/ldm/data/lsun.py new file mode 100644 index 0000000000000000000000000000000000000000..f5bf26c1425413f2d8ecb39913c2baccc42c5631 --- /dev/null +++ b/examples/images/diffusion/ldm/data/lsun.py @@ -0,0 +1,110 @@ +import os +import numpy as np +import PIL +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + +# This class is used to create a dataset of images from LSUN dataset for training +class LSUNBase(Dataset): + def __init__(self, + txt_file, # path to the text file containing the list of image paths + data_root, # root directory of the LSUN dataset + size=None, # the size of images to resize to + interpolation="bicubic", # interpolation method to be used while resizing + flip_p=0.5 # probability of random horizontal flipping + ): + self.data_paths = txt_file # store path to text file containing list of images + self.data_root = data_root # store path to root directory of the dataset + with open(self.data_paths, "r") as f: # open and read the text file + self.image_paths = f.read().splitlines() # read the lines of the file and store as list + self._length = len(self.image_paths) # store the number of images + + # create dictionary to hold image path information + self.labels = { + "relative_file_path_": [l for l in self.image_paths], + "file_path_": [os.path.join(self.data_root, l) + for l in self.image_paths], + } + + # set the image size to be resized + self.size = size + # set the interpolation method for resizing the image + self.interpolation = {"linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + }[interpolation] + # randomly flip the image horizontally with a given probability + self.flip = transforms.RandomHorizontalFlip(p=flip_p) + + def __len__(self): + # return the length of dataset + return self._length + + + def __getitem__(self, i): + # get the image path for the given index + example = dict((k, self.labels[k][i]) for k in self.labels) + image = Image.open(example["file_path_"]) + # convert it to RGB format + if not image.mode == "RGB": + image = image.convert("RGB") + + # default to score-sde preprocessing + + img = np.array(image).astype(np.uint8) # convert image to numpy array + crop = min(img.shape[0], img.shape[1]) # crop the image to a square shape + h, w, = img.shape[0], img.shape[1] # get the height and width of image + img = img[(h - crop) // 2:(h + crop) // 2, + (w - crop) // 2:(w + crop) // 2] # crop the image to a square shape + + image = Image.fromarray(img) # create an image from numpy array + if self.size is not None: # if image size is provided, resize the image + image = image.resize((self.size, self.size), resample=self.interpolation) + + image = self.flip(image) # flip the image horizontally with the given probability + image = np.array(image).astype(np.uint8) + example["image"] = (image / 127.5 - 1.0).astype(np.float32) # normalize the image values and convert to float32 + return example # return the example dictionary containing the image and its file paths + +#A dataset class for LSUN Churches training set. +# It initializes by calling the constructor of LSUNBase class and passing the appropriate arguments. +# The text file containing the paths to the images and the root directory where the images are stored are passed as arguments. Any additional keyword arguments passed to this class will be forwarded to the constructor of the parent class. +class LSUNChurchesTrain(LSUNBase): + def __init__(self, **kwargs): + super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs) + +#A dataset class for LSUN Churches validation set. +# It is similar to LSUNChurchesTrain except that it uses a different text file and sets the flip probability to zero by default. +class LSUNChurchesValidation(LSUNBase): + def __init__(self, flip_p=0., **kwargs): + super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", + flip_p=flip_p, **kwargs) + +# A dataset class for LSUN Bedrooms training set. +# It initializes by calling the constructor of LSUNBase class and passing the appropriate arguments. +class LSUNBedroomsTrain(LSUNBase): + def __init__(self, **kwargs): + super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs) + +# A dataset class for LSUN Bedrooms validation set. +# It is similar to LSUNBedroomsTrain except that it uses a different text file and sets the flip probability to zero by default. +class LSUNBedroomsValidation(LSUNBase): + def __init__(self, flip_p=0.0, **kwargs): + super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", + flip_p=flip_p, **kwargs) + +# A dataset class for LSUN Cats training set. +# It initializes by calling the constructor of LSUNBase class and passing the appropriate arguments. +# The text file containing the paths to the images and the root directory where the images are stored are passed as arguments. +class LSUNCatsTrain(LSUNBase): + def __init__(self, **kwargs): + super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs) + +# A dataset class for LSUN Cats validation set. +# It is similar to LSUNCatsTrain except that it uses a different text file and sets the flip probability to zero by default. +class LSUNCatsValidation(LSUNBase): + def __init__(self, flip_p=0., **kwargs): + super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", + flip_p=flip_p, **kwargs) diff --git a/examples/images/diffusion/ldm/data/teyvat.py b/examples/images/diffusion/ldm/data/teyvat.py new file mode 100644 index 0000000000000000000000000000000000000000..eb5d3ea469d4c68186f15008accbcc1a9ae8c8b7 --- /dev/null +++ b/examples/images/diffusion/ldm/data/teyvat.py @@ -0,0 +1,152 @@ +from typing import Dict +import numpy as np +from omegaconf import DictConfig, ListConfig +import torch +from torch.utils.data import Dataset +from pathlib import Path +import json +from PIL import Image +from torchvision import transforms +from einops import rearrange +from ldm.util import instantiate_from_config +from datasets import load_dataset + +def make_multi_folder_data(paths, caption_files=None, **kwargs): + """Make a concat dataset from multiple folders + Don't support captions yet + If paths is a list, that's ok, if it's a Dict interpret it as: + k=folder v=n_times to repeat that + """ + list_of_paths = [] + if isinstance(paths, (Dict, DictConfig)): + assert caption_files is None, \ + "Caption files not yet supported for repeats" + for folder_path, repeats in paths.items(): + list_of_paths.extend([folder_path]*repeats) + paths = list_of_paths + + if caption_files is not None: + datasets = [FolderData(p, caption_file=c, **kwargs) for (p, c) in zip(paths, caption_files)] + else: + datasets = [FolderData(p, **kwargs) for p in paths] + return torch.utils.data.ConcatDataset(datasets) + +class FolderData(Dataset): + def __init__(self, + root_dir, + caption_file=None, + image_transforms=[], + ext="jpg", + default_caption="", + postprocess=None, + return_paths=False, + ) -> None: + """Create a dataset from a folder of images. + If you pass in a root directory it will be searched for images + ending in ext (ext can be a list) + """ + self.root_dir = Path(root_dir) + self.default_caption = default_caption + self.return_paths = return_paths + if isinstance(postprocess, DictConfig): + postprocess = instantiate_from_config(postprocess) + self.postprocess = postprocess + if caption_file is not None: + with open(caption_file, "rt") as f: + ext = Path(caption_file).suffix.lower() + if ext == ".json": + captions = json.load(f) + elif ext == ".jsonl": + lines = f.readlines() + lines = [json.loads(x) for x in lines] + captions = {x["file_name"]: x["text"].strip("\n") for x in lines} + else: + raise ValueError(f"Unrecognised format: {ext}") + self.captions = captions + else: + self.captions = None + + if not isinstance(ext, (tuple, list, ListConfig)): + ext = [ext] + + # Only used if there is no caption file + self.paths = [] + for e in ext: + self.paths.extend(list(self.root_dir.rglob(f"*.{e}"))) + if isinstance(image_transforms, ListConfig): + image_transforms = [instantiate_from_config(tt) for tt in image_transforms] + image_transforms.extend([transforms.ToTensor(), + transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) + image_transforms = transforms.Compose(image_transforms) + self.tform = image_transforms + + + def __len__(self): + if self.captions is not None: + return len(self.captions.keys()) + else: + return len(self.paths) + + def __getitem__(self, index): + data = {} + if self.captions is not None: + chosen = list(self.captions.keys())[index] + caption = self.captions.get(chosen, None) + if caption is None: + caption = self.default_caption + filename = self.root_dir/chosen + else: + filename = self.paths[index] + + if self.return_paths: + data["path"] = str(filename) + + im = Image.open(filename) + im = self.process_im(im) + data["image"] = im + + if self.captions is not None: + data["txt"] = caption + else: + data["txt"] = self.default_caption + + if self.postprocess is not None: + data = self.postprocess(data) + + return data + + def process_im(self, im): + im = im.convert("RGB") + return self.tform(im) + +def hf_dataset( + path = "Fazzie/Teyvat", + image_transforms=[], + image_column="image", + text_column="text", + image_key='image', + caption_key='txt', + ): + """Make huggingface dataset with appropriate list of transforms applied + """ + ds = load_dataset(path, name="train") + ds = ds["train"] + image_transforms = [instantiate_from_config(tt) for tt in image_transforms] + image_transforms.extend([transforms.Resize((256, 256)), + transforms.ToTensor(), + transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))] + ) + tform = transforms.Compose(image_transforms) + + assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}" + assert text_column in ds.column_names, f"Didn't find column {text_column} in {ds.column_names}" + + def pre_process(examples): + processed = {} + processed[image_key] = [tform(im) for im in examples[image_column]] + processed[caption_key] = examples[text_column] + + return processed + + ds.set_transform(pre_process) + return ds \ No newline at end of file diff --git a/examples/images/diffusion/ldm/lr_scheduler.py b/examples/images/diffusion/ldm/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..be39da9ca6dacc22bf3df9c7389bbb403a4a3ade --- /dev/null +++ b/examples/images/diffusion/ldm/lr_scheduler.py @@ -0,0 +1,98 @@ +import numpy as np + + +class LambdaWarmUpCosineScheduler: + """ + note: use with a base_lr of 1.0 + """ + def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): + self.lr_warm_up_steps = warm_up_steps + self.lr_start = lr_start + self.lr_min = lr_min + self.lr_max = lr_max + self.lr_max_decay_steps = max_decay_steps + self.last_lr = 0. + self.verbosity_interval = verbosity_interval + + def schedule(self, n, **kwargs): + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") + if n < self.lr_warm_up_steps: + lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start + self.last_lr = lr + return lr + else: + t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) + t = min(t, 1.0) + lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( + 1 + np.cos(t * np.pi)) + self.last_lr = lr + return lr + + def __call__(self, n, **kwargs): + return self.schedule(n,**kwargs) + + +class LambdaWarmUpCosineScheduler2: + """ + supports repeated iterations, configurable via lists + note: use with a base_lr of 1.0. + """ + def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): + assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) + self.lr_warm_up_steps = warm_up_steps + self.f_start = f_start + self.f_min = f_min + self.f_max = f_max + self.cycle_lengths = cycle_lengths + self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) + self.last_f = 0. + self.verbosity_interval = verbosity_interval + + def find_in_interval(self, n): + interval = 0 + for cl in self.cum_cycles[1:]: + if n <= cl: + return interval + interval += 1 + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) + t = min(t, 1.0) + f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( + 1 + np.cos(t * np.pi)) + self.last_f = f + return f + + def __call__(self, n, **kwargs): + return self.schedule(n, **kwargs) + + +class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") + + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) + self.last_f = f + return f + diff --git a/examples/images/diffusion/ldm/models/autoencoder.py b/examples/images/diffusion/ldm/models/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..f0a69fe63a8ce73f6b95540d0cdd9e0adfdbf170 --- /dev/null +++ b/examples/images/diffusion/ldm/models/autoencoder.py @@ -0,0 +1,220 @@ +import torch +import lightning.pytorch as pl + +from torch import nn +from torch.nn import functional as F +from torch.nn import Identity +from contextlib import contextmanager + +from ldm.modules.diffusionmodules.model import Encoder, Decoder +from ldm.modules.distributions.distributions import DiagonalGaussianDistribution +from ldm.modules.ema import LitEma + + +class AutoencoderKL(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + ema_decay=None, + learn_logvar=False + ): + super().__init__() + self.learn_logvar = learn_logvar + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = Identity() + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + + self.use_ema = ema_decay is not None + if self.use_ema: + self.ema_decay = ema_decay + assert 0. < ema_decay < 1. + self.model_ema = LitEma(self, decay=ema_decay) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.parameters()) + self.model_ema.copy_to(self) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self) + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + + if optimizer_idx == 0: + # train encoder+decoder+logvar + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + if optimizer_idx == 1: + # train the discriminator + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + + self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return discloss + + def validation_step(self, batch, batch_idx): + log_dict = self._validation_step(batch, batch_idx) + with self.ema_scope(): + log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") + return log_dict + + def _validation_step(self, batch, batch_idx, postfix=""): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, + last_layer=self.get_last_layer(), split="val"+postfix) + + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, + last_layer=self.get_last_layer(), split="val"+postfix) + + self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list( + self.quant_conv.parameters()) + list(self.post_quant_conv.parameters()) + if self.learn_logvar: + print(f"{self.__class__.__name__}: Learning logvar") + ae_params_list.append(self.loss.logvar) + opt_ae = torch.optim.Adam(ae_params_list, + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + if log_ema or self.use_ema: + with self.ema_scope(): + xrec_ema, posterior_ema = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec_ema.shape[1] > 3 + xrec_ema = self.to_rgb(xrec_ema) + log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample())) + log["reconstructions_ema"] = xrec_ema + log["inputs"] = x + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x + diff --git a/examples/images/diffusion/ldm/models/diffusion/__init__.py b/examples/images/diffusion/ldm/models/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/images/diffusion/ldm/models/diffusion/classifier.py b/examples/images/diffusion/ldm/models/diffusion/classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..3cf12f093beaa34d52119f4634a1922a35ee00eb --- /dev/null +++ b/examples/images/diffusion/ldm/models/diffusion/classifier.py @@ -0,0 +1,268 @@ +import os +import torch +import lightning.pytorch as pl +from omegaconf import OmegaConf +from torch.nn import functional as F +from torch.optim import AdamW +from torch.optim.lr_scheduler import LambdaLR +from copy import deepcopy +from einops import rearrange +from glob import glob +from natsort import natsorted +from ldm.models.diffusion.ddpm import LatentDiffusion +from ldm.lr_scheduler import LambdaLinearScheduler +from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel +from ldm.util import log_txt_as_img, default, ismap + +__models__ = { + 'class_label': EncoderUNetModel, + 'segmentation': UNetModel +} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class NoisyLatentImageClassifier(pl.LightningModule): + + def __init__(self, + diffusion_path, + num_classes, + ckpt_path=None, + pool='attention', + label_key=None, + diffusion_ckpt_path=None, + scheduler_config=None, + weight_decay=1.e-2, + log_steps=10, + monitor='val/loss', + *args, + **kwargs): + super().__init__(*args, **kwargs) + self.num_classes = num_classes + # get latest config of diffusion model + diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] + self.diffusion_config = OmegaConf.load(diffusion_config).model + self.diffusion_config.params.ckpt_path = diffusion_ckpt_path + self.load_diffusion() + + self.monitor = monitor + self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 + self.log_time_interval = self.diffusion_model.num_timesteps // log_steps + self.log_steps = log_steps + + self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ + else self.diffusion_model.cond_stage_key + + assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' + + if self.label_key not in __models__: + raise NotImplementedError() + + self.load_classifier(ckpt_path, pool) + + self.scheduler_config = scheduler_config + self.use_scheduler = self.scheduler_config is not None + self.weight_decay = weight_decay + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def load_diffusion(self): + model = LatentDiffusion(**self.diffusion_config.get('params',dict())) + self.diffusion_model = model.eval() + self.diffusion_model.train = disabled_train + for param in self.diffusion_model.parameters(): + param.requires_grad = False + + def load_classifier(self, ckpt_path, pool): + model_config = deepcopy(self.diffusion_config.params.unet_config.params) + model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels + model_config.out_channels = self.num_classes + if self.label_key == 'class_label': + model_config.pool = pool + + self.model = __models__[self.label_key](**model_config) + if ckpt_path is not None: + print('#####################################################################') + print(f'load from ckpt "{ckpt_path}"') + print('#####################################################################') + self.init_from_ckpt(ckpt_path) + + @torch.no_grad() + def get_x_noisy(self, x, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x)) + continuous_sqrt_alpha_cumprod = None + if self.diffusion_model.use_continuous_noise: + continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) + # todo: make sure t+1 is correct here + + return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, + continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) + + def forward(self, x_noisy, t, *args, **kwargs): + return self.model(x_noisy, t) + + @torch.no_grad() + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = rearrange(x, 'b h w c -> b c h w') + x = x.to(memory_format=torch.contiguous_format).float() + return x + + @torch.no_grad() + def get_conditioning(self, batch, k=None): + if k is None: + k = self.label_key + assert k is not None, 'Needs to provide label key' + + targets = batch[k].to(self.device) + + if self.label_key == 'segmentation': + targets = rearrange(targets, 'b h w c -> b c h w') + for down in range(self.numd): + h, w = targets.shape[-2:] + targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') + + # targets = rearrange(targets,'b c h w -> b h w c') + + return targets + + def compute_top_k(self, logits, labels, k, reduction="mean"): + _, top_ks = torch.topk(logits, k, dim=1) + if reduction == "mean": + return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() + elif reduction == "none": + return (top_ks == labels[:, None]).float().sum(dim=-1) + + def on_train_epoch_start(self): + # save some memory + self.diffusion_model.model.to('cpu') + + @torch.no_grad() + def write_logs(self, loss, logits, targets): + log_prefix = 'train' if self.training else 'val' + log = {} + log[f"{log_prefix}/loss"] = loss.mean() + log[f"{log_prefix}/acc@1"] = self.compute_top_k( + logits, targets, k=1, reduction="mean" + ) + log[f"{log_prefix}/acc@5"] = self.compute_top_k( + logits, targets, k=5, reduction="mean" + ) + + self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) + self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) + self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) + + def shared_step(self, batch, t=None): + x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) + targets = self.get_conditioning(batch) + if targets.dim() == 4: + targets = targets.argmax(dim=1) + if t is None: + t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long() + else: + t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() + x_noisy = self.get_x_noisy(x, t) + logits = self(x_noisy, t) + + loss = F.cross_entropy(logits, targets, reduction='none') + + self.write_logs(loss.detach(), logits.detach(), targets.detach()) + + loss = loss.mean() + return loss, logits, x_noisy, targets + + def training_step(self, batch, batch_idx): + loss, *_ = self.shared_step(batch) + return loss + + def reset_noise_accs(self): + self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in + range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} + + def on_validation_start(self): + self.reset_noise_accs() + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + loss, *_ = self.shared_step(batch) + + for t in self.noisy_acc: + _, logits, _, targets = self.shared_step(batch, t) + self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) + self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) + + return loss + + def configure_optimizers(self): + optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) + + if self.use_scheduler: + scheduler = LambdaLinearScheduler(**self.scheduler_config.get('params',dict())) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [optimizer], scheduler + + return optimizer + + @torch.no_grad() + def log_images(self, batch, N=8, *args, **kwargs): + log = dict() + x = self.get_input(batch, self.diffusion_model.first_stage_key) + log['inputs'] = x + + y = self.get_conditioning(batch) + + if self.label_key == 'class_label': + y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) + log['labels'] = y + + if ismap(y): + log['labels'] = self.diffusion_model.to_rgb(y) + + for step in range(self.log_steps): + current_time = step * self.log_time_interval + + _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) + + log[f'inputs@t{current_time}'] = x_noisy + + pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) + pred = rearrange(pred, 'b h w c -> b c h w') + + log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) + + for key in log: + log[key] = log[key][:N] + + return log diff --git a/examples/images/diffusion/ldm/models/diffusion/ddim.py b/examples/images/diffusion/ldm/models/diffusion/ddim.py new file mode 100644 index 0000000000000000000000000000000000000000..27ead0ea914c64c747b64e690662899fb3801144 --- /dev/null +++ b/examples/images/diffusion/ldm/models/diffusion/ddim.py @@ -0,0 +1,336 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm + +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + dynamic_threshold=None, + ucg_schedule=None, + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + ctmp = conditioning[list(conditioning.keys())[0]] + while isinstance(ctmp, list): ctmp = ctmp[0] + cbs = ctmp.shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + + elif isinstance(conditioning, list): + for ctmp in conditioning: + if ctmp.shape[0] != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ucg_schedule=ucg_schedule + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None, + ucg_schedule=None): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + if ucg_schedule is not None: + assert len(ucg_schedule) == len(time_range) + unconditional_guidance_scale = ucg_schedule[i] + + outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold) + img, pred_x0 = outs + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, + dynamic_threshold=None): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + model_output = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + if isinstance(c, dict): + assert isinstance(unconditional_conditioning, dict) + c_in = dict() + for k in c: + if isinstance(c[k], list): + c_in[k] = [torch.cat([ + unconditional_conditioning[k][i], + c[k][i]]) for i in range(len(c[k]))] + else: + c_in[k] = torch.cat([ + unconditional_conditioning[k], + c[k]]) + elif isinstance(c, list): + c_in = list() + assert isinstance(unconditional_conditioning, list) + for i in range(len(c)): + c_in.append(torch.cat([unconditional_conditioning[i], c[i]])) + else: + c_in = torch.cat([unconditional_conditioning, c]) + model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond) + + if self.model.parameterization == "v": + e_t = self.model.predict_eps_from_z_and_v(x, t, model_output) + else: + e_t = model_output + + if score_corrector is not None: + assert self.model.parameterization == "eps", 'not implemented' + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + if self.model.parameterization != "v": + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + else: + pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) + + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + + if dynamic_threshold is not None: + raise NotImplementedError() + + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + @torch.no_grad() + def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None, + unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None): + num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0] + + assert t_enc <= num_reference_steps + num_steps = t_enc + + if use_original_steps: + alphas_next = self.alphas_cumprod[:num_steps] + alphas = self.alphas_cumprod_prev[:num_steps] + else: + alphas_next = self.ddim_alphas[:num_steps] + alphas = torch.tensor(self.ddim_alphas_prev[:num_steps]) + + x_next = x0 + intermediates = [] + inter_steps = [] + for i in tqdm(range(num_steps), desc='Encoding Image'): + t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long) + if unconditional_guidance_scale == 1.: + noise_pred = self.model.apply_model(x_next, t, c) + else: + assert unconditional_conditioning is not None + e_t_uncond, noise_pred = torch.chunk( + self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)), + torch.cat((unconditional_conditioning, c))), 2) + noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond) + + xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next + weighted_noise_pred = alphas_next[i].sqrt() * ( + (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred + x_next = xt_weighted + weighted_noise_pred + if return_intermediates and i % ( + num_steps // return_intermediates) == 0 and i < num_steps - 1: + intermediates.append(x_next) + inter_steps.append(i) + elif return_intermediates and i >= num_steps - 2: + intermediates.append(x_next) + inter_steps.append(i) + if callback: callback(i) + + out = {'x_encoded': x_next, 'intermediate_steps': inter_steps} + if return_intermediates: + out.update({'intermediates': intermediates}) + return x_next, out + + @torch.no_grad() + def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): + # fast, but does not allow for exact reconstruction + # t serves as an index to gather the correct alphas + if use_original_steps: + sqrt_alphas_cumprod = self.sqrt_alphas_cumprod + sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod + else: + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas + + if noise is None: + noise = torch.randn_like(x0) + return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) + + @torch.no_grad() + def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, + use_original_steps=False, callback=None): + + timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps + timesteps = timesteps[:t_start] + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='Decoding image', total=total_steps) + x_dec = x_latent + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) + x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + if callback: callback(i) + return x_dec \ No newline at end of file diff --git a/examples/images/diffusion/ldm/models/diffusion/ddpm.py b/examples/images/diffusion/ldm/models/diffusion/ddpm.py new file mode 100644 index 0000000000000000000000000000000000000000..842ec1371ea09fdb023f584e0525b815111ad2be --- /dev/null +++ b/examples/images/diffusion/ldm/models/diffusion/ddpm.py @@ -0,0 +1,2059 @@ +""" +wild mixture of +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py +https://github.com/CompVis/taming-transformers +-- merci +""" + +import numpy as np +import torch +import torch.nn as nn + +try: + import lightning.pytorch as pl + from lightning.pytorch.utilities import rank_zero_info, rank_zero_only +except: + import pytorch_lightning as pl + from pytorch_lightning.utilities import rank_zero_only, rank_zero_info + +import itertools +from contextlib import contextmanager, nullcontext +from functools import partial + +from einops import rearrange, repeat +from ldm.lr_scheduler import LambdaLinearScheduler +from ldm.models.autoencoder import * +from ldm.models.autoencoder import AutoencoderKL, IdentityFirstStage +from ldm.models.diffusion.ddim import * +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.modules.midas.api import MiDaSInference +from ldm.modules.diffusionmodules.model import * +from ldm.modules.diffusionmodules.model import Decoder, Encoder, Model +from ldm.modules.diffusionmodules.openaimodel import * +from ldm.modules.diffusionmodules.openaimodel import AttentionPool2d, UNetModel +from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule, noise_like +from ldm.modules.distributions.distributions import DiagonalGaussianDistribution, normal_kl +from ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation +from ldm.modules.ema import LitEma +from ldm.modules.encoders.modules import * +from ldm.util import count_params, default, exists, isimage, ismap, log_txt_as_img, mean_flat +from omegaconf import ListConfig +from torch.optim.lr_scheduler import LambdaLR +from torchvision.utils import make_grid +from tqdm import tqdm + +__conditioning_keys__ = {'concat': 'c_concat', 'crossattn': 'c_crossattn', 'adm': 'y'} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def uniform_on_device(r1, r2, shape, device): + return (r1 - r2) * torch.rand(*shape, device=device) + r2 + + +class DDPM(pl.LightningModule): + # classic DDPM with Gaussian diffusion, in image space + def __init__( + self, + unet_config, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + ckpt=None, + ignore_keys=[], + load_only_unet=False, + monitor="val/loss", + use_ema=True, + first_stage_key="image", + image_size=256, + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0., + v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1., + conditioning_key=None, + parameterization="eps", # all assuming fixed variance schedules + scheduler_config=None, + use_positional_encodings=False, + learn_logvar=False, + logvar_init=0., + use_fp16=True, + make_it_fit=False, + ucg_training=None, + reset_ema=False, + reset_num_ema_updates=False, + ): + super().__init__() + assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"' + self.parameterization = parameterization + rank_zero_info(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") + self.cond_stage_model = None + self.clip_denoised = clip_denoised + self.log_every_t = log_every_t + self.first_stage_key = first_stage_key + self.image_size = image_size + self.channels = channels + self.use_positional_encodings = use_positional_encodings + + self.unet_config = unet_config + self.conditioning_key = conditioning_key + self.model = DiffusionWrapper(unet_config, conditioning_key) + # count_params(self.model, verbose=True) + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self.model) + rank_zero_info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.use_scheduler = scheduler_config is not None + if self.use_scheduler: + self.scheduler_config = scheduler_config + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + if monitor is not None: + self.monitor = monitor + self.make_it_fit = make_it_fit + self.ckpt = ckpt + self.ignore_keys = ignore_keys + self.load_only_unet = load_only_unet + self.reset_ema = reset_ema + self.reset_num_ema_updates = reset_num_ema_updates + + if reset_ema: + assert exists(ckpt) + ''' + Uncomment if you Use DDP Strategy + ''' + # if ckpt is not None: + # self.init_from_ckpt(ckpt, ignore_keys=ignore_keys, only_model=load_only_unet) + # if reset_ema: + # assert self.use_ema + # rank_zero_info(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.") + # self.model_ema = LitEma(self.model) + + if reset_num_ema_updates: + rank_zero_info(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ") + assert self.use_ema + self.model_ema.reset_num_updates() + + self.timesteps = timesteps + self.beta_schedule = beta_schedule + self.given_betas = given_betas + self.linear_start = linear_start + self.linear_end = linear_end + self.cosine_s = cosine_s + + self.register_schedule(given_betas=given_betas, + beta_schedule=beta_schedule, + timesteps=timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s) + + self.loss_type = loss_type + + self.logvar_init = logvar_init + self.learn_logvar = learn_logvar + self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + self.use_fp16 = use_fp16 + self.ucg_training = ucg_training or dict() + if self.ucg_training: + self.ucg_prng = np.random.RandomState() + + def register_schedule(self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule(beta_schedule, + timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( + 1. - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer('posterior_mean_coef1', + to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) + self.register_buffer('posterior_mean_coef2', + to_torch((1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + + if self.parameterization == "eps": + lvlb_weights = self.betas**2 / (2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + elif self.parameterization == "x0": + lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) + elif self.parameterization == "v": + lvlb_weights = torch.ones_like(self.betas**2 / (2 * self.posterior_variance * to_torch(alphas) * + (1 - self.alphas_cumprod))) + else: + raise NotImplementedError("mu not supported") + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + rank_zero_info(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + rank_zero_info(f"{context}: Restored training weights") + + @torch.no_grad() + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + rank_zero_info("Deleting key {} from state_dict.".format(k)) + del sd[k] + if self.make_it_fit: + n_params = len([name for name, _ in itertools.chain(self.named_parameters(), self.named_buffers())]) + for name, param in tqdm(itertools.chain(self.named_parameters(), self.named_buffers()), + desc="Fitting old weights to new weights", + total=n_params): + if not name in sd: + continue + old_shape = sd[name].shape + new_shape = param.shape + assert len(old_shape) == len(new_shape) + if len(new_shape) > 2: + # we only modify first two axes + assert new_shape[2:] == old_shape[2:] + # assumes first axis corresponds to output dim + if not new_shape == old_shape: + new_param = param.clone() + old_param = sd[name] + if len(new_shape) == 1: + for i in range(new_param.shape[0]): + new_param[i] = old_param[i % old_shape[0]] + elif len(new_shape) >= 2: + for i in range(new_param.shape[0]): + for j in range(new_param.shape[1]): + new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]] + + n_used_old = torch.ones(old_shape[1]) + for j in range(new_param.shape[1]): + n_used_old[j % old_shape[1]] += 1 + n_used_new = torch.zeros(new_shape[1]) + for j in range(new_param.shape[1]): + n_used_new[j] = n_used_old[j % old_shape[1]] + + n_used_new = n_used_new[None, :] + while len(n_used_new.shape) < len(new_shape): + n_used_new = n_used_new.unsqueeze(-1) + new_param /= n_used_new + + sd[name] = new_param + + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + rank_zero_info(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + rank_zero_info(f"Missing Keys:\n {missing}") + if len(unexpected) > 0: + rank_zero_info(f"\nUnexpected Keys:\n {unexpected}") + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start) + variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise) + + def predict_start_from_z_and_v(self, x_t, t, v): + # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v) + + def predict_eps_from_z_and_v(self, x_t, t, v): + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = (extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t) + posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool): + model_out = self.model(x, t) + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + if clip_denoised: + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, shape, return_intermediates=False): + device = self.betas.device + b = shape[0] + img = torch.randn(shape, device=device) + intermediates = [img] + for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps): + img = self.p_sample(img, + torch.full((b,), i, device=device, dtype=torch.long), + clip_denoised=self.clip_denoised) + if i % self.log_every_t == 0 or i == self.num_timesteps - 1: + intermediates.append(img) + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, batch_size=16, return_intermediates=False): + image_size = self.image_size + channels = self.channels + return self.p_sample_loop((batch_size, channels, image_size, image_size), + return_intermediates=return_intermediates) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def get_v(self, x, noise, t): + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise - + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x) + + def get_loss(self, pred, target, mean=True): + if self.loss_type == 'l1': + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == 'l2': + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, pred, reduction='none') + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + return loss + + def p_losses(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_out = self.model(x_noisy, t) + + loss_dict = {} + if self.parameterization == "eps": + target = noise + elif self.parameterization == "x0": + target = x_start + elif self.parameterization == "v": + target = self.get_v(x_start, noise, t) + else: + raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported") + + loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) + + log_prefix = 'train' if self.training else 'val' + + loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) + loss_simple = loss.mean() * self.l_simple_weight + + loss_vlb = (self.lvlb_weights[t] * loss).mean() + loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) + + loss = loss_simple + self.original_elbo_weight * loss_vlb + + loss_dict.update({f'{log_prefix}/loss': loss}) + + return loss, loss_dict + + def forward(self, x, *args, **kwargs): + # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size + # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + return self.p_losses(x, t, *args, **kwargs) + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = rearrange(x, 'b h w c -> b c h w') + if self.use_fp16: + x = x.to(memory_format=torch.contiguous_format).half() + else: + x = x.to(memory_format=torch.contiguous_format).float() + return x + + def shared_step(self, batch): + x = self.get_input(batch, self.first_stage_key) + loss, loss_dict = self(x) + return loss, loss_dict + + def training_step(self, batch, batch_idx): + for k in self.ucg_training: + p = self.ucg_training[k]["p"] + val = self.ucg_training[k]["val"] + if val is None: + val = "" + for i in range(len(batch[k])): + if self.ucg_prng.choice(2, p=[1 - p, p]): + batch[k][i] = val + + loss, loss_dict = self.shared_step(batch) + + self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True) + + self.log("global_step", self.global_step, prog_bar=True, logger=True, on_step=True, on_epoch=False) + + if self.use_scheduler: + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) + + return loss + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + _, loss_dict_no_ema = self.shared_step(batch) + with self.ema_scope(): + _, loss_dict_ema = self.shared_step(batch) + loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} + self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self.model) + + def _get_rows_from_list(self, samples): + n_imgs_per_row = len(samples) + denoise_grid = rearrange(samples, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): + log = dict() + x = self.get_input(batch, self.first_stage_key) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + x = x.to(self.device)[:N] + log["inputs"] = x + + # get diffusion row + diffusion_row = list() + x_start = x[:n_row] + + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(x_start) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + diffusion_row.append(x_noisy) + + log["diffusion_row"] = self._get_rows_from_list(diffusion_row) + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, denoise_row = self.sample(batch_size=N, return_intermediates=True) + + log["samples"] = samples + log["denoise_row"] = self._get_rows_from_list(denoise_row) + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.learn_logvar: + params = params + [self.logvar] + opt = torch.optim.AdamW(params, lr=lr) + return opt + + +class LatentDiffusion(DDPM): + """main class""" + + def __init__(self, + first_stage_config, + cond_stage_config, + num_timesteps_cond=None, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + scale_by_std=False, + use_fp16=True, + force_null_conditioning=False, + *args, + **kwargs): + self.force_null_conditioning = force_null_conditioning + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs['timesteps'] + # for backwards compatibility after implementation of DiffusionWrapper + if conditioning_key is None: + conditioning_key = 'concat' if concat_mode else 'crossattn' + if cond_stage_config == '__is_unconditional__' and not self.force_null_conditioning: + conditioning_key = None + + super().__init__(conditioning_key=conditioning_key, *args, **kwargs) + self.concat_mode = concat_mode + self.cond_stage_trainable = cond_stage_trainable + self.cond_stage_key = cond_stage_key + try: + self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 + except: + self.num_downs = 0 + + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.register_buffer('scale_factor', torch.tensor(scale_factor)) + self.first_stage_config = first_stage_config + self.cond_stage_config = cond_stage_config + self.instantiate_first_stage(first_stage_config) + self.instantiate_cond_stage(cond_stage_config) + self.cond_stage_forward = cond_stage_forward + self.clip_denoised = False + self.bbox_tokenizer = None + ''' + Uncomment if you Use DDP Strategy + ''' + # self.restarted_from_ckpt = False + # if self.ckpt is not None: + # self.init_from_ckpt(self.ckpt, self.ignore_keys) + # self.restarted_from_ckpt = True + # if self.reset_ema: + # assert self.use_ema + # rank_zero_info( + # f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.") + # self.model_ema = LitEma(self.model) + if self.reset_num_ema_updates: + rank_zero_info(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ") + assert self.use_ema + self.model_ema.reset_num_updates() + + def configure_sharded_model(self) -> None: + rank_zero_info("Configure sharded model for LatentDiffusion") + self.model = DiffusionWrapper(self.unet_config, self.conditioning_key) + count_params(self.model, verbose=True) + if self.use_ema: + self.model_ema = LitEma(self.model) + + if self.ckpt is not None: + self.init_from_ckpt(self.ckpt, ignore_keys=self.ignore_keys, only_model=self.load_only_unet) + if self.reset_ema: + assert self.use_ema + rank_zero_info( + f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.") + self.model_ema = LitEma(self.model) + + self.register_schedule(given_betas=self.given_betas, + beta_schedule=self.beta_schedule, + timesteps=self.timesteps, + linear_start=self.linear_start, + linear_end=self.linear_end, + cosine_s=self.cosine_s) + + self.logvar = torch.full(fill_value=self.logvar_init, size=(self.num_timesteps,)) + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + if self.ucg_training: + self.ucg_prng = np.random.RandomState() + + self.instantiate_first_stage(self.first_stage_config) + self.instantiate_cond_stage(self.cond_stage_config) + if self.ckpt is not None: + self.init_from_ckpt(self.ckpt, self.ignore_keys) + self.restarted_from_ckpt = True + if self.reset_ema: + assert self.use_ema + rank_zero_info( + f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.") + self.model_ema = LitEma(self.model) + + def make_cond_schedule(self,): + self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) + ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() + self.cond_ids[:self.num_timesteps_cond] = ids + + @rank_zero_only + @torch.no_grad() + def on_train_batch_start(self, batch, batch_idx): + # only for very first batch + if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: + assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' + # set rescale weight to 1./std of encodings + rank_zero_info("### USING STD-RESCALING ###") + x = super().get_input(batch, self.first_stage_key) + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + del self.scale_factor + self.register_buffer('scale_factor', 1. / z.flatten().std()) + rank_zero_info(f"setting self.scale_factor to {self.scale_factor}") + rank_zero_info("### USING STD-RESCALING ###") + + def register_schedule(self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3): + super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def instantiate_first_stage(self, config): + model = AutoencoderKL(**config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + def instantiate_cond_stage(self, config): + if not self.cond_stage_trainable: + if config == "__is_first_stage__": + rank_zero_info("Using first stage also as cond stage.") + self.cond_stage_model = self.first_stage_model + elif config == "__is_unconditional__": + rank_zero_info(f"Training {self.__class__.__name__} as an unconditional model.") + self.cond_stage_model = None + # self.be_unconditional = True + else: + model = FrozenOpenCLIPEmbedder(**config) + self.cond_stage_model = model.eval() + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + else: + model = FrozenOpenCLIPEmbedder(**config) + self.cond_stage_model = model + + def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): + denoise_row = [] + for zd in tqdm(samples, desc=desc): + denoise_row.append( + self.decode_first_stage(zd.to(self.device), force_not_quantize=force_no_decoder_quantization)) + n_imgs_per_row = len(denoise_row) + denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W + denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") + return self.scale_factor * z.half() if self.use_fp16 else self.scale_factor * z + + def get_learned_conditioning(self, c): + if self.cond_stage_forward is None: + if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + else: + c = self.cond_stage_model(c) + else: + assert hasattr(self.cond_stage_model, self.cond_stage_forward) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + return c + + def meshgrid(self, h, w): + y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1) + x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1) + + arr = torch.cat([y, x], dim=-1) + return arr + + def delta_border(self, h, w): + """ + :param h: height + :param w: width + :return: normalized distance to image border, + wtith min distance = 0 at border and max dist = 0.5 at image center + """ + lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2) + arr = self.meshgrid(h, w) / lower_right_corner + dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] + dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] + edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0] + return edge_dist + + def get_weighting(self, h, w, Ly, Lx, device): + weighting = self.delta_border(h, w) + weighting = torch.clip( + weighting, + self.split_input_params["clip_min_weight"], + self.split_input_params["clip_max_weight"], + ) + weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) + + if self.split_input_params["tie_braker"]: + L_weighting = self.delta_border(Ly, Lx) + L_weighting = torch.clip(L_weighting, self.split_input_params["clip_min_tie_weight"], + self.split_input_params["clip_max_tie_weight"]) + + L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) + weighting = weighting * L_weighting + return weighting + + def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code + """ + :param x: img of size (bs, c, h, w) + :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) + """ + bs, nc, h, w = x.shape + + # number of crops in image + Ly = (h - kernel_size[0]) // stride[0] + 1 + Lx = (w - kernel_size[1]) // stride[1] + 1 + + if uf == 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) + + weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) + + elif uf > 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), + dilation=1, + padding=0, + stride=(stride[0] * uf, stride[1] * uf)) + fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)) + + elif df > 1 and uf == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df), + dilation=1, + padding=0, + stride=(stride[0] // df, stride[1] // df)) + fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)) + + else: + raise NotImplementedError + + return fold, unfold, normalization, weighting + + @torch.no_grad() + def get_input(self, + batch, + k, + return_first_stage_outputs=False, + force_c_encode=False, + cond_key=None, + return_original_cond=False, + bs=None, + return_x=False): + x = super().get_input(batch, k) + if bs is not None: + x = x[:bs] + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + + if self.model.conditioning_key is not None and not self.force_null_conditioning: + if cond_key is None: + cond_key = self.cond_stage_key + if cond_key != self.first_stage_key: + if cond_key in ['caption', 'coordinates_bbox', "txt"]: + xc = batch[cond_key] + elif cond_key in ['class_label', 'cls']: + xc = batch + else: + xc = super().get_input(batch, cond_key).to(self.device) + else: + xc = x + if not self.cond_stage_trainable or force_c_encode: + if isinstance(xc, dict) or isinstance(xc, list): + c = self.get_learned_conditioning(xc) + else: + c = self.get_learned_conditioning(xc.to(self.device)) + else: + c = xc + if bs is not None: + c = c[:bs] + + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + ckey = __conditioning_keys__[self.model.conditioning_key] + c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y} + + else: + c = None + xc = None + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + c = {'pos_x': pos_x, 'pos_y': pos_y} + out = [z, c] + if return_first_stage_outputs: + xrec = self.decode_first_stage(z) + out.extend([x, xrec]) + if return_x: + out.extend([x]) + if return_original_cond: + out.append(xc) + + return out + + @torch.no_grad() + def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1. / self.scale_factor * z + return self.first_stage_model.decode(z) + + @torch.no_grad() + def encode_first_stage(self, x): + return self.first_stage_model.encode(x) + + def shared_step(self, batch, **kwargs): + x, c = self.get_input(batch, self.first_stage_key) + loss = self(x, c) + return loss + + def forward(self, x, c, *args, **kwargs): + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + if self.model.conditioning_key is not None: + assert c is not None + if self.cond_stage_trainable: + c = self.get_learned_conditioning(c) + if self.shorten_cond_schedule: # TODO: drop this option + tc = self.cond_ids[t].to(self.device) + c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) + return self.p_losses(x, c, t, *args, **kwargs) + + def apply_model(self, x_noisy, t, cond, return_ids=False): + if isinstance(cond, dict): + # hybrid case, cond is expected to be a dict + pass + else: + if not isinstance(cond, list): + cond = [cond] + key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' + cond = {key: cond} + + x_recon = self.model(x_noisy, t, **cond) + + if isinstance(x_recon, tuple) and not return_ids: + return x_recon[0] + else: + return x_recon + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \ + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) + return mean_flat(kl_prior) / np.log(2.0) + + def p_losses(self, x_start, cond, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_output = self.apply_model(x_noisy, t, cond) + + loss_dict = {} + prefix = 'train' if self.training else 'val' + + if self.parameterization == "x0": + target = x_start + elif self.parameterization == "eps": + target = noise + elif self.parameterization == "v": + target = self.get_v(x_start, noise, t) + else: + raise NotImplementedError() + + loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) + loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) + + logvar_t = self.logvar[t].to(self.device) + + loss = loss_simple / torch.exp(logvar_t) + logvar_t + # loss = loss_simple / torch.exp(self.logvar) + self.logvar + if self.learn_logvar: + loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) + loss_dict.update({'logvar': self.logvar.data.mean()}) + + loss = self.l_simple_weight * loss.mean() + + loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) + loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() + loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) + loss += (self.original_elbo_weight * loss_vlb) + loss_dict.update({f'{prefix}/loss': loss}) + + return loss, loss_dict + + def p_mean_variance(self, + x, + c, + t, + clip_denoised: bool, + return_codebook_ids=False, + quantize_denoised=False, + return_x0=False, + score_corrector=None, + corrector_kwargs=None): + t_in = t + model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) + + if score_corrector is not None: + assert self.parameterization == "eps" + model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs) + + if return_codebook_ids: + model_out, logits = model_out + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + else: + raise NotImplementedError() + + if clip_denoised: + x_recon.clamp_(-1., 1.) + if quantize_denoised: + x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + if return_codebook_ids: + return model_mean, posterior_variance, posterior_log_variance, logits + elif return_x0: + return model_mean, posterior_variance, posterior_log_variance, x_recon + else: + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, + x, + c, + t, + clip_denoised=False, + repeat_noise=False, + return_codebook_ids=False, + quantize_denoised=False, + return_x0=False, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None): + b, *_, device = *x.shape, x.device + outputs = self.p_mean_variance(x=x, + c=c, + t=t, + clip_denoised=clip_denoised, + return_codebook_ids=return_codebook_ids, + quantize_denoised=quantize_denoised, + return_x0=return_x0, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs) + if return_codebook_ids: + raise DeprecationWarning("Support dropped.") + model_mean, _, model_log_variance, logits = outputs + elif return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + + if return_codebook_ids: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1) + if return_x0: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0 + else: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def progressive_denoising(self, + cond, + shape, + verbose=True, + callback=None, + quantize_denoised=False, + img_callback=None, + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + batch_size=None, + x_T=None, + start_T=None, + log_every_t=None): + if not log_every_t: + log_every_t = self.log_every_t + timesteps = self.num_timesteps + if batch_size is not None: + b = batch_size if batch_size is not None else shape[0] + shape = [batch_size] + list(shape) + else: + b = batch_size = shape[0] + if x_T is None: + img = torch.randn(shape, device=self.device) + else: + img = x_T + intermediates = [] + if cond is not None: + if isinstance(cond, dict): + cond = { + key: cond[key][:batch_size] if not isinstance(cond[key], list) else list( + map(lambda x: x[:batch_size], cond[key])) for key in cond + } + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation', + total=timesteps) if verbose else reversed(range(0, timesteps)) + if type(temperature) == float: + temperature = [temperature] * timesteps + + for i in iterator: + ts = torch.full((b,), i, device=self.device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img, x0_partial = self.p_sample(img, + cond, + ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, + return_x0=True, + temperature=temperature[i], + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs) + if mask is not None: + assert x0 is not None + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(x0_partial) + if callback: + callback(i) + if img_callback: + img_callback(img, i) + return img, intermediates + + @torch.no_grad() + def p_sample_loop(self, + cond, + shape, + return_intermediates=False, + x_T=None, + verbose=True, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + start_T=None, + log_every_t=None): + + if not log_every_t: + log_every_t = self.log_every_t + device = self.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + intermediates = [img] + if timesteps is None: + timesteps = self.num_timesteps + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed( + range(0, timesteps)) + + if mask is not None: + assert x0 is not None + assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + + for i in iterator: + ts = torch.full((b,), i, device=device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img = self.p_sample(img, cond, ts, clip_denoised=self.clip_denoised, quantize_denoised=quantize_denoised) + if mask is not None: + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(img) + if callback: + callback(i) + if img_callback: + img_callback(img, i) + + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, + cond, + batch_size=16, + return_intermediates=False, + x_T=None, + verbose=True, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + shape=None, + **kwargs): + if shape is None: + shape = (batch_size, self.channels, self.image_size, self.image_size) + if cond is not None: + if isinstance(cond, dict): + cond = { + key: cond[key][:batch_size] if not isinstance(cond[key], list) else list( + map(lambda x: x[:batch_size], cond[key])) for key in cond + } + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + return self.p_sample_loop(cond, + shape, + return_intermediates=return_intermediates, + x_T=x_T, + verbose=verbose, + timesteps=timesteps, + quantize_denoised=quantize_denoised, + mask=mask, + x0=x0) + + @torch.no_grad() + def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs): + if ddim: + ddim_sampler = DDIMSampler(self) + shape = (self.channels, self.image_size, self.image_size) + samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs) + + else: + samples, intermediates = self.sample(cond=cond, batch_size=batch_size, return_intermediates=True, **kwargs) + + return samples, intermediates + + @torch.no_grad() + def get_unconditional_conditioning(self, batch_size, null_label=None): + if null_label is not None: + xc = null_label + if isinstance(xc, ListConfig): + xc = list(xc) + if isinstance(xc, dict) or isinstance(xc, list): + c = self.get_learned_conditioning(xc) + else: + if hasattr(xc, "to"): + xc = xc.to(self.device) + c = self.get_learned_conditioning(xc) + else: + if self.cond_stage_key in ["class_label", "cls"]: + xc = self.cond_stage_model.get_unconditional_conditioning(batch_size, device=self.device) + return self.get_learned_conditioning(xc) + else: + raise NotImplementedError("todo") + if isinstance(c, list): # in case the encoder gives us a list + for i in range(len(c)): + c[i] = repeat(c[i], '1 ... -> b ...', b=batch_size).to(self.device) + else: + c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device) + return c + + @torch.no_grad() + def log_images(self, + batch, + N=8, + n_row=4, + sample=True, + ddim_steps=50, + ddim_eta=0., + return_keys=None, + quantize_denoised=True, + inpaint=True, + plot_denoise_rows=False, + plot_progressive_rows=True, + plot_diffusion_rows=True, + unconditional_guidance_scale=1., + unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc = self.get_input(batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=N) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) + log["conditioning"] = xc + elif self.cond_stage_key in ['class_label', "cls"]: + try: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) + log['conditioning'] = xc + except KeyError: + # probably no "human_label" in batch + pass + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log(cond=c, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance( + self.first_stage_model, IdentityFirstStage): + # also display when quantizing x0 while sampling + with ema_scope("Plotting Quantized Denoised"): + samples, z_denoise_row = self.sample_log(cond=c, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + quantize_denoised=True) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, + # quantize_denoised=True) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_x0_quantized"] = x_samples + + if unconditional_guidance_scale > 1.0: + uc = self.get_unconditional_conditioning(N, unconditional_guidance_label) + if self.model.conditioning_key == "crossattn-adm": + uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]} + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log( + cond=c, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + + if inpaint: + # make a simple center square + b, h, w = z.shape[0], z.shape[2], z.shape[3] + mask = torch.ones(N, h, w).to(self.device) + # zeros will be filled in + mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. + mask = mask[:, None, ...] + with ema_scope("Plotting Inpaint"): + samples, _ = self.sample_log(cond=c, + batch_size=N, + ddim=use_ddim, + eta=ddim_eta, + ddim_steps=ddim_steps, + x0=z[:N], + mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_inpainting"] = x_samples + log["mask"] = mask + + # outpaint + mask = 1. - mask + with ema_scope("Plotting Outpaint"): + samples, _ = self.sample_log(cond=c, + batch_size=N, + ddim=use_ddim, + eta=ddim_eta, + ddim_steps=ddim_steps, + x0=z[:N], + mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_outpainting"] = x_samples + + if plot_progressive_rows: + with ema_scope("Plotting Progressives"): + img, progressives = self.progressive_denoising(c, + shape=(self.channels, self.image_size, self.image_size), + batch_size=N) + prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") + log["progressive_row"] = prog_row + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.cond_stage_trainable: + rank_zero_info(f"{self.__class__.__name__}: Also optimizing conditioner params!") + params = params + list(self.cond_stage_model.parameters()) + if self.learn_logvar: + rank_zero_info('Diffusion model optimizing logvar') + params.append(self.logvar) + + from colossalai.nn.optimizer import HybridAdam + opt = HybridAdam(params, lr=lr) + + # opt = torch.optim.AdamW(params, lr=lr) + if self.use_scheduler: + scheduler = LambdaLinearScheduler(**self.scheduler_config) + + rank_zero_info("Setting up LambdaLR scheduler...") + scheduler = [{'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), 'interval': 'step', 'frequency': 1}] + return [opt], scheduler + return opt + + @torch.no_grad() + def to_rgb(self, x): + x = x.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = nn.functional.conv2d(x, weight=self.colorize) + x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + return x + + +class DiffusionWrapper(pl.LightningModule): + + def __init__(self, diff_model_config, conditioning_key): + super().__init__() + self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False) + self.diffusion_model = UNetModel(**diff_model_config) + self.conditioning_key = conditioning_key + assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm'] + + def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None): + if self.conditioning_key is None: + out = self.diffusion_model(x, t) + elif self.conditioning_key == 'concat': + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t) + elif self.conditioning_key == 'crossattn': + if not self.sequential_cross_attn: + cc = torch.cat(c_crossattn, 1) + else: + cc = c_crossattn + out = self.diffusion_model(x, t, context=cc) + elif self.conditioning_key == 'hybrid': + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc) + elif self.conditioning_key == 'hybrid-adm': + assert c_adm is not None + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc, y=c_adm) + elif self.conditioning_key == 'crossattn-adm': + assert c_adm is not None + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(x, t, context=cc, y=c_adm) + elif self.conditioning_key == 'adm': + cc = c_crossattn[0] + out = self.diffusion_model(x, t, y=cc) + else: + raise NotImplementedError() + + return out + + +class LatentUpscaleDiffusion(LatentDiffusion): + + def __init__(self, *args, low_scale_config, low_scale_key="LR", noise_level_key=None, **kwargs): + super().__init__(*args, **kwargs) + # assumes that neither the cond_stage nor the low_scale_model contain trainable params + assert not self.cond_stage_trainable + self.instantiate_low_stage(low_scale_config) + self.low_scale_key = low_scale_key + self.noise_level_key = noise_level_key + + def instantiate_low_stage(self, config): + model = ImageConcatWithNoiseAugmentation(**config) + self.low_scale_model = model.eval() + self.low_scale_model.train = disabled_train + for param in self.low_scale_model.parameters(): + param.requires_grad = False + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False): + if not log_mode: + z, c = super().get_input(batch, k, force_c_encode=True, bs=bs) + else: + z, c, x, xrec, xc = super().get_input(batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=bs) + x_low = batch[self.low_scale_key][:bs] + x_low = rearrange(x_low, 'b h w c -> b c h w') + if self.use_fp16: + x_low = x_low.to(memory_format=torch.contiguous_format).half() + else: + x_low = x_low.to(memory_format=torch.contiguous_format).float() + zx, noise_level = self.low_scale_model(x_low) + if self.noise_level_key is not None: + # get noise level from batch instead, e.g. when extracting a custom noise level for bsr + raise NotImplementedError('TODO') + + all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level} + if log_mode: + # TODO: maybe disable if too expensive + x_low_rec = self.low_scale_model.decode(zx) + return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level + return z, all_conds + + @torch.no_grad() + def log_images(self, + batch, + N=8, + n_row=4, + sample=True, + ddim_steps=200, + ddim_eta=1., + return_keys=None, + plot_denoise_rows=False, + plot_progressive_rows=True, + plot_diffusion_rows=True, + unconditional_guidance_scale=1., + unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, + self.first_stage_key, + bs=N, + log_mode=True) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + log["x_lr"] = x_low + log[f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"] = x_low_rec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) + log["conditioning"] = xc + elif self.cond_stage_key in ['class_label', 'cls']: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log(cond=c, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if unconditional_guidance_scale > 1.0: + uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label) + # TODO explore better "unconditional" choices for the other keys + # maybe guide away from empty text label and highest noise level and maximally degraded zx? + uc = dict() + for k in c: + if k == "c_crossattn": + assert isinstance(c[k], list) and len(c[k]) == 1 + uc[k] = [uc_tmp] + elif k == "c_adm": # todo: only run with text-based guidance? + assert isinstance(c[k], torch.Tensor) + #uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level + uc[k] = c[k] + elif isinstance(c[k], list): + uc[k] = [c[k][i] for i in range(len(c[k]))] + else: + uc[k] = c[k] + + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log( + cond=c, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + + if plot_progressive_rows: + with ema_scope("Plotting Progressives"): + img, progressives = self.progressive_denoising(c, + shape=(self.channels, self.image_size, self.image_size), + batch_size=N) + prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") + log["progressive_row"] = prog_row + + return log + + +class LatentFinetuneDiffusion(LatentDiffusion): + """ + Basis for different finetunas, such as inpainting or depth2image + To disable finetuning mode, set finetune_keys to None + """ + + def __init__( + self, + concat_keys: tuple, + finetune_keys=("model.diffusion_model.input_blocks.0.0.weight", + "model_ema.diffusion_modelinput_blocks00weight"), + keep_finetune_dims=4, + # if model was trained without concat mode before and we would like to keep these channels + c_concat_log_start=None, # to log reconstruction of c_concat codes + c_concat_log_end=None, + *args, + **kwargs): + ckpt = kwargs.pop("ckpt", None) + ignore_keys = kwargs.pop("ignore_keys", list()) + super().__init__(*args, **kwargs) + self.finetune_keys = finetune_keys + self.concat_keys = concat_keys + self.keep_dims = keep_finetune_dims + self.c_concat_log_start = c_concat_log_start + self.c_concat_log_end = c_concat_log_end + if exists(self.finetune_keys): + assert exists(ckpt), 'can only finetune from a given checkpoint' + if exists(ckpt): + self.init_from_ckpt(ckpt, ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + rank_zero_info("Deleting key {} from state_dict.".format(k)) + del sd[k] + + # make it explicit, finetune by including extra input channels + if exists(self.finetune_keys) and k in self.finetune_keys: + new_entry = None + for name, param in self.named_parameters(): + if name in self.finetune_keys: + rank_zero_info( + f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only" + ) + new_entry = torch.zeros_like(param) # zero init + assert exists(new_entry), 'did not find matching parameter to modify' + new_entry[:, :self.keep_dims, ...] = sd[k] + sd[k] = new_entry + + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + rank_zero_info(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + rank_zero_info(f"Missing Keys: {missing}") + if len(unexpected) > 0: + rank_zero_info(f"Unexpected Keys: {unexpected}") + + @torch.no_grad() + def log_images(self, + batch, + N=8, + n_row=4, + sample=True, + ddim_steps=200, + ddim_eta=1., + return_keys=None, + quantize_denoised=True, + inpaint=True, + plot_denoise_rows=False, + plot_progressive_rows=True, + plot_diffusion_rows=True, + unconditional_guidance_scale=1., + unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True) + c_cat, c = c["c_concat"][0], c["c_crossattn"][0] + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) + log["conditioning"] = xc + elif self.cond_stage_key in ['class_label', 'cls']: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if not (self.c_concat_log_start is None and self.c_concat_log_end is None): + log["c_concat_decoded"] = self.decode_first_stage(c_cat[:, self.c_concat_log_start:self.c_concat_log_end]) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log(cond={ + "c_concat": [c_cat], + "c_crossattn": [c] + }, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if unconditional_guidance_scale > 1.0: + uc_cross = self.get_unconditional_conditioning(N, unconditional_guidance_label) + uc_cat = c_cat + uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]} + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log( + cond={ + "c_concat": [c_cat], + "c_crossattn": [c] + }, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc_full, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + + return log + + +class LatentInpaintDiffusion(LatentFinetuneDiffusion): + """ + can either run as pure inpainting model (only concat mode) or with mixed conditionings, + e.g. mask as concat and text via cross-attn. + To disable finetuning mode, set finetune_keys to None + """ + + def __init__(self, concat_keys=("mask", "masked_image"), masked_image_key="masked_image", *args, **kwargs): + super().__init__(concat_keys, *args, **kwargs) + self.masked_image_key = masked_image_key + assert self.masked_image_key in concat_keys + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): + # note: restricted to non-trainable encoders currently + assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting' + z, c, x, xrec, xc = super().get_input(batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=bs) + + assert exists(self.concat_keys) + c_cat = list() + for ck in self.concat_keys: + if self.use_fp16: + cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).half() + else: + cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float() + if bs is not None: + cc = cc[:bs] + cc = cc.to(self.device) + bchw = z.shape + if ck != self.masked_image_key: + cc = torch.nn.functional.interpolate(cc, size=bchw[-2:]) + else: + cc = self.get_first_stage_encoding(self.encode_first_stage(cc)) + c_cat.append(cc) + c_cat = torch.cat(c_cat, dim=1) + all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} + if return_first_stage_outputs: + return z, all_conds, x, xrec, xc + return z, all_conds + + @torch.no_grad() + def log_images(self, *args, **kwargs): + log = super(LatentInpaintDiffusion, self).log_images(*args, **kwargs) + log["masked_image"] = rearrange(args[0]["masked_image"], + 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float() + return log + + +class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion): + """ + condition on monocular depth estimation + """ + + def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs): + super().__init__(concat_keys=concat_keys, *args, **kwargs) + self.depth_model = MiDaSInference(**depth_stage_config) + self.depth_stage_key = concat_keys[0] + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): + # note: restricted to non-trainable encoders currently + assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for depth2img' + z, c, x, xrec, xc = super().get_input(batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=bs) + + assert exists(self.concat_keys) + assert len(self.concat_keys) == 1 + c_cat = list() + for ck in self.concat_keys: + cc = batch[ck] + if bs is not None: + cc = cc[:bs] + cc = cc.to(self.device) + cc = self.depth_model(cc) + cc = torch.nn.functional.interpolate( + cc, + size=z.shape[2:], + mode="bicubic", + align_corners=False, + ) + + depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, + dim=[1, 2, 3], + keepdim=True) + cc = 2. * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1. + c_cat.append(cc) + c_cat = torch.cat(c_cat, dim=1) + all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} + if return_first_stage_outputs: + return z, all_conds, x, xrec, xc + return z, all_conds + + @torch.no_grad() + def log_images(self, *args, **kwargs): + log = super().log_images(*args, **kwargs) + depth = self.depth_model(args[0][self.depth_stage_key]) + depth_min, depth_max = torch.amin(depth, dim=[1, 2, 3], keepdim=True), \ + torch.amax(depth, dim=[1, 2, 3], keepdim=True) + log["depth"] = 2. * (depth - depth_min) / (depth_max - depth_min) - 1. + return log + + +class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion): + """ + condition on low-res image (and optionally on some spatial noise augmentation) + """ + + def __init__(self, + concat_keys=("lr",), + reshuffle_patch_size=None, + low_scale_config=None, + low_scale_key=None, + *args, + **kwargs): + super().__init__(concat_keys=concat_keys, *args, **kwargs) + self.reshuffle_patch_size = reshuffle_patch_size + self.low_scale_model = None + if low_scale_config is not None: + rank_zero_info("Initializing a low-scale model") + assert exists(low_scale_key) + self.instantiate_low_stage(low_scale_config) + self.low_scale_key = low_scale_key + + def instantiate_low_stage(self, config): + model = ImageConcatWithNoiseAugmentation(**config) + self.low_scale_model = model.eval() + self.low_scale_model.train = disabled_train + for param in self.low_scale_model.parameters(): + param.requires_grad = False + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): + # note: restricted to non-trainable encoders currently + assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for upscaling-ft' + z, c, x, xrec, xc = super().get_input(batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=bs) + + assert exists(self.concat_keys) + assert len(self.concat_keys) == 1 + # optionally make spatial noise_level here + c_cat = list() + noise_level = None + for ck in self.concat_keys: + cc = batch[ck] + cc = rearrange(cc, 'b h w c -> b c h w') + if exists(self.reshuffle_patch_size): + assert isinstance(self.reshuffle_patch_size, int) + cc = rearrange(cc, + 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w', + p1=self.reshuffle_patch_size, + p2=self.reshuffle_patch_size) + if bs is not None: + cc = cc[:bs] + cc = cc.to(self.device) + if exists(self.low_scale_model) and ck == self.low_scale_key: + cc, noise_level = self.low_scale_model(cc) + c_cat.append(cc) + c_cat = torch.cat(c_cat, dim=1) + if exists(noise_level): + all_conds = {"c_concat": [c_cat], "c_crossattn": [c], "c_adm": noise_level} + else: + all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} + if return_first_stage_outputs: + return z, all_conds, x, xrec, xc + return z, all_conds + + @torch.no_grad() + def log_images(self, *args, **kwargs): + log = super().log_images(*args, **kwargs) + log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w') + return log diff --git a/examples/images/diffusion/ldm/models/diffusion/dpm_solver/__init__.py b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7427f38c07530afbab79154ea8aaf88c4bf70a08 --- /dev/null +++ b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/__init__.py @@ -0,0 +1 @@ +from .sampler import DPMSolverSampler \ No newline at end of file diff --git a/examples/images/diffusion/ldm/models/diffusion/dpm_solver/dpm_solver.py b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/dpm_solver.py new file mode 100644 index 0000000000000000000000000000000000000000..095e5ba3ce0b1aa7f4b3f1e2e5d8fff7cfe6dc8c --- /dev/null +++ b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/dpm_solver.py @@ -0,0 +1,1154 @@ +import torch +import torch.nn.functional as F +import math +from tqdm import tqdm + + +class NoiseScheduleVP: + def __init__( + self, + schedule='discrete', + betas=None, + alphas_cumprod=None, + continuous_beta_0=0.1, + continuous_beta_1=20., + ): + """Create a wrapper class for the forward SDE (VP type). + *** + Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. + We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. + *** + The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). + We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). + Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: + log_alpha_t = self.marginal_log_mean_coeff(t) + sigma_t = self.marginal_std(t) + lambda_t = self.marginal_lambda(t) + Moreover, as lambda(t) is an invertible function, we also support its inverse function: + t = self.inverse_lambda(lambda_t) + =============================================================== + We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). + 1. For discrete-time DPMs: + For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: + t_i = (i + 1) / N + e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. + We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. + Args: + betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) + alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) + Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. + **Important**: Please pay special attention for the args for `alphas_cumprod`: + The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that + q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). + Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have + alpha_{t_n} = \sqrt{\hat{alpha_n}}, + and + log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). + 2. For continuous-time DPMs: + We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise + schedule are the default settings in DDPM and improved-DDPM: + Args: + beta_min: A `float` number. The smallest beta for the linear schedule. + beta_max: A `float` number. The largest beta for the linear schedule. + cosine_s: A `float` number. The hyperparameter in the cosine schedule. + cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule. + T: A `float` number. The ending time of the forward process. + =============================================================== + Args: + schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, + 'linear' or 'cosine' for continuous-time DPMs. + Returns: + A wrapper object of the forward SDE (VP type). + + =============================================================== + Example: + # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', betas=betas) + # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + # For continuous-time DPMs (VPSDE), linear schedule: + >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) + """ + + if schedule not in ['discrete', 'linear', 'cosine']: + raise ValueError( + "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format( + schedule)) + + self.schedule = schedule + if schedule == 'discrete': + if betas is not None: + log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) + else: + assert alphas_cumprod is not None + log_alphas = 0.5 * torch.log(alphas_cumprod) + self.total_N = len(log_alphas) + self.T = 1. + self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)) + self.log_alpha_array = log_alphas.reshape((1, -1,)) + else: + self.total_N = 1000 + self.beta_0 = continuous_beta_0 + self.beta_1 = continuous_beta_1 + self.cosine_s = 0.008 + self.cosine_beta_max = 999. + self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * ( + 1. + self.cosine_s) / math.pi - self.cosine_s + self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.)) + self.schedule = schedule + if schedule == 'cosine': + # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. + # Note that T = 0.9946 may be not the optimal setting. However, we find it works well. + self.T = 0.9946 + else: + self.T = 1. + + def marginal_log_mean_coeff(self, t): + """ + Compute log(alpha_t) of a given continuous-time label t in [0, T]. + """ + if self.schedule == 'discrete': + return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), + self.log_alpha_array.to(t.device)).reshape((-1)) + elif self.schedule == 'linear': + return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + elif self.schedule == 'cosine': + log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.)) + log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 + return log_alpha_t + + def marginal_alpha(self, t): + """ + Compute alpha_t of a given continuous-time label t in [0, T]. + """ + return torch.exp(self.marginal_log_mean_coeff(t)) + + def marginal_std(self, t): + """ + Compute sigma_t of a given continuous-time label t in [0, T]. + """ + return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) + + def marginal_lambda(self, t): + """ + Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. + """ + log_mean_coeff = self.marginal_log_mean_coeff(t) + log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) + return log_mean_coeff - log_std + + def inverse_lambda(self, lamb): + """ + Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. + """ + if self.schedule == 'linear': + tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + Delta = self.beta_0 ** 2 + tmp + return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) + elif self.schedule == 'discrete': + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb) + t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), + torch.flip(self.t_array.to(lamb.device), [1])) + return t.reshape((-1,)) + else: + log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * ( + 1. + self.cosine_s) / math.pi - self.cosine_s + t = t_fn(log_alpha) + return t + + +def model_wrapper( + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + condition=None, + unconditional_condition=None, + guidance_scale=1., + classifier_fn=None, + classifier_kwargs={}, +): + """Create a wrapper function for the noise prediction model. + DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to + firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. + We support four types of the diffusion model by setting `model_type`: + 1. "noise": noise prediction model. (Trained by predicting noise). + 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). + 3. "v": velocity prediction model. (Trained by predicting the velocity). + The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2]. + [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." + arXiv preprint arXiv:2202.00512 (2022). + [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." + arXiv preprint arXiv:2210.02303 (2022). + + 4. "score": marginal score function. (Trained by denoising score matching). + Note that the score function and the noise prediction model follows a simple relationship: + ``` + noise(x_t, t) = -sigma_t * score(x_t, t) + ``` + We support three types of guided sampling by DPMs by setting `guidance_type`: + 1. "uncond": unconditional sampling by DPMs. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + The input `classifier_fn` has the following format: + `` + classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) + `` + [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," + in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. + 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. + The input `model` has the following format: + `` + model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score + `` + And if cond == `unconditional_condition`, the model output is the unconditional DPM output. + [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." + arXiv preprint arXiv:2207.12598 (2022). + + The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) + or continuous-time labels (i.e. epsilon to T). + We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise: + `` + def model_fn(x, t_continuous) -> noise: + t_input = get_model_input_time(t_continuous) + return noise_pred(model, x, t_input, **model_kwargs) + `` + where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver. + =============================================================== + Args: + model: A diffusion model with the corresponding format described above. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + model_type: A `str`. The parameterization type of the diffusion model. + "noise" or "x_start" or "v" or "score". + model_kwargs: A `dict`. A dict for the other inputs of the model function. + guidance_type: A `str`. The type of the guidance for sampling. + "uncond" or "classifier" or "classifier-free". + condition: A pytorch tensor. The condition for the guided sampling. + Only used for "classifier" or "classifier-free" guidance type. + unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. + Only used for "classifier-free" guidance type. + guidance_scale: A `float`. The scale for the guided sampling. + classifier_fn: A classifier function. Only used for the classifier guidance. + classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. + Returns: + A noise prediction model that accepts the noised data and the continuous time as the inputs. + """ + + def get_model_input_time(t_continuous): + """ + Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. + For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. + For continuous-time DPMs, we just use `t_continuous`. + """ + if noise_schedule.schedule == 'discrete': + return (t_continuous - 1. / noise_schedule.total_N) * 1000. + else: + return t_continuous + + def noise_pred_fn(x, t_continuous, cond=None): + if t_continuous.reshape((-1,)).shape[0] == 1: + t_continuous = t_continuous.expand((x.shape[0])) + t_input = get_model_input_time(t_continuous) + if cond is None: + output = model(x, t_input, **model_kwargs) + else: + output = model(x, t_input, cond, **model_kwargs) + if model_type == "noise": + return output + elif model_type == "x_start": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims) + elif model_type == "v": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x + elif model_type == "score": + sigma_t = noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return -expand_dims(sigma_t, dims) * output + + def cond_grad_fn(x, t_input): + """ + Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). + """ + with torch.enable_grad(): + x_in = x.detach().requires_grad_(True) + log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) + return torch.autograd.grad(log_prob.sum(), x_in)[0] + + def model_fn(x, t_continuous): + """ + The noise predicition model function that is used for DPM-Solver. + """ + if t_continuous.reshape((-1,)).shape[0] == 1: + t_continuous = t_continuous.expand((x.shape[0])) + if guidance_type == "uncond": + return noise_pred_fn(x, t_continuous) + elif guidance_type == "classifier": + assert classifier_fn is not None + t_input = get_model_input_time(t_continuous) + cond_grad = cond_grad_fn(x, t_input) + sigma_t = noise_schedule.marginal_std(t_continuous) + noise = noise_pred_fn(x, t_continuous) + return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad + elif guidance_type == "classifier-free": + if guidance_scale == 1. or unconditional_condition is None: + return noise_pred_fn(x, t_continuous, cond=condition) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) + return noise_uncond + guidance_scale * (noise - noise_uncond) + + assert model_type in ["noise", "x_start", "v"] + assert guidance_type in ["uncond", "classifier", "classifier-free"] + return model_fn + + +class DPM_Solver: + def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.): + """Construct a DPM-Solver. + We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0"). + If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver). + If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++). + In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True. + The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales. + Args: + model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]): + `` + def model_fn(x, t_continuous): + return noise + `` + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model. + thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1]. + max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding. + + [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b. + """ + self.model = model_fn + self.noise_schedule = noise_schedule + self.predict_x0 = predict_x0 + self.thresholding = thresholding + self.max_val = max_val + + def noise_prediction_fn(self, x, t): + """ + Return the noise prediction model. + """ + return self.model(x, t) + + def data_prediction_fn(self, x, t): + """ + Return the data prediction model (with thresholding). + """ + noise = self.noise_prediction_fn(x, t) + dims = x.dim() + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims) + if self.thresholding: + p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims) + x0 = torch.clamp(x0, -s, s) / s + return x0 + + def model_fn(self, x, t): + """ + Convert the model to the noise prediction model or the data prediction model. + """ + if self.predict_x0: + return self.data_prediction_fn(x, t) + else: + return self.noise_prediction_fn(x, t) + + def get_time_steps(self, skip_type, t_T, t_0, N, device): + """Compute the intermediate time steps for sampling. + Args: + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + N: A `int`. The total number of the spacing of the time steps. + device: A torch device. + Returns: + A pytorch tensor of the time steps, with the shape (N + 1,). + """ + if skip_type == 'logSNR': + lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) + lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) + logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) + return self.noise_schedule.inverse_lambda(logSNR_steps) + elif skip_type == 'time_uniform': + return torch.linspace(t_T, t_0, N + 1).to(device) + elif skip_type == 'time_quadratic': + t_order = 2 + t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device) + return t + else: + raise ValueError( + "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) + + def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): + """ + Get the order of each step for sampling by the singlestep DPM-Solver. + We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast". + Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is: + - If order == 1: + We take `steps` of DPM-Solver-1 (i.e. DDIM). + - If order == 2: + - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of DPM-Solver-2. + - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If order == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2. + ============================================ + Args: + order: A `int`. The max order for the solver (2 or 3). + steps: A `int`. The total number of function evaluations (NFE). + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + device: A torch device. + Returns: + orders: A list of the solver order of each step. + """ + if order == 3: + K = steps // 3 + 1 + if steps % 3 == 0: + orders = [3, ] * (K - 2) + [2, 1] + elif steps % 3 == 1: + orders = [3, ] * (K - 1) + [1] + else: + orders = [3, ] * (K - 1) + [2] + elif order == 2: + if steps % 2 == 0: + K = steps // 2 + orders = [2, ] * K + else: + K = steps // 2 + 1 + orders = [2, ] * (K - 1) + [1] + elif order == 1: + K = 1 + orders = [1, ] * steps + else: + raise ValueError("'order' must be '1' or '2' or '3'.") + if skip_type == 'logSNR': + # To reproduce the results in DPM-Solver paper + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) + else: + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[ + torch.cumsum(torch.tensor([0, ] + orders)).to(device)] + return timesteps_outer, orders + + def denoise_to_zero_fn(self, x, s): + """ + Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. + """ + return self.data_prediction_fn(x, s) + + def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False): + """ + DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + if self.predict_x0: + phi_1 = torch.expm1(-h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + ) + if return_intermediate: + return x_t, {'model_s': model_s} + else: + return x_t + else: + phi_1 = torch.expm1(h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + ) + if return_intermediate: + return x_t, {'model_s': model_s} + else: + return x_t + + def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, + solver_type='dpm_solver'): + """ + Singlestep solver DPM-Solver-2 from time `s` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + r1: A `float`. The hyperparameter of the second-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpm_solver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 0.5 + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + s1 = ns.inverse_lambda(lambda_s1) + log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff( + s1), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t) + alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t) + + if self.predict_x0: + phi_11 = torch.expm1(-r1 * h) + phi_1 = torch.expm1(-h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = ( + expand_dims(sigma_s1 / sigma_s, dims) * x + - expand_dims(alpha_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s) + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * ( + model_s1 - model_s) + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_1 = torch.expm1(h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = ( + expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x + - expand_dims(sigma_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s) + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s) + ) + if return_intermediate: + return x_t, {'model_s': model_s, 'model_s1': model_s1} + else: + return x_t + + def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None, + return_intermediate=False, solver_type='dpm_solver'): + """ + Singlestep solver DPM-Solver-3 from time `s` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + r1: A `float`. The hyperparameter of the third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`). + If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpm_solver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 1. / 3. + if r2 is None: + r2 = 2. / 3. + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + lambda_s2 = lambda_s + r2 * h + s1 = ns.inverse_lambda(lambda_s1) + s2 = ns.inverse_lambda(lambda_s2) + log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff( + s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std( + s2), ns.marginal_std(t) + alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t) + + if self.predict_x0: + phi_11 = torch.expm1(-r1 * h) + phi_12 = torch.expm1(-r2 * h) + phi_1 = torch.expm1(-h) + phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1. + phi_2 = phi_1 / h + 1. + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = ( + expand_dims(sigma_s1 / sigma_s, dims) * x + - expand_dims(alpha_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + expand_dims(sigma_s2 / sigma_s, dims) * x + - expand_dims(alpha_s2 * phi_12, dims) * model_s + + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s) + ) + elif solver_type == 'taylor': + D1_0 = (1. / r1) * (model_s1 - model_s) + D1_1 = (1. / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + expand_dims(alpha_t * phi_2, dims) * D1 + - expand_dims(alpha_t * phi_3, dims) * D2 + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_12 = torch.expm1(r2 * h) + phi_1 = torch.expm1(h) + phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1. + phi_2 = phi_1 / h - 1. + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = ( + expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x + - expand_dims(sigma_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x + - expand_dims(sigma_s2 * phi_12, dims) * model_s + - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s) + ) + elif solver_type == 'taylor': + D1_0 = (1. / r1) * (model_s1 - model_s) + D1_1 = (1. / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - expand_dims(sigma_t * phi_2, dims) * D1 + - expand_dims(sigma_t * phi_3, dims) * D2 + ) + + if return_intermediate: + return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2} + else: + return x_t + + def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"): + """ + Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpm_solver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + ns = self.noise_schedule + dims = x.dim() + model_prev_1, model_prev_0 = model_prev_list + t_prev_1, t_prev_0 = t_prev_list + lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda( + t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0 = h_0 / h + D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) + if self.predict_x0: + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 + - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0 + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 + + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0 + ) + else: + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 + - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0 + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 + - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0 + ) + return x_t + + def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'): + """ + Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + dims = x.dim() + model_prev_2, model_prev_1, model_prev_0 = model_prev_list + t_prev_2, t_prev_1, t_prev_0 = t_prev_list + lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda( + t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_1 = lambda_prev_1 - lambda_prev_2 + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0, r1 = h_0 / h, h_1 / h + D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) + D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2) + D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1) + D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1) + if self.predict_x0: + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 + + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1 + - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2 + ) + else: + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 + - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1 + - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2 + ) + return x_t + + def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None, + r2=None): + """ + Singlestep DPM-Solver with the order `order` from time `s` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + r1: A `float`. The hyperparameter of the second-order or third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) + elif order == 2: + return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, + solver_type=solver_type, r1=r1) + elif order == 3: + return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, + solver_type=solver_type, r1=r1, r2=r2) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'): + """ + Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1]) + elif order == 2: + return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + elif order == 3: + return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, + solver_type='dpm_solver'): + """ + The adaptive step size solver based on singlestep DPM-Solver. + Args: + x: A pytorch tensor. The initial value at time `t_T`. + order: A `int`. The (higher) order of the solver. We only support order == 2 or 3. + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + h_init: A `float`. The initial step size (for logSNR). + atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1]. + rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05. + theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1]. + t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the + current time and `t_0` is less than `t_err`. The default setting is 1e-5. + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_0: A pytorch tensor. The approximated solution at time `t_0`. + [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021. + """ + ns = self.noise_schedule + s = t_T * torch.ones((x.shape[0],)).to(x) + lambda_s = ns.marginal_lambda(s) + lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x)) + h = h_init * torch.ones_like(s).to(x) + x_prev = x + nfe = 0 + if order == 2: + r1 = 0.5 + lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, + solver_type=solver_type, + **kwargs) + elif order == 3: + r1, r2 = 1. / 3., 2. / 3. + lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, + return_intermediate=True, + solver_type=solver_type) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, + solver_type=solver_type, + **kwargs) + else: + raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order)) + while torch.abs((s - t_0)).mean() > t_err: + t = ns.inverse_lambda(lambda_s + h) + x_lower, lower_noise_kwargs = lower_update(x, s, t) + x_higher = higher_update(x, s, t, **lower_noise_kwargs) + delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev))) + norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) + E = norm_fn((x_higher - x_lower) / delta).max() + if torch.all(E <= 1.): + x = x_higher + s = t + x_prev = x_lower + lambda_s = ns.marginal_lambda(s) + h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s) + nfe += order + print('adaptive solver nfe', nfe) + return x + + def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform', + method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver', + atol=0.0078, rtol=0.05, + ): + """ + Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. + ===================================================== + We support the following algorithms for both noise prediction model and data prediction model: + - 'singlestep': + Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver. + We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps). + The total number of function evaluations (NFE) == `steps`. + Given a fixed NFE == `steps`, the sampling procedure is: + - If `order` == 1: + - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2. + - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If `order` == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2. + - 'multistep': + Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`. + We initialize the first `order` values by lower order multistep solvers. + Given a fixed NFE == `steps`, the sampling procedure is: + Denote K = steps. + - If `order` == 1: + - We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2. + - If `order` == 3: + - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3. + - 'singlestep_fixed': + Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3). + We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE. + - 'adaptive': + Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper). + We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`. + You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs + (NFE) and the sample quality. + - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2. + - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3. + ===================================================== + Some advices for choosing the algorithm: + - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs: + Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False) + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + - For **guided sampling with large guidance scale** by DPMs: + Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True) + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2, + skip_type='time_uniform', method='multistep') + We support three types of `skip_type`: + - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images** + - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**. + - 'time_quadratic': quadratic time for the time steps. + ===================================================== + Args: + x: A pytorch tensor. The initial value at time `t_start` + e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution. + steps: A `int`. The total number of function evaluations (NFE). + t_start: A `float`. The starting time of the sampling. + If `T` is None, we use self.noise_schedule.T (default is 1.0). + t_end: A `float`. The ending time of the sampling. + If `t_end` is None, we use 1. / self.noise_schedule.total_N. + e.g. if total_N == 1000, we have `t_end` == 1e-3. + For discrete-time DPMs: + - We recommend `t_end` == 1. / self.noise_schedule.total_N. + For continuous-time DPMs: + - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15. + order: A `int`. The order of DPM-Solver. + skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'. + method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'. + denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step. + Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1). + This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and + score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID + for diffusion models sampling by diffusion SDEs for low-resolutional images + (such as CIFAR-10). However, we observed that such trick does not matter for + high-resolutional images. As it needs an additional NFE, we do not recommend + it for high-resolutional images. + lower_order_final: A `bool`. Whether to use lower order solvers at the final steps. + Only valid for `method=multistep` and `steps < 15`. We empirically find that + this trick is a key to stabilizing the sampling by DPM-Solver with very few steps + (especially for steps <= 10). So we recommend to set it to be `True`. + solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`. + atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + Returns: + x_end: A pytorch tensor. The approximated solution at time `t_end`. + """ + t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end + t_T = self.noise_schedule.T if t_start is None else t_start + device = x.device + if method == 'adaptive': + with torch.no_grad(): + x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, + solver_type=solver_type) + elif method == 'multistep': + assert steps >= order + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + assert timesteps.shape[0] - 1 == steps + with torch.no_grad(): + vec_t = timesteps[0].expand((x.shape[0])) + model_prev_list = [self.model_fn(x, vec_t)] + t_prev_list = [vec_t] + # Init the first `order` values by lower order multistep DPM-Solver. + for init_order in tqdm(range(1, order), desc="DPM init order"): + vec_t = timesteps[init_order].expand(x.shape[0]) + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order, + solver_type=solver_type) + model_prev_list.append(self.model_fn(x, vec_t)) + t_prev_list.append(vec_t) + # Compute the remaining values by `order`-th order multistep DPM-Solver. + for step in tqdm(range(order, steps + 1), desc="DPM multistep"): + vec_t = timesteps[step].expand(x.shape[0]) + if lower_order_final and steps < 15: + step_order = min(order, steps + 1 - step) + else: + step_order = order + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order, + solver_type=solver_type) + for i in range(order - 1): + t_prev_list[i] = t_prev_list[i + 1] + model_prev_list[i] = model_prev_list[i + 1] + t_prev_list[-1] = vec_t + # We do not need to evaluate the final model value. + if step < steps: + model_prev_list[-1] = self.model_fn(x, vec_t) + elif method in ['singlestep', 'singlestep_fixed']: + if method == 'singlestep': + timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, + skip_type=skip_type, + t_T=t_T, t_0=t_0, + device=device) + elif method == 'singlestep_fixed': + K = steps // order + orders = [order, ] * K + timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) + for i, order in enumerate(orders): + t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1] + timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(), + N=order, device=device) + lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) + vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0]) + h = lambda_inner[-1] - lambda_inner[0] + r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h + r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h + x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2) + if denoise_to_zero: + x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0) + return x + + +############################################################# +# other utility functions +############################################################# + +def interpolate_fn(x, xp, yp): + """ + A piecewise linear function y = f(x), using xp and yp as keypoints. + We implement f(x) in a differentiable way (i.e. applicable for autograd). + The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) + Args: + x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). + xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. + yp: PyTorch tensor with shape [C, K]. + Returns: + The function values f(x), with shape [N, C]. + """ + N, K = x.shape[0], xp.shape[1] + all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) + sorted_all_x, x_indices = torch.sort(all_x, dim=2) + x_idx = torch.argmin(x_indices, dim=2) + cand_start_idx = x_idx - 1 + start_idx = torch.where( + torch.eq(x_idx, 0), + torch.tensor(1, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) + start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) + end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) + start_idx2 = torch.where( + torch.eq(x_idx, 0), + torch.tensor(0, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) + start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) + end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) + cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) + return cand + + +def expand_dims(v, dims): + """ + Expand the tensor `v` to the dim `dims`. + Args: + `v`: a PyTorch tensor with shape [N]. + `dim`: a `int`. + Returns: + a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. + """ + return v[(...,) + (None,) * (dims - 1)] \ No newline at end of file diff --git a/examples/images/diffusion/ldm/models/diffusion/dpm_solver/sampler.py b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..7d137b8cf36718c1c58faa09f9dd919e5fb2977b --- /dev/null +++ b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/sampler.py @@ -0,0 +1,87 @@ +"""SAMPLING ONLY.""" +import torch + +from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver + + +MODEL_TYPES = { + "eps": "noise", + "v": "v" +} + + +class DPMSolverSampler(object): + def __init__(self, model, **kwargs): + super().__init__() + self.model = model + to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) + self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + + print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') + + device = self.model.betas.device + if x_T is None: + img = torch.randn(size, device=device) + else: + img = x_T + + ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) + + model_fn = model_wrapper( + lambda x, t, c: self.model.apply_model(x, t, c), + ns, + model_type=MODEL_TYPES[self.model.parameterization], + guidance_type="classifier-free", + condition=conditioning, + unconditional_condition=unconditional_conditioning, + guidance_scale=unconditional_guidance_scale, + ) + + dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) + x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) + + return x.to(device), None \ No newline at end of file diff --git a/examples/images/diffusion/ldm/models/diffusion/plms.py b/examples/images/diffusion/ldm/models/diffusion/plms.py new file mode 100644 index 0000000000000000000000000000000000000000..7002a365d27168ced0a04e9a4d83e088f8284eae --- /dev/null +++ b/examples/images/diffusion/ldm/models/diffusion/plms.py @@ -0,0 +1,244 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like +from ldm.models.diffusion.sampling_util import norm_thresholding + + +class PLMSSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + if ddim_eta != 0: + raise ValueError('ddim_eta must be 0 for PLMS') + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + dynamic_threshold=None, + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for PLMS sampling is {size}') + + samples, intermediates = self.plms_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ) + return samples, intermediates + + @torch.no_grad() + def plms_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, + dynamic_threshold=None): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running PLMS Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) + old_eps = [] + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + old_eps=old_eps, t_next=ts_next, + dynamic_threshold=dynamic_threshold) + img, pred_x0, e_t = outs + old_eps.append(e_t) + if len(old_eps) >= 4: + old_eps.pop(0) + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, + dynamic_threshold=None): + b, *_, device = *x.shape, x.device + + def get_model_output(x, t): + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + return e_t + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + + def get_x_prev_and_pred_x0(e_t, index): + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + if dynamic_threshold is not None: + pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + e_t = get_model_output(x, t) + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) + e_t_next = get_model_output(x_prev, t_next) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + elif len(old_eps) >= 3: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 + + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + + return x_prev, pred_x0, e_t diff --git a/examples/images/diffusion/ldm/models/diffusion/sampling_util.py b/examples/images/diffusion/ldm/models/diffusion/sampling_util.py new file mode 100644 index 0000000000000000000000000000000000000000..7eff02be6d7c54d43ee6680636ac0698dd3b3f33 --- /dev/null +++ b/examples/images/diffusion/ldm/models/diffusion/sampling_util.py @@ -0,0 +1,22 @@ +import torch +import numpy as np + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions. + From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') + return x[(...,) + (None,) * dims_to_append] + + +def norm_thresholding(x0, value): + s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) + return x0 * (value / s) + + +def spatial_norm_thresholding(x0, value): + # b c h w + s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) + return x0 * (value / s) \ No newline at end of file diff --git a/examples/images/diffusion/ldm/modules/attention.py b/examples/images/diffusion/ldm/modules/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..d504d939f6a02cf45f028799d7d73b84500cee06 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/attention.py @@ -0,0 +1,331 @@ +from inspect import isfunction +import math +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat +from typing import Optional, Any + +from ldm.modules.diffusionmodules.util import checkpoint + + +try: + import xformers + import xformers.ops + XFORMERS_IS_AVAILBLE = True +except: + XFORMERS_IS_AVAILBLE = False + + +def exists(val): + return val is not None + + +def uniq(arr): + return{el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = rearrange(q, 'b c h w -> b (h w) c') + k = rearrange(k, 'b c h w -> b c (h w)') + w_ = torch.einsum('bij,bjk->bik', q, k) + + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, 'b c h w -> b c (h w)') + w_ = rearrange(w_, 'b i j -> b j i') + h_ = torch.einsum('bij,bjk->bik', v, w_) + h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + h_ = self.proj_out(h_) + + return x+h_ + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + del q, k + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', sim, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +class MemoryEfficientCrossAttention(nn.Module): + # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " + f"{heads} heads.") + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.heads = heads + self.dim_head = dim_head + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.attention_op: Optional[Any] = None + + def forward(self, x, context=None, mask=None): + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + + # actually compute the attention, what we cannot get enough of + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + + if exists(mask): + raise NotImplementedError + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + ATTENTION_MODES = { + "softmax": CrossAttention, # vanilla attention + "softmax-xformers": MemoryEfficientCrossAttention + } + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, + disable_self_attn=False): + super().__init__() + attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax" + assert attn_mode in self.ATTENTION_MODES + attn_cls = self.ATTENTION_MODES[attn_mode] + self.disable_self_attn = disable_self_attn + self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + NEW: use_linear for more efficiency instead of the 1x1 convs + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None, + disable_self_attn=False, use_linear=False, + use_checkpoint=True): + super().__init__() + if exists(context_dim) and not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + if not use_linear: + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], + disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) + for d in range(depth)] + ) + if not use_linear: + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context[i]) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in + diff --git a/examples/images/diffusion/ldm/modules/diffusionmodules/__init__.py b/examples/images/diffusion/ldm/modules/diffusionmodules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/images/diffusion/ldm/modules/diffusionmodules/model.py b/examples/images/diffusion/ldm/modules/diffusionmodules/model.py new file mode 100644 index 0000000000000000000000000000000000000000..fb088db58919dd3ab79b2d6c7ab4d0e6a40f7454 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/diffusionmodules/model.py @@ -0,0 +1,860 @@ +# pytorch_diffusion + derived encoder decoder +import math +from typing import Any, Optional + +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange + +try: + from lightning.pytorch.utilities import rank_zero_info +except: + from pytorch_lightning.utilities import rank_zero_info + +from ldm.modules.attention import MemoryEfficientCrossAttention + +try: + import xformers + import xformers.ops + XFORMERS_IS_AVAILBLE = True +except: + XFORMERS_IS_AVAILBLE = False + print("No module 'xformers'. Proceeding without it.") + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class AttnBlock(nn.Module): + + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + + +class MemoryEfficientAttnBlock(nn.Module): + """ + Uses xformers efficient implementation, + see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + Note: this is a single-head self-attention operation + """ + + # + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.attention_op: Optional[Any] = None + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + B, C, H, W = q.shape + q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v)) + + q, k, v = map( + lambda t: t.unsqueeze(3).reshape(B, t.shape[1], 1, C).permute(0, 2, 1, 3).reshape(B * 1, t.shape[1], C). + contiguous(), + (q, k, v), + ) + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + + out = (out.unsqueeze(0).reshape(B, 1, out.shape[1], C).permute(0, 2, 1, 3).reshape(B, out.shape[1], C)) + out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C) + out = self.proj_out(out) + return x + out + + +class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention): + + def forward(self, x, context=None, mask=None): + b, c, h, w = x.shape + x = rearrange(x, 'b c h w -> b (h w) c') + out = super().forward(x, context=context, mask=mask) + out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c) + return x + out + + +def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): + assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", + "none"], f'attn_type {attn_type} unknown' + if XFORMERS_IS_AVAILBLE and attn_type == "vanilla": + attn_type = "vanilla-xformers" + if attn_type == "vanilla": + assert attn_kwargs is None + return AttnBlock(in_channels) + elif attn_type == "vanilla-xformers": + rank_zero_info(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...") + return MemoryEfficientAttnBlock(in_channels) + elif type == "memory-efficient-cross-attn": + attn_kwargs["query_dim"] = in_channels + return MemoryEfficientCrossAttentionWrapper(**attn_kwargs) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + raise NotImplementedError() + + +class Model(nn.Module): + + def __init__(self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + use_timestep=True, + use_linear_attn=False, + attn_type="vanilla"): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch * 4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + skip_in = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + if i_block == self.num_res_blocks: + skip_in = ch * in_ch_mult[i_level] + block.append( + ResnetBlock(in_channels=block_in + skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, x, t=None, context=None): + #assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + + def __init__(self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + use_linear_attn=False, + attn_type="vanilla", + **ignore_kwargs): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + + def __init__(self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + attn_type="vanilla", + **ignorekwargs): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2**(self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + rank_zero_info("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class SimpleDecoder(nn.Module): + + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList([ + nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, out_channels=4 * in_channels, temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0), + nn.Conv2d(2 * in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True) + ]) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1, 2, 3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + + def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, ch_mult=(2, 2), dropout=0.0): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2**(self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append( + ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1) + self.res_block1 = nn.ModuleList([ + ResnetBlock(in_channels=mid_channels, out_channels=mid_channels, temb_channels=0, dropout=0.0) + for _ in range(depth) + ]) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList([ + ResnetBlock(in_channels=mid_channels, out_channels=mid_channels, temb_channels=0, dropout=0.0) + for _ in range(depth) + ]) + + self.conv_out = nn.Conv2d( + mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate(x, + size=(int(round(x.shape[2] * self.factor)), + int(round(x.shape[3] * self.factor)))) + x = self.attn(x) + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + + def __init__(self, + in_channels, + ch, + resolution, + out_ch, + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + ch_mult=(1, 2, 4, 8), + rescale_factor=1.0, + rescale_module_depth=1): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder(in_channels=in_channels, + num_res_blocks=num_res_blocks, + ch=ch, + ch_mult=ch_mult, + z_channels=intermediate_chn, + double_z=False, + resolution=resolution, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + out_ch=None) + self.rescaler = LatentRescaler(factor=rescale_factor, + in_channels=intermediate_chn, + mid_channels=intermediate_chn, + out_channels=out_ch, + depth=rescale_module_depth) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + + def __init__(self, + z_channels, + out_ch, + resolution, + num_res_blocks, + attn_resolutions, + ch, + ch_mult=(1, 2, 4, 8), + dropout=0.0, + resamp_with_conv=True, + rescale_factor=1.0, + rescale_module_depth=1): + super().__init__() + tmp_chn = z_channels * ch_mult[-1] + self.decoder = Decoder(out_ch=out_ch, + z_channels=tmp_chn, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + in_channels=None, + num_res_blocks=num_res_blocks, + ch_mult=ch_mult, + resolution=resolution, + ch=ch) + self.rescaler = LatentRescaler(factor=rescale_factor, + in_channels=z_channels, + mid_channels=tmp_chn, + out_channels=tmp_chn, + depth=rescale_module_depth) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size // in_size)) + 1 + factor_up = 1. + (out_size % in_size) + rank_zero_info( + f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}" + ) + self.rescaler = LatentRescaler(factor=factor_up, + in_channels=in_channels, + mid_channels=2 * in_channels, + out_channels=in_channels) + self.decoder = Decoder(out_ch=out_channels, + resolution=out_size, + z_channels=in_channels, + num_res_blocks=2, + attn_resolutions=[], + in_channels=None, + ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)]) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + rank_zero_info( + f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=4, stride=2, padding=1) + + def forward(self, x, scale_factor=1.0): + if scale_factor == 1.0: + return x + else: + x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) + return x diff --git a/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py b/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py new file mode 100644 index 0000000000000000000000000000000000000000..cd639d9360466c72b92db403e52514a149997ed8 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py @@ -0,0 +1,787 @@ +from abc import abstractmethod +import math + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from ldm.modules.diffusionmodules.util import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from ldm.modules.attention import SpatialTransformer +from ldm.util import exists + + +# dummy replace +def convert_module_to_f16(x): + pass + +def convert_module_to_f32(x): + pass + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + +class TransposedUpsample(nn.Module): + 'Learned 2x upsampling without padding' + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) + + def forward(self,x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + #return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + disable_self_attentions=None, + num_attention_blocks=None, + disable_middle_self_attn=False, + use_linear_in_transformer=False, + ): + super().__init__() + if use_spatial_transformer: + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError("provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult") + self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: + # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not + assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) + print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set.") + + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + if isinstance(self.num_classes, int): + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + elif self.num_classes == "continuous": + print("setting up linear c_adm embedding layer") + self.label_emb = nn.Linear(1, time_embed_dim) + else: + raise ValueError() + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(self.num_res_blocks[level] + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or i < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ) + ) + if level and i == self.num_res_blocks[level]: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps=None, context=None, y=None,**kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + t_emb = t_emb.type(self.dtype) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) diff --git a/examples/images/diffusion/ldm/modules/diffusionmodules/upscaling.py b/examples/images/diffusion/ldm/modules/diffusionmodules/upscaling.py new file mode 100644 index 0000000000000000000000000000000000000000..03816662098ce1ffac79bd939b892e867ab91988 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/diffusionmodules/upscaling.py @@ -0,0 +1,81 @@ +import torch +import torch.nn as nn +import numpy as np +from functools import partial + +from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule +from ldm.util import default + + +class AbstractLowScaleModel(nn.Module): + # for concatenating a downsampled image to the latent representation + def __init__(self, noise_schedule_config=None): + super(AbstractLowScaleModel, self).__init__() + if noise_schedule_config is not None: + self.register_schedule(**noise_schedule_config) + + def register_schedule(self, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def forward(self, x): + return x, None + + def decode(self, x): + return x + + +class SimpleImageConcat(AbstractLowScaleModel): + # no noise level conditioning + def __init__(self): + super(SimpleImageConcat, self).__init__(noise_schedule_config=None) + self.max_noise_level = 0 + + def forward(self, x): + # fix to constant noise level + return x, torch.zeros(x.shape[0], device=x.device).long() + + +class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel): + def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False): + super().__init__(noise_schedule_config=noise_schedule_config) + self.max_noise_level = max_noise_level + + def forward(self, x, noise_level=None): + if noise_level is None: + noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() + else: + assert isinstance(noise_level, torch.Tensor) + z = self.q_sample(x, noise_level) + return z, noise_level + + + diff --git a/examples/images/diffusion/ldm/modules/diffusionmodules/util.py b/examples/images/diffusion/ldm/modules/diffusionmodules/util.py new file mode 100644 index 0000000000000000000000000000000000000000..36b4a171b6c2382206c4754a26889fd685276db1 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/diffusionmodules/util.py @@ -0,0 +1,273 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + +import math +import os + +import numpy as np +import torch +import torch.nn as nn +from einops import repeat +from ldm.util import instantiate_from_config + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = (torch.linspace(linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64)**2) + + elif schedule == "cosine": + timesteps = (torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)**0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps))**2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + from torch.utils.checkpoint import checkpoint as torch_checkpoint + return torch_checkpoint(func, *inputs) + # args = tuple(inputs) + tuple(params) + # return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + ctx.gpu_autocast_kwargs = { + "enabled": torch.is_autocast_enabled(), + "dtype": torch.get_autocast_gpu_dtype(), + "cache_enabled": torch.is_autocast_cache_enabled() + } + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(), \ + torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / + half).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return nn.GroupNorm(16, channels) + # return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() diff --git a/examples/images/diffusion/ldm/modules/distributions/__init__.py b/examples/images/diffusion/ldm/modules/distributions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/images/diffusion/ldm/modules/distributions/distributions.py b/examples/images/diffusion/ldm/modules/distributions/distributions.py new file mode 100644 index 0000000000000000000000000000000000000000..f2b8ef901130efc171aa69742ca0244d94d3f2e9 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/distributions/distributions.py @@ -0,0 +1,92 @@ +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/examples/images/diffusion/ldm/modules/ema.py b/examples/images/diffusion/ldm/modules/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..bded25019b9bcbcd0260f0b8185f8c7859ca58c4 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/ema.py @@ -0,0 +1,80 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + + self.m_name2s_name = {} + self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) + self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates + else torch.tensor(-1, dtype=torch.int)) + + for name, p in model.named_parameters(): + if p.requires_grad: + # remove as '.'-character is not allowed in buffers + s_name = name.replace('.', '') + self.m_name2s_name.update({name: s_name}) + self.register_buffer(s_name, p.clone().detach().data) + + self.collected_params = [] + + def reset_num_updates(self): + del self.num_updates + self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int)) + + def forward(self, model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/examples/images/diffusion/ldm/modules/encoders/__init__.py b/examples/images/diffusion/ldm/modules/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/images/diffusion/ldm/modules/encoders/modules.py b/examples/images/diffusion/ldm/modules/encoders/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..4edd5496b9e668ea72a5be39db9cca94b6a42f9b --- /dev/null +++ b/examples/images/diffusion/ldm/modules/encoders/modules.py @@ -0,0 +1,213 @@ +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint + +from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel + +import open_clip +from ldm.util import default, count_params + + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + +class IdentityEncoder(AbstractEncoder): + + def encode(self, x): + return x + + +class ClassEmbedder(nn.Module): + def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1): + super().__init__() + self.key = key + self.embedding = nn.Embedding(n_classes, embed_dim) + self.n_classes = n_classes + self.ucg_rate = ucg_rate + + def forward(self, batch, key=None, disable_dropout=False): + if key is None: + key = self.key + # this is for use in crossattn + c = batch[key][:, None] + if self.ucg_rate > 0. and not disable_dropout: + mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) + c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1) + c = c.long() + c = self.embedding(c) + return c + + def get_unconditional_conditioning(self, bs, device="cuda"): + uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) + uc = torch.ones((bs,), device=device) * uc_class + uc = {self.key: uc} + return uc + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class FrozenT5Embedder(AbstractEncoder): + """Uses the T5 transformer encoder for text""" + def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl + super().__init__() + self.tokenizer = T5Tokenizer.from_pretrained(version) + self.transformer = T5EncoderModel.from_pretrained(version) + self.device = device + self.max_length = max_length # TODO: typical value? + if freeze: + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + #self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + + +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from huggingface)""" + LAYERS = [ + "last", + "pooled", + "hidden" + ] + def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, + freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 + super().__init__() + assert layer in self.LAYERS + self.tokenizer = CLIPTokenizer.from_pretrained(version) + self.transformer = CLIPTextModel.from_pretrained(version) + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + self.layer_idx = layer_idx + if layer == "hidden": + assert layer_idx is not None + assert 0 <= abs(layer_idx) <= 12 + + def freeze(self): + self.transformer = self.transformer.eval() + #self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") + if self.layer == "last": + z = outputs.last_hidden_state + elif self.layer == "pooled": + z = outputs.pooler_output[:, None, :] + else: + z = outputs.hidden_states[self.layer_idx] + return z + + def encode(self, text): + return self(text) + + +class FrozenOpenCLIPEmbedder(AbstractEncoder): + """ + Uses the OpenCLIP transformer encoder for text + """ + LAYERS = [ + #"pooled", + "last", + "penultimate" + ] + def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, + freeze=True, layer="last"): + super().__init__() + assert layer in self.LAYERS + model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) + del model.visual + self.model = model + + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + if self.layer == "last": + self.layer_idx = 0 + elif self.layer == "penultimate": + self.layer_idx = 1 + else: + raise NotImplementedError() + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + tokens = open_clip.tokenize(text) + z = self.encode_with_transformer(tokens.to(self.device)) + return z + + def encode_with_transformer(self, text): + x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.model.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.model.ln_final(x) + return x + + def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): + for i, r in enumerate(self.model.transformer.resblocks): + if i == len(self.model.transformer.resblocks) - self.layer_idx: + break + if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + def encode(self, text): + return self(text) + + +class FrozenCLIPT5Encoder(AbstractEncoder): + def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", + clip_max_length=77, t5_max_length=77): + super().__init__() + self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length) + self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) + print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, " + f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.") + + def encode(self, text): + return self(text) + + def forward(self, text): + clip_z = self.clip_encoder.encode(text) + t5_z = self.t5_encoder.encode(text) + return [clip_z, t5_z] + + diff --git a/examples/images/diffusion/ldm/modules/image_degradation/__init__.py b/examples/images/diffusion/ldm/modules/image_degradation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7836cada81f90ded99c58d5942eea4c3477f58fc --- /dev/null +++ b/examples/images/diffusion/ldm/modules/image_degradation/__init__.py @@ -0,0 +1,2 @@ +from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr +from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light diff --git a/examples/images/diffusion/ldm/modules/image_degradation/bsrgan.py b/examples/images/diffusion/ldm/modules/image_degradation/bsrgan.py new file mode 100644 index 0000000000000000000000000000000000000000..32ef56169978e550090261cddbcf5eb611a6173b --- /dev/null +++ b/examples/images/diffusion/ldm/modules/image_degradation/bsrgan.py @@ -0,0 +1,730 @@ +# -*- coding: utf-8 -*- +""" +# -------------------------------------------- +# Super-Resolution +# -------------------------------------------- +# +# Kai Zhang (cskaizhang@gmail.com) +# https://github.com/cszn +# From 2019/03--2021/08 +# -------------------------------------------- +""" + +import numpy as np +import cv2 +import torch + +from functools import partial +import random +from scipy import ndimage +import scipy +import scipy.stats as ss +from scipy.interpolate import interp2d +from scipy.linalg import orth +import albumentations + +import ldm.modules.image_degradation.utils_image as util + + +def modcrop_np(img, sf): + ''' + Args: + img: numpy image, WxH or WxHxC + sf: scale factor + Return: + cropped image + ''' + w, h = img.shape[:2] + im = np.copy(img) + return im[:w - w % sf, :h - h % sf, ...] + + +""" +# -------------------------------------------- +# anisotropic Gaussian kernels +# -------------------------------------------- +""" + + +def analytic_kernel(k): + """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" + k_size = k.shape[0] + # Calculate the big kernels size + big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) + # Loop over the small kernel to fill the big one + for r in range(k_size): + for c in range(k_size): + big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + # Crop the edges of the big kernel to ignore very small values and increase run time of SR + crop = k_size // 2 + cropped_big_k = big_k[crop:-crop, crop:-crop] + # Normalize to 1 + return cropped_big_k / cropped_big_k.sum() + + +def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): + """ generate an anisotropic Gaussian kernel + Args: + ksize : e.g., 15, kernel size + theta : [0, pi], rotation angle range + l1 : [0.1,50], scaling of eigenvalues + l2 : [0.1,l1], scaling of eigenvalues + If l1 = l2, will get an isotropic Gaussian kernel. + Returns: + k : kernel + """ + + v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + V = np.array([[v[0], v[1]], [v[1], -v[0]]]) + D = np.array([[l1, 0], [0, l2]]) + Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) + k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) + + return k + + +def gm_blur_kernel(mean, cov, size=15): + center = size / 2.0 + 0.5 + k = np.zeros([size, size]) + for y in range(size): + for x in range(size): + cy = y - center + 1 + cx = x - center + 1 + k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) + + k = k / np.sum(k) + return k + + +def shift_pixel(x, sf, upper_left=True): + """shift pixel for super-resolution with different scale factors + Args: + x: WxHxC or WxH + sf: scale factor + upper_left: shift direction + """ + h, w = x.shape[:2] + shift = (sf - 1) * 0.5 + xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) + if upper_left: + x1 = xv + shift + y1 = yv + shift + else: + x1 = xv - shift + y1 = yv - shift + + x1 = np.clip(x1, 0, w - 1) + y1 = np.clip(y1, 0, h - 1) + + if x.ndim == 2: + x = interp2d(xv, yv, x)(x1, y1) + if x.ndim == 3: + for i in range(x.shape[-1]): + x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) + + return x + + +def blur(x, k): + ''' + x: image, NxcxHxW + k: kernel, Nx1xhxw + ''' + n, c = x.shape[:2] + p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + k = k.repeat(1, c, 1, 1) + k = k.view(-1, 1, k.shape[2], k.shape[3]) + x = x.view(1, -1, x.shape[2], x.shape[3]) + x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) + x = x.view(n, c, x.shape[2], x.shape[3]) + + return x + + +def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): + """" + # modified version of https://github.com/assafshocher/BlindSR_dataset_generator + # Kai Zhang + # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var + # max_var = 2.5 * sf + """ + # Set random eigen-vals (lambdas) and angle (theta) for COV matrix + lambda_1 = min_var + np.random.rand() * (max_var - min_var) + lambda_2 = min_var + np.random.rand() * (max_var - min_var) + theta = np.random.rand() * np.pi # random theta + noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 + + # Set COV matrix using Lambdas and Theta + LAMBDA = np.diag([lambda_1, lambda_2]) + Q = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + SIGMA = Q @ LAMBDA @ Q.T + INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] + + # Set expectation position (shifting kernel for aligned image) + MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) + MU = MU[None, None, :, None] + + # Create meshgrid for Gaussian + [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) + Z = np.stack([X, Y], 2)[:, :, :, None] + + # Calcualte Gaussian for every pixel of the kernel + ZZ = Z - MU + ZZ_t = ZZ.transpose(0, 1, 3, 2) + raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) + + # shift the kernel so it will be centered + # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) + + # Normalize the kernel and return + # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) + kernel = raw_kernel / np.sum(raw_kernel) + return kernel + + +def fspecial_gaussian(hsize, sigma): + hsize = [hsize, hsize] + siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] + std = sigma + [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) + arg = -(x * x + y * y) / (2 * std * std) + h = np.exp(arg) + h[h < scipy.finfo(float).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + h = h / sumh + return h + + +def fspecial_laplacian(alpha): + alpha = max([0, min([alpha, 1])]) + h1 = alpha / (alpha + 1) + h2 = (1 - alpha) / (alpha + 1) + h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] + h = np.array(h) + return h + + +def fspecial(filter_type, *args, **kwargs): + ''' + python code from: + https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py + ''' + if filter_type == 'gaussian': + return fspecial_gaussian(*args, **kwargs) + if filter_type == 'laplacian': + return fspecial_laplacian(*args, **kwargs) + + +""" +# -------------------------------------------- +# degradation models +# -------------------------------------------- +""" + + +def bicubic_degradation(x, sf=3): + ''' + Args: + x: HxWxC image, [0, 1] + sf: down-scale factor + Return: + bicubicly downsampled LR image + ''' + x = util.imresize_np(x, scale=1 / sf) + return x + + +def srmd_degradation(x, k, sf=3): + ''' blur + bicubic downsampling + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2018learning, + title={Learning a single convolutional super-resolution network for multiple degradations}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={3262--3271}, + year={2018} + } + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + x = bicubic_degradation(x, sf=sf) + return x + + +def dpsr_degradation(x, k, sf=3): + ''' bicubic downsampling + blur + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2019deep, + title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={1671--1681}, + year={2019} + } + ''' + x = bicubic_degradation(x, sf=sf) + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + return x + + +def classical_degradation(x, k, sf=3): + ''' blur + downsampling + Args: + x: HxWxC image, [0, 1]/[0, 255] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) + st = 0 + return x[st::sf, st::sf, ...] + + +def add_sharpening(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. borrowed from real-ESRGAN + Input image: I; Blurry image: B. + 1. K = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * K + (1 - Mask) * I + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype('float32') + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + K = img + weight * residual + K = np.clip(K, 0, 1) + return soft_mask * K + (1 - soft_mask) * img + + +def add_blur(img, sf=4): + wd2 = 4.0 + sf + wd = 2.0 + 0.2 * sf + if random.random() < 0.5: + l1 = wd2 * random.random() + l2 = wd2 * random.random() + k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) + else: + k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random()) + img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + + return img + + +def add_resize(img, sf=4): + rnum = np.random.rand() + if rnum > 0.8: # up + sf1 = random.uniform(1, 2) + elif rnum < 0.7: # down + sf1 = random.uniform(0.5 / sf, 1) + else: + sf1 = 1.0 + img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + return img + + +# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): +# noise_level = random.randint(noise_level1, noise_level2) +# rnum = np.random.rand() +# if rnum > 0.6: # add color Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) +# elif rnum < 0.4: # add grayscale Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) +# else: # add noise +# L = noise_level2 / 255. +# D = np.diag(np.random.rand(3)) +# U = orth(np.random.rand(3, 3)) +# conv = np.dot(np.dot(np.transpose(U), D), U) +# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) +# img = np.clip(img, 0.0, 1.0) +# return img + +def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + rnum = np.random.rand() + if rnum > 0.6: # add color Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: # add grayscale Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: # add noise + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_speckle_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + img = np.clip(img, 0.0, 1.0) + rnum = random.random() + if rnum > 0.6: + img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: + img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_Poisson_noise(img): + img = np.clip((img * 255.0).round(), 0, 255) / 255. + vals = 10 ** (2 * random.random() + 2.0) # [2, 4] + if random.random() < 0.5: + img = np.random.poisson(img * vals).astype(np.float32) / vals + else: + img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray + img += noise_gray[:, :, np.newaxis] + img = np.clip(img, 0.0, 1.0) + return img + + +def add_JPEG_noise(img): + quality_factor = random.randint(30, 95) + img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) + result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + img = cv2.imdecode(encimg, 1) + img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) + return img + + +def random_crop(lq, hq, sf=4, lq_patchsize=64): + h, w = lq.shape[:2] + rnd_h = random.randint(0, h - lq_patchsize) + rnd_w = random.randint(0, w - lq_patchsize) + lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + + rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) + hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + return lq, hq + + +def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + hq = img.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + img = util.imresize_np(img, 1 / 2, True) + img = np.clip(img, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + img = add_blur(img, sf=sf) + + elif i == 1: + img = add_blur(img, sf=sf) + + elif i == 2: + a, b = img.shape[1], img.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + img = img[0::sf, 0::sf, ...] # nearest downsampling + img = np.clip(img, 0.0, 1.0) + + elif i == 3: + # downsample3 + img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + img = add_JPEG_noise(img) + + elif i == 6: + # add processed camera sensor noise + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf_ori, lq_patchsize) + + return img, hq + + +# todo no isp_model? +def degradation_bsrgan_variant(image, sf=4, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + image = util.uint2single(image) + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = image.shape[:2] + image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = image.shape[:2] + + hq = image.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + image = util.imresize_np(image, 1 / 2, True) + image = np.clip(image, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + image = add_blur(image, sf=sf) + + elif i == 1: + image = add_blur(image, sf=sf) + + elif i == 2: + a, b = image.shape[1], image.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = image[0::sf, 0::sf, ...] # nearest downsampling + image = np.clip(image, 0.0, 1.0) + + elif i == 3: + # downsample3 + image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + image = np.clip(image, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + image = add_JPEG_noise(image) + + # elif i == 6: + # # add processed camera sensor noise + # if random.random() < isp_prob and isp_model is not None: + # with torch.no_grad(): + # img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + image = add_JPEG_noise(image) + image = util.single2uint(image) + example = {"image":image} + return example + + +# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc... +def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None): + """ + This is an extended degradation model by combining + the degradation models of BSRGAN and Real-ESRGAN + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + use_shuffle: the degradation shuffle + use_sharp: sharpening the img + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + if use_sharp: + img = add_sharpening(img) + hq = img.copy() + + if random.random() < shuffle_prob: + shuffle_order = random.sample(range(13), 13) + else: + shuffle_order = list(range(13)) + # local shuffle for noise, JPEG is always the last one + shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6))) + shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13))) + + poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1 + + for i in shuffle_order: + if i == 0: + img = add_blur(img, sf=sf) + elif i == 1: + img = add_resize(img, sf=sf) + elif i == 2: + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + elif i == 3: + if random.random() < poisson_prob: + img = add_Poisson_noise(img) + elif i == 4: + if random.random() < speckle_prob: + img = add_speckle_noise(img) + elif i == 5: + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + elif i == 6: + img = add_JPEG_noise(img) + elif i == 7: + img = add_blur(img, sf=sf) + elif i == 8: + img = add_resize(img, sf=sf) + elif i == 9: + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + elif i == 10: + if random.random() < poisson_prob: + img = add_Poisson_noise(img) + elif i == 11: + if random.random() < speckle_prob: + img = add_speckle_noise(img) + elif i == 12: + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + else: + print('check the shuffle!') + + # resize to desired size + img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])), + interpolation=random.choice([1, 2, 3])) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf, lq_patchsize) + + return img, hq + + +if __name__ == '__main__': + print("hey") + img = util.imread_uint('utils/test.png', 3) + print(img) + img = util.uint2single(img) + print(img) + img = img[:448, :448] + h = img.shape[0] // 4 + print("resizing to", h) + sf = 4 + deg_fn = partial(degradation_bsrgan_variant, sf=sf) + for i in range(20): + print(i) + img_lq = deg_fn(img) + print(img_lq) + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"] + print(img_lq.shape) + print("bicubic", img_lq_bicubic.shape) + print(img_hq.shape) + lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) + util.imsave(img_concat, str(i) + '.png') + + diff --git a/examples/images/diffusion/ldm/modules/image_degradation/bsrgan_light.py b/examples/images/diffusion/ldm/modules/image_degradation/bsrgan_light.py new file mode 100644 index 0000000000000000000000000000000000000000..808c7f882cb75e2ba2340d5b55881d11927351f0 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/image_degradation/bsrgan_light.py @@ -0,0 +1,651 @@ +# -*- coding: utf-8 -*- +import numpy as np +import cv2 +import torch + +from functools import partial +import random +from scipy import ndimage +import scipy +import scipy.stats as ss +from scipy.interpolate import interp2d +from scipy.linalg import orth +import albumentations + +import ldm.modules.image_degradation.utils_image as util + +""" +# -------------------------------------------- +# Super-Resolution +# -------------------------------------------- +# +# Kai Zhang (cskaizhang@gmail.com) +# https://github.com/cszn +# From 2019/03--2021/08 +# -------------------------------------------- +""" + +def modcrop_np(img, sf): + ''' + Args: + img: numpy image, WxH or WxHxC + sf: scale factor + Return: + cropped image + ''' + w, h = img.shape[:2] + im = np.copy(img) + return im[:w - w % sf, :h - h % sf, ...] + + +""" +# -------------------------------------------- +# anisotropic Gaussian kernels +# -------------------------------------------- +""" + + +def analytic_kernel(k): + """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" + k_size = k.shape[0] + # Calculate the big kernels size + big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) + # Loop over the small kernel to fill the big one + for r in range(k_size): + for c in range(k_size): + big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + # Crop the edges of the big kernel to ignore very small values and increase run time of SR + crop = k_size // 2 + cropped_big_k = big_k[crop:-crop, crop:-crop] + # Normalize to 1 + return cropped_big_k / cropped_big_k.sum() + + +def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): + """ generate an anisotropic Gaussian kernel + Args: + ksize : e.g., 15, kernel size + theta : [0, pi], rotation angle range + l1 : [0.1,50], scaling of eigenvalues + l2 : [0.1,l1], scaling of eigenvalues + If l1 = l2, will get an isotropic Gaussian kernel. + Returns: + k : kernel + """ + + v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + V = np.array([[v[0], v[1]], [v[1], -v[0]]]) + D = np.array([[l1, 0], [0, l2]]) + Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) + k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) + + return k + + +def gm_blur_kernel(mean, cov, size=15): + center = size / 2.0 + 0.5 + k = np.zeros([size, size]) + for y in range(size): + for x in range(size): + cy = y - center + 1 + cx = x - center + 1 + k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) + + k = k / np.sum(k) + return k + + +def shift_pixel(x, sf, upper_left=True): + """shift pixel for super-resolution with different scale factors + Args: + x: WxHxC or WxH + sf: scale factor + upper_left: shift direction + """ + h, w = x.shape[:2] + shift = (sf - 1) * 0.5 + xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) + if upper_left: + x1 = xv + shift + y1 = yv + shift + else: + x1 = xv - shift + y1 = yv - shift + + x1 = np.clip(x1, 0, w - 1) + y1 = np.clip(y1, 0, h - 1) + + if x.ndim == 2: + x = interp2d(xv, yv, x)(x1, y1) + if x.ndim == 3: + for i in range(x.shape[-1]): + x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) + + return x + + +def blur(x, k): + ''' + x: image, NxcxHxW + k: kernel, Nx1xhxw + ''' + n, c = x.shape[:2] + p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + k = k.repeat(1, c, 1, 1) + k = k.view(-1, 1, k.shape[2], k.shape[3]) + x = x.view(1, -1, x.shape[2], x.shape[3]) + x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) + x = x.view(n, c, x.shape[2], x.shape[3]) + + return x + + +def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): + """" + # modified version of https://github.com/assafshocher/BlindSR_dataset_generator + # Kai Zhang + # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var + # max_var = 2.5 * sf + """ + # Set random eigen-vals (lambdas) and angle (theta) for COV matrix + lambda_1 = min_var + np.random.rand() * (max_var - min_var) + lambda_2 = min_var + np.random.rand() * (max_var - min_var) + theta = np.random.rand() * np.pi # random theta + noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 + + # Set COV matrix using Lambdas and Theta + LAMBDA = np.diag([lambda_1, lambda_2]) + Q = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + SIGMA = Q @ LAMBDA @ Q.T + INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] + + # Set expectation position (shifting kernel for aligned image) + MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) + MU = MU[None, None, :, None] + + # Create meshgrid for Gaussian + [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) + Z = np.stack([X, Y], 2)[:, :, :, None] + + # Calcualte Gaussian for every pixel of the kernel + ZZ = Z - MU + ZZ_t = ZZ.transpose(0, 1, 3, 2) + raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) + + # shift the kernel so it will be centered + # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) + + # Normalize the kernel and return + # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) + kernel = raw_kernel / np.sum(raw_kernel) + return kernel + + +def fspecial_gaussian(hsize, sigma): + hsize = [hsize, hsize] + siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] + std = sigma + [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) + arg = -(x * x + y * y) / (2 * std * std) + h = np.exp(arg) + h[h < scipy.finfo(float).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + h = h / sumh + return h + + +def fspecial_laplacian(alpha): + alpha = max([0, min([alpha, 1])]) + h1 = alpha / (alpha + 1) + h2 = (1 - alpha) / (alpha + 1) + h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] + h = np.array(h) + return h + + +def fspecial(filter_type, *args, **kwargs): + ''' + python code from: + https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py + ''' + if filter_type == 'gaussian': + return fspecial_gaussian(*args, **kwargs) + if filter_type == 'laplacian': + return fspecial_laplacian(*args, **kwargs) + + +""" +# -------------------------------------------- +# degradation models +# -------------------------------------------- +""" + + +def bicubic_degradation(x, sf=3): + ''' + Args: + x: HxWxC image, [0, 1] + sf: down-scale factor + Return: + bicubicly downsampled LR image + ''' + x = util.imresize_np(x, scale=1 / sf) + return x + + +def srmd_degradation(x, k, sf=3): + ''' blur + bicubic downsampling + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2018learning, + title={Learning a single convolutional super-resolution network for multiple degradations}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={3262--3271}, + year={2018} + } + ''' + x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + x = bicubic_degradation(x, sf=sf) + return x + + +def dpsr_degradation(x, k, sf=3): + ''' bicubic downsampling + blur + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2019deep, + title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={1671--1681}, + year={2019} + } + ''' + x = bicubic_degradation(x, sf=sf) + x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + return x + + +def classical_degradation(x, k, sf=3): + ''' blur + downsampling + Args: + x: HxWxC image, [0, 1]/[0, 255] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + ''' + x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) + st = 0 + return x[st::sf, st::sf, ...] + + +def add_sharpening(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. borrowed from real-ESRGAN + Input image: I; Blurry image: B. + 1. K = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * K + (1 - Mask) * I + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype('float32') + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + K = img + weight * residual + K = np.clip(K, 0, 1) + return soft_mask * K + (1 - soft_mask) * img + + +def add_blur(img, sf=4): + wd2 = 4.0 + sf + wd = 2.0 + 0.2 * sf + + wd2 = wd2/4 + wd = wd/4 + + if random.random() < 0.5: + l1 = wd2 * random.random() + l2 = wd2 * random.random() + k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) + else: + k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random()) + img = ndimage.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + + return img + + +def add_resize(img, sf=4): + rnum = np.random.rand() + if rnum > 0.8: # up + sf1 = random.uniform(1, 2) + elif rnum < 0.7: # down + sf1 = random.uniform(0.5 / sf, 1) + else: + sf1 = 1.0 + img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + return img + + +# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): +# noise_level = random.randint(noise_level1, noise_level2) +# rnum = np.random.rand() +# if rnum > 0.6: # add color Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) +# elif rnum < 0.4: # add grayscale Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) +# else: # add noise +# L = noise_level2 / 255. +# D = np.diag(np.random.rand(3)) +# U = orth(np.random.rand(3, 3)) +# conv = np.dot(np.dot(np.transpose(U), D), U) +# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) +# img = np.clip(img, 0.0, 1.0) +# return img + +def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + rnum = np.random.rand() + if rnum > 0.6: # add color Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: # add grayscale Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: # add noise + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_speckle_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + img = np.clip(img, 0.0, 1.0) + rnum = random.random() + if rnum > 0.6: + img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: + img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_Poisson_noise(img): + img = np.clip((img * 255.0).round(), 0, 255) / 255. + vals = 10 ** (2 * random.random() + 2.0) # [2, 4] + if random.random() < 0.5: + img = np.random.poisson(img * vals).astype(np.float32) / vals + else: + img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray + img += noise_gray[:, :, np.newaxis] + img = np.clip(img, 0.0, 1.0) + return img + + +def add_JPEG_noise(img): + quality_factor = random.randint(80, 95) + img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) + result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + img = cv2.imdecode(encimg, 1) + img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) + return img + + +def random_crop(lq, hq, sf=4, lq_patchsize=64): + h, w = lq.shape[:2] + rnd_h = random.randint(0, h - lq_patchsize) + rnd_w = random.randint(0, w - lq_patchsize) + lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + + rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) + hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + return lq, hq + + +def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + hq = img.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + img = util.imresize_np(img, 1 / 2, True) + img = np.clip(img, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + img = add_blur(img, sf=sf) + + elif i == 1: + img = add_blur(img, sf=sf) + + elif i == 2: + a, b = img.shape[1], img.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + img = ndimage.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + img = img[0::sf, 0::sf, ...] # nearest downsampling + img = np.clip(img, 0.0, 1.0) + + elif i == 3: + # downsample3 + img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + img = add_JPEG_noise(img) + + elif i == 6: + # add processed camera sensor noise + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf_ori, lq_patchsize) + + return img, hq + + +# todo no isp_model? +def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + image = util.uint2single(image) + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = image.shape[:2] + image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = image.shape[:2] + + hq = image.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + image = util.imresize_np(image, 1 / 2, True) + image = np.clip(image, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + image = add_blur(image, sf=sf) + + # elif i == 1: + # image = add_blur(image, sf=sf) + + if i == 0: + pass + + elif i == 2: + a, b = image.shape[1], image.shape[0] + # downsample2 + if random.random() < 0.8: + sf1 = random.uniform(1, 2 * sf) + image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + image = ndimage.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = image[0::sf, 0::sf, ...] # nearest downsampling + + image = np.clip(image, 0.0, 1.0) + + elif i == 3: + # downsample3 + image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + image = np.clip(image, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + image = add_JPEG_noise(image) + # + # elif i == 6: + # # add processed camera sensor noise + # if random.random() < isp_prob and isp_model is not None: + # with torch.no_grad(): + # img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + image = add_JPEG_noise(image) + image = util.single2uint(image) + if up: + image = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_CUBIC) # todo: random, as above? want to condition on it then + example = {"image": image} + return example + + + + +if __name__ == '__main__': + print("hey") + img = util.imread_uint('utils/test.png', 3) + img = img[:448, :448] + h = img.shape[0] // 4 + print("resizing to", h) + sf = 4 + deg_fn = partial(degradation_bsrgan_variant, sf=sf) + for i in range(20): + print(i) + img_hq = img + img_lq = deg_fn(img)["image"] + img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq) + print(img_lq) + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"] + print(img_lq.shape) + print("bicubic", img_lq_bicubic.shape) + print(img_hq.shape) + lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), + (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) + util.imsave(img_concat, str(i) + '.png') diff --git a/examples/images/diffusion/ldm/modules/image_degradation/utils/test.png b/examples/images/diffusion/ldm/modules/image_degradation/utils/test.png new file mode 100644 index 0000000000000000000000000000000000000000..4249b43de0f22707758d13c240268a401642f6e6 Binary files /dev/null and b/examples/images/diffusion/ldm/modules/image_degradation/utils/test.png differ diff --git a/examples/images/diffusion/ldm/modules/image_degradation/utils_image.py b/examples/images/diffusion/ldm/modules/image_degradation/utils_image.py new file mode 100644 index 0000000000000000000000000000000000000000..0175f155ad900ae33c3c46ed87f49b352e3faf98 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/image_degradation/utils_image.py @@ -0,0 +1,916 @@ +import os +import math +import random +import numpy as np +import torch +import cv2 +from torchvision.utils import make_grid +from datetime import datetime +#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py + + +os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" + + +''' +# -------------------------------------------- +# Kai Zhang (github: https://github.com/cszn) +# 03/Mar/2019 +# -------------------------------------------- +# https://github.com/twhui/SRGAN-pyTorch +# https://github.com/xinntao/BasicSR +# -------------------------------------------- +''' + + +IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif'] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def get_timestamp(): + return datetime.now().strftime('%y%m%d-%H%M%S') + + +def imshow(x, title=None, cbar=False, figsize=None): + plt.figure(figsize=figsize) + plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray') + if title: + plt.title(title) + if cbar: + plt.colorbar() + plt.show() + + +def surf(Z, cmap='rainbow', figsize=None): + plt.figure(figsize=figsize) + ax3 = plt.axes(projection='3d') + + w, h = Z.shape[:2] + xx = np.arange(0,w,1) + yy = np.arange(0,h,1) + X, Y = np.meshgrid(xx, yy) + ax3.plot_surface(X,Y,Z,cmap=cmap) + #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap) + plt.show() + + +''' +# -------------------------------------------- +# get image pathes +# -------------------------------------------- +''' + + +def get_image_paths(dataroot): + paths = None # return None if dataroot is None + if dataroot is not None: + paths = sorted(_get_paths_from_images(dataroot)) + return paths + + +def _get_paths_from_images(path): + assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) + images = [] + for dirpath, _, fnames in sorted(os.walk(path)): + for fname in sorted(fnames): + if is_image_file(fname): + img_path = os.path.join(dirpath, fname) + images.append(img_path) + assert images, '{:s} has no valid image file'.format(path) + return images + + +''' +# -------------------------------------------- +# split large images into small images +# -------------------------------------------- +''' + + +def patches_from_image(img, p_size=512, p_overlap=64, p_max=800): + w, h = img.shape[:2] + patches = [] + if w > p_max and h > p_max: + w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int)) + h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int)) + w1.append(w-p_size) + h1.append(h-p_size) +# print(w1) +# print(h1) + for i in w1: + for j in h1: + patches.append(img[i:i+p_size, j:j+p_size,:]) + else: + patches.append(img) + + return patches + + +def imssave(imgs, img_path): + """ + imgs: list, N images of size WxHxC + """ + img_name, ext = os.path.splitext(os.path.basename(img_path)) + + for i, img in enumerate(imgs): + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png') + cv2.imwrite(new_path, img) + + +def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000): + """ + split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size), + and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max) + will be splitted. + Args: + original_dataroot: + taget_dataroot: + p_size: size of small images + p_overlap: patch size in training is a good choice + p_max: images with smaller size than (p_max)x(p_max) keep unchanged. + """ + paths = get_image_paths(original_dataroot) + for img_path in paths: + # img_name, ext = os.path.splitext(os.path.basename(img_path)) + img = imread_uint(img_path, n_channels=n_channels) + patches = patches_from_image(img, p_size, p_overlap, p_max) + imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path))) + #if original_dataroot == taget_dataroot: + #del img_path + +''' +# -------------------------------------------- +# makedir +# -------------------------------------------- +''' + + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + + +def mkdirs(paths): + if isinstance(paths, str): + mkdir(paths) + else: + for path in paths: + mkdir(path) + + +def mkdir_and_rename(path): + if os.path.exists(path): + new_name = path + '_archived_' + get_timestamp() + print('Path already exists. Rename it to [{:s}]'.format(new_name)) + os.rename(path, new_name) + os.makedirs(path) + + +''' +# -------------------------------------------- +# read image from path +# opencv is fast, but read BGR numpy image +# -------------------------------------------- +''' + + +# -------------------------------------------- +# get uint8 image of size HxWxn_channles (RGB) +# -------------------------------------------- +def imread_uint(path, n_channels=3): + # input: path + # output: HxWx3(RGB or GGG), or HxWx1 (G) + if n_channels == 1: + img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE + img = np.expand_dims(img, axis=2) # HxWx1 + elif n_channels == 3: + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG + else: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB + return img + + +# -------------------------------------------- +# matlab's imwrite +# -------------------------------------------- +def imsave(img, img_path): + img = np.squeeze(img) + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + cv2.imwrite(img_path, img) + +def imwrite(img, img_path): + img = np.squeeze(img) + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + cv2.imwrite(img_path, img) + + + +# -------------------------------------------- +# get single image of size HxWxn_channles (BGR) +# -------------------------------------------- +def read_img(path): + # read image by cv2 + # return: Numpy float32, HWC, BGR, [0,1] + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE + img = img.astype(np.float32) / 255. + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + # some images have 4 channels + if img.shape[2] > 3: + img = img[:, :, :3] + return img + + +''' +# -------------------------------------------- +# image format conversion +# -------------------------------------------- +# numpy(single) <---> numpy(unit) +# numpy(single) <---> tensor +# numpy(unit) <---> tensor +# -------------------------------------------- +''' + + +# -------------------------------------------- +# numpy(single) [0, 1] <---> numpy(unit) +# -------------------------------------------- + + +def uint2single(img): + + return np.float32(img/255.) + + +def single2uint(img): + + return np.uint8((img.clip(0, 1)*255.).round()) + + +def uint162single(img): + + return np.float32(img/65535.) + + +def single2uint16(img): + + return np.uint16((img.clip(0, 1)*65535.).round()) + + +# -------------------------------------------- +# numpy(unit) (HxWxC or HxW) <---> tensor +# -------------------------------------------- + + +# convert uint to 4-dimensional torch tensor +def uint2tensor4(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0) + + +# convert uint to 3-dimensional torch tensor +def uint2tensor3(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.) + + +# convert 2/3/4-dimensional torch tensor to uint +def tensor2uint(img): + img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + return np.uint8((img*255.0).round()) + + +# -------------------------------------------- +# numpy(single) (HxWxC) <---> tensor +# -------------------------------------------- + + +# convert single (HxWxC) to 3-dimensional torch tensor +def single2tensor3(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float() + + +# convert single (HxWxC) to 4-dimensional torch tensor +def single2tensor4(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0) + + +# convert torch tensor to single +def tensor2single(img): + img = img.data.squeeze().float().cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + + return img + +# convert torch tensor to single +def tensor2single3(img): + img = img.data.squeeze().float().cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + elif img.ndim == 2: + img = np.expand_dims(img, axis=2) + return img + + +def single2tensor5(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0) + + +def single32tensor5(img): + return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0) + + +def single42tensor4(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float() + + +# from skimage.io import imread, imsave +def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): + ''' + Converts a torch Tensor into an image Numpy array of BGR channel order + Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order + Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) + ''' + tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp + tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] + n_dim = tensor.dim() + if n_dim == 4: + n_img = len(tensor) + img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 3: + img_np = tensor.numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 2: + img_np = tensor.numpy() + else: + raise TypeError( + 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) + if out_type == np.uint8: + img_np = (img_np * 255.0).round() + # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. + return img_np.astype(out_type) + + +''' +# -------------------------------------------- +# Augmentation, flipe and/or rotate +# -------------------------------------------- +# The following two are enough. +# (1) augmet_img: numpy image of WxHxC or WxH +# (2) augment_img_tensor4: tensor image 1xCxWxH +# -------------------------------------------- +''' + + +def augment_img(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + if mode == 0: + return img + elif mode == 1: + return np.flipud(np.rot90(img)) + elif mode == 2: + return np.flipud(img) + elif mode == 3: + return np.rot90(img, k=3) + elif mode == 4: + return np.flipud(np.rot90(img, k=2)) + elif mode == 5: + return np.rot90(img) + elif mode == 6: + return np.rot90(img, k=2) + elif mode == 7: + return np.flipud(np.rot90(img, k=3)) + + +def augment_img_tensor4(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + if mode == 0: + return img + elif mode == 1: + return img.rot90(1, [2, 3]).flip([2]) + elif mode == 2: + return img.flip([2]) + elif mode == 3: + return img.rot90(3, [2, 3]) + elif mode == 4: + return img.rot90(2, [2, 3]).flip([2]) + elif mode == 5: + return img.rot90(1, [2, 3]) + elif mode == 6: + return img.rot90(2, [2, 3]) + elif mode == 7: + return img.rot90(3, [2, 3]).flip([2]) + + +def augment_img_tensor(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + img_size = img.size() + img_np = img.data.cpu().numpy() + if len(img_size) == 3: + img_np = np.transpose(img_np, (1, 2, 0)) + elif len(img_size) == 4: + img_np = np.transpose(img_np, (2, 3, 1, 0)) + img_np = augment_img(img_np, mode=mode) + img_tensor = torch.from_numpy(np.ascontiguousarray(img_np)) + if len(img_size) == 3: + img_tensor = img_tensor.permute(2, 0, 1) + elif len(img_size) == 4: + img_tensor = img_tensor.permute(3, 2, 0, 1) + + return img_tensor.type_as(img) + + +def augment_img_np3(img, mode=0): + if mode == 0: + return img + elif mode == 1: + return img.transpose(1, 0, 2) + elif mode == 2: + return img[::-1, :, :] + elif mode == 3: + img = img[::-1, :, :] + img = img.transpose(1, 0, 2) + return img + elif mode == 4: + return img[:, ::-1, :] + elif mode == 5: + img = img[:, ::-1, :] + img = img.transpose(1, 0, 2) + return img + elif mode == 6: + img = img[:, ::-1, :] + img = img[::-1, :, :] + return img + elif mode == 7: + img = img[:, ::-1, :] + img = img[::-1, :, :] + img = img.transpose(1, 0, 2) + return img + + +def augment_imgs(img_list, hflip=True, rot=True): + # horizontal flip OR rotate + hflip = hflip and random.random() < 0.5 + vflip = rot and random.random() < 0.5 + rot90 = rot and random.random() < 0.5 + + def _augment(img): + if hflip: + img = img[:, ::-1, :] + if vflip: + img = img[::-1, :, :] + if rot90: + img = img.transpose(1, 0, 2) + return img + + return [_augment(img) for img in img_list] + + +''' +# -------------------------------------------- +# modcrop and shave +# -------------------------------------------- +''' + + +def modcrop(img_in, scale): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + if img.ndim == 2: + H, W = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r] + elif img.ndim == 3: + H, W, C = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r, :] + else: + raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim)) + return img + + +def shave(img_in, border=0): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + h, w = img.shape[:2] + img = img[border:h-border, border:w-border] + return img + + +''' +# -------------------------------------------- +# image processing process on numpy image +# channel_convert(in_c, tar_type, img_list): +# rgb2ycbcr(img, only_y=True): +# bgr2ycbcr(img, only_y=True): +# ycbcr2rgb(img): +# -------------------------------------------- +''' + + +def rgb2ycbcr(img, only_y=True): + '''same as matlab rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], + [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def ycbcr2rgb(img): + '''same as matlab ycbcr2rgb + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def bgr2ycbcr(img, only_y=True): + '''bgr version of rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def channel_convert(in_c, tar_type, img_list): + # conversion among BGR, gray and y + if in_c == 3 and tar_type == 'gray': # BGR to gray + gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list] + return [np.expand_dims(img, axis=2) for img in gray_list] + elif in_c == 3 and tar_type == 'y': # BGR to y + y_list = [bgr2ycbcr(img, only_y=True) for img in img_list] + return [np.expand_dims(img, axis=2) for img in y_list] + elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR + return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list] + else: + return img_list + + +''' +# -------------------------------------------- +# metric, PSNR and SSIM +# -------------------------------------------- +''' + + +# -------------------------------------------- +# PSNR +# -------------------------------------------- +def calculate_psnr(img1, img2, border=0): + # img1 and img2 have range [0, 255] + #img1 = img1.squeeze() + #img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h-border, border:w-border] + img2 = img2[border:h-border, border:w-border] + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + mse = np.mean((img1 - img2)**2) + if mse == 0: + return float('inf') + return 20 * math.log10(255.0 / math.sqrt(mse)) + + +# -------------------------------------------- +# SSIM +# -------------------------------------------- +def calculate_ssim(img1, img2, border=0): + '''calculate SSIM + the same outputs as MATLAB's + img1, img2: [0, 255] + ''' + #img1 = img1.squeeze() + #img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h-border, border:w-border] + img2 = img2[border:h-border, border:w-border] + + if img1.ndim == 2: + return ssim(img1, img2) + elif img1.ndim == 3: + if img1.shape[2] == 3: + ssims = [] + for i in range(3): + ssims.append(ssim(img1[:,:,i], img2[:,:,i])) + return np.array(ssims).mean() + elif img1.shape[2] == 1: + return ssim(np.squeeze(img1), np.squeeze(img2)) + else: + raise ValueError('Wrong input image dimensions.') + + +def ssim(img1, img2): + C1 = (0.01 * 255)**2 + C2 = (0.03 * 255)**2 + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * + (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean() + + +''' +# -------------------------------------------- +# matlab's bicubic imresize (numpy and torch) [0, 1] +# -------------------------------------------- +''' + + +# matlab 'imresize' function, now only support 'bicubic' +def cubic(x): + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \ + (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx)) + + +def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): + if (scale < 1) and (antialiasing): + # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5+scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + P = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view( + 1, P).expand(out_length, P) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices + # apply cubic kernel + if (scale < 1) and (antialiasing): + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, P) + + # If a column in weights is all zero, get rid of it. only consider the first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, P - 2) + weights = weights.narrow(1, 1, P - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, P - 2) + weights = weights.narrow(1, 0, P - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +# -------------------------------------------- +# imresize for tensor image [0, 1] +# -------------------------------------------- +def imresize(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: pytorch tensor, CHW or HW [0,1] + # output: CHW or HW [0,1] w/o round + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(0) + in_C, in_H, in_W = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) + img_aug.narrow(1, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:, :sym_len_Hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_He:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_C, out_H, in_W) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We) + out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_Ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_We:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_C, out_H, out_W) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + return out_2 + + +# -------------------------------------------- +# imresize for numpy image [0, 1] +# -------------------------------------------- +def imresize_np(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: Numpy, HWC or HW [0,1] + # output: HWC or HW [0,1] w/o round + img = torch.from_numpy(img) + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(2) + + in_H, in_W, in_C = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) + img_aug.narrow(0, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:sym_len_Hs, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[-sym_len_He:, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(out_H, in_W, in_C) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C) + out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :sym_len_Ws, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, -sym_len_We:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(out_H, out_W, in_C) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + + return out_2.numpy() + + +if __name__ == '__main__': + print('---') +# img = imread_uint('test.bmp', 3) +# img = uint2single(img) +# img_bicubic = imresize_np(img, 1/4) \ No newline at end of file diff --git a/examples/images/diffusion/ldm/modules/midas/__init__.py b/examples/images/diffusion/ldm/modules/midas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/images/diffusion/ldm/modules/midas/api.py b/examples/images/diffusion/ldm/modules/midas/api.py new file mode 100644 index 0000000000000000000000000000000000000000..b58ebbffd942a2fc22264f0ab47e400c26b9f41c --- /dev/null +++ b/examples/images/diffusion/ldm/modules/midas/api.py @@ -0,0 +1,170 @@ +# based on https://github.com/isl-org/MiDaS + +import cv2 +import torch +import torch.nn as nn +from torchvision.transforms import Compose + +from ldm.modules.midas.midas.dpt_depth import DPTDepthModel +from ldm.modules.midas.midas.midas_net import MidasNet +from ldm.modules.midas.midas.midas_net_custom import MidasNet_small +from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet + + +ISL_PATHS = { + "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt", + "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt", + "midas_v21": "", + "midas_v21_small": "", +} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def load_midas_transform(model_type): + # https://github.com/isl-org/MiDaS/blob/master/run.py + # load transform only + if model_type == "dpt_large": # DPT-Large + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid": # DPT-Hybrid + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21": + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + elif model_type == "midas_v21_small": + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + else: + assert False, f"model_type '{model_type}' not implemented, use: --model_type large" + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + return transform + + +def load_model(model_type): + # https://github.com/isl-org/MiDaS/blob/master/run.py + # load network + model_path = ISL_PATHS[model_type] + if model_type == "dpt_large": # DPT-Large + model = DPTDepthModel( + path=model_path, + backbone="vitl16_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid": # DPT-Hybrid + model = DPTDepthModel( + path=model_path, + backbone="vitb_rn50_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21": + model = MidasNet(model_path, non_negative=True) + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + elif model_type == "midas_v21_small": + model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, + non_negative=True, blocks={'expand': True}) + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + else: + print(f"model_type '{model_type}' not implemented, use: --model_type large") + assert False + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + return model.eval(), transform + + +class MiDaSInference(nn.Module): + MODEL_TYPES_TORCH_HUB = [ + "DPT_Large", + "DPT_Hybrid", + "MiDaS_small" + ] + MODEL_TYPES_ISL = [ + "dpt_large", + "dpt_hybrid", + "midas_v21", + "midas_v21_small", + ] + + def __init__(self, model_type): + super().__init__() + assert (model_type in self.MODEL_TYPES_ISL) + model, _ = load_model(model_type) + self.model = model + self.model.train = disabled_train + + def forward(self, x): + # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array + # NOTE: we expect that the correct transform has been called during dataloading. + with torch.no_grad(): + prediction = self.model(x) + prediction = torch.nn.functional.interpolate( + prediction.unsqueeze(1), + size=x.shape[2:], + mode="bicubic", + align_corners=False, + ) + assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3]) + return prediction + diff --git a/examples/images/diffusion/ldm/modules/midas/midas/__init__.py b/examples/images/diffusion/ldm/modules/midas/midas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/images/diffusion/ldm/modules/midas/midas/base_model.py b/examples/images/diffusion/ldm/modules/midas/midas/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5cf430239b47ec5ec07531263f26f5c24a2311cd --- /dev/null +++ b/examples/images/diffusion/ldm/modules/midas/midas/base_model.py @@ -0,0 +1,16 @@ +import torch + + +class BaseModel(torch.nn.Module): + def load(self, path): + """Load model from file. + + Args: + path (str): file path + """ + parameters = torch.load(path, map_location=torch.device('cpu')) + + if "optimizer" in parameters: + parameters = parameters["model"] + + self.load_state_dict(parameters) diff --git a/examples/images/diffusion/ldm/modules/midas/midas/blocks.py b/examples/images/diffusion/ldm/modules/midas/midas/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..2145d18fa98060a618536d9a64fe6589e9be4f78 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/midas/midas/blocks.py @@ -0,0 +1,342 @@ +import torch +import torch.nn as nn + +from .vit import ( + _make_pretrained_vitb_rn50_384, + _make_pretrained_vitl16_384, + _make_pretrained_vitb16_384, + forward_vit, +) + +def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): + if backbone == "vitl16_384": + pretrained = _make_pretrained_vitl16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # ViT-L/16 - 85.0% Top1 (backbone) + elif backbone == "vitb_rn50_384": + pretrained = _make_pretrained_vitb_rn50_384( + use_pretrained, + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) + scratch = _make_scratch( + [256, 512, 768, 768], features, groups=groups, expand=expand + ) # ViT-H/16 - 85.0% Top1 (backbone) + elif backbone == "vitb16_384": + pretrained = _make_pretrained_vitb16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # ViT-B/16 - 84.6% Top1 (backbone) + elif backbone == "resnext101_wsl": + pretrained = _make_pretrained_resnext101_wsl(use_pretrained) + scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 + elif backbone == "efficientnet_lite3": + pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) + scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 + else: + print(f"Backbone '{backbone}' not implemented") + assert False + + return pretrained, scratch + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + out_shape4 = out_shape + if expand==True: + out_shape1 = out_shape + out_shape2 = out_shape*2 + out_shape3 = out_shape*4 + out_shape4 = out_shape*8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + + return scratch + + +def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): + efficientnet = torch.hub.load( + "rwightman/gen-efficientnet-pytorch", + "tf_efficientnet_lite3", + pretrained=use_pretrained, + exportable=exportable + ) + return _make_efficientnet_backbone(efficientnet) + + +def _make_efficientnet_backbone(effnet): + pretrained = nn.Module() + + pretrained.layer1 = nn.Sequential( + effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] + ) + pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) + pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) + pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) + + return pretrained + + +def _make_resnet_backbone(resnet): + pretrained = nn.Module() + pretrained.layer1 = nn.Sequential( + resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 + ) + + pretrained.layer2 = resnet.layer2 + pretrained.layer3 = resnet.layer3 + pretrained.layer4 = resnet.layer4 + + return pretrained + + +def _make_pretrained_resnext101_wsl(use_pretrained): + resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") + return _make_resnet_backbone(resnet) + + + +class Interpolate(nn.Module): + """Interpolation module. + """ + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners + ) + + return x + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + + return out + x + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.resConfUnit1 = ResidualConvUnit(features) + self.resConfUnit2 = ResidualConvUnit(features) + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + output += self.resConfUnit1(xs[1]) + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=True + ) + + return output + + + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + if self.bn==True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn==True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn==True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + # return out + x + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand==True: + out_features = features//2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output + diff --git a/examples/images/diffusion/ldm/modules/midas/midas/dpt_depth.py b/examples/images/diffusion/ldm/modules/midas/midas/dpt_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..4e9aab5d2767dffea39da5b3f30e2798688216f1 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/midas/midas/dpt_depth.py @@ -0,0 +1,109 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base_model import BaseModel +from .blocks import ( + FeatureFusionBlock, + FeatureFusionBlock_custom, + Interpolate, + _make_encoder, + forward_vit, +) + + +def _make_fusion_block(features, use_bn): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + ) + + +class DPT(BaseModel): + def __init__( + self, + head, + features=256, + backbone="vitb_rn50_384", + readout="project", + channels_last=False, + use_bn=False, + ): + + super(DPT, self).__init__() + + self.channels_last = channels_last + + hooks = { + "vitb_rn50_384": [0, 1, 8, 11], + "vitb16_384": [2, 5, 8, 11], + "vitl16_384": [5, 11, 17, 23], + } + + # Instantiate backbone and reassemble blocks + self.pretrained, self.scratch = _make_encoder( + backbone, + features, + False, # Set to true of you want to train from scratch, uses ImageNet weights + groups=1, + expand=False, + exportable=False, + hooks=hooks[backbone], + use_readout=readout, + ) + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + self.scratch.output_conv = head + + + def forward(self, x): + if self.channels_last == True: + x.contiguous(memory_format=torch.channels_last) + + layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return out + + +class DPTDepthModel(DPT): + def __init__(self, path=None, non_negative=True, **kwargs): + features = kwargs["features"] if "features" in kwargs else 256 + + head = nn.Sequential( + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + super().__init__(head, **kwargs) + + if path is not None: + self.load(path) + + def forward(self, x): + return super().forward(x).squeeze(dim=1) + diff --git a/examples/images/diffusion/ldm/modules/midas/midas/midas_net.py b/examples/images/diffusion/ldm/modules/midas/midas/midas_net.py new file mode 100644 index 0000000000000000000000000000000000000000..8a954977800b0a0f48807e80fa63041910e33c1f --- /dev/null +++ b/examples/images/diffusion/ldm/modules/midas/midas/midas_net.py @@ -0,0 +1,76 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, Interpolate, _make_encoder + + +class MidasNet(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=256, non_negative=True): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet, self).__init__() + + use_pretrained = False if path is None else True + + self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) + + self.scratch.refinenet4 = FeatureFusionBlock(features) + self.scratch.refinenet3 = FeatureFusionBlock(features) + self.scratch.refinenet2 = FeatureFusionBlock(features) + self.scratch.refinenet1 = FeatureFusionBlock(features) + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + ) + + if path: + self.load(path) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) diff --git a/examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py b/examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py new file mode 100644 index 0000000000000000000000000000000000000000..50e4acb5e53d5fabefe3dde16ab49c33c2b7797c --- /dev/null +++ b/examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py @@ -0,0 +1,128 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder + + +class MidasNet_small(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, + blocks={'expand': True}): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet_small, self).__init__() + + use_pretrained = False if path else True + + self.channels_last = channels_last + self.blocks = blocks + self.backbone = backbone + + self.groups = 1 + + features1=features + features2=features + features3=features + features4=features + self.expand = False + if "expand" in self.blocks and self.blocks['expand'] == True: + self.expand = True + features1=features + features2=features*2 + features3=features*4 + features4=features*8 + + self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) + + self.scratch.activation = nn.ReLU(False) + + self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) + + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), + self.scratch.activation, + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + if path: + self.load(path) + + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + if self.channels_last==True: + print("self.channels_last = ", self.channels_last) + x.contiguous(memory_format=torch.channels_last) + + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) + + + +def fuse_model(m): + prev_previous_type = nn.Identity() + prev_previous_name = '' + previous_type = nn.Identity() + previous_name = '' + for name, module in m.named_modules(): + if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: + # print("FUSED ", prev_previous_name, previous_name, name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) + elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: + # print("FUSED ", prev_previous_name, previous_name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) + # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: + # print("FUSED ", previous_name, name) + # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) + + prev_previous_type = previous_type + prev_previous_name = previous_name + previous_type = type(module) + previous_name = name \ No newline at end of file diff --git a/examples/images/diffusion/ldm/modules/midas/midas/transforms.py b/examples/images/diffusion/ldm/modules/midas/midas/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..350cbc11662633ad7f8968eb10be2e7de6e384e9 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/midas/midas/transforms.py @@ -0,0 +1,234 @@ +import numpy as np +import cv2 +import math + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST + ) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "disparity" in sample: + disparity = sample["disparity"].astype(np.float32) + sample["disparity"] = np.ascontiguousarray(disparity) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + return sample diff --git a/examples/images/diffusion/ldm/modules/midas/midas/vit.py b/examples/images/diffusion/ldm/modules/midas/midas/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..ea46b1be88b261b0dec04f3da0256f5f66f88a74 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/midas/midas/vit.py @@ -0,0 +1,491 @@ +import torch +import torch.nn as nn +import timm +import types +import math +import torch.nn.functional as F + + +class Slice(nn.Module): + def __init__(self, start_index=1): + super(Slice, self).__init__() + self.start_index = start_index + + def forward(self, x): + return x[:, self.start_index :] + + +class AddReadout(nn.Module): + def __init__(self, start_index=1): + super(AddReadout, self).__init__() + self.start_index = start_index + + def forward(self, x): + if self.start_index == 2: + readout = (x[:, 0] + x[:, 1]) / 2 + else: + readout = x[:, 0] + return x[:, self.start_index :] + readout.unsqueeze(1) + + +class ProjectReadout(nn.Module): + def __init__(self, in_features, start_index=1): + super(ProjectReadout, self).__init__() + self.start_index = start_index + + self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) + + def forward(self, x): + readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) + features = torch.cat((x[:, self.start_index :], readout), -1) + + return self.project(features) + + +class Transpose(nn.Module): + def __init__(self, dim0, dim1): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + x = x.transpose(self.dim0, self.dim1) + return x + + +def forward_vit(pretrained, x): + b, c, h, w = x.shape + + glob = pretrained.model.forward_flex(x) + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + layer_4 = pretrained.activations["4"] + + layer_1 = pretrained.act_postprocess1[0:2](layer_1) + layer_2 = pretrained.act_postprocess2[0:2](layer_2) + layer_3 = pretrained.act_postprocess3[0:2](layer_3) + layer_4 = pretrained.act_postprocess4[0:2](layer_4) + + unflatten = nn.Sequential( + nn.Unflatten( + 2, + torch.Size( + [ + h // pretrained.model.patch_size[1], + w // pretrained.model.patch_size[0], + ] + ), + ) + ) + + if layer_1.ndim == 3: + layer_1 = unflatten(layer_1) + if layer_2.ndim == 3: + layer_2 = unflatten(layer_2) + if layer_3.ndim == 3: + layer_3 = unflatten(layer_3) + if layer_4.ndim == 3: + layer_4 = unflatten(layer_4) + + layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) + layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) + layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) + layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def _resize_pos_embed(self, posemb, gs_h, gs_w): + posemb_tok, posemb_grid = ( + posemb[:, : self.start_index], + posemb[0, self.start_index :], + ) + + gs_old = int(math.sqrt(len(posemb_grid))) + + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + +def forward_flex(self, x): + b, c, h, w = x.shape + + pos_embed = self._resize_pos_embed( + self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] + ) + + B = x.shape[0] + + if hasattr(self.patch_embed, "backbone"): + x = self.patch_embed.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + + x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) + + if getattr(self, "dist_token", None) is not None: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + else: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + return x + + +activations = {} + + +def get_activation(name): + def hook(model, input, output): + activations[name] = output + + return hook + + +def get_readout_oper(vit_features, features, use_readout, start_index=1): + if use_readout == "ignore": + readout_oper = [Slice(start_index)] * len(features) + elif use_readout == "add": + readout_oper = [AddReadout(start_index)] * len(features) + elif use_readout == "project": + readout_oper = [ + ProjectReadout(vit_features, start_index) for out_feat in features + ] + else: + assert ( + False + ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" + + return readout_oper + + +def _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[2, 5, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + # 32, 48, 136, 384 + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model( + "vit_deit_base_distilled_patch16_384", pretrained=pretrained + ) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + hooks=hooks, + use_readout=use_readout, + start_index=2, + ) + + +def _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=[0, 1, 8, 11], + vit_features=768, + use_vit_only=False, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + + if use_vit_only == True: + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + else: + pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( + get_activation("1") + ) + pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( + get_activation("2") + ) + + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + if use_vit_only == True: + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + else: + pretrained.act_postprocess1 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + pretrained.act_postprocess2 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitb_rn50_384( + pretrained, use_readout="ignore", hooks=None, use_vit_only=False +): + model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) + + hooks = [0, 1, 8, 11] if hooks == None else hooks + return _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) diff --git a/examples/images/diffusion/ldm/modules/midas/utils.py b/examples/images/diffusion/ldm/modules/midas/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9a9d3b5b66370fa98da9e067ba53ead848ea9a59 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/midas/utils.py @@ -0,0 +1,189 @@ +"""Utils for monoDepth.""" +import sys +import re +import numpy as np +import cv2 +import torch + + +def read_pfm(path): + """Read pfm file. + + Args: + path (str): path to file + + Returns: + tuple: (data, scale) + """ + with open(path, "rb") as file: + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header.decode("ascii") == "PF": + color = True + elif header.decode("ascii") == "Pf": + color = False + else: + raise Exception("Not a PFM file: " + path) + + dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) + if dim_match: + width, height = list(map(int, dim_match.groups())) + else: + raise Exception("Malformed PFM header.") + + scale = float(file.readline().decode("ascii").rstrip()) + if scale < 0: + # little-endian + endian = "<" + scale = -scale + else: + # big-endian + endian = ">" + + data = np.fromfile(file, endian + "f") + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + + return data, scale + + +def write_pfm(path, image, scale=1): + """Write pfm file. + + Args: + path (str): pathto file + image (array): data + scale (int, optional): Scale. Defaults to 1. + """ + + with open(path, "wb") as file: + color = None + + if image.dtype.name != "float32": + raise Exception("Image dtype must be float32.") + + image = np.flipud(image) + + if len(image.shape) == 3 and image.shape[2] == 3: # color image + color = True + elif ( + len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 + ): # greyscale + color = False + else: + raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") + + file.write("PF\n" if color else "Pf\n".encode()) + file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) + + endian = image.dtype.byteorder + + if endian == "<" or endian == "=" and sys.byteorder == "little": + scale = -scale + + file.write("%f\n".encode() % scale) + + image.tofile(file) + + +def read_image(path): + """Read image and output RGB image (0-1). + + Args: + path (str): path to file + + Returns: + array: RGB image (0-1) + """ + img = cv2.imread(path) + + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 + + return img + + +def resize_image(img): + """Resize image and make it fit for network. + + Args: + img (array): image + + Returns: + tensor: data ready for network + """ + height_orig = img.shape[0] + width_orig = img.shape[1] + + if width_orig > height_orig: + scale = width_orig / 384 + else: + scale = height_orig / 384 + + height = (np.ceil(height_orig / scale / 32) * 32).astype(int) + width = (np.ceil(width_orig / scale / 32) * 32).astype(int) + + img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) + + img_resized = ( + torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() + ) + img_resized = img_resized.unsqueeze(0) + + return img_resized + + +def resize_depth(depth, width, height): + """Resize depth map and bring to CPU (numpy). + + Args: + depth (tensor): depth + width (int): image width + height (int): image height + + Returns: + array: processed depth + """ + depth = torch.squeeze(depth[0, :, :, :]).to("cpu") + + depth_resized = cv2.resize( + depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC + ) + + return depth_resized + +def write_depth(path, depth, bits=1): + """Write depth map to pfm and png file. + + Args: + path (str): filepath without extension + depth (array): depth + """ + write_pfm(path + ".pfm", depth.astype(np.float32)) + + depth_min = depth.min() + depth_max = depth.max() + + max_val = (2**(8*bits))-1 + + if depth_max - depth_min > np.finfo("float").eps: + out = max_val * (depth - depth_min) / (depth_max - depth_min) + else: + out = np.zeros(depth.shape, dtype=depth.type) + + if bits == 1: + cv2.imwrite(path + ".png", out.astype("uint8")) + elif bits == 2: + cv2.imwrite(path + ".png", out.astype("uint16")) + + return diff --git a/examples/images/diffusion/ldm/util.py b/examples/images/diffusion/ldm/util.py new file mode 100644 index 0000000000000000000000000000000000000000..8c09ca1c72f7ceb3f9d7f9546aae5561baf62b13 --- /dev/null +++ b/examples/images/diffusion/ldm/util.py @@ -0,0 +1,197 @@ +import importlib + +import torch +from torch import optim +import numpy as np + +from inspect import isfunction +from PIL import Image, ImageDraw, ImageFont + + +def log_txt_as_img(wh, xc, size=10): + # wh a tuple of (width, height) + # xc a list of captions to plot + b = len(xc) + txts = list() + for bi in range(b): + txt = Image.new("RGB", wh, color="white") + draw = ImageDraw.Draw(txt) + font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) + nc = int(40 * (wh[0] / 256)) + lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) + + try: + draw.text((0, 0), lines, fill="black", font=font) + except UnicodeEncodeError: + print("Cant encode string for logging. Skipping.") + + txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 + txts.append(txt) + txts = np.stack(txts) + txts = torch.tensor(txts) + return txts + + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x,torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def mean_flat(tensor): + """ + https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +class AdamWwithEMAandWings(optim.Optimizer): + # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 + def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using + weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code + ema_power=1., param_names=()): + """AdamW that saves EMA versions of the parameters.""" + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0.0 <= ema_decay <= 1.0: + raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, + ema_power=ema_power, param_names=param_names) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + ema_params_with_grad = [] + state_sums = [] + max_exp_avg_sqs = [] + state_steps = [] + amsgrad = group['amsgrad'] + beta1, beta2 = group['betas'] + ema_decay = group['ema_decay'] + ema_power = group['ema_power'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('AdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of parameter values + state['param_exp_avg'] = p.detach().float().clone() + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + ema_params_with_grad.append(state['param_exp_avg']) + + if amsgrad: + max_exp_avg_sqs.append(state['max_exp_avg_sq']) + + # update the steps for each param group update + state['step'] += 1 + # record the step after step update + state_steps.append(state['step']) + + optim._functional.adamw(params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + maximize=False) + + cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) + for param, ema_param in zip(params_with_grad, ema_params_with_grad): + ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) + + return loss \ No newline at end of file diff --git a/examples/images/diffusion/main.py b/examples/images/diffusion/main.py new file mode 100644 index 0000000000000000000000000000000000000000..713029fc677d818b9d63e3d2c6f15a5592d2d3aa --- /dev/null +++ b/examples/images/diffusion/main.py @@ -0,0 +1,870 @@ +import argparse +import csv +import datetime +import glob +import importlib +import os +import sys +import time + +import numpy as np +import torch +import torchvision +import lightning.pytorch as pl + + +from functools import partial + +from omegaconf import OmegaConf +from packaging import version +from PIL import Image +from prefetch_generator import BackgroundGenerator +from torch.utils.data import DataLoader, Dataset, Subset, random_split +from ldm.models.diffusion.ddpm import LatentDiffusion + +from lightning.pytorch import seed_everything +from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint +from lightning.pytorch.trainer import Trainer +from lightning.pytorch.utilities import rank_zero_info, rank_zero_only +from lightning.pytorch.loggers import WandbLogger, TensorBoardLogger +from lightning.pytorch.strategies import ColossalAIStrategy,DDPStrategy +LIGHTNING_PACK_NAME = "lightning.pytorch." + +from ldm.data.base import Txt2ImgIterableBaseDataset +from ldm.util import instantiate_from_config + +# from ldm.modules.attention import enable_flash_attentions + + +class DataLoaderX(DataLoader): +# A custom data loader class that inherits from DataLoader + def __iter__(self): + # Overriding the __iter__ method of DataLoader to return a BackgroundGenerator + #This is to enable data loading in the background to improve training performance + return BackgroundGenerator(super().__iter__()) + + +def get_parser(**parser_kwargs): + #A function to create an ArgumentParser object and add arguments to it + + def str2bool(v): + # A helper function to parse boolean values from command line arguments + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("Boolean value expected.") + # Create an ArgumentParser object with specifies kwargs + parser = argparse.ArgumentParser(**parser_kwargs) + + # Add various command line arguments with their default values and descriptions + parser.add_argument( + "-n", + "--name", + type=str, + const=True, + default="", + nargs="?", + help="postfix for logdir", + ) + parser.add_argument( + "-r", + "--resume", + type=str, + const=True, + default="", + nargs="?", + help="resume from logdir or checkpoint in logdir", + ) + parser.add_argument( + "-b", + "--base", + nargs="*", + metavar="base_config.yaml", + help="paths to base configs. Loaded from left-to-right. " + "Parameters can be overwritten or added with command-line options of the form `--key value`.", + default=list(), + ) + parser.add_argument( + "-t", + "--train", + type=str2bool, + const=True, + default=False, + nargs="?", + help="train", + ) + parser.add_argument( + "--no-test", + type=str2bool, + const=True, + default=False, + nargs="?", + help="disable test", + ) + parser.add_argument( + "-p", + "--project", + help="name of new or path to existing project", + ) + parser.add_argument( + "-c", + "--ckpt", + type=str, + const=True, + default="", + nargs="?", + help="load pretrained checkpoint from stable AI", + ) + parser.add_argument( + "-d", + "--debug", + type=str2bool, + nargs="?", + const=True, + default=False, + help="enable post-mortem debugging", + ) + parser.add_argument( + "-s", + "--seed", + type=int, + default=23, + help="seed for seed_everything", + ) + parser.add_argument( + "-f", + "--postfix", + type=str, + default="", + help="post-postfix for default name", + ) + parser.add_argument( + "-l", + "--logdir", + type=str, + default="logs", + help="directory for logging dat shit", + ) + parser.add_argument( + "--scale_lr", + type=str2bool, + nargs="?", + const=True, + default=True, + help="scale base-lr by ngpu * batch_size * n_accumulate", + ) + + return parser + +# A function that returns the non-default arguments between two objects +def nondefault_trainer_args(opt): + # create an argument parser + parser = argparse.ArgumentParser() + # add pytorch lightning trainer default arguments + parser = Trainer.add_argparse_args(parser) + # parse the empty arguments to obtain the default values + args = parser.parse_args([]) + # return all non-default arguments + return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k)) + +# A dataset wrapper class to create a pytorch dataset from an arbitrary object +class WrappedDataset(Dataset): + """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset""" + + def __init__(self, dataset): + self.data = dataset + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx] + +# A function to initialize worker processes +def worker_init_fn(_): + worker_info = torch.utils.data.get_worker_info() + + dataset = worker_info.dataset + worker_id = worker_info.id + + if isinstance(dataset, Txt2ImgIterableBaseDataset): + #divide the dataset into equal parts for each worker + split_size = dataset.num_records // worker_info.num_workers + #set the sample IDs for the current worker + # reset num_records to the true number to retain reliable length information + dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size] + # set the seed for the current worker + current_id = np.random.choice(len(np.random.get_state()[1]), 1) + return np.random.seed(np.random.get_state()[1][current_id] + worker_id) + else: + return np.random.seed(np.random.get_state()[1][0] + worker_id) + +#Provide functionality for creating data loaders based on provided dataset configurations +class DataModuleFromConfig(pl.LightningDataModule): + + def __init__(self, + batch_size, + train=None, + validation=None, + test=None, + predict=None, + wrap=False, + num_workers=None, + shuffle_test_loader=False, + use_worker_init_fn=False, + shuffle_val_dataloader=False): + super().__init__() + # Set data module attributes + self.batch_size = batch_size + self.dataset_configs = dict() + self.num_workers = num_workers if num_workers is not None else batch_size * 2 + self.use_worker_init_fn = use_worker_init_fn + # If a dataset is passed, add it to the dataset configs and create a corresponding dataloader method + if train is not None: + self.dataset_configs["train"] = train + self.train_dataloader = self._train_dataloader + if validation is not None: + self.dataset_configs["validation"] = validation + self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader) + if test is not None: + self.dataset_configs["test"] = test + self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader) + if predict is not None: + self.dataset_configs["predict"] = predict + self.predict_dataloader = self._predict_dataloader + self.wrap = wrap + + def prepare_data(self): + # Instantiate datasets + for data_cfg in self.dataset_configs.values(): + instantiate_from_config(data_cfg) + + def setup(self, stage=None): + # Instantiate datasets from the dataset configs + self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs) + + # If wrap is true, create a WrappedDataset for each dataset + if self.wrap: + for k in self.datasets: + self.datasets[k] = WrappedDataset(self.datasets[k]) + + def _train_dataloader(self): + #Check if the train dataset is iterable + is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) + #Set the worker initialization function of the dataset is iterable or use_worker_init_fn is True + if is_iterable_dataset or self.use_worker_init_fn: + init_fn = worker_init_fn + else: + init_fn = None + # Return a DataLoaderX object for the train dataset + return DataLoaderX(self.datasets["train"], + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False if is_iterable_dataset else True, + worker_init_fn=init_fn) + + def _val_dataloader(self, shuffle=False): + #Check if the validation dataset is iterable + if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: + init_fn = worker_init_fn + else: + init_fn = None + # Return a DataLoaderX object for the validation dataset + return DataLoaderX(self.datasets["validation"], + batch_size=self.batch_size, + num_workers=self.num_workers, + worker_init_fn=init_fn, + shuffle=shuffle) + + def _test_dataloader(self, shuffle=False): + # Check if the test dataset is iterable + is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) + # Set the worker initialization function if the dataset is iterable or use_worker_init_fn is True + if is_iterable_dataset or self.use_worker_init_fn: + init_fn = worker_init_fn + else: + init_fn = None + + # do not shuffle dataloader for iterable dataset + shuffle = shuffle and (not is_iterable_dataset) + + return DataLoaderX(self.datasets["test"], + batch_size=self.batch_size, + num_workers=self.num_workers, + worker_init_fn=init_fn, + shuffle=shuffle) + + def _predict_dataloader(self, shuffle=False): + if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: + init_fn = worker_init_fn + else: + init_fn = None + return DataLoaderX(self.datasets["predict"], + batch_size=self.batch_size, + num_workers=self.num_workers, + worker_init_fn=init_fn) + + +class SetupCallback(Callback): + # Initialize the callback with the necessary parameters + + def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config): + super().__init__() + self.resume = resume + self.now = now + self.logdir = logdir + self.ckptdir = ckptdir + self.cfgdir = cfgdir + self.config = config + self.lightning_config = lightning_config + + # Save a checkpoint if training is interrupted with keyboard interrupt + def on_keyboard_interrupt(self, trainer, pl_module): + if trainer.global_rank == 0: + print("Summoning checkpoint.") + ckpt_path = os.path.join(self.ckptdir, "last.ckpt") + trainer.save_checkpoint(ckpt_path) + + # Create necessary directories and save configuration files before training starts + # def on_pretrain_routine_start(self, trainer, pl_module): + def on_fit_start(self, trainer, pl_module): + if trainer.global_rank == 0: + # Create logdirs and save configs + os.makedirs(self.logdir, exist_ok=True) + os.makedirs(self.ckptdir, exist_ok=True) + os.makedirs(self.cfgdir, exist_ok=True) + + #Create trainstep checkpoint directory if necessary + if "callbacks" in self.lightning_config: + if 'metrics_over_trainsteps_checkpoint' in self.lightning_config['callbacks']: + os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True) + print("Project config") + print(OmegaConf.to_yaml(self.config)) + OmegaConf.save(self.config, os.path.join(self.cfgdir, "{}-project.yaml".format(self.now))) + + # Save project config and lightning config as YAML files + print("Lightning config") + print(OmegaConf.to_yaml(self.lightning_config)) + OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}), + os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now))) + + # Remove log directory if resuming training and directory already exists + else: + # ModelCheckpoint callback created log directory --- remove it + if not self.resume and os.path.exists(self.logdir): + dst, name = os.path.split(self.logdir) + dst = os.path.join(dst, "child_runs", name) + os.makedirs(os.path.split(dst)[0], exist_ok=True) + try: + os.rename(self.logdir, dst) + except FileNotFoundError: + pass + + # def on_fit_end(self, trainer, pl_module): + # if trainer.global_rank == 0: + # ckpt_path = os.path.join(self.ckptdir, "last.ckpt") + # rank_zero_info(f"Saving final checkpoint in {ckpt_path}.") + # trainer.save_checkpoint(ckpt_path) + + +# PyTorch Lightning callback for logging images during training and validation of a deep learning model +class ImageLogger(Callback): + + def __init__(self, + batch_frequency, # Frequency of batches on which to log images + max_images, # Maximum number of images to log + clamp=True, # Whether to clamp pixel values to [-1,1] + increase_log_steps=True, # Whether to increase frequency of log steps exponentially + rescale=True, # Whether to rescale pixel values to [0,1] + disabled=False, # Whether to disable logging + log_on_batch_idx=False, # Whether to log on batch index instead of global step + log_first_step=False, # Whether to log on the first step + log_images_kwargs=None): # Additional keyword arguments to pass to log_images method + super().__init__() + self.rescale = rescale + self.batch_freq = batch_frequency + self.max_images = max_images + self.logger_log_images = { + # Dictionary of logger classes and their corresponding logging methods + pl.loggers.CSVLogger: self._testtube, + } + # Create a list of exponentially increasing log steps, starting from 1 and ending at batch_frequency + self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)] + if not increase_log_steps: + self.log_steps = [self.batch_freq] + self.clamp = clamp + self.disabled = disabled + self.log_on_batch_idx = log_on_batch_idx + self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} + self.log_first_step = log_first_step + + @rank_zero_only # Ensure that only the first process in distributed training executes this method + def _testtube(self, # The PyTorch Lightning module + pl_module, # A dictionary of images to log. + images, # + batch_idx, # The batch index. + split # The split (train/val) on which to log the images + ): + # Method for logging images using test-tube logger + for k in images: + grid = torchvision.utils.make_grid(images[k]) + grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w + + tag = f"{split}/{k}" + # Add image grid to logger's experiment + pl_module.logger.experiment.add_image(tag, grid, global_step=pl_module.global_step) + + @rank_zero_only + def log_local(self, + save_dir, + split, # The split (train/val) on which to log the images + images, # A dictionary of images to log + global_step, # The global step + current_epoch, # The current epoch. + batch_idx + ): + # Method for saving image grids to local file system + root = os.path.join(save_dir, "images", split) + for k in images: + grid = torchvision.utils.make_grid(images[k], nrow=4) + if self.rescale: + grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w + grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) + grid = grid.numpy() + grid = (grid * 255).astype(np.uint8) + filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx) + path = os.path.join(root, filename) + os.makedirs(os.path.split(path)[0], exist_ok=True) + # Save image grid as PNG file + Image.fromarray(grid).save(path) + + def log_img(self, pl_module, batch, batch_idx, split="train"): + #Function for logging images to both the logger and local file system. + check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step + # check if it's time to log an image batch + if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0 + hasattr(pl_module, "log_images") and callable(pl_module.log_images) and self.max_images > 0): + # Get logger type and check if training mode is on + logger = type(pl_module.logger) + + is_train = pl_module.training + if is_train: + pl_module.eval() + + with torch.no_grad(): + # Get images from log_images method of the pl_module + images = pl_module.log_images(batch, split=split, **self.log_images_kwargs) + + # Clip images if specified and convert to CPU tensor + for k in images: + N = min(images[k].shape[0], self.max_images) + images[k] = images[k][:N] + if isinstance(images[k], torch.Tensor): + images[k] = images[k].detach().cpu() + if self.clamp: + images[k] = torch.clamp(images[k], -1., 1.) + + # Log images locally to file system + self.log_local(pl_module.logger.save_dir, split, images, pl_module.global_step, pl_module.current_epoch, + batch_idx) + + # log the images using the logger + logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None) + logger_log_images(pl_module, images, pl_module.global_step, split) + + # switch back to training mode if necessary + if is_train: + pl_module.train() + + # The function checks if it's time to log an image batch + def check_frequency(self, check_idx): + if ((check_idx % self.batch_freq) == 0 or + (check_idx in self.log_steps)) and (check_idx > 0 or self.log_first_step): + try: + self.log_steps.pop(0) + except IndexError as e: + print(e) + pass + return True + return False + + # Log images on train batch end if logging is not disabled + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + # if not self.disabled and (pl_module.global_step > 0 or self.log_first_step): + # self.log_img(pl_module, batch, batch_idx, split="train") + pass + + # Log images on validation batch end if logging is not disabled and in validation mode + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + if not self.disabled and pl_module.global_step > 0: + self.log_img(pl_module, batch, batch_idx, split="val") + # log gradients during calibration if necessary + if hasattr(pl_module, 'calibrate_grad_norm'): + if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0: + self.log_gradients(trainer, pl_module, batch_idx=batch_idx) + + +class CUDACallback(Callback): + # see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py + + def on_train_start(self, trainer, pl_module): + rank_zero_info("Training is starting") + + #the method is called at the end of each training epoch + def on_train_end(self, trainer, pl_module): + rank_zero_info("Training is ending") + + def on_train_epoch_start(self, trainer, pl_module): + # Reset the memory use counter + torch.cuda.reset_peak_memory_stats(trainer.strategy.root_device.index) + torch.cuda.synchronize(trainer.strategy.root_device.index) + self.start_time = time.time() + + def on_train_epoch_end(self, trainer, pl_module): + torch.cuda.synchronize(trainer.strategy.root_device.index) + max_memory = torch.cuda.max_memory_allocated(trainer.strategy.root_device.index) / 2**20 + epoch_time = time.time() - self.start_time + + try: + max_memory = trainer.strategy.reduce(max_memory) + epoch_time = trainer.strategy.reduce(epoch_time) + + rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds") + rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB") + except AttributeError: + pass + + +if __name__ == "__main__": + # custom parser to specify config files, train, test and debug mode, + # postfix, resume. + # `--key value` arguments are interpreted as arguments to the trainer. + # `nested.key=value` arguments are interpreted as config parameters. + # configs are merged from left-to-right followed by command line parameters. + + # model: + # base_learning_rate: float + # target: path to lightning module + # params: + # key: value + # data: + # target: main.DataModuleFromConfig + # params: + # batch_size: int + # wrap: bool + # train: + # target: path to train dataset + # params: + # key: value + # validation: + # target: path to validation dataset + # params: + # key: value + # test: + # target: path to test dataset + # params: + # key: value + # lightning: (optional, has sane defaults and can be specified on cmdline) + # trainer: + # additional arguments to trainer + # logger: + # logger to instantiate + # modelcheckpoint: + # modelcheckpoint to instantiate + # callbacks: + # callback1: + # target: importpath + # params: + # key: value + + # get the current time to create a new logging directory + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + + # add cwd for convenience and to make classes in this file available when + # running as `python main.py` + # (in particular `main.DataModuleFromConfig`) + sys.path.append(os.getcwd()) + + parser = get_parser() + parser = Trainer.add_argparse_args(parser) + + opt, unknown = parser.parse_known_args() + # Verify the arguments are both specified + if opt.name and opt.resume: + raise ValueError("-n/--name and -r/--resume cannot be specified both." + "If you want to resume training in a new log folder, " + "use -n/--name in combination with --resume_from_checkpoint") + + # Check if the "resume" option is specified, resume training from the checkpoint if it is true + ckpt = None + if opt.resume: + rank_zero_info("Resuming from {}".format(opt.resume)) + if not os.path.exists(opt.resume): + raise ValueError("Cannot find {}".format(opt.resume)) + if os.path.isfile(opt.resume): + paths = opt.resume.split("/") + # idx = len(paths)-paths[::-1].index("logs")+1 + # logdir = "/".join(paths[:idx]) + logdir = "/".join(paths[:-2]) + rank_zero_info("logdir: {}".format(logdir)) + ckpt = opt.resume + else: + assert os.path.isdir(opt.resume), opt.resume + logdir = opt.resume.rstrip("/") + ckpt = os.path.join(logdir, "checkpoints", "last.ckpt") + + # Finds all ".yaml" configuration files in the log directory and adds them to the list of base configurations + base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml"))) + opt.base = base_configs + opt.base + # Gets the name of the current log directory by splitting the path and taking the last element. + _tmp = logdir.split("/") + nowname = _tmp[-1] + else: + if opt.name: + name = "_" + opt.name + elif opt.base: + rank_zero_info("Using base config {}".format(opt.base)) + cfg_fname = os.path.split(opt.base[0])[-1] + cfg_name = os.path.splitext(cfg_fname)[0] + name = "_" + cfg_name + else: + name = "" + nowname = now + name + opt.postfix + logdir = os.path.join(opt.logdir, nowname) + + # Sets the checkpoint path of the 'ckpt' option is specified + if opt.ckpt: + ckpt = opt.ckpt + + # Create the checkpoint and configuration directories within the log directory. + ckptdir = os.path.join(logdir, "checkpoints") + cfgdir = os.path.join(logdir, "configs") + # Sets the seed for the random number generator to ensure reproducibility + seed_everything(opt.seed) + + # Initialize and save configuration using teh OmegaConf library. + try: + # init and save configs + configs = [OmegaConf.load(cfg) for cfg in opt.base] + cli = OmegaConf.from_dotlist(unknown) + config = OmegaConf.merge(*configs, cli) + lightning_config = config.pop("lightning", OmegaConf.create()) + # merge trainer cli with config + trainer_config = lightning_config.get("trainer", OmegaConf.create()) + + for k in nondefault_trainer_args(opt): + trainer_config[k] = getattr(opt, k) + + # Check whether the accelerator is gpu + if not trainer_config["accelerator"] == "gpu": + del trainer_config["accelerator"] + cpu = True + else: + cpu = False + trainer_opt = argparse.Namespace(**trainer_config) + lightning_config.trainer = trainer_config + + # model + use_fp16 = trainer_config.get("precision", 32) == 16 + if use_fp16: + config.model["params"].update({"use_fp16": True}) + else: + config.model["params"].update({"use_fp16": False}) + + if ckpt is not None: + #If a checkpoint path is specified in the ckpt variable, the code updates the "ckpt" key in the "params" dictionary of the config.model configuration with the value of ckpt + config.model["params"].update({"ckpt": ckpt}) + rank_zero_info("Using ckpt_path = {}".format(config.model["params"]["ckpt"])) + + model = LatentDiffusion(**config.model.get("params", dict())) + # trainer and callbacks + trainer_kwargs = dict() + + # config the logger + # Default logger configs to log training metrics during the training process. + default_logger_cfgs = { + "wandb": { + "name": nowname, + "save_dir": logdir, + "offline": opt.debug, + "id": nowname, + } + , + "tensorboard": { + "save_dir": logdir, + "name": "diff_tb", + "log_graph": True + } + } + + # Set up the logger for TensorBoard + default_logger_cfg = default_logger_cfgs["tensorboard"] + if "logger" in lightning_config: + logger_cfg = lightning_config.logger + trainer_kwargs["logger"] = WandbLogger(**logger_cfg) + else: + logger_cfg = default_logger_cfg + trainer_kwargs["logger"] = TensorBoardLogger(**logger_cfg) + + # config the strategy, defualt is ddp + if "strategy" in trainer_config: + strategy_cfg = trainer_config["strategy"] + trainer_kwargs["strategy"] = ColossalAIStrategy(**strategy_cfg) + else: + strategy_cfg = {"find_unused_parameters": False} + trainer_kwargs["strategy"] = DDPStrategy(**strategy_cfg) + + # Set up ModelCheckpoint callback to save best models + # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to + # specify which metric is used to determine best models + default_modelckpt_cfg = { + "dirpath": ckptdir, + "filename": "{epoch:06}", + "verbose": True, + "save_last": True, + } + if hasattr(model, "monitor"): + default_modelckpt_cfg["monitor"] = model.monitor + default_modelckpt_cfg["save_top_k"] = 3 + + if "modelcheckpoint" in lightning_config: + modelckpt_cfg = lightning_config.modelcheckpoint["params"] + else: + modelckpt_cfg = OmegaConf.create() + modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg) + if version.parse(pl.__version__) < version.parse('1.4.0'): + trainer_kwargs["checkpoint_callback"] = ModelCheckpoint(**modelckpt_cfg) + + #Create an empty OmegaConf configuration object + + callbacks_cfg = OmegaConf.create() + + #Instantiate items according to the configs + trainer_kwargs.setdefault("callbacks", []) + setup_callback_config = { + "resume": opt.resume, # resume training if applicable + "now": now, + "logdir": logdir, # directory to save the log file + "ckptdir": ckptdir, # directory to save the checkpoint file + "cfgdir": cfgdir, # directory to save the configuration file + "config": config, # configuration dictionary + "lightning_config": lightning_config, # LightningModule configuration + } + trainer_kwargs["callbacks"].append(SetupCallback(**setup_callback_config)) + + image_logger_config = { + + "batch_frequency": 750, # how frequently to log images + "max_images": 4, # maximum number of images to log + "clamp": True # whether to clamp pixel values to [0,1] + } + trainer_kwargs["callbacks"].append(ImageLogger(**image_logger_config)) + + learning_rate_logger_config = { + "logging_interval": "step", # logging frequency (either 'step' or 'epoch') + # "log_momentum": True # whether to log momentum (currently commented out) + } + trainer_kwargs["callbacks"].append(LearningRateMonitor(**learning_rate_logger_config)) + + metrics_over_trainsteps_checkpoint_config= { + "dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'), + "filename": "{epoch:06}-{step:09}", + "verbose": True, + 'save_top_k': -1, + 'every_n_train_steps': 10000, + 'save_weights_only': True + } + trainer_kwargs["callbacks"].append(ModelCheckpoint(**metrics_over_trainsteps_checkpoint_config)) + trainer_kwargs["callbacks"].append(CUDACallback()) + + # Create a Trainer object with the specified command-line arguments and keyword arguments, and set the log directory + trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs) + trainer.logdir = logdir + + # Create a data module based on the configuration file + data = DataModuleFromConfig(**config.data) + + # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html + # calling these ourselves should not be necessary but it is. + # lightning still takes care of proper multiprocessing though + data.prepare_data() + data.setup() + + # Print some information about the datasets in the data module + for k in data.datasets: + rank_zero_info(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}") + + # Configure learning rate based on the batch size, base learning rate and number of GPUs + # If scale_lr is true, calculate the learning rate based on additional factors + bs, base_lr = config.data.batch_size, config.model.base_learning_rate + if not cpu: + ngpu = trainer_config["devices"] + else: + ngpu = 1 + if 'accumulate_grad_batches' in lightning_config.trainer: + accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches + else: + accumulate_grad_batches = 1 + rank_zero_info(f"accumulate_grad_batches = {accumulate_grad_batches}") + lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches + if opt.scale_lr: + model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr + rank_zero_info( + "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)" + .format(model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr)) + else: + model.learning_rate = base_lr + rank_zero_info("++++ NOT USING LR SCALING ++++") + rank_zero_info(f"Setting learning rate to {model.learning_rate:.2e}") + + # Allow checkpointing via USR1 + def melk(*args, **kwargs): + # run all checkpoint hooks + if trainer.global_rank == 0: + print("Summoning checkpoint.") + ckpt_path = os.path.join(ckptdir, "last.ckpt") + trainer.save_checkpoint(ckpt_path) + + def divein(*args, **kwargs): + if trainer.global_rank == 0: + import pudb + pudb.set_trace() + + import signal + # Assign melk to SIGUSR1 signal and divein to SIGUSR2 signal + signal.signal(signal.SIGUSR1, melk) + signal.signal(signal.SIGUSR2, divein) + + # Run the training and validation + if opt.train: + try: + trainer.fit(model, data) + except Exception: + melk() + raise + # Print the maximum GPU memory allocated during training + print(f"GPU memory usage: {torch.cuda.max_memory_allocated() / 1024**2:.0f} MB") + # if not opt.no_test and not trainer.interrupted: + # trainer.test(model, data) + except Exception: + # If there's an exception, debug it if opt.debug is true and the trainer's global rank is 0 + if opt.debug and trainer.global_rank == 0: + try: + import pudb as debugger + except ImportError: + import pdb as debugger + debugger.post_mortem() + raise + finally: + # Move the log directory to debug_runs if opt.debug is true and the trainer's global + if opt.debug and not opt.resume and trainer.global_rank == 0: + dst, name = os.path.split(logdir) + dst = os.path.join(dst, "debug_runs", name) + os.makedirs(os.path.split(dst)[0], exist_ok=True) + os.rename(logdir, dst) + if trainer.global_rank == 0: + print(trainer.profiler.summary()) diff --git a/examples/images/diffusion/requirements.txt b/examples/images/diffusion/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..59d027fcf60f74e6023531ec4a68d72602c91681 --- /dev/null +++ b/examples/images/diffusion/requirements.txt @@ -0,0 +1,19 @@ +albumentations==1.3.0 +opencv-python==4.6.0.66 +pudb==2019.2 +prefetch_generator +imageio==2.9.0 +imageio-ffmpeg==0.4.2 +torchmetrics==0.7 +omegaconf==2.1.1 +test-tube>=0.7.5 +streamlit>=0.73.1 +einops==0.3.0 +transformers +webdataset==0.2.5 +open-clip-torch==2.7.0 +gradio==3.11 +lightning==1.9.0 +datasets +colossalai +-e . diff --git a/examples/images/diffusion/scripts/download_first_stages.sh b/examples/images/diffusion/scripts/download_first_stages.sh new file mode 100755 index 0000000000000000000000000000000000000000..a8d79e99ccdff0a8d8762f23f3c0642401f32f6c --- /dev/null +++ b/examples/images/diffusion/scripts/download_first_stages.sh @@ -0,0 +1,41 @@ +#!/bin/bash +wget -O models/first_stage_models/kl-f4/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f4.zip +wget -O models/first_stage_models/kl-f8/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f8.zip +wget -O models/first_stage_models/kl-f16/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f16.zip +wget -O models/first_stage_models/kl-f32/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f32.zip +wget -O models/first_stage_models/vq-f4/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4.zip +wget -O models/first_stage_models/vq-f4-noattn/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4-noattn.zip +wget -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip +wget -O models/first_stage_models/vq-f8-n256/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip +wget -O models/first_stage_models/vq-f16/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f16.zip + + + +cd models/first_stage_models/kl-f4 +unzip -o model.zip + +cd ../kl-f8 +unzip -o model.zip + +cd ../kl-f16 +unzip -o model.zip + +cd ../kl-f32 +unzip -o model.zip + +cd ../vq-f4 +unzip -o model.zip + +cd ../vq-f4-noattn +unzip -o model.zip + +cd ../vq-f8 +unzip -o model.zip + +cd ../vq-f8-n256 +unzip -o model.zip + +cd ../vq-f16 +unzip -o model.zip + +cd ../.. \ No newline at end of file diff --git a/examples/images/diffusion/scripts/download_models.sh b/examples/images/diffusion/scripts/download_models.sh new file mode 100755 index 0000000000000000000000000000000000000000..84297d7b8b9a78d241edcd5adaf7d9aa273790de --- /dev/null +++ b/examples/images/diffusion/scripts/download_models.sh @@ -0,0 +1,49 @@ +#!/bin/bash +wget -O models/ldm/celeba256/celeba-256.zip https://ommer-lab.com/files/latent-diffusion/celeba.zip +wget -O models/ldm/ffhq256/ffhq-256.zip https://ommer-lab.com/files/latent-diffusion/ffhq.zip +wget -O models/ldm/lsun_churches256/lsun_churches-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_churches.zip +wget -O models/ldm/lsun_beds256/lsun_beds-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip +wget -O models/ldm/text2img256/model.zip https://ommer-lab.com/files/latent-diffusion/text2img.zip +wget -O models/ldm/cin256/model.zip https://ommer-lab.com/files/latent-diffusion/cin.zip +wget -O models/ldm/semantic_synthesis512/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip +wget -O models/ldm/semantic_synthesis256/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip +wget -O models/ldm/bsr_sr/model.zip https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip +wget -O models/ldm/layout2img-openimages256/model.zip https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip +wget -O models/ldm/inpainting_big/model.zip https://ommer-lab.com/files/latent-diffusion/inpainting_big.zip + + + +cd models/ldm/celeba256 +unzip -o celeba-256.zip + +cd ../ffhq256 +unzip -o ffhq-256.zip + +cd ../lsun_churches256 +unzip -o lsun_churches-256.zip + +cd ../lsun_beds256 +unzip -o lsun_beds-256.zip + +cd ../text2img256 +unzip -o model.zip + +cd ../cin256 +unzip -o model.zip + +cd ../semantic_synthesis512 +unzip -o model.zip + +cd ../semantic_synthesis256 +unzip -o model.zip + +cd ../bsr_sr +unzip -o model.zip + +cd ../layout2img-openimages256 +unzip -o model.zip + +cd ../inpainting_big +unzip -o model.zip + +cd ../.. diff --git a/examples/images/diffusion/scripts/img2img.py b/examples/images/diffusion/scripts/img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..877538d4733dd06ab68d14b5205491fd516cfae2 --- /dev/null +++ b/examples/images/diffusion/scripts/img2img.py @@ -0,0 +1,296 @@ +"""make variations of input image""" + +import argparse, os +import PIL +import torch +import numpy as np +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm, trange +from itertools import islice +from einops import rearrange, repeat +from torchvision.utils import make_grid +from torch import autocast +from contextlib import nullcontext +try: + from lightning.pytorch import seed_everything +except: + from pytorch_lightning import seed_everything +from imwatermark import WatermarkEncoder + + +from scripts.txt2img import put_watermark +from ldm.util import instantiate_from_config +from ldm.models.diffusion.ddim import DDIMSampler +from utils import replace_module, getModelSize + + +def chunk(it, size): + it = iter(it) + return iter(lambda: tuple(islice(it, size)), ()) + + +def load_model_from_config(config, ckpt, verbose=False): + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + model.eval() + return model + + +def load_img(path): + image = Image.open(path).convert("RGB") + w, h = image.size + print(f"loaded input image of size ({w}, {h}) from {path}") + w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2. * image - 1. + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--prompt", + type=str, + nargs="?", + default="a painting of a virus monster playing guitar", + help="the prompt to render" + ) + + parser.add_argument( + "--init-img", + type=str, + nargs="?", + help="path to the input image" + ) + + parser.add_argument( + "--outdir", + type=str, + nargs="?", + help="dir to write results to", + default="outputs/img2img-samples" + ) + + parser.add_argument( + "--ddim_steps", + type=int, + default=50, + help="number of ddim sampling steps", + ) + + parser.add_argument( + "--fixed_code", + action='store_true', + help="if enabled, uses the same starting code across all samples ", + ) + + parser.add_argument( + "--ddim_eta", + type=float, + default=0.0, + help="ddim eta (eta=0.0 corresponds to deterministic sampling", + ) + parser.add_argument( + "--n_iter", + type=int, + default=1, + help="sample this often", + ) + + parser.add_argument( + "--C", + type=int, + default=4, + help="latent channels", + ) + parser.add_argument( + "--f", + type=int, + default=8, + help="downsampling factor, most often 8 or 16", + ) + + parser.add_argument( + "--n_samples", + type=int, + default=2, + help="how many samples to produce for each given prompt. A.k.a batch size", + ) + + parser.add_argument( + "--n_rows", + type=int, + default=0, + help="rows in the grid (default: n_samples)", + ) + + parser.add_argument( + "--scale", + type=float, + default=9.0, + help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", + ) + + parser.add_argument( + "--strength", + type=float, + default=0.8, + help="strength for noising/unnoising. 1.0 corresponds to full destruction of information in init image", + ) + + parser.add_argument( + "--from-file", + type=str, + help="if specified, load prompts from this file", + ) + parser.add_argument( + "--config", + type=str, + default="configs/stable-diffusion/v2-inference.yaml", + help="path to config which constructs model", + ) + parser.add_argument( + "--ckpt", + type=str, + help="path to checkpoint of model", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="the seed (for reproducible sampling)", + ) + parser.add_argument( + "--precision", + type=str, + help="evaluate at this precision", + choices=["full", "autocast"], + default="autocast" + ) + parser.add_argument( + "--use_int8", + type=bool, + default=False, + help="use int8 for inference", + ) + + opt = parser.parse_args() + seed_everything(opt.seed) + + config = OmegaConf.load(f"{opt.config}") + model = load_model_from_config(config, f"{opt.ckpt}") + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + + # quantize model + if opt.use_int8: + model = replace_module(model) + # # to compute the model size + # getModelSize(model) + + sampler = DDIMSampler(model) + + os.makedirs(opt.outdir, exist_ok=True) + outpath = opt.outdir + + print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...") + wm = "SDV2" + wm_encoder = WatermarkEncoder() + wm_encoder.set_watermark('bytes', wm.encode('utf-8')) + + batch_size = opt.n_samples + n_rows = opt.n_rows if opt.n_rows > 0 else batch_size + if not opt.from_file: + prompt = opt.prompt + assert prompt is not None + data = [batch_size * [prompt]] + + else: + print(f"reading prompts from {opt.from_file}") + with open(opt.from_file, "r") as f: + data = f.read().splitlines() + data = list(chunk(data, batch_size)) + + sample_path = os.path.join(outpath, "samples") + os.makedirs(sample_path, exist_ok=True) + base_count = len(os.listdir(sample_path)) + grid_count = len(os.listdir(outpath)) - 1 + + assert os.path.isfile(opt.init_img) + init_image = load_img(opt.init_img).to(device) + init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) + init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space + + sampler.make_schedule(ddim_num_steps=opt.ddim_steps, ddim_eta=opt.ddim_eta, verbose=False) + + assert 0. <= opt.strength <= 1., 'can only work with strength in [0.0, 1.0]' + t_enc = int(opt.strength * opt.ddim_steps) + print(f"target t_enc is {t_enc} steps") + + precision_scope = autocast if opt.precision == "autocast" else nullcontext + with torch.no_grad(): + with precision_scope("cuda"): + with model.ema_scope(): + all_samples = list() + for n in trange(opt.n_iter, desc="Sampling"): + for prompts in tqdm(data, desc="data"): + uc = None + if opt.scale != 1.0: + uc = model.get_learned_conditioning(batch_size * [""]) + if isinstance(prompts, tuple): + prompts = list(prompts) + c = model.get_learned_conditioning(prompts) + + # encode (scaled latent) + z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc] * batch_size).to(device)) + # decode it + samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, ) + + x_samples = model.decode_first_stage(samples) + x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) + + for x_sample in x_samples: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + img = Image.fromarray(x_sample.astype(np.uint8)) + img = put_watermark(img, wm_encoder) + img.save(os.path.join(sample_path, f"{base_count:05}.png")) + base_count += 1 + all_samples.append(x_samples) + + # additionally, save as grid + grid = torch.stack(all_samples, 0) + grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = make_grid(grid, nrow=n_rows) + + # to image + grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() + grid = Image.fromarray(grid.astype(np.uint8)) + grid = put_watermark(grid, wm_encoder) + grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png')) + grid_count += 1 + + print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.") + + +if __name__ == "__main__": + main() + # # to compute the mem allocated + # print(torch.cuda.max_memory_allocated() / 1024 / 1024) diff --git a/examples/images/diffusion/scripts/inpaint.py b/examples/images/diffusion/scripts/inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..d6e6387a9a3b0afa73fae8af25f43a8ba856240e --- /dev/null +++ b/examples/images/diffusion/scripts/inpaint.py @@ -0,0 +1,98 @@ +import argparse, os, sys, glob +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm +import numpy as np +import torch +from main import instantiate_from_config +from ldm.models.diffusion.ddim import DDIMSampler + + +def make_batch(image, mask, device): + image = np.array(Image.open(image).convert("RGB")) + image = image.astype(np.float32)/255.0 + image = image[None].transpose(0,3,1,2) + image = torch.from_numpy(image) + + mask = np.array(Image.open(mask).convert("L")) + mask = mask.astype(np.float32)/255.0 + mask = mask[None,None] + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + masked_image = (1-mask)*image + + batch = {"image": image, "mask": mask, "masked_image": masked_image} + for k in batch: + batch[k] = batch[k].to(device=device) + batch[k] = batch[k]*2.0-1.0 + return batch + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--indir", + type=str, + nargs="?", + help="dir containing image-mask pairs (`example.png` and `example_mask.png`)", + ) + parser.add_argument( + "--outdir", + type=str, + nargs="?", + help="dir to write results to", + ) + parser.add_argument( + "--steps", + type=int, + default=50, + help="number of ddim sampling steps", + ) + opt = parser.parse_args() + + masks = sorted(glob.glob(os.path.join(opt.indir, "*_mask.png"))) + images = [x.replace("_mask.png", ".png") for x in masks] + print(f"Found {len(masks)} inputs.") + + config = OmegaConf.load("models/ldm/inpainting_big/config.yaml") + model = instantiate_from_config(config.model) + model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"], + strict=False) + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + sampler = DDIMSampler(model) + + os.makedirs(opt.outdir, exist_ok=True) + with torch.no_grad(): + with model.ema_scope(): + for image, mask in tqdm(zip(images, masks)): + outpath = os.path.join(opt.outdir, os.path.split(image)[1]) + batch = make_batch(image, mask, device=device) + + # encode masked image and concat downsampled mask + c = model.cond_stage_model.encode(batch["masked_image"]) + cc = torch.nn.functional.interpolate(batch["mask"], + size=c.shape[-2:]) + c = torch.cat((c, cc), dim=1) + + shape = (c.shape[1]-1,)+c.shape[2:] + samples_ddim, _ = sampler.sample(S=opt.steps, + conditioning=c, + batch_size=c.shape[0], + shape=shape, + verbose=False) + x_samples_ddim = model.decode_first_stage(samples_ddim) + + image = torch.clamp((batch["image"]+1.0)/2.0, + min=0.0, max=1.0) + mask = torch.clamp((batch["mask"]+1.0)/2.0, + min=0.0, max=1.0) + predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0, + min=0.0, max=1.0) + + inpainted = (1-mask)*image+mask*predicted_image + inpainted = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255 + Image.fromarray(inpainted.astype(np.uint8)).save(outpath) diff --git a/examples/images/diffusion/scripts/knn2img.py b/examples/images/diffusion/scripts/knn2img.py new file mode 100644 index 0000000000000000000000000000000000000000..e6eaaecab53eac9c97051c9a5cb457a240679725 --- /dev/null +++ b/examples/images/diffusion/scripts/knn2img.py @@ -0,0 +1,398 @@ +import argparse, os, sys, glob +import clip +import torch +import torch.nn as nn +import numpy as np +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm, trange +from itertools import islice +from einops import rearrange, repeat +from torchvision.utils import make_grid +import scann +import time +from multiprocessing import cpu_count + +from ldm.util import instantiate_from_config, parallel_data_prefetch +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.plms import PLMSSampler +from ldm.modules.encoders.modules import FrozenClipImageEmbedder, FrozenCLIPTextEmbedder + +DATABASES = [ + "openimages", + "artbench-art_nouveau", + "artbench-baroque", + "artbench-expressionism", + "artbench-impressionism", + "artbench-post_impressionism", + "artbench-realism", + "artbench-romanticism", + "artbench-renaissance", + "artbench-surrealism", + "artbench-ukiyo_e", +] + + +def chunk(it, size): + it = iter(it) + return iter(lambda: tuple(islice(it, size)), ()) + + +def load_model_from_config(config, ckpt, verbose=False): + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + model.cuda() + model.eval() + return model + + +class Searcher(object): + def __init__(self, database, retriever_version='ViT-L/14'): + assert database in DATABASES + # self.database = self.load_database(database) + self.database_name = database + self.searcher_savedir = f'data/rdm/searchers/{self.database_name}' + self.database_path = f'data/rdm/retrieval_databases/{self.database_name}' + self.retriever = self.load_retriever(version=retriever_version) + self.database = {'embedding': [], + 'img_id': [], + 'patch_coords': []} + self.load_database() + self.load_searcher() + + def train_searcher(self, k, + metric='dot_product', + searcher_savedir=None): + + print('Start training searcher') + searcher = scann.scann_ops_pybind.builder(self.database['embedding'] / + np.linalg.norm(self.database['embedding'], axis=1)[:, np.newaxis], + k, metric) + self.searcher = searcher.score_brute_force().build() + print('Finish training searcher') + + if searcher_savedir is not None: + print(f'Save trained searcher under "{searcher_savedir}"') + os.makedirs(searcher_savedir, exist_ok=True) + self.searcher.serialize(searcher_savedir) + + def load_single_file(self, saved_embeddings): + compressed = np.load(saved_embeddings) + self.database = {key: compressed[key] for key in compressed.files} + print('Finished loading of clip embeddings.') + + def load_multi_files(self, data_archive): + out_data = {key: [] for key in self.database} + for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'): + for key in d.files: + out_data[key].append(d[key]) + + return out_data + + def load_database(self): + + print(f'Load saved patch embedding from "{self.database_path}"') + file_content = glob.glob(os.path.join(self.database_path, '*.npz')) + + if len(file_content) == 1: + self.load_single_file(file_content[0]) + elif len(file_content) > 1: + data = [np.load(f) for f in file_content] + prefetched_data = parallel_data_prefetch(self.load_multi_files, data, + n_proc=min(len(data), cpu_count()), target_data_type='dict') + + self.database = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in + self.database} + else: + raise ValueError(f'No npz-files in specified path "{self.database_path}" is this directory existing?') + + print(f'Finished loading of retrieval database of length {self.database["embedding"].shape[0]}.') + + def load_retriever(self, version='ViT-L/14', ): + model = FrozenClipImageEmbedder(model=version) + if torch.cuda.is_available(): + model.cuda() + model.eval() + return model + + def load_searcher(self): + print(f'load searcher for database {self.database_name} from {self.searcher_savedir}') + self.searcher = scann.scann_ops_pybind.load_searcher(self.searcher_savedir) + print('Finished loading searcher.') + + def search(self, x, k): + if self.searcher is None and self.database['embedding'].shape[0] < 2e4: + self.train_searcher(k) # quickly fit searcher on the fly for small databases + assert self.searcher is not None, 'Cannot search with uninitialized searcher' + if isinstance(x, torch.Tensor): + x = x.detach().cpu().numpy() + if len(x.shape) == 3: + x = x[:, 0] + query_embeddings = x / np.linalg.norm(x, axis=1)[:, np.newaxis] + + start = time.time() + nns, distances = self.searcher.search_batched(query_embeddings, final_num_neighbors=k) + end = time.time() + + out_embeddings = self.database['embedding'][nns] + out_img_ids = self.database['img_id'][nns] + out_pc = self.database['patch_coords'][nns] + + out = {'nn_embeddings': out_embeddings / np.linalg.norm(out_embeddings, axis=-1)[..., np.newaxis], + 'img_ids': out_img_ids, + 'patch_coords': out_pc, + 'queries': x, + 'exec_time': end - start, + 'nns': nns, + 'q_embeddings': query_embeddings} + + return out + + def __call__(self, x, n): + return self.search(x, n) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # TODO: add n_neighbors and modes (text-only, text-image-retrieval, image-image retrieval etc) + # TODO: add 'image variation' mode when knn=0 but a single image is given instead of a text prompt? + parser.add_argument( + "--prompt", + type=str, + nargs="?", + default="a painting of a virus monster playing guitar", + help="the prompt to render" + ) + + parser.add_argument( + "--outdir", + type=str, + nargs="?", + help="dir to write results to", + default="outputs/txt2img-samples" + ) + + parser.add_argument( + "--skip_grid", + action='store_true', + help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", + ) + + parser.add_argument( + "--ddim_steps", + type=int, + default=50, + help="number of ddim sampling steps", + ) + + parser.add_argument( + "--n_repeat", + type=int, + default=1, + help="number of repeats in CLIP latent space", + ) + + parser.add_argument( + "--plms", + action='store_true', + help="use plms sampling", + ) + + parser.add_argument( + "--ddim_eta", + type=float, + default=0.0, + help="ddim eta (eta=0.0 corresponds to deterministic sampling", + ) + parser.add_argument( + "--n_iter", + type=int, + default=1, + help="sample this often", + ) + + parser.add_argument( + "--H", + type=int, + default=768, + help="image height, in pixel space", + ) + + parser.add_argument( + "--W", + type=int, + default=768, + help="image width, in pixel space", + ) + + parser.add_argument( + "--n_samples", + type=int, + default=3, + help="how many samples to produce for each given prompt. A.k.a batch size", + ) + + parser.add_argument( + "--n_rows", + type=int, + default=0, + help="rows in the grid (default: n_samples)", + ) + + parser.add_argument( + "--scale", + type=float, + default=5.0, + help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", + ) + + parser.add_argument( + "--from-file", + type=str, + help="if specified, load prompts from this file", + ) + + parser.add_argument( + "--config", + type=str, + default="configs/retrieval-augmented-diffusion/768x768.yaml", + help="path to config which constructs model", + ) + + parser.add_argument( + "--ckpt", + type=str, + default="models/rdm/rdm768x768/model.ckpt", + help="path to checkpoint of model", + ) + + parser.add_argument( + "--clip_type", + type=str, + default="ViT-L/14", + help="which CLIP model to use for retrieval and NN encoding", + ) + parser.add_argument( + "--database", + type=str, + default='artbench-surrealism', + choices=DATABASES, + help="The database used for the search, only applied when --use_neighbors=True", + ) + parser.add_argument( + "--use_neighbors", + default=False, + action='store_true', + help="Include neighbors in addition to text prompt for conditioning", + ) + parser.add_argument( + "--knn", + default=10, + type=int, + help="The number of included neighbors, only applied when --use_neighbors=True", + ) + + opt = parser.parse_args() + + config = OmegaConf.load(f"{opt.config}") + model = load_model_from_config(config, f"{opt.ckpt}") + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + + clip_text_encoder = FrozenCLIPTextEmbedder(opt.clip_type).to(device) + + if opt.plms: + sampler = PLMSSampler(model) + else: + sampler = DDIMSampler(model) + + os.makedirs(opt.outdir, exist_ok=True) + outpath = opt.outdir + + batch_size = opt.n_samples + n_rows = opt.n_rows if opt.n_rows > 0 else batch_size + if not opt.from_file: + prompt = opt.prompt + assert prompt is not None + data = [batch_size * [prompt]] + + else: + print(f"reading prompts from {opt.from_file}") + with open(opt.from_file, "r") as f: + data = f.read().splitlines() + data = list(chunk(data, batch_size)) + + sample_path = os.path.join(outpath, "samples") + os.makedirs(sample_path, exist_ok=True) + base_count = len(os.listdir(sample_path)) + grid_count = len(os.listdir(outpath)) - 1 + + print(f"sampling scale for cfg is {opt.scale:.2f}") + + searcher = None + if opt.use_neighbors: + searcher = Searcher(opt.database) + + with torch.no_grad(): + with model.ema_scope(): + for n in trange(opt.n_iter, desc="Sampling"): + all_samples = list() + for prompts in tqdm(data, desc="data"): + print("sampling prompts:", prompts) + if isinstance(prompts, tuple): + prompts = list(prompts) + c = clip_text_encoder.encode(prompts) + uc = None + if searcher is not None: + nn_dict = searcher(c, opt.knn) + c = torch.cat([c, torch.from_numpy(nn_dict['nn_embeddings']).cuda()], dim=1) + if opt.scale != 1.0: + uc = torch.zeros_like(c) + if isinstance(prompts, tuple): + prompts = list(prompts) + shape = [16, opt.H // 16, opt.W // 16] # note: currently hardcoded for f16 model + samples_ddim, _ = sampler.sample(S=opt.ddim_steps, + conditioning=c, + batch_size=c.shape[0], + shape=shape, + verbose=False, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + eta=opt.ddim_eta, + ) + + x_samples_ddim = model.decode_first_stage(samples_ddim) + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + + for x_sample in x_samples_ddim: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + Image.fromarray(x_sample.astype(np.uint8)).save( + os.path.join(sample_path, f"{base_count:05}.png")) + base_count += 1 + all_samples.append(x_samples_ddim) + + if not opt.skip_grid: + # additionally, save as grid + grid = torch.stack(all_samples, 0) + grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = make_grid(grid, nrow=n_rows) + + # to image + grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() + Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png')) + grid_count += 1 + + print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.") diff --git a/examples/images/diffusion/scripts/sample_diffusion.py b/examples/images/diffusion/scripts/sample_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..876fe3c3642fcc8c7209e4f763c0134166615f78 --- /dev/null +++ b/examples/images/diffusion/scripts/sample_diffusion.py @@ -0,0 +1,313 @@ +import argparse, os, sys, glob, datetime, yaml +import torch +import time +import numpy as np +from tqdm import trange + +from omegaconf import OmegaConf +from PIL import Image + +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.util import instantiate_from_config + +rescale = lambda x: (x + 1.) / 2. + +def custom_to_pil(x): + x = x.detach().cpu() + x = torch.clamp(x, -1., 1.) + x = (x + 1.) / 2. + x = x.permute(1, 2, 0).numpy() + x = (255 * x).astype(np.uint8) + x = Image.fromarray(x) + if not x.mode == "RGB": + x = x.convert("RGB") + return x + + +def custom_to_np(x): + # saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py + sample = x.detach().cpu() + sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8) + sample = sample.permute(0, 2, 3, 1) + sample = sample.contiguous() + return sample + + +def logs2pil(logs, keys=["sample"]): + imgs = dict() + for k in logs: + try: + if len(logs[k].shape) == 4: + img = custom_to_pil(logs[k][0, ...]) + elif len(logs[k].shape) == 3: + img = custom_to_pil(logs[k]) + else: + print(f"Unknown format for key {k}. ") + img = None + except: + img = None + imgs[k] = img + return imgs + + +@torch.no_grad() +def convsample(model, shape, return_intermediates=True, + verbose=True, + make_prog_row=False): + + + if not make_prog_row: + return model.p_sample_loop(None, shape, + return_intermediates=return_intermediates, verbose=verbose) + else: + return model.progressive_denoising( + None, shape, verbose=True + ) + + +@torch.no_grad() +def convsample_ddim(model, steps, shape, eta=1.0 + ): + ddim = DDIMSampler(model) + bs = shape[0] + shape = shape[1:] + samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, eta=eta, verbose=False,) + return samples, intermediates + + +@torch.no_grad() +def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=None, eta=1.0,): + + + log = dict() + + shape = [batch_size, + model.model.diffusion_model.in_channels, + model.model.diffusion_model.image_size, + model.model.diffusion_model.image_size] + + with model.ema_scope("Plotting"): + t0 = time.time() + if vanilla: + sample, progrow = convsample(model, shape, + make_prog_row=True) + else: + sample, intermediates = convsample_ddim(model, steps=custom_steps, shape=shape, + eta=eta) + + t1 = time.time() + + x_sample = model.decode_first_stage(sample) + + log["sample"] = x_sample + log["time"] = t1 - t0 + log['throughput'] = sample.shape[0] / (t1 - t0) + print(f'Throughput for this batch: {log["throughput"]}') + return log + +def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None, n_samples=50000, nplog=None): + if vanilla: + print(f'Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.') + else: + print(f'Using DDIM sampling with {custom_steps} sampling steps and eta={eta}') + + + tstart = time.time() + n_saved = len(glob.glob(os.path.join(logdir,'*.png')))-1 + # path = logdir + if model.cond_stage_model is None: + all_images = [] + + print(f"Running unconditional sampling for {n_samples} samples") + for _ in trange(n_samples // batch_size, desc="Sampling Batches (unconditional)"): + logs = make_convolutional_sample(model, batch_size=batch_size, + vanilla=vanilla, custom_steps=custom_steps, + eta=eta) + n_saved = save_logs(logs, logdir, n_saved=n_saved, key="sample") + all_images.extend([custom_to_np(logs["sample"])]) + if n_saved >= n_samples: + print(f'Finish after generating {n_saved} samples') + break + all_img = np.concatenate(all_images, axis=0) + all_img = all_img[:n_samples] + shape_str = "x".join([str(x) for x in all_img.shape]) + nppath = os.path.join(nplog, f"{shape_str}-samples.npz") + np.savez(nppath, all_img) + + else: + raise NotImplementedError('Currently only sampling for unconditional models supported.') + + print(f"sampling of {n_saved} images finished in {(time.time() - tstart) / 60.:.2f} minutes.") + + +def save_logs(logs, path, n_saved=0, key="sample", np_path=None): + for k in logs: + if k == key: + batch = logs[key] + if np_path is None: + for x in batch: + img = custom_to_pil(x) + imgpath = os.path.join(path, f"{key}_{n_saved:06}.png") + img.save(imgpath) + n_saved += 1 + else: + npbatch = custom_to_np(batch) + shape_str = "x".join([str(x) for x in npbatch.shape]) + nppath = os.path.join(np_path, f"{n_saved}-{shape_str}-samples.npz") + np.savez(nppath, npbatch) + n_saved += npbatch.shape[0] + return n_saved + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-r", + "--resume", + type=str, + nargs="?", + help="load from logdir or checkpoint in logdir", + ) + parser.add_argument( + "-n", + "--n_samples", + type=int, + nargs="?", + help="number of samples to draw", + default=50000 + ) + parser.add_argument( + "-e", + "--eta", + type=float, + nargs="?", + help="eta for ddim sampling (0.0 yields deterministic sampling)", + default=1.0 + ) + parser.add_argument( + "-v", + "--vanilla_sample", + default=False, + action='store_true', + help="vanilla sampling (default option is DDIM sampling)?", + ) + parser.add_argument( + "-l", + "--logdir", + type=str, + nargs="?", + help="extra logdir", + default="none" + ) + parser.add_argument( + "-c", + "--custom_steps", + type=int, + nargs="?", + help="number of steps for ddim and fastdpm sampling", + default=50 + ) + parser.add_argument( + "--batch_size", + type=int, + nargs="?", + help="the bs", + default=10 + ) + return parser + + +def load_model_from_config(config, sd): + model = instantiate_from_config(config) + model.load_state_dict(sd,strict=False) + model.cuda() + model.eval() + return model + + +def load_model(config, ckpt, gpu, eval_mode): + if ckpt: + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + global_step = pl_sd["global_step"] + else: + pl_sd = {"state_dict": None} + global_step = None + model = load_model_from_config(config.model, + pl_sd["state_dict"]) + + return model, global_step + + +if __name__ == "__main__": + now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + sys.path.append(os.getcwd()) + command = " ".join(sys.argv) + + parser = get_parser() + opt, unknown = parser.parse_known_args() + ckpt = None + + if not os.path.exists(opt.resume): + raise ValueError("Cannot find {}".format(opt.resume)) + if os.path.isfile(opt.resume): + # paths = opt.resume.split("/") + try: + logdir = '/'.join(opt.resume.split('/')[:-1]) + # idx = len(paths)-paths[::-1].index("logs")+1 + print(f'Logdir is {logdir}') + except ValueError: + paths = opt.resume.split("/") + idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt + logdir = "/".join(paths[:idx]) + ckpt = opt.resume + else: + assert os.path.isdir(opt.resume), f"{opt.resume} is not a directory" + logdir = opt.resume.rstrip("/") + ckpt = os.path.join(logdir, "model.ckpt") + + base_configs = sorted(glob.glob(os.path.join(logdir, "config.yaml"))) + opt.base = base_configs + + configs = [OmegaConf.load(cfg) for cfg in opt.base] + cli = OmegaConf.from_dotlist(unknown) + config = OmegaConf.merge(*configs, cli) + + gpu = True + eval_mode = True + + if opt.logdir != "none": + locallog = logdir.split(os.sep)[-1] + if locallog == "": locallog = logdir.split(os.sep)[-2] + print(f"Switching logdir from '{logdir}' to '{os.path.join(opt.logdir, locallog)}'") + logdir = os.path.join(opt.logdir, locallog) + + print(config) + + model, global_step = load_model(config, ckpt, gpu, eval_mode) + print(f"global step: {global_step}") + print(75 * "=") + print("logging to:") + logdir = os.path.join(logdir, "samples", f"{global_step:08}", now) + imglogdir = os.path.join(logdir, "img") + numpylogdir = os.path.join(logdir, "numpy") + + os.makedirs(imglogdir) + os.makedirs(numpylogdir) + print(logdir) + print(75 * "=") + + # write config out + sampling_file = os.path.join(logdir, "sampling_config.yaml") + sampling_conf = vars(opt) + + with open(sampling_file, 'w') as f: + yaml.dump(sampling_conf, f, default_flow_style=False) + print(sampling_conf) + + + run(model, imglogdir, eta=opt.eta, + vanilla=opt.vanilla_sample, n_samples=opt.n_samples, custom_steps=opt.custom_steps, + batch_size=opt.batch_size, nplog=numpylogdir) + + print("done.") diff --git a/examples/images/diffusion/scripts/tests/test_checkpoint.py b/examples/images/diffusion/scripts/tests/test_checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..13622c4989fd48d44c49c875de8417b0ee7710cf --- /dev/null +++ b/examples/images/diffusion/scripts/tests/test_checkpoint.py @@ -0,0 +1,38 @@ +import os +import sys +from copy import deepcopy + +import yaml +from datetime import datetime + +from diffusers import StableDiffusionPipeline +import torch + +from main import get_parser +from ldm.modules.diffusionmodules.openaimodel import UNetModel + +if __name__ == "__main__": + with torch.no_grad(): + yaml_path = "../../train_colossalai.yaml" + with open(yaml_path, 'r', encoding='utf-8') as f: + config = f.read() + base_config = yaml.load(config, Loader=yaml.FullLoader) + unet_config = base_config['model']['params']['unet_config'] + diffusion_model = UNetModel(**unet_config).to("cuda:0") + + pipe = StableDiffusionPipeline.from_pretrained( + "/data/scratch/diffuser/stable-diffusion-v1-4" + ).to("cuda:0") + dif_model_2 = pipe.unet + + random_input_ = torch.rand((4, 4, 32, 32)).to("cuda:0") + random_input_2 = torch.clone(random_input_).to("cuda:0") + time_stamp = torch.randint(20, (4,)).to("cuda:0") + time_stamp2 = torch.clone(time_stamp).to("cuda:0") + context_ = torch.rand((4, 77, 768)).to("cuda:0") + context_2 = torch.clone(context_).to("cuda:0") + + out_1 = diffusion_model(random_input_, time_stamp, context_) + out_2 = dif_model_2(random_input_2, time_stamp2, context_2) + print(out_1.shape) + print(out_2['sample'].shape) \ No newline at end of file diff --git a/examples/images/diffusion/scripts/tests/test_watermark.py b/examples/images/diffusion/scripts/tests/test_watermark.py new file mode 100644 index 0000000000000000000000000000000000000000..f93f8a6e70763c0e284157bc8225827520b2f5ef --- /dev/null +++ b/examples/images/diffusion/scripts/tests/test_watermark.py @@ -0,0 +1,18 @@ +import cv2 +import fire +from imwatermark import WatermarkDecoder + + +def testit(img_path): + bgr = cv2.imread(img_path) + decoder = WatermarkDecoder('bytes', 136) + watermark = decoder.decode(bgr, 'dwtDct') + try: + dec = watermark.decode('utf-8') + except: + dec = "null" + print(dec) + + +if __name__ == "__main__": + fire.Fire(testit) \ No newline at end of file diff --git a/examples/images/diffusion/scripts/train_searcher.py b/examples/images/diffusion/scripts/train_searcher.py new file mode 100644 index 0000000000000000000000000000000000000000..1e7904889c0145f9fb740fd4ae8e45c08728b255 --- /dev/null +++ b/examples/images/diffusion/scripts/train_searcher.py @@ -0,0 +1,147 @@ +import os, sys +import numpy as np +import scann +import argparse +import glob +from multiprocessing import cpu_count +from tqdm import tqdm + +from ldm.util import parallel_data_prefetch + + +def search_bruteforce(searcher): + return searcher.score_brute_force().build() + + +def search_partioned_ah(searcher, dims_per_block, aiq_threshold, reorder_k, + partioning_trainsize, num_leaves, num_leaves_to_search): + return searcher.tree(num_leaves=num_leaves, + num_leaves_to_search=num_leaves_to_search, + training_sample_size=partioning_trainsize). \ + score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build() + + +def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k): + return searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder( + reorder_k).build() + +def load_datapool(dpath): + + + def load_single_file(saved_embeddings): + compressed = np.load(saved_embeddings) + database = {key: compressed[key] for key in compressed.files} + return database + + def load_multi_files(data_archive): + database = {key: [] for key in data_archive[0].files} + for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'): + for key in d.files: + database[key].append(d[key]) + + return database + + print(f'Load saved patch embedding from "{dpath}"') + file_content = glob.glob(os.path.join(dpath, '*.npz')) + + if len(file_content) == 1: + data_pool = load_single_file(file_content[0]) + elif len(file_content) > 1: + data = [np.load(f) for f in file_content] + prefetched_data = parallel_data_prefetch(load_multi_files, data, + n_proc=min(len(data), cpu_count()), target_data_type='dict') + + data_pool = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()} + else: + raise ValueError(f'No npz-files in specified path "{dpath}" is this directory existing?') + + print(f'Finished loading of retrieval database of length {data_pool["embedding"].shape[0]}.') + return data_pool + + +def train_searcher(opt, + metric='dot_product', + partioning_trainsize=None, + reorder_k=None, + # todo tune + aiq_thld=0.2, + dims_per_block=2, + num_leaves=None, + num_leaves_to_search=None,): + + data_pool = load_datapool(opt.database) + k = opt.knn + + if not reorder_k: + reorder_k = 2 * k + + # normalize + # embeddings = + searcher = scann.scann_ops_pybind.builder(data_pool['embedding'] / np.linalg.norm(data_pool['embedding'], axis=1)[:, np.newaxis], k, metric) + pool_size = data_pool['embedding'].shape[0] + + print(*(['#'] * 100)) + print('Initializing scaNN searcher with the following values:') + print(f'k: {k}') + print(f'metric: {metric}') + print(f'reorder_k: {reorder_k}') + print(f'anisotropic_quantization_threshold: {aiq_thld}') + print(f'dims_per_block: {dims_per_block}') + print(*(['#'] * 100)) + print('Start training searcher....') + print(f'N samples in pool is {pool_size}') + + # this reflects the recommended design choices proposed at + # https://github.com/google-research/google-research/blob/aca5f2e44e301af172590bb8e65711f0c9ee0cfd/scann/docs/algorithms.md + if pool_size < 2e4: + print('Using brute force search.') + searcher = search_bruteforce(searcher) + elif 2e4 <= pool_size and pool_size < 1e5: + print('Using asymmetric hashing search and reordering.') + searcher = search_ah(searcher, dims_per_block, aiq_thld, reorder_k) + else: + print('Using using partioning, asymmetric hashing search and reordering.') + + if not partioning_trainsize: + partioning_trainsize = data_pool['embedding'].shape[0] // 10 + if not num_leaves: + num_leaves = int(np.sqrt(pool_size)) + + if not num_leaves_to_search: + num_leaves_to_search = max(num_leaves // 20, 1) + + print('Partitioning params:') + print(f'num_leaves: {num_leaves}') + print(f'num_leaves_to_search: {num_leaves_to_search}') + # self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k) + searcher = search_partioned_ah(searcher, dims_per_block, aiq_thld, reorder_k, + partioning_trainsize, num_leaves, num_leaves_to_search) + + print('Finish training searcher') + searcher_savedir = opt.target_path + os.makedirs(searcher_savedir, exist_ok=True) + searcher.serialize(searcher_savedir) + print(f'Saved trained searcher under "{searcher_savedir}"') + +if __name__ == '__main__': + sys.path.append(os.getcwd()) + parser = argparse.ArgumentParser() + parser.add_argument('--database', + '-d', + default='data/rdm/retrieval_databases/openimages', + type=str, + help='path to folder containing the clip feature of the database') + parser.add_argument('--target_path', + '-t', + default='data/rdm/searchers/openimages', + type=str, + help='path to the target folder where the searcher shall be stored.') + parser.add_argument('--knn', + '-k', + default=20, + type=int, + help='number of nearest neighbors, for which the searcher shall be optimized') + + opt, _ = parser.parse_known_args() + + train_searcher(opt,) \ No newline at end of file diff --git a/examples/images/diffusion/scripts/txt2img.py b/examples/images/diffusion/scripts/txt2img.py new file mode 100644 index 0000000000000000000000000000000000000000..364ebac6c67b62532b5bf0c26187dc366bec81af --- /dev/null +++ b/examples/images/diffusion/scripts/txt2img.py @@ -0,0 +1,307 @@ +import argparse, os +import cv2 +import torch +import numpy as np +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm, trange +from itertools import islice +from einops import rearrange +from torchvision.utils import make_grid +try: + from lightning.pytorch import seed_everything +except: + from pytorch_lightning import seed_everything +from torch import autocast +from contextlib import nullcontext +from imwatermark import WatermarkEncoder + +from ldm.util import instantiate_from_config +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.plms import PLMSSampler +from ldm.models.diffusion.dpm_solver import DPMSolverSampler +from utils import replace_module, getModelSize + +torch.set_grad_enabled(False) + +def chunk(it, size): + it = iter(it) + return iter(lambda: tuple(islice(it, size)), ()) + + +def load_model_from_config(config, ckpt, verbose=False): + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + model.eval() + return model + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--prompt", + type=str, + nargs="?", + default="a professional photograph of an astronaut riding a triceratops", + help="the prompt to render" + ) + parser.add_argument( + "--outdir", + type=str, + nargs="?", + help="dir to write results to", + default="outputs/txt2img-samples" + ) + parser.add_argument( + "--steps", + type=int, + default=50, + help="number of ddim sampling steps", + ) + parser.add_argument( + "--plms", + action='store_true', + help="use plms sampling", + ) + parser.add_argument( + "--dpm", + action='store_true', + help="use DPM (2) sampler", + ) + parser.add_argument( + "--fixed_code", + action='store_true', + help="if enabled, uses the same starting code across all samples ", + ) + parser.add_argument( + "--ddim_eta", + type=float, + default=0.0, + help="ddim eta (eta=0.0 corresponds to deterministic sampling", + ) + parser.add_argument( + "--n_iter", + type=int, + default=3, + help="sample this often", + ) + parser.add_argument( + "--H", + type=int, + default=512, + help="image height, in pixel space", + ) + parser.add_argument( + "--W", + type=int, + default=512, + help="image width, in pixel space", + ) + parser.add_argument( + "--C", + type=int, + default=4, + help="latent channels", + ) + parser.add_argument( + "--f", + type=int, + default=8, + help="downsampling factor, most often 8 or 16", + ) + parser.add_argument( + "--n_samples", + type=int, + default=3, + help="how many samples to produce for each given prompt. A.k.a batch size", + ) + parser.add_argument( + "--n_rows", + type=int, + default=0, + help="rows in the grid (default: n_samples)", + ) + parser.add_argument( + "--scale", + type=float, + default=9.0, + help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", + ) + parser.add_argument( + "--from-file", + type=str, + help="if specified, load prompts from this file, separated by newlines", + ) + parser.add_argument( + "--config", + type=str, + default="configs/stable-diffusion/v2-inference.yaml", + help="path to config which constructs model", + ) + parser.add_argument( + "--ckpt", + type=str, + help="path to checkpoint of model", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="the seed (for reproducible sampling)", + ) + parser.add_argument( + "--precision", + type=str, + help="evaluate at this precision", + choices=["full", "autocast"], + default="autocast" + ) + parser.add_argument( + "--repeat", + type=int, + default=1, + help="repeat each prompt in file this often", + ) + parser.add_argument( + "--use_int8", + type=bool, + default=False, + help="use int8 for inference", + ) + opt = parser.parse_args() + return opt + + +def put_watermark(img, wm_encoder=None): + if wm_encoder is not None: + img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) + img = wm_encoder.encode(img, 'dwtDct') + img = Image.fromarray(img[:, :, ::-1]) + return img + + +def main(opt): + seed_everything(opt.seed) + + config = OmegaConf.load(f"{opt.config}") + model = load_model_from_config(config, f"{opt.ckpt}") + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + model = model.to(device) + + # quantize model + if opt.use_int8: + model = replace_module(model) + # # to compute the model size + # getModelSize(model) + + if opt.plms: + sampler = PLMSSampler(model) + elif opt.dpm: + sampler = DPMSolverSampler(model) + else: + sampler = DDIMSampler(model) + + os.makedirs(opt.outdir, exist_ok=True) + outpath = opt.outdir + + print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...") + wm = "SDV2" + wm_encoder = WatermarkEncoder() + wm_encoder.set_watermark('bytes', wm.encode('utf-8')) + + batch_size = opt.n_samples + n_rows = opt.n_rows if opt.n_rows > 0 else batch_size + if not opt.from_file: + prompt = opt.prompt + assert prompt is not None + data = [batch_size * [prompt]] + + else: + print(f"reading prompts from {opt.from_file}") + with open(opt.from_file, "r") as f: + data = f.read().splitlines() + data = [p for p in data for i in range(opt.repeat)] + data = list(chunk(data, batch_size)) + + sample_path = os.path.join(outpath, "samples") + os.makedirs(sample_path, exist_ok=True) + sample_count = 0 + base_count = len(os.listdir(sample_path)) + grid_count = len(os.listdir(outpath)) - 1 + + start_code = None + if opt.fixed_code: + start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device) + + precision_scope = autocast if opt.precision == "autocast" else nullcontext + with torch.no_grad(), \ + precision_scope("cuda"), \ + model.ema_scope(): + all_samples = list() + for n in trange(opt.n_iter, desc="Sampling"): + for prompts in tqdm(data, desc="data"): + uc = None + if opt.scale != 1.0: + uc = model.get_learned_conditioning(batch_size * [""]) + if isinstance(prompts, tuple): + prompts = list(prompts) + c = model.get_learned_conditioning(prompts) + shape = [opt.C, opt.H // opt.f, opt.W // opt.f] + samples, _ = sampler.sample(S=opt.steps, + conditioning=c, + batch_size=opt.n_samples, + shape=shape, + verbose=False, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + eta=opt.ddim_eta, + x_T=start_code) + + x_samples = model.decode_first_stage(samples) + x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) + + for x_sample in x_samples: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + img = Image.fromarray(x_sample.astype(np.uint8)) + img = put_watermark(img, wm_encoder) + img.save(os.path.join(sample_path, f"{base_count:05}.png")) + base_count += 1 + sample_count += 1 + + all_samples.append(x_samples) + + # additionally, save as grid + grid = torch.stack(all_samples, 0) + grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = make_grid(grid, nrow=n_rows) + + # to image + grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() + grid = Image.fromarray(grid.astype(np.uint8)) + grid = put_watermark(grid, wm_encoder) + grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png')) + grid_count += 1 + + print(f"Your samples are ready and waiting for you here: \n{outpath} \n" + f" \nEnjoy.") + + +if __name__ == "__main__": + opt = parse_args() + main(opt) + # # to compute the mem allocated + # print(torch.cuda.max_memory_allocated() / 1024 / 1024) diff --git a/examples/images/diffusion/scripts/txt2img.sh b/examples/images/diffusion/scripts/txt2img.sh new file mode 100755 index 0000000000000000000000000000000000000000..bc6480b6bdaaf430f56706867e87b9830d597148 --- /dev/null +++ b/examples/images/diffusion/scripts/txt2img.sh @@ -0,0 +1,5 @@ +python scripts/txt2img.py --prompt "Teyvat, Medium Female, a woman in a blue outfit holding a sword" --plms \ + --outdir ./output \ + --ckpt checkpoints/last.ckpt \ + --config configs/2023-02-02T18-06-14-project.yaml \ + --n_samples 4 diff --git a/examples/images/diffusion/scripts/utils.py b/examples/images/diffusion/scripts/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c954b22ca19045c985e2f0bc8da32fb79474996d --- /dev/null +++ b/examples/images/diffusion/scripts/utils.py @@ -0,0 +1,83 @@ +import bitsandbytes as bnb +import torch.nn as nn +import torch + +class Linear8bit(nn.Linear): + def __init__( + self, + input_features, + output_features, + bias=True, + has_fp16_weights=False, + memory_efficient_backward=False, + threshold=6.0, + weight_data=None, + bias_data=None + ): + super(Linear8bit, self).__init__( + input_features, output_features, bias + ) + self.state = bnb.MatmulLtState() + self.bias = bias_data + self.state.threshold = threshold + self.state.has_fp16_weights = has_fp16_weights + self.state.memory_efficient_backward = memory_efficient_backward + if threshold > 0.0 and not has_fp16_weights: + self.state.use_pool = True + + self.register_parameter("SCB", nn.Parameter(torch.empty(0), requires_grad=False)) + self.weight = weight_data + self.quant() + + + def quant(self): + weight = self.weight.data.contiguous().half().cuda() + CB, _, SCB, _, _ = bnb.functional.double_quant(weight) + delattr(self, "weight") + setattr(self, "weight", nn.Parameter(CB, requires_grad=False)) + delattr(self, "SCB") + setattr(self, "SCB", nn.Parameter(SCB, requires_grad=False)) + del weight + + def forward(self, x): + self.state.is_training = self.training + + if self.bias is not None and self.bias.dtype != torch.float16: + self.bias.data = self.bias.data.half() + + self.state.CB = self.weight.data + self.state.SCB = self.SCB.data + + out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) + del self.state.CxB + return out + +def replace_module(model): + for name, module in model.named_children(): + if len(list(module.children())) > 0: + replace_module(module) + + if isinstance(module, nn.Linear) and "out_proj" not in name: + model._modules[name] = Linear8bit( + input_features=module.in_features, + output_features=module.out_features, + threshold=6.0, + weight_data=module.weight, + bias_data=module.bias, + ) + return model + +def getModelSize(model): + param_size = 0 + param_sum = 0 + for param in model.parameters(): + param_size += param.nelement() * param.element_size() + param_sum += param.nelement() + buffer_size = 0 + buffer_sum = 0 + for buffer in model.buffers(): + buffer_size += buffer.nelement() * buffer.element_size() + buffer_sum += buffer.nelement() + all_size = (param_size + buffer_size) / 1024 / 1024 + print('Model Size: {:.3f}MB'.format(all_size)) + return (param_size, param_sum, buffer_size, buffer_sum, all_size) diff --git a/examples/images/diffusion/setup.py b/examples/images/diffusion/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..a24d541676407eee1bea271179ffd1d80c6a8e79 --- /dev/null +++ b/examples/images/diffusion/setup.py @@ -0,0 +1,13 @@ +from setuptools import setup, find_packages + +setup( + name='latent-diffusion', + version='0.0.1', + description='', + packages=find_packages(), + install_requires=[ + 'torch', + 'numpy', + 'tqdm', + ], +) \ No newline at end of file diff --git a/examples/images/diffusion/test_ci.sh b/examples/images/diffusion/test_ci.sh new file mode 100755 index 0000000000000000000000000000000000000000..44cf470466846d1fecc90baa7ad84c3c68102df8 --- /dev/null +++ b/examples/images/diffusion/test_ci.sh @@ -0,0 +1,15 @@ +#!/bin/bash +set -euxo pipefail + +conda env create -f environment.yaml + +conda activate ldm + +conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch +pip install transformers diffusers invisible-watermark + +CUDA_EXT=1 pip install colossalai + +wget https://huggingface.co/stabilityai/stable-diffusion-2-base/resolve/main/512-base-ema.ckpt + +python main.py --logdir /tmp --train --base configs/Teyvat/train_colossalai_teyvat.yaml --ckpt 512-base-ema.ckpt diff --git a/examples/images/diffusion/train_colossalai.sh b/examples/images/diffusion/train_colossalai.sh new file mode 100755 index 0000000000000000000000000000000000000000..7f1a1bd14615a66e948879ffa49b4aae8a542d0a --- /dev/null +++ b/examples/images/diffusion/train_colossalai.sh @@ -0,0 +1,6 @@ +HF_DATASETS_OFFLINE=1 +TRANSFORMERS_OFFLINE=1 +DIFFUSERS_OFFLINE=1 + +python main.py --logdir /tmp --train --base configs/Teyvat/train_colossalai_teyvat.yaml --ckpt diffuser_root_dir/512-base-ema.ckpt + diff --git a/examples/images/diffusion/train_ddp.sh b/examples/images/diffusion/train_ddp.sh new file mode 100644 index 0000000000000000000000000000000000000000..78fe765488c6451e8eeb23f217d49f750aabfd03 --- /dev/null +++ b/examples/images/diffusion/train_ddp.sh @@ -0,0 +1,5 @@ +HF_DATASETS_OFFLINE=1 +TRANSFORMERS_OFFLINE=1 +DIFFUSERS_OFFLINE=1 + +python main.py --logdir /tmp -t -b /configs/train_ddp.yaml diff --git a/examples/images/dreambooth/README.md b/examples/images/dreambooth/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7c117d841e24ccbb1e5d8cc7856df044986cc03e --- /dev/null +++ b/examples/images/dreambooth/README.md @@ -0,0 +1,123 @@ +# [DreamBooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth) by [colossalai](https://github.com/hpcaitech/ColossalAI.git) + +[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject. +The `train_dreambooth_colossalai.py` script shows how to implement the training procedure and adapt it for stable diffusion. + +By accommodating model data in CPU and GPU and moving the data to the computing device when necessary, [Gemini](https://www.colossalai.org/docs/advanced_tutorials/meet_gemini), the Heterogeneous Memory Manager of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) can breakthrough the GPU memory wall by using GPU and CPU memory (composed of CPU DRAM or nvme SSD memory) together at the same time. Moreover, the model scale can be further improved by combining heterogeneous training with the other parallel approaches, such as data parallel, tensor parallel and pipeline parallel. + +## Installation + +To begin with, make sure your operating system has the cuda version suitable for this exciting training session, which is cuda11.6-11.8. Notice that you may want to make sure the module versions suitable for the whole environment. Before running the scripts, make sure to install the library's training dependencies: + +```bash +pip install -r requirements.txt +``` + +### Install [colossalai](https://github.com/hpcaitech/ColossalAI.git) + +```bash +pip install colossalai +``` + +**From source** + +```bash +git clone https://github.com/hpcaitech/ColossalAI.git +python setup.py install +``` + +## Dataset for Teyvat BLIP captions +Dataset used to train [Teyvat characters text to image model](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion). + +BLIP generated captions for characters images from [genshin-impact fandom wiki](https://genshin-impact.fandom.com/wiki/Character#Playable_Characters)and [biligame wiki for genshin impact](https://wiki.biligame.com/ys/%E8%A7%92%E8%89%B2). + +For each row the dataset contains `image` and `text` keys. `image` is a varying size PIL png, and `text` is the accompanying text caption. Only a train split is provided. + +The `text` include the tag `Teyvat`, `Name`,`Element`, `Weapon`, `Region`, `Model type`, and `Description`, the `Description` is captioned with the [pre-trained BLIP model](https://github.com/salesforce/BLIP). + +## Training + +We provide the script `colossalai.sh` to run the training task with colossalai. Meanwhile, we also provided traditional training process of dreambooth, `dreambooth.sh`, for possible comparation. For instance, the script of training process for [stable-diffusion-v1-4] model can be modified into: + +```bash +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export INSTANCE_DIR="path-to-instance-images" +export OUTPUT_DIR="path-to-save-model" + +torchrun --nproc_per_node 2 train_dreambooth_colossalai.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --instance_prompt="a photo of sks dog" \ + --resolution=512 \ + --train_batch_size=1 \ + --learning_rate=5e-6 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --max_train_steps=400 \ + --placement="cuda" +``` +- `MODEL_NAME` refers to the model you are training. +- `INSTANCE_DIR` refers to personalized path to instance images, you might need to insert information here. +- `OUTPUT_DIR` refers to local path to save the trained model, you might need to find a path with enough space. +- `resolution` refers to the corresponding resolution number of your target model. Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model. +- `placement` refers to the training strategy supported by Colossal AI, default = 'cuda', which refers to loading all the parameters into cuda memory. On the other hand, 'cpu' refers to 'cpu offload' strategy while 'auto' enables 'Gemini', both featured by Colossal AI. + +### Training with prior-preservation loss + +Prior-preservation is used to avoid overfitting and language-drift. Refer to the paper to learn more about it. For prior-preservation we first generate images using the model with a class prompt and then use those during training along with our data. + +According to the paper, it's recommended to generate `num_epochs * num_samples` images for prior-preservation. 200-300 works well for most cases. The `num_class_images` flag sets the number of images to generate with the class prompt. You can place existing images in `class_data_dir`, and the training script will generate any additional images so that `num_class_images` are present in `class_data_dir` during training time. The general script can be then modified as the following. + +```bash +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export INSTANCE_DIR="path-to-instance-images" +export CLASS_DIR="path-to-class-images" +export OUTPUT_DIR="path-to-save-model" + +torchrun --nproc_per_node 2 train_dreambooth_colossalai.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --class_data_dir=$CLASS_DIR \ + --output_dir=$OUTPUT_DIR \ + --with_prior_preservation --prior_loss_weight=1.0 \ + --instance_prompt="a photo of sks dog" \ + --class_prompt="a photo of dog" \ + --resolution=512 \ + --train_batch_size=1 \ + --learning_rate=5e-6 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --max_train_steps=800 \ + --placement="cuda" +``` + +## Inference + +Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. `--instance_prompt="a photo of sks dog" ` in the above example) in your prompt. + +```python +from diffusers import StableDiffusionPipeline +import torch + +model_id = "path-to-save-model" +pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") + +prompt = "A photo of sks dog in a bucket" +image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0] + +image.save("dog-bucket.png") +``` + +## Invitation to open-source contribution +Referring to the successful attempts of [BLOOM](https://bigscience.huggingface.co/) and [Stable Diffusion](https://en.wikipedia.org/wiki/Stable_Diffusion), any and all developers and partners with computing powers, datasets, models are welcome to join and build the Colossal-AI community, making efforts towards the era of big AI models! + +You may contact us or participate in the following ways: +1. [Leaving a Star ⭐](https://github.com/hpcaitech/ColossalAI/stargazers) to show your like and support. Thanks! +2. Posting an [issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose), or submitting a PR on GitHub follow the guideline in [Contributing](https://github.com/hpcaitech/ColossalAI/blob/main/CONTRIBUTING.md). +3. Join the Colossal-AI community on +[Slack](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w), +and [WeChat(微信)](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png "qrcode") to share your ideas. +4. Send your official proposal to email contact@hpcaitech.com + +Thanks so much to all of our amazing contributors! diff --git a/examples/images/dreambooth/colossalai.sh b/examples/images/dreambooth/colossalai.sh new file mode 100755 index 0000000000000000000000000000000000000000..227d8b8bdb0410d29ae61b1f51ff4d6022bc7dd0 --- /dev/null +++ b/examples/images/dreambooth/colossalai.sh @@ -0,0 +1,22 @@ +export MODEL_NAME= +export INSTANCE_DIR= +export CLASS_DIR="path-to-class-images" +export OUTPUT_DIR="path-to-save-model" + +HF_DATASETS_OFFLINE=1 +TRANSFORMERS_OFFLINE=1 +DIFFUSERS_OFFLINE=1 + +torchrun --nproc_per_node 2 --master_port=25641 train_dreambooth_colossalai.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --instance_prompt="a photo of a dog" \ + --resolution=512 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=1 \ + --learning_rate=5e-6 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --num_class_images=200 \ + --placement="cuda" \ diff --git a/examples/images/dreambooth/debug.py b/examples/images/dreambooth/debug.py new file mode 100644 index 0000000000000000000000000000000000000000..33219b2caa298c626f9b3b50e8fc8645c922b9ca --- /dev/null +++ b/examples/images/dreambooth/debug.py @@ -0,0 +1,21 @@ +''' +torchrun --standalone --nproc_per_node=1 debug.py +''' + +from diffusers import AutoencoderKL + +import colossalai +from colossalai.zero import ColoInitContext, post_process_colo_init_ctx + +path = "/data/scratch/diffuser/stable-diffusion-v1-4" + +colossalai.launch_from_torch(config={}) +with ColoInitContext(device='cpu'): + vae = AutoencoderKL.from_pretrained( + path, + subfolder="vae", + revision=None, + ) + +for n, p in vae.named_parameters(): + print(n) diff --git a/examples/images/dreambooth/dreambooth.sh b/examples/images/dreambooth/dreambooth.sh new file mode 100644 index 0000000000000000000000000000000000000000..e063bc8279c53784a16df0a0f8c76bf785ecf46e --- /dev/null +++ b/examples/images/dreambooth/dreambooth.sh @@ -0,0 +1,12 @@ +python train_dreambooth.py \ + --pretrained_model_name_or_path= ## Your Model Path \ + --instance_data_dir= ## Your Training Input Pics Path \ + --output_dir="path-to-save-model" \ + --instance_prompt="a photo of a dog" \ + --resolution=512 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=1 \ + --learning_rate=5e-6 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --num_class_images=200 \ diff --git a/examples/images/dreambooth/inference.py b/examples/images/dreambooth/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..c342821c783003b20b4ddb765c21a99f9397331f --- /dev/null +++ b/examples/images/dreambooth/inference.py @@ -0,0 +1,12 @@ +from diffusers import StableDiffusionPipeline, DiffusionPipeline +import torch + +model_id = +print(f"Loading model... from{model_id}") + +pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") + +prompt = "A photo of an apple." +image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0] + +image.save("output.png") diff --git a/examples/images/dreambooth/requirements.txt b/examples/images/dreambooth/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..1ec828c630ef418242991e1e90fefc054e2ed527 --- /dev/null +++ b/examples/images/dreambooth/requirements.txt @@ -0,0 +1,7 @@ +diffusers>==0.5.0 +accelerate +torchvision +transformers>=4.21.0 +ftfy +tensorboard +modelcards diff --git a/examples/images/dreambooth/test_ci.sh b/examples/images/dreambooth/test_ci.sh new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/images/dreambooth/train_dreambooth.py b/examples/images/dreambooth/train_dreambooth.py new file mode 100644 index 0000000000000000000000000000000000000000..b989955f7fb70c43a8daa5f2d87fbc0a3a3f0465 --- /dev/null +++ b/examples/images/dreambooth/train_dreambooth.py @@ -0,0 +1,694 @@ +import argparse +import hashlib +import itertools +import math +import os +from pathlib import Path +from typing import Optional + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import set_seed +from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel +from diffusers.optimization import get_scheduler +from huggingface_hub import HfFolder, Repository, whoami +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + +logger = get_logger(__name__) + + +def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "RobertaSeriesModelWithTransformation": + from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation + + return RobertaSeriesModelWithTransformation + else: + raise ValueError(f"{model_class} is not supported.") + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + required=True, + help="A folder containing the training data of instance images.", + ) + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=("Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt."), + ) + parser.add_argument( + "--output_dir", + type=str, + default="text-inversion-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=("The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution"), + ) + parser.add_argument("--center_crop", + action="store_true", + help="Whether to center crop images before resizing to resolution") + parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder") + parser.add_argument("--train_batch_size", + type=int, + default=4, + help="Batch size (per device) for the training dataloader.") + parser.add_argument("--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images.") + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.") + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]'), + ) + parser.add_argument("--lr_warmup_steps", + type=int, + default=500, + help="Number of steps for the warmup in the lr scheduler.") + parser.add_argument("--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes.") + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + else: + if args.class_data_dir is not None: + logger.warning("You need not use --class_data_dir without --with_prior_preservation.") + if args.class_prompt is not None: + logger.warning("You need not use --class_prompt without --with_prior_preservation.") + + return args + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images and the tokenizes prompts. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + tokenizer, + class_data_root=None, + class_prompt=None, + size=512, + center_crop=False, + ): + self.size = size + self.center_crop = center_crop + self.tokenizer = tokenizer + + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + self.instance_images_path = list(Path(instance_data_root).iterdir()) + self.num_instance_images = len(self.instance_images_path) + self.instance_prompt = instance_prompt + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + self.class_prompt = class_prompt + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose([ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ]) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + example["instance_images"] = self.image_transforms(instance_image) + example["instance_prompt_ids"] = self.tokenizer( + self.instance_prompt, + padding="do_not_pad", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + example["class_prompt_ids"] = self.tokenizer( + self.class_prompt, + padding="do_not_pad", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids + + return example + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): + if token is None: + token = HfFolder.get_token() + if organization is None: + username = whoami(token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with="tensorboard", + logging_dir=logging_dir, + ) + + # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate + # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. + # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. + if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: + raise ValueError( + "Gradient accumulation is not supported when training the text encoder in distributed training. " + "Please set gradient_accumulation_steps to 1. This feature will be supported in the future.") + + if args.seed is not None: + set_seed(args.seed) + + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + safety_checker=None, + revision=args.revision, + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + + for example in tqdm(sample_dataloader, + desc="Generating class images", + disable=not accelerator.is_local_main_process): + images = pipeline(example["prompt"]).images + + for i, image in enumerate(images): + hash_image = hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Handle the repository creation + if accelerator.is_main_process: + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) + else: + repo_name = args.hub_model_id + repo = Repository(args.output_dir, clone_from=repo_name) + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Load the tokenizer + if args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer_name, + revision=args.revision, + use_fast=False, + ) + elif args.pretrained_model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + ) + + # import correct text encoder class + text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path) + + # Load models and create wrapper for stable diffusion + text_encoder = text_encoder_cls.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + ) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + ) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="unet", + revision=args.revision, + ) + + vae.requires_grad_(False) + if not args.train_text_encoder: + text_encoder.requires_grad_(False) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + if args.train_text_encoder: + text_encoder.gradient_checkpointing_enable() + + if args.scale_lr: + args.learning_rate = (args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * + accelerator.num_processes) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + params_to_optimize = (itertools.chain(unet.parameters(), text_encoder.parameters()) + if args.train_text_encoder else unet.parameters()) + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler") + + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_prompt=args.class_prompt, + tokenizer=tokenizer, + size=args.resolution, + center_crop=args.center_crop, + ) + + def collate_fn(examples): + input_ids = [example["instance_prompt_ids"] for example in examples] + pixel_values = [example["instance_images"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if args.with_prior_preservation: + input_ids += [example["class_prompt_ids"] for example in examples] + pixel_values += [example["class_images"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + input_ids = tokenizer.pad( + { + "input_ids": input_ids + }, + padding="max_length", + max_length=tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + batch = { + "input_ids": input_ids, + "pixel_values": pixel_values, + } + return batch + + train_dataloader = torch.utils.data.DataLoader(train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=collate_fn, + num_workers=1) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + + if args.train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, + lr_scheduler) + + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move text_encode and vae to gpu. + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + vae.to(accelerator.device, dtype=weight_dtype) + if not args.train_text_encoder: + text_encoder.to(accelerator.device, dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("dreambooth", config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + global_step = 0 + + for epoch in range(args.num_train_epochs): + unet.train() + if args.train_text_encoder: + text_encoder.train() + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(unet): + # Convert images to latent space + latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(batch["input_ids"])[0] + + # Predict the noise residual + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if args.with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute instance loss + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() + + # Compute prior loss + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + else: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()) + if args.train_text_encoder else unet.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if global_step % args.save_steps == 0: + if accelerator.is_main_process: + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet), + text_encoder=accelerator.unwrap_model(text_encoder), + revision=args.revision, + ) + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + pipeline.save_pretrained(save_path) + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + accelerator.wait_for_everyone() + + # Create the pipeline using using the trained modules and save it. + if accelerator.is_main_process: + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet), + text_encoder=accelerator.unwrap_model(text_encoder), + revision=args.revision, + ) + pipeline.save_pretrained(args.output_dir) + + if args.push_to_hub: + repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/images/dreambooth/train_dreambooth_colossalai.py b/examples/images/dreambooth/train_dreambooth_colossalai.py new file mode 100644 index 0000000000000000000000000000000000000000..e6159e1058b97f21135e60ffe88e535455e993d3 --- /dev/null +++ b/examples/images/dreambooth/train_dreambooth_colossalai.py @@ -0,0 +1,677 @@ +import argparse +import hashlib +import math +import os +from pathlib import Path +from typing import Optional + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel +from diffusers.optimization import get_scheduler +from huggingface_hub import HfFolder, Repository, create_repo, whoami +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + +import colossalai +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.utils import get_current_device +from colossalai.zero import ColoInitContext, GeminiAdamOptimizer +from colossalai.zero.gemini import get_static_torch_model + +disable_existing_loggers() +logger = get_dist_logger() + + +def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "RobertaSeriesModelWithTransformation": + from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation + + return RobertaSeriesModelWithTransformation + else: + raise ValueError(f"{model_class} is not supported.") + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + required=True, + help="A folder containing the training data of instance images.", + ) + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default="a photo of sks dog", + required=False, + help="The prompt with identifier specifying the instance", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=("Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt."), + ) + parser.add_argument( + "--output_dir", + type=str, + default="text-inversion-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=("The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution"), + ) + parser.add_argument( + "--placement", + type=str, + default="cpu", + help="Placement Policy for Gemini. Valid when using colossalai as dist plan.", + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=("Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping."), + ) + parser.add_argument("--train_batch_size", + type=int, + default=4, + help="Batch size (per device) for the training dataloader.") + parser.add_argument("--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images.") + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.") + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]'), + ) + parser.add_argument("--lr_warmup_steps", + type=int, + default=500, + help="Number of steps for the warmup in the lr scheduler.") + parser.add_argument("--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes.") + + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + else: + if args.class_data_dir is not None: + logger.warning("You need not use --class_data_dir without --with_prior_preservation.") + if args.class_prompt is not None: + logger.warning("You need not use --class_prompt without --with_prior_preservation.") + + return args + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images and the tokenizes prompts. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + tokenizer, + class_data_root=None, + class_prompt=None, + size=512, + center_crop=False, + ): + self.size = size + self.center_crop = center_crop + self.tokenizer = tokenizer + + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + self.instance_images_path = list(Path(instance_data_root).iterdir()) + self.num_instance_images = len(self.instance_images_path) + self.instance_prompt = instance_prompt + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + self.class_prompt = class_prompt + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose([ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ]) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + example["instance_images"] = self.image_transforms(instance_image) + example["instance_prompt_ids"] = self.tokenizer( + self.instance_prompt, + padding="do_not_pad", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + example["class_prompt_ids"] = self.tokenizer( + self.class_prompt, + padding="do_not_pad", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids + + return example + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): + if token is None: + token = HfFolder.get_token() + if organization is None: + username = whoami(token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + +# Gemini + ZeRO DDP +def gemini_zero_dpp(model: torch.nn.Module, placememt_policy: str = "auto"): + from colossalai.nn.parallel import GeminiDDP + + model = GeminiDDP(model, + device=get_current_device(), + placement_policy=placememt_policy, + pin_memory=True, + search_range_mb=64) + return model + + +def main(args): + if args.seed is None: + colossalai.launch_from_torch(config={}) + else: + colossalai.launch_from_torch(config={}, seed=args.seed) + + local_rank = gpc.get_local_rank(ParallelMode.DATA) + world_size = gpc.get_world_size(ParallelMode.DATA) + + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + torch_dtype = torch.float16 if get_current_device() == "cuda" else torch.float32 + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + safety_checker=None, + revision=args.revision, + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + pipeline.to(get_current_device()) + + for example in tqdm( + sample_dataloader, + desc="Generating class images", + disable=not local_rank == 0, + ): + images = pipeline(example["prompt"]).images + + for i, image in enumerate(images): + hash_image = hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + + # Handle the repository creation + if local_rank == 0: + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) + else: + repo_name = args.hub_model_id + create_repo(repo_name, exist_ok=True, token=args.hub_token) + repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Load the tokenizer + if args.tokenizer_name: + logger.info(f"Loading tokenizer from {args.tokenizer_name}", ranks=[0]) + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer_name, + revision=args.revision, + use_fast=False, + ) + elif args.pretrained_model_name_or_path: + logger.info("Loading tokenizer from pretrained model", ranks=[0]) + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + ) + # import correct text encoder class + text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path) + + # Load models and create wrapper for stable diffusion + + logger.info(f"Loading text_encoder from {args.pretrained_model_name_or_path}", ranks=[0]) + + text_encoder = text_encoder_cls.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + ) + + logger.info(f"Loading AutoencoderKL from {args.pretrained_model_name_or_path}", ranks=[0]) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + ) + + logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0]) + with ColoInitContext(device=get_current_device()): + unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, + subfolder="unet", + revision=args.revision, + low_cpu_mem_usage=False) + + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + + if args.scale_lr: + args.learning_rate = args.learning_rate * args.train_batch_size * world_size + + unet = gemini_zero_dpp(unet, args.placement) + + # config optimizer for colossalai zero + optimizer = GeminiAdamOptimizer(unet, lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm) + + # load noise_scheduler + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + + # prepare dataset + logger.info(f"Prepare dataset from {args.instance_data_dir}", ranks=[0]) + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_prompt=args.class_prompt, + tokenizer=tokenizer, + size=args.resolution, + center_crop=args.center_crop, + ) + + def collate_fn(examples): + input_ids = [example["instance_prompt_ids"] for example in examples] + pixel_values = [example["instance_images"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if args.with_prior_preservation: + input_ids += [example["class_prompt_ids"] for example in examples] + pixel_values += [example["class_images"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + input_ids = tokenizer.pad( + { + "input_ids": input_ids + }, + padding="max_length", + max_length=tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + batch = { + "input_ids": input_ids, + "pixel_values": pixel_values, + } + return batch + + train_dataloader = torch.utils.data.DataLoader(train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=collate_fn, + num_workers=1) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader)) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps, + num_training_steps=args.max_train_steps, + ) + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move text_encode and vae to gpu. + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + vae.to(get_current_device(), dtype=weight_dtype) + text_encoder.to(get_current_device(), dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader)) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # Train! + total_batch_size = args.train_batch_size * world_size + + logger.info("***** Running training *****", ranks=[0]) + logger.info(f" Num examples = {len(train_dataset)}", ranks=[0]) + logger.info(f" Num batches each epoch = {len(train_dataloader)}", ranks=[0]) + logger.info(f" Num Epochs = {args.num_train_epochs}", ranks=[0]) + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}", ranks=[0]) + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}", ranks=[0]) + logger.info(f" Total optimization steps = {args.max_train_steps}", ranks=[0]) + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(args.max_train_steps), disable=not local_rank == 0) + progress_bar.set_description("Steps") + global_step = 0 + + torch.cuda.synchronize() + for epoch in range(args.num_train_epochs): + unet.train() + for step, batch in enumerate(train_dataloader): + torch.cuda.reset_peak_memory_stats() + # Move batch to gpu + for key, value in batch.items(): + batch[key] = value.to(get_current_device(), non_blocking=True) + + # Convert images to latent space + optimizer.zero_grad() + + latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(batch["input_ids"])[0] + + # Predict the noise residual + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if args.with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute instance loss + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() + + # Compute prior loss + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + else: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + optimizer.backward(loss) + + optimizer.step() + lr_scheduler.step() + logger.info(f"max GPU_mem cost is {torch.cuda.max_memory_allocated()/2**20} MB", ranks=[0]) + # Checks if the accelerator has performed an optimization step behind the scenes + progress_bar.update(1) + global_step += 1 + logs = { + "loss": loss.detach().item(), + "lr": optimizer.param_groups[0]["lr"], + } # lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step % args.save_steps == 0: + torch.cuda.synchronize() + torch_unet = get_static_torch_model(unet) + if local_rank == 0: + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=torch_unet, + revision=args.revision, + ) + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + pipeline.save_pretrained(save_path) + logger.info(f"Saving model checkpoint to {save_path}", ranks=[0]) + if global_step >= args.max_train_steps: + break + + torch.cuda.synchronize() + unet = get_static_torch_model(unet) + + if local_rank == 0: + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=unet, + revision=args.revision, + ) + + pipeline.save_pretrained(args.output_dir) + logger.info(f"Saving model checkpoint to {args.output_dir}", ranks=[0]) + + if args.push_to_hub: + repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/images/dreambooth/train_dreambooth_colossalai_lora.py b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..1b2fc778d5ed05de78b59464c9bd231d0d411a21 --- /dev/null +++ b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py @@ -0,0 +1,690 @@ +import argparse +import hashlib +import math +import os +from pathlib import Path +from typing import Optional + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel +from diffusers.loaders import AttnProcsLayers +from diffusers.models.cross_attention import LoRACrossAttnProcessor +from diffusers.optimization import get_scheduler +from huggingface_hub import HfFolder, Repository, create_repo, whoami +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + +import colossalai +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.utils import get_current_device +from colossalai.zero import ColoInitContext, GeminiAdamOptimizer +from colossalai.zero.gemini import get_static_torch_model + +disable_existing_loggers() +logger = get_dist_logger() + + +def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "RobertaSeriesModelWithTransformation": + from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation + + return RobertaSeriesModelWithTransformation + else: + raise ValueError(f"{model_class} is not supported.") + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + required=True, + help="A folder containing the training data of instance images.", + ) + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default="a photo of sks dog", + required=False, + help="The prompt with identifier specifying the instance", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=("Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt."), + ) + parser.add_argument( + "--output_dir", + type=str, + default="text-inversion-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=("The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution"), + ) + parser.add_argument( + "--placement", + type=str, + default="cpu", + help="Placement Policy for Gemini. Valid when using colossalai as dist plan.", + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=("Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping."), + ) + parser.add_argument("--train_batch_size", + type=int, + default=4, + help="Batch size (per device) for the training dataloader.") + parser.add_argument("--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images.") + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.") + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]'), + ) + parser.add_argument("--lr_warmup_steps", + type=int, + default=500, + help="Number of steps for the warmup in the lr scheduler.") + parser.add_argument("--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes.") + + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + else: + if args.class_data_dir is not None: + logger.warning("You need not use --class_data_dir without --with_prior_preservation.") + if args.class_prompt is not None: + logger.warning("You need not use --class_prompt without --with_prior_preservation.") + + return args + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images and the tokenizes prompts. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + tokenizer, + class_data_root=None, + class_prompt=None, + size=512, + center_crop=False, + ): + self.size = size + self.center_crop = center_crop + self.tokenizer = tokenizer + + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + self.instance_images_path = list(Path(instance_data_root).iterdir()) + self.num_instance_images = len(self.instance_images_path) + self.instance_prompt = instance_prompt + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + self.class_prompt = class_prompt + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose([ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ]) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + example["instance_images"] = self.image_transforms(instance_image) + example["instance_prompt_ids"] = self.tokenizer( + self.instance_prompt, + padding="do_not_pad", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + example["class_prompt_ids"] = self.tokenizer( + self.class_prompt, + padding="do_not_pad", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids + + return example + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): + if token is None: + token = HfFolder.get_token() + if organization is None: + username = whoami(token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + +# Gemini + ZeRO DDP +def gemini_zero_dpp(model: torch.nn.Module, placememt_policy: str = "auto"): + from colossalai.nn.parallel import GeminiDDP + + model = GeminiDDP(model, + device=get_current_device(), + placement_policy=placememt_policy, + pin_memory=True, + search_range_mb=64) + return model + + +def main(args): + if args.seed is None: + colossalai.launch_from_torch(config={}) + else: + colossalai.launch_from_torch(config={}, seed=args.seed) + + local_rank = gpc.get_local_rank(ParallelMode.DATA) + world_size = gpc.get_world_size(ParallelMode.DATA) + + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + torch_dtype = torch.float16 if get_current_device() == "cuda" else torch.float32 + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + safety_checker=None, + revision=args.revision, + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + pipeline.to(get_current_device()) + + for example in tqdm( + sample_dataloader, + desc="Generating class images", + disable=not local_rank == 0, + ): + images = pipeline(example["prompt"]).images + + for i, image in enumerate(images): + hash_image = hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + + # Handle the repository creation + if local_rank == 0: + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) + else: + repo_name = args.hub_model_id + create_repo(repo_name, exist_ok=True, token=args.hub_token) + repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Load the tokenizer + if args.tokenizer_name: + logger.info(f"Loading tokenizer from {args.tokenizer_name}", ranks=[0]) + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer_name, + revision=args.revision, + use_fast=False, + ) + elif args.pretrained_model_name_or_path: + logger.info("Loading tokenizer from pretrained model", ranks=[0]) + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + ) + # import correct text encoder class + text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path) + + # Load models and create wrapper for stable diffusion + + logger.info(f"Loading text_encoder from {args.pretrained_model_name_or_path}", ranks=[0]) + + text_encoder = text_encoder_cls.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + ) + + logger.info(f"Loading AutoencoderKL from {args.pretrained_model_name_or_path}", ranks=[0]) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + ) + + logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0]) + with ColoInitContext(device=get_current_device()): + unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, + subfolder="unet", + revision=args.revision, + low_cpu_mem_usage=False) + unet.requires_grad_(False) + + # Set correct lora layers + lora_attn_procs = {} + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + + lora_attn_procs[name] = LoRACrossAttnProcessor(hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim) + + unet.set_attn_processor(lora_attn_procs) + lora_layers = AttnProcsLayers(unet.attn_processors) + + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + + if args.scale_lr: + args.learning_rate = args.learning_rate * args.train_batch_size * world_size + + unet = gemini_zero_dpp(unet, args.placement) + + # config optimizer for colossalai zero + optimizer = GeminiAdamOptimizer(unet, lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm) + + # load noise_scheduler + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + + # prepare dataset + logger.info(f"Prepare dataset from {args.instance_data_dir}", ranks=[0]) + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_prompt=args.class_prompt, + tokenizer=tokenizer, + size=args.resolution, + center_crop=args.center_crop, + ) + + def collate_fn(examples): + input_ids = [example["instance_prompt_ids"] for example in examples] + pixel_values = [example["instance_images"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if args.with_prior_preservation: + input_ids += [example["class_prompt_ids"] for example in examples] + pixel_values += [example["class_images"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + input_ids = tokenizer.pad( + { + "input_ids": input_ids + }, + padding="max_length", + max_length=tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + batch = { + "input_ids": input_ids, + "pixel_values": pixel_values, + } + return batch + + train_dataloader = torch.utils.data.DataLoader(train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=collate_fn, + num_workers=1) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader)) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps, + num_training_steps=args.max_train_steps, + ) + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move text_encode and vae to gpu. + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + vae.to(get_current_device(), dtype=weight_dtype) + text_encoder.to(get_current_device(), dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader)) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # Train! + total_batch_size = args.train_batch_size * world_size + + logger.info("***** Running training *****", ranks=[0]) + logger.info(f" Num examples = {len(train_dataset)}", ranks=[0]) + logger.info(f" Num batches each epoch = {len(train_dataloader)}", ranks=[0]) + logger.info(f" Num Epochs = {args.num_train_epochs}", ranks=[0]) + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}", ranks=[0]) + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}", ranks=[0]) + logger.info(f" Total optimization steps = {args.max_train_steps}", ranks=[0]) + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(args.max_train_steps), disable=not local_rank == 0) + progress_bar.set_description("Steps") + global_step = 0 + + torch.cuda.synchronize() + for epoch in range(args.num_train_epochs): + unet.train() + for step, batch in enumerate(train_dataloader): + torch.cuda.reset_peak_memory_stats() + # Move batch to gpu + for key, value in batch.items(): + batch[key] = value.to(get_current_device(), non_blocking=True) + + # Convert images to latent space + optimizer.zero_grad() + + latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(batch["input_ids"])[0] + + # Predict the noise residual + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if args.with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute instance loss + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() + + # Compute prior loss + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + else: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + optimizer.backward(loss) + + optimizer.step() + lr_scheduler.step() + logger.info(f"max GPU_mem cost is {torch.cuda.max_memory_allocated()/2**20} MB", ranks=[0]) + # Checks if the accelerator has performed an optimization step behind the scenes + progress_bar.update(1) + global_step += 1 + logs = { + "loss": loss.detach().item(), + "lr": optimizer.param_groups[0]["lr"], + } # lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step % args.save_steps == 0: + torch.cuda.synchronize() + torch_unet = get_static_torch_model(unet) + if local_rank == 0: + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + torch_unet = torch_unet.to(torch.float32) + torch_unet.save_attn_procs(save_path) + logger.info(f"Saving model checkpoint to {save_path}", ranks=[0]) + if global_step >= args.max_train_steps: + break + + torch.cuda.synchronize() + torch_unet = get_static_torch_model(unet) + + if local_rank == 0: + torch_unet = torch_unet.to(torch.float32) + torch_unet.save_attn_procs(save_path) + logger.info(f"Saving model checkpoint to {args.output_dir}", ranks=[0]) + + if args.push_to_hub: + repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/images/dreambooth/train_dreambooth_inpaint.py b/examples/images/dreambooth/train_dreambooth_inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..774cd4c458e9c02548a1165376857eb8398c4311 --- /dev/null +++ b/examples/images/dreambooth/train_dreambooth_inpaint.py @@ -0,0 +1,720 @@ +import argparse +import hashlib +import itertools +import math +import os +import random +from pathlib import Path +from typing import Optional + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import set_seed +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + StableDiffusionInpaintPipeline, + StableDiffusionPipeline, + UNet2DConditionModel, +) +from diffusers.optimization import get_scheduler +from huggingface_hub import HfFolder, Repository, whoami +from PIL import Image, ImageDraw +from torch.utils.data import Dataset +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer + +logger = get_logger(__name__) + + +def prepare_mask_and_masked_image(image, mask): + image = np.array(image.convert("RGB")) + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + + mask = np.array(mask.convert("L")) + mask = mask.astype(np.float32) / 255.0 + mask = mask[None, None] + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + masked_image = image * (mask < 0.5) + + return mask, masked_image + + +# generate random masks +def random_mask(im_shape, ratio=1, mask_full_image=False): + mask = Image.new("L", im_shape, 0) + draw = ImageDraw.Draw(mask) + size = (random.randint(0, int(im_shape[0] * ratio)), random.randint(0, int(im_shape[1] * ratio))) + # use this to always mask the whole image + if mask_full_image: + size = (int(im_shape[0] * ratio), int(im_shape[1] * ratio)) + limits = (im_shape[0] - size[0] // 2, im_shape[1] - size[1] // 2) + center = (random.randint(size[0] // 2, limits[0]), random.randint(size[1] // 2, limits[1])) + draw_type = random.randint(0, 1) + if draw_type == 0 or mask_full_image: + draw.rectangle( + (center[0] - size[0] // 2, center[1] - size[1] // 2, center[0] + size[0] // 2, center[1] + size[1] // 2), + fill=255, + ) + else: + draw.ellipse( + (center[0] - size[0] // 2, center[1] - size[1] // 2, center[0] + size[0] // 2, center[1] + size[1] // 2), + fill=255, + ) + + return mask + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + required=True, + help="A folder containing the training data of instance images.", + ) + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + help="The prompt with identifier specifying the instance", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=("Minimal class images for prior preservation loss. If not have enough images, additional images will be" + " sampled with class_prompt."), + ) + parser.add_argument( + "--output_dir", + type=str, + default="text-inversion-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=("The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution"), + ) + parser.add_argument("--center_crop", + action="store_true", + help="Whether to center crop images before resizing to resolution") + parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder") + parser.add_argument("--train_batch_size", + type=int, + default=4, + help="Batch size (per device) for the training dataloader.") + parser.add_argument("--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images.") + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]'), + ) + parser.add_argument("--lr_warmup_steps", + type=int, + default=500, + help="Number of steps for the warmup in the lr scheduler.") + parser.add_argument("--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes.") + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help=("Whether to use mixed precision. Choose" + "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." + "and an Nvidia Ampere GPU."), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.instance_data_dir is None: + raise ValueError("You must specify a train data directory.") + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + + return args + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images and the tokenizes prompts. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + tokenizer, + class_data_root=None, + class_prompt=None, + size=512, + center_crop=False, + ): + self.size = size + self.center_crop = center_crop + self.tokenizer = tokenizer + + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + self.instance_images_path = list(Path(instance_data_root).iterdir()) + self.num_instance_images = len(self.instance_images_path) + self.instance_prompt = instance_prompt + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + self.class_prompt = class_prompt + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose([ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ]) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + + example["PIL_images"] = instance_image + example["instance_images"] = self.image_transforms(instance_image) + + example["instance_prompt_ids"] = self.tokenizer( + self.instance_prompt, + padding="do_not_pad", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + example["class_PIL_images"] = class_image + example["class_prompt_ids"] = self.tokenizer( + self.class_prompt, + padding="do_not_pad", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids + + return example + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): + if token is None: + token = HfFolder.get_token() + if organization is None: + username = whoami(token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + +def main(): + args = parse_args() + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with="tensorboard", + logging_dir=logging_dir, + ) + + # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate + # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. + # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. + if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: + raise ValueError( + "Gradient accumulation is not supported when training the text encoder in distributed training. " + "Please set gradient_accumulation_steps to 1. This feature will be supported in the future.") + + if args.seed is not None: + set_seed(args.seed) + + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 + pipeline = StableDiffusionInpaintPipeline.from_pretrained(args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + safety_checker=None) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, + batch_size=args.sample_batch_size, + num_workers=1) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + transform_to_pil = transforms.ToPILImage() + for example in tqdm(sample_dataloader, + desc="Generating class images", + disable=not accelerator.is_local_main_process): + bsz = len(example["prompt"]) + fake_images = torch.rand((3, args.resolution, args.resolution)) + transform_to_pil = transforms.ToPILImage() + fake_pil_images = transform_to_pil(fake_images) + + fake_mask = random_mask((args.resolution, args.resolution), ratio=1, mask_full_image=True) + + images = pipeline(prompt=example["prompt"], mask_image=fake_mask, image=fake_pil_images).images + + for i, image in enumerate(images): + hash_image = hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Handle the repository creation + if accelerator.is_main_process: + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) + else: + repo_name = args.hub_model_id + repo = Repository(args.output_dir, clone_from=repo_name) + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Load the tokenizer + if args.tokenizer_name: + tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) + elif args.pretrained_model_name_or_path: + tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") + + # Load models and create wrapper for stable diffusion + text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") + unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") + + vae.requires_grad_(False) + if not args.train_text_encoder: + text_encoder.requires_grad_(False) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + if args.train_text_encoder: + text_encoder.gradient_checkpointing_enable() + + if args.scale_lr: + args.learning_rate = (args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * + accelerator.num_processes) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + params_to_optimize = (itertools.chain(unet.parameters(), text_encoder.parameters()) + if args.train_text_encoder else unet.parameters()) + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler") + + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_prompt=args.class_prompt, + tokenizer=tokenizer, + size=args.resolution, + center_crop=args.center_crop, + ) + + def collate_fn(examples): + image_transforms = transforms.Compose([ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), + ]) + input_ids = [example["instance_prompt_ids"] for example in examples] + pixel_values = [example["instance_images"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if args.with_prior_preservation: + input_ids += [example["class_prompt_ids"] for example in examples] + pixel_values += [example["class_images"] for example in examples] + pior_pil = [example["class_PIL_images"] for example in examples] + + masks = [] + masked_images = [] + for example in examples: + pil_image = example["PIL_images"] + # generate a random mask + mask = random_mask(pil_image.size, 1, False) + # apply transforms + mask = image_transforms(mask) + pil_image = image_transforms(pil_image) + # prepare mask and masked image + mask, masked_image = prepare_mask_and_masked_image(pil_image, mask) + + masks.append(mask) + masked_images.append(masked_image) + + if args.with_prior_preservation: + for pil_image in pior_pil: + # generate a random mask + mask = random_mask(pil_image.size, 1, False) + # apply transforms + mask = image_transforms(mask) + pil_image = image_transforms(pil_image) + # prepare mask and masked image + mask, masked_image = prepare_mask_and_masked_image(pil_image, mask) + + masks.append(mask) + masked_images.append(masked_image) + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids + masks = torch.stack(masks) + masked_images = torch.stack(masked_images) + batch = {"input_ids": input_ids, "pixel_values": pixel_values, "masks": masks, "masked_images": masked_images} + return batch + + train_dataloader = torch.utils.data.DataLoader(train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=collate_fn) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + + if args.train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, + lr_scheduler) + + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move text_encode and vae to gpu. + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + vae.to(accelerator.device, dtype=weight_dtype) + if not args.train_text_encoder: + text_encoder.to(accelerator.device, dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("dreambooth", config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + global_step = 0 + + for epoch in range(args.num_train_epochs): + unet.train() + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(unet): + # Convert images to latent space + + latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + + # Convert masked images to latent space + masked_latents = vae.encode(batch["masked_images"].reshape( + batch["pixel_values"].shape).to(dtype=weight_dtype)).latent_dist.sample() + masked_latents = masked_latents * 0.18215 + + masks = batch["masks"] + # resize the mask to latents shape as we concatenate the mask to the latents + mask = torch.stack([ + torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8)) + for mask in masks + ]) + mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # concatenate the noised latents with the mask and the masked latents + latent_model_input = torch.cat([noisy_latents, mask, masked_latents], dim=1) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(batch["input_ids"])[0] + + # Predict the noise residual + noise_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if args.with_prior_preservation: + # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. + noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute instance loss + loss = F.mse_loss(noise_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() + + # Compute prior loss + prior_loss = F.mse_loss(noise_pred_prior.float(), target_prior.float(), reduction="mean") + + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + else: + loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean") + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()) + if args.train_text_encoder else unet.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + accelerator.wait_for_everyone() + + # Create the pipeline using using the trained modules and save it. + if accelerator.is_main_process: + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet), + text_encoder=accelerator.unwrap_model(text_encoder), + ) + pipeline.save_pretrained(args.output_dir) + + if args.push_to_hub: + repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + + accelerator.end_training() + + +if __name__ == "__main__": + main() diff --git a/examples/images/vit/README.md b/examples/images/vit/README.md new file mode 100644 index 0000000000000000000000000000000000000000..4423d85d19e0549a04138097aa7526d8f13d71a4 --- /dev/null +++ b/examples/images/vit/README.md @@ -0,0 +1,61 @@ +# Vision Transformer with ColoTensor + +# Overview + +In this example, we will run Vision Transformer with ColoTensor. + +We use model **ViTForImageClassification** from Hugging Face [Link](https://huggingface.co/docs/transformers/model_doc/vit) for unit test. +You can change world size or decide whether use DDP in our code. + +We use model **vision_transformer** from timm [Link](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) for training example. + +(2022/6/28) The default configuration now supports 2DP+2TP with gradient accumulation and checkpoint support. Zero is not supported at present. + +# Requirement + +Install colossalai version >= 0.1.11 + +## Unit test +To run unit test, you should install pytest, transformers with: +```shell +pip install pytest transformers +``` + +## Training example +To run training example with ViT-S, you should install **NVIDIA DALI** from [Link](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/installation.html) for dataloader support. +You also need to install timm and titans for model/dataloader support with: +```shell +pip install timm titans +``` + +### Data preparation +You can download the ImageNet dataset from the [ImageNet official website](https://www.image-net.org/download.php). You should get the raw images after downloading the dataset. As we use **NVIDIA DALI** to read data, we use the TFRecords dataset instead of raw Imagenet dataset. This offers better speedup to IO. If you don't have TFRecords dataset, follow [imagenet-tools](https://github.com/ver217/imagenet-tools) to build one. + +Before you start training, you need to set the environment variable `DATA` so that the script knows where to fetch the data for DALI dataloader. +```shell +export DATA=/path/to/ILSVRC2012 +``` + + +# How to run + +## Unit test +In your terminal +```shell +pytest test_vit.py +``` + +This will evaluate models with different **world_size** and **use_ddp**. + +## Training example +Modify the settings in run.sh according to your environment. +For example, if you set `--nproc_per_node=8` in `run.sh` and `TP_WORLD_SIZE=2` in your config file, +data parallel size will be automatically calculated as 4. +Thus, the parallel strategy is set to 4DP+2TP. + +Then in your terminal +```shell +sh run.sh +``` + +This will start ViT-S training with ImageNet. diff --git a/examples/images/vit/configs/vit_1d_tp2.py b/examples/images/vit/configs/vit_1d_tp2.py new file mode 100644 index 0000000000000000000000000000000000000000..fbf399f2e50daaa70f52726ccfa6f4e035ce7380 --- /dev/null +++ b/examples/images/vit/configs/vit_1d_tp2.py @@ -0,0 +1,32 @@ +from colossalai.amp import AMP_TYPE + +# hyperparameters +# BATCH_SIZE is as per GPU +# global batch size = BATCH_SIZE x data parallel size +BATCH_SIZE = 256 +LEARNING_RATE = 3e-3 +WEIGHT_DECAY = 0.3 +NUM_EPOCHS = 300 +WARMUP_EPOCHS = 32 + +# model config +IMG_SIZE = 224 +PATCH_SIZE = 16 +HIDDEN_SIZE = 384 +DEPTH = 12 +NUM_HEADS = 6 +MLP_RATIO = 4 +NUM_CLASSES = 1000 +CHECKPOINT = False +SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token + +USE_DDP = True +TP_WORLD_SIZE = 2 +TP_TYPE = 'row' +parallel = dict(tensor=dict(mode="1d", size=TP_WORLD_SIZE),) + +fp16 = dict(mode=AMP_TYPE.NAIVE) +clip_grad_norm = 1.0 +gradient_accumulation = 8 + +LOG_PATH = "./log" diff --git a/examples/images/vit/configs/vit_1d_tp2_ci.py b/examples/images/vit/configs/vit_1d_tp2_ci.py new file mode 100644 index 0000000000000000000000000000000000000000..e491e4ada45e25f57c8e6c93df41e36794f0b420 --- /dev/null +++ b/examples/images/vit/configs/vit_1d_tp2_ci.py @@ -0,0 +1,32 @@ +from colossalai.amp import AMP_TYPE + +# hyperparameters +# BATCH_SIZE is as per GPU +# global batch size = BATCH_SIZE x data parallel size +BATCH_SIZE = 8 +LEARNING_RATE = 3e-3 +WEIGHT_DECAY = 0.3 +NUM_EPOCHS = 3 +WARMUP_EPOCHS = 1 + +# model config +IMG_SIZE = 224 +PATCH_SIZE = 16 +HIDDEN_SIZE = 32 +DEPTH = 2 +NUM_HEADS = 4 +MLP_RATIO = 4 +NUM_CLASSES = 10 +CHECKPOINT = False +SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token + +USE_DDP = True +TP_WORLD_SIZE = 2 +TP_TYPE = 'row' +parallel = dict(tensor=dict(mode="1d", size=TP_WORLD_SIZE),) + +fp16 = dict(mode=AMP_TYPE.NAIVE) +clip_grad_norm = 1.0 +gradient_accumulation = 2 + +LOG_PATH = "./log_ci" diff --git a/examples/images/vit/requirements.txt b/examples/images/vit/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..1f69794ebe700eda3b4f37e17d9253978a938125 --- /dev/null +++ b/examples/images/vit/requirements.txt @@ -0,0 +1,8 @@ +colossalai >= 0.1.12 +torch >= 1.8.1 +numpy>=1.24.1 +timm>=0.6.12 +titans>=0.0.7 +tqdm>=4.61.2 +transformers>=4.25.1 +nvidia-dali-cuda110>=1.8.0 --extra-index-url https://developer.download.nvidia.com/compute/redist diff --git a/examples/images/vit/run.sh b/examples/images/vit/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..84fe58f11a6a7c25242eab1714fd94698509be9a --- /dev/null +++ b/examples/images/vit/run.sh @@ -0,0 +1,15 @@ +export DATA=/data/scratch/imagenet/tf_records +export OMP_NUM_THREADS=4 + +# resume +# CUDA_VISIBLE_DEVICES=4,5,6,7 colossalai run \ +# --nproc_per_node 4 train.py \ +# --config configs/vit_1d_tp2.py \ +# --resume_from checkpoint/epoch_10 \ +# --master_port 29598 | tee ./out 2>&1 + +# train +CUDA_VISIBLE_DEVICES=4,5,6,7 colossalai run \ +--nproc_per_node 4 train.py \ +--config configs/vit_1d_tp2.py \ +--master_port 29598 | tee ./out 2>&1 diff --git a/examples/images/vit/test_ci.sh b/examples/images/vit/test_ci.sh new file mode 100644 index 0000000000000000000000000000000000000000..41d25ee23521d85040efa510a6b792344c0a62bb --- /dev/null +++ b/examples/images/vit/test_ci.sh @@ -0,0 +1,9 @@ +export OMP_NUM_THREADS=4 + +pip install -r requirements.txt + +# train +colossalai run \ +--nproc_per_node 4 train.py \ +--config configs/vit_1d_tp2_ci.py \ +--dummy_data diff --git a/examples/images/vit/test_vit.py b/examples/images/vit/test_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..c0ae35bca87169c2813a3c796dd84064202742ac --- /dev/null +++ b/examples/images/vit/test_vit.py @@ -0,0 +1,160 @@ +import os +import random + +import numpy as np +import pytest +import torch +from torch.nn.parallel import DistributedDataParallel as DDP +from vit import get_training_components + +import colossalai +from colossalai.context import ParallelMode +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.nn.parallel.data_parallel import ColoDDP +from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils.cuda import get_current_device +from colossalai.zero import ColoInitContext + + +def set_seed(seed): + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + + +def tensor_equal(A, B): + return torch.allclose(A, B, rtol=1e-3, atol=1e-1) + + +def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor): + assert tensor.ndim == shard.ndim + if tensor.shape == shard.shape: + return tensor_equal(tensor, shard) + else: + dims_not_eq = torch.nonzero(torch.tensor(tensor.shape) != torch.tensor(shard.shape)) + if dims_not_eq.numel() == 1: + # 1D shard + dim = dims_not_eq.item() + world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) + rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + return tensor_equal(tensor.chunk(world_size, dim)[rank], shard) + else: + raise + + +# Only for all Linear, it's 1d_row split because Linear will be transposed when calculating. +# But for other layers, it's 1d_col split. +# Layernorm is not supported for now. +# patch_embeddings.projection has nn.Conv2d +# https://github.com/huggingface/transformers/blob/dcb08b99f44919425f8ba9be9ddcc041af8ec25e/src/transformers/models/vit/modeling_vit.py#L182 +def init_1d_row_for_linear_weight_spec(model, world_size: int): + pg = ProcessGroup(tp_degree=world_size) + spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + with DistSpecManager.no_grad(): + for n, p in model.named_parameters(): + if 'weight' in n and 'layernorm' not in n and 'embeddings.patch_embeddings.projection.weight' not in n: + p.set_process_group(pg) + p.set_tensor_spec(*spec) + + +# Similarly, it's col split for Linear but row split for others. +def init_1d_col_for_linear_weight_bias_spec(model, world_size: int): + pg = ProcessGroup(tp_degree=world_size) + spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + with DistSpecManager.no_grad(): + for n, p in model.named_parameters(): + if ('weight' in n + or 'bias' in n) and 'layernorm' not in n and 'embeddings.patch_embeddings.projection' not in n: + p.set_process_group(pg) + p.set_tensor_spec(*spec) + + +def check_param_equal(model, torch_model): + for p, torch_p in zip(model.parameters(), torch_model.parameters()): + assert tensor_shard_equal(torch_p, p) + + +def check_grad_equal(model, torch_model): + for p, torch_p in zip(model.parameters(), torch_model.parameters()): + if (torch_p.grad.shape == p.grad.shape): + assert torch.allclose(torch_p.grad, p.grad, rtol=1e-3, atol=2.0) == True + else: + dims_not_eq = torch.nonzero(torch.tensor(torch_p.grad.shape) != torch.tensor(p.grad.shape)) + dim = dims_not_eq.item() + world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) + rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + assert torch.allclose(torch_p.grad.chunk(world_size, dim)[rank], p.grad, rtol=1e-3, atol=2.0) == True + + +def run_vit(init_spec_func, use_ddp): + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_training_components() + with ColoInitContext(device=get_current_device()): + model = model_builder() + model = model.cuda() + torch_model = model_builder().cuda() + if use_ddp: + model = ColoDDP(model) + torch_model = DDP(torch_model, + device_ids=[gpc.get_global_rank()], + process_group=gpc.get_group(ParallelMode.DATA)) + for torch_p, p in zip(torch_model.parameters(), model.parameters()): + torch_p.data.copy_(p) + + world_size = torch.distributed.get_world_size() + init_spec_func(model, world_size) + + check_param_equal(model, torch_model) + model.train() + torch_model.train() + set_seed(gpc.get_local_rank(ParallelMode.DATA)) + + optimizer = optimizer_class(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) + torch_optimizer = optimizer_class(torch_model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) + + for i, image_dict in enumerate(train_dataloader): + if use_ddp: + model.zero_grad() + else: + optimizer.zero_grad() + logits = model(image_dict['pixel_values']) + torch_logits = torch_model(image_dict['pixel_values']) + assert tensor_equal(torch_logits.logits, logits.logits) + loss = criterion(logits.logits, image_dict['label']) + torch_loss = criterion(torch_logits.logits, image_dict['label']) + if use_ddp: + model.backward(loss) + else: + loss.backward() + torch_loss.backward() + check_grad_equal(model, torch_model) + optimizer.step() + torch_optimizer.step() + check_param_equal(model, torch_model) + break + + +def run_dist(rank, world_size, port, use_ddp): + if use_ddp and world_size == 1: + return + tp_world_size = world_size // 2 if use_ddp else world_size + config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),)) + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_vit(init_1d_row_for_linear_weight_spec, use_ddp) + run_vit(init_1d_col_for_linear_weight_bias_spec, use_ddp) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@pytest.mark.parametrize('use_ddp', [False, True]) +@rerun_if_address_is_in_use() +def test_vit(world_size, use_ddp): + spawn(run_dist, world_size, use_ddp=use_ddp) + + +if __name__ == '__main__': + test_vit(1, False) diff --git a/examples/images/vit/train.py b/examples/images/vit/train.py new file mode 100644 index 0000000000000000000000000000000000000000..b42cf2bedc6bf738c17230c9fa40fcfc3418ba99 --- /dev/null +++ b/examples/images/vit/train.py @@ -0,0 +1,174 @@ +import os + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from timm.models.vision_transformer import _create_vision_transformer +from titans.dataloader.imagenet import build_dali_imagenet +from tqdm import tqdm +from vit import DummyDataLoader + +import colossalai +from colossalai.core import global_context as gpc +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn import CrossEntropyLoss +from colossalai.nn._ops import * +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.nn.optimizer import HybridAdam +from colossalai.nn.parallel.data_parallel import ColoDDP +from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec +from colossalai.utils import get_current_device +from colossalai.zero import ColoInitContext + + +def init_1d_row_for_linear_weight_spec(model, world_size: int): + pg = ProcessGroup(tp_degree=world_size) + spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + with DistSpecManager.no_grad(): + for n, p in model.named_parameters(): + if 'weight' in n and 'norm' not in n and 'patch_embed.proj.weight' not in n: + p.set_process_group(pg) + p.set_tensor_spec(*spec) + + +# Similarly, it's col split for Linear but row split for others. +def init_1d_col_for_linear_weight_bias_spec(model, world_size: int): + pg = ProcessGroup(tp_degree=world_size) + spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + with DistSpecManager.no_grad(): + for n, p in model.named_parameters(): + if ('weight' in n or 'bias' in n) and 'norm' not in n and ('patch_embed.proj.weight' not in n + and 'patch_embed.proj.bias' not in n): + p.set_process_group(pg) + p.set_tensor_spec(*spec) + + +def init_spec_func(model, tp_type): + world_size = torch.distributed.get_world_size() + if tp_type == 'row': + init_1d_row_for_linear_weight_spec(model, world_size) + elif tp_type == 'col': + init_1d_col_for_linear_weight_bias_spec(model, world_size) + else: + raise NotImplemented + + +def train_imagenet(): + + parser = colossalai.get_default_parser() + parser.add_argument('--resume_from', default=False, action='store_true') + parser.add_argument('--dummy_data', default=False, action='store_true') + + args = parser.parse_args() + colossalai.launch_from_torch(config=args.config) + use_ddp = gpc.config.USE_DDP + + disable_existing_loggers() + + logger = get_dist_logger() + if hasattr(gpc.config, 'LOG_PATH'): + if gpc.get_global_rank() == 0: + log_path = gpc.config.LOG_PATH + if not os.path.exists(log_path): + os.mkdir(log_path) + logger.log_to_file(log_path) + + logger.info('Build data loader', ranks=[0]) + if not args.dummy_data: + root = os.environ['DATA'] + train_dataloader, test_dataloader = build_dali_imagenet(root, + train_batch_size=gpc.config.BATCH_SIZE, + test_batch_size=gpc.config.BATCH_SIZE) + else: + train_dataloader = DummyDataLoader(length=10, + batch_size=gpc.config.BATCH_SIZE, + category=gpc.config.NUM_CLASSES, + image_size=gpc.config.IMG_SIZE, + return_dict=False) + test_dataloader = DummyDataLoader(length=5, + batch_size=gpc.config.BATCH_SIZE, + category=gpc.config.NUM_CLASSES, + image_size=gpc.config.IMG_SIZE, + return_dict=False) + + logger.info('Build model', ranks=[0]) + + model_kwargs = dict(img_size=gpc.config.IMG_SIZE, + patch_size=gpc.config.PATCH_SIZE, + embed_dim=gpc.config.HIDDEN_SIZE, + depth=gpc.config.DEPTH, + num_heads=gpc.config.NUM_HEADS, + mlp_ratio=gpc.config.MLP_RATIO, + num_classes=gpc.config.NUM_CLASSES, + drop_rate=0.1, + attn_drop_rate=0.1, + weight_init='jax') + + with ColoInitContext(device=get_current_device()): + model = _create_vision_transformer('vit_small_patch16_224', pretrained=False, **model_kwargs) + init_spec_func(model, gpc.config.TP_TYPE) + + world_size = torch.distributed.get_world_size() + model = ColoDDP(module=model, process_group=ProcessGroup(tp_degree=world_size)) + logger.info('Build criterion, optimizer, lr_scheduler', ranks=[0]) + optimizer = HybridAdam(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY) + + criterion = CrossEntropyLoss() + lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, + total_steps=gpc.config.NUM_EPOCHS, + warmup_steps=gpc.config.WARMUP_EPOCHS) + + start_epoch = 0 + if args.resume_from: + load_model = torch.load(args.resume_from + '_model.pth') + start_epoch = load_model['epoch'] + model.load_state_dict(load_model['model']) + load_optim = torch.load(args.resume_from + '_optim_rank_{}.pth'.format(dist.get_rank())) + optimizer.load_state_dict(load_optim['optim']) + + for epoch in range(start_epoch, gpc.config.NUM_EPOCHS): + model.train() + for index, (x, y) in tqdm(enumerate(train_dataloader), total=len(train_dataloader), leave=False): + x, y = x.cuda(), y.cuda() + output = model(x) + loss = criterion(output, y) + loss = loss / gpc.config.gradient_accumulation + if use_ddp: + model.backward(loss) + else: + loss.backward() + if (index + 1) % gpc.config.gradient_accumulation == 0: + optimizer.step() + if use_ddp: + model.zero_grad() + else: + optimizer.zero_grad() + + logger.info( + f"Finish Train Epoch [{epoch+1}/{gpc.config.NUM_EPOCHS}] loss: {loss.item():.3f} lr: {optimizer.state_dict()['param_groups'][0]['lr']}", + ranks=[0]) + + model.eval() + test_loss = 0 + correct = 0 + test_sum = 0 + with torch.no_grad(): + for index, (x, y) in tqdm(enumerate(test_dataloader), total=len(test_dataloader), leave=False): + x, y = x.cuda(), y.cuda() + output = model(x) + test_loss += F.cross_entropy(output, y, reduction='sum').item() + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq(y.view_as(pred)).sum().item() + test_sum += y.size(0) + + test_loss /= test_sum + logger.info( + f"Finish Test Epoch [{epoch+1}/{gpc.config.NUM_EPOCHS}] loss: {test_loss:.3f} Accuracy: [{correct}/{test_sum}]({correct/test_sum:.3f})", + ranks=[0]) + + lr_scheduler.step() + + +if __name__ == '__main__': + train_imagenet() diff --git a/examples/images/vit/vit.py b/examples/images/vit/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..f22e8ea90cecca131c4314d7eb7fb02c74cb371c --- /dev/null +++ b/examples/images/vit/vit.py @@ -0,0 +1,95 @@ +from abc import ABC, abstractmethod + +import torch +import torch.nn as nn +from transformers import ViTConfig, ViTForImageClassification + +from colossalai.utils.cuda import get_current_device + + +class DummyDataGenerator(ABC): + + def __init__(self, length=10): + self.length = length + + @abstractmethod + def generate(self): + pass + + def __iter__(self): + self.step = 0 + return self + + def __next__(self): + if self.step < self.length: + self.step += 1 + return self.generate() + else: + raise StopIteration + + def __len__(self): + return self.length + + +class DummyDataLoader(DummyDataGenerator): + + def __init__(self, length=10, batch_size=4, channel=3, category=8, image_size=224, return_dict=True): + super().__init__(length) + self.batch_size = batch_size + self.channel = channel + self.category = category + self.image_size = image_size + self.return_dict = return_dict + + def generate(self): + image_dict = {} + image_dict['pixel_values'] = torch.rand( + self.batch_size, self.channel, self.image_size, self.image_size, device=get_current_device()) * 2 - 1 + image_dict['label'] = torch.randint(self.category, (self.batch_size,), + dtype=torch.int64, + device=get_current_device()) + if not self.return_dict: + return image_dict['pixel_values'], image_dict['label'] + return image_dict + + +class ViTCVModel(nn.Module): + + def __init__(self, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + image_size=224, + patch_size=16, + num_channels=3, + num_labels=8, + checkpoint=False): + super().__init__() + self.checkpoint = checkpoint + self.model = ViTForImageClassification( + ViTConfig(hidden_size=hidden_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + image_size=image_size, + patch_size=patch_size, + num_channels=num_channels, + num_labels=num_labels)) + if checkpoint: + self.model.gradient_checkpointing_enable() + + def forward(self, pixel_values): + return self.model(pixel_values=pixel_values) + + +def vit_base_s(checkpoint=True): + return ViTCVModel(checkpoint=checkpoint) + + +def vit_base_micro(checkpoint=True): + return ViTCVModel(hidden_size=32, num_hidden_layers=2, num_attention_heads=4, checkpoint=checkpoint) + + +def get_training_components(): + trainloader = DummyDataLoader() + testloader = DummyDataLoader() + return vit_base_micro, trainloader, testloader, torch.optim.Adam, torch.nn.functional.cross_entropy diff --git a/examples/language/bert/run_gemini.sh b/examples/language/bert/run_gemini.sh new file mode 100644 index 0000000000000000000000000000000000000000..d791334e8c97312f343a0cea6ac70d9ba4f7d3fe --- /dev/null +++ b/examples/language/bert/run_gemini.sh @@ -0,0 +1,22 @@ +set -x +# distplan in ["CAI_ZeRO1", "CAI_ZeRO2", "CAI_Gemini", "Pytorch_DDP", "Pytorch_ZeRO"] +export DISTPLAN=${DISTPLAN:-"CAI_Gemini"} + +# The following options only valid when DISTPLAN="colossalai" +export GPUNUM=${GPUNUM:-1} +export PLACEMENT=${PLACEMENT:-"cpu"} +export BATCH_SIZE=${BATCH_SIZE:-16} + +# bert | albert +export MODEL_TYPE=${MODEL_TYPE:-"bert"} +export TRAIN_STEP=${TRAIN_STEP:-10} + +mkdir -p gemini_logs + +env CUDA_LAUNCH_BLOCKING=1 torchrun --standalone --nproc_per_node=${GPUNUM} ./train_bert_demo.py \ +--model_type=${MODEL_TYPE} \ +--batch_size=${BATCH_SIZE} \ +--placement=${PLACEMENT} \ +--distplan=${DISTPLAN} \ +--train_step=${TRAIN_STEP} \ +2>&1 | tee ./gemini_logs/${MODEL_TYPE}_${DISTPLAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_${PLACEMENT}.log diff --git a/examples/language/bert/test_ci.sh b/examples/language/bert/test_ci.sh new file mode 100644 index 0000000000000000000000000000000000000000..42c63fec50c0e3729704a4944f22c19116e76494 --- /dev/null +++ b/examples/language/bert/test_ci.sh @@ -0,0 +1,2 @@ +set -x +env GPUNUM=1 bash run_gemini.sh diff --git a/examples/language/bert/train_bert_demo.py b/examples/language/bert/train_bert_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..9a0278b2c711d441405f3f9324fad1bdde66846c --- /dev/null +++ b/examples/language/bert/train_bert_demo.py @@ -0,0 +1,331 @@ +import os +from functools import partial +from time import time + +import psutil +import torch +from packaging import version +from torch import nn +from torch.nn.parallel import DistributedDataParallel as DDP +from transformers import AlbertConfig, AlbertForSequenceClassification, BertConfig, BertForSequenceClassification + +import colossalai +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.optimizer import HybridAdam +from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec +from colossalai.utils import get_current_device +from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper + +CAI_VERSION = colossalai.__version__ + + +def get_tflops(model_numel, batch_size, seq_len, step_time): + return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) + + +def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir): + from contextlib import nullcontext + + from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler + if enable_flag: + return profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps), + on_trace_ready=tensorboard_trace_handler(save_dir), + record_shapes=True, + profile_memory=True) + else: + + class DummyProfiler: + + def __init__(self): + self.step_number = 0 + + def step(self): + self.step_number += 1 + + return nullcontext(DummyProfiler()) + + +def get_time_stamp(): + import time + cur_time = time.strftime("%d-%H:%M", time.localtime()) + return cur_time + + +def get_bert_data(batch_size: int, sequence_length: int, vacob_size: int, n_class: int, device: torch.device): + input = torch.randint( + low=0, + high=vacob_size, + size=(batch_size, sequence_length), + device=device, + dtype=torch.long, + ) + label = torch.randint(low=0, high=n_class, size=(batch_size,), device=device, dtype=torch.long) + return input, label + + +def parse_args(): + parser = colossalai.get_default_parser() + parser.add_argument( + "--distplan", + type=str, + default='CAI_Gemini', + help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].", + ) + parser.add_argument( + "--placement", + type=str, + default='cpu', + help="Placement Policy for Gemini. Valid when using colossalai as dist plan.", + ) + parser.add_argument( + "--batch_size", + type=int, + default=8, + help="batch size per DP group of training.", + ) + parser.add_argument( + "--model_type", + type=str, + default="bert", + help="bert or albert", + ) + parser.add_argument( + "--train_step", + type=int, + default=10, + help="training iterations for test", + ) + + args = parser.parse_args() + return args + + +SEQ_LEN = 512 +VOCAB_SIZE = 1000 +NUM_LABELS = 10 + + +# Parameter Sharding Strategies for Tensor Parallelism +def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup): + spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + param.set_tensor_spec(*spec) + + +def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup): + split_param_single_dim_tp1d(0, param, pg) + + +def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup): + split_param_single_dim_tp1d(-1, param, pg) + + +def get_cpu_mem(): + return psutil.Process().memory_info().rss / 1024**2 + + +def get_gpu_mem(): + return torch.cuda.memory_allocated() / 1024**2 + + +def get_mem_info(prefix=''): + return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB' + + +def get_model_size(model: nn.Module): + total_numel = 0 + for module in model.modules(): + for p in module.parameters(recurse=False): + total_numel += p.numel() + return total_numel + + +def model_builder(args): + if args.model_type == "bert": + cfg = BertConfig(vocab_size=VOCAB_SIZE, num_labels=NUM_LABELS) + return BertForSequenceClassification(cfg) + elif args.model_type == "albert": + cfg = AlbertConfig(vocab_size=VOCAB_SIZE, num_labels=NUM_LABELS) + return AlbertForSequenceClassification(cfg) + else: + raise RuntimeError + + +def model_size_formatter(numel: int) -> str: + GB_SIZE = 10**9 + MB_SIZE = 10**6 + KB_SIZE = 10**3 + if numel >= GB_SIZE: + return f'{numel / GB_SIZE:.1f}B' + elif numel >= MB_SIZE: + return f'{numel / MB_SIZE:.1f}M' + elif numel >= KB_SIZE: + return f'{numel / KB_SIZE:.1f}K' + else: + return str(numel) + + +def set_cpu_maximum_parallelism(): + conf_str = torch.__config__.parallel_info() + inter_str = conf_str.split("hardware_concurrency() : ")[1] + max_concurrency = inter_str.split('\n')[0] + os.environ["OMP_NUM_THREADS"] = max_concurrency + print(f"environmental variable OMP_NUM_THREADS is set to {max_concurrency}.") + + +def main(): + # version check + # this example is supposed to work for versions greater than 0.2.0 + assert version.parse(CAI_VERSION) >= version.parse("0.2.0") + + set_cpu_maximum_parallelism() + args = parse_args() + + # if args.distplan not in ["colossalai", "torch_ddp", "torch_zero", "zero1", "zero2"]: + if args.distplan not in ["CAI_ZeRO1", "CAI_ZeRO2", "CAI_Gemini", "Pytorch_DDP", "Pytorch_ZeRO"]: + raise TypeError(f"{args.distplan} is error") + + # batch size per DP degree + BATCH_SIZE = args.batch_size + + NUM_STEPS = args.train_step + + WARMUP_STEPS = 1 + assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps" + assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median" + PROF_FLAG = False # The flag of profiling, False by default + + disable_existing_loggers() + colossalai.launch_from_torch(config={}) + + logger = get_dist_logger() + logger.info(f" {args.distplan}, batch size {BATCH_SIZE}", ranks=[0]) + + torch.manual_seed(123) + if args.distplan.startswith("CAI"): + # all param must use the same process group. + world_size = torch.distributed.get_world_size() + + # build a base-bert model + with ColoInitContext(device=get_current_device(), dtype=torch.half): + model = model_builder(args) + # model = BertForSequenceClassification(BertConfig(vocal_size = VOCAB_SIZE)) + + # asign running configurations + gemini_config = None + if args.distplan.startswith("CAI_ZeRO"): + optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True) + elif args.distplan == "CAI_Gemini": + gemini_config = dict(strict_ddp_mode=True, + device=get_current_device(), + placement_policy=args.placement, + pin_memory=True, + hidden_dim=model.config.hidden_size, + search_range_mb=128) + optim_config = dict(gpu_margin_mem_ratio=0.) + else: + raise RuntimeError + + # build a highly optimized gpu/cpu optimizer + optimizer = HybridAdam(model.parameters(), lr=1e-3) + + if args.distplan == "CAI_ZeRO1": + zero_stage = 1 + elif args.distplan == "CAI_ZeRO2": + zero_stage = 2 + elif args.distplan == "CAI_Gemini": + zero_stage = 3 + else: + raise RuntimeError + + # wrap your model and optimizer + model = zero_model_wrapper(model, zero_stage, gemini_config) + optimizer = zero_optim_wrapper(model, optimizer, optim_config=optim_config) + + logger.info(get_mem_info(prefix='After init optim, '), ranks=[0]) + elif args.distplan.startswith("Pytorch"): + model = model_builder(args).cuda() + model = DDP(model) + if args.distplan.endswith("DDP"): + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + elif args.distplan.endswith("ZeRO"): + from torch.distributed.optim import ZeroRedundancyOptimizer + optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=1e-3) + else: + raise RuntimeError + + # model is shared after TP + numel = get_model_size(model) + logger.info(f"the size of testing model size is {model_size_formatter(numel)}.") + logger.info(get_mem_info(prefix='After init model, '), ranks=[0]) + + # Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu + # = (batch_per_DP_group * dp_degree) * (numel * tp_degree) * seq_len * 8 / (tp_degree * dp_degree) + # = batch_per_DP_group * numel * seq_len * 8 + get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN) + + torch.cuda.synchronize() + model.train() + tflops_list = [] + + def train_step(): + # we just use randomly generated data here + input_ids, labels = get_bert_data(BATCH_SIZE, + SEQ_LEN, + VOCAB_SIZE, + NUM_LABELS, + device=torch.cuda.current_device()) + optimizer.zero_grad() + + start = time() + outputs = model(input_ids, labels=labels) + loss, logits = outputs[:2] + torch.cuda.synchronize() + fwd_end = time() + fwd_time = fwd_end - start + logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Forward '), ranks=[0]) + + if args.distplan.startswith("CAI"): + optimizer.backward(loss) + elif args.distplan.startswith("Pytorch"): + loss.backward() + else: + raise RuntimeError + + torch.cuda.synchronize() + bwd_end = time() + bwd_time = bwd_end - fwd_end + logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Backward '), ranks=[0]) + + optimizer.step() + torch.cuda.synchronize() + optim_time = time() - bwd_end + step_time = time() - start + logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Optimizer step '), ranks=[0]) + + step_tflops = get_tflops_func(step_time) + logger.info( + f"[{n + 1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}, FWD time: {fwd_time:.3f}s, BWD time: {bwd_time:.3f}s, OPTIM time: {optim_time:.3f}s", + ranks=[0], + ) + if n >= WARMUP_STEPS: + tflops_list.append(step_tflops) + + demo_profiler = get_profile_context(PROF_FLAG, + WARMUP_STEPS, + NUM_STEPS - WARMUP_STEPS, + save_dir=f"profile/{get_time_stamp()}-demo") + + with demo_profiler as prof: + for n in range(NUM_STEPS): + train_step() + prof.step() + + tflops_list.sort() + median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS + logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}") + torch.cuda.synchronize() + + +if __name__ == '__main__': + main() diff --git a/examples/language/commons/utils.py b/examples/language/commons/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..782f546dc26cae5e99784bc0d8b269b431a01c2b --- /dev/null +++ b/examples/language/commons/utils.py @@ -0,0 +1,12 @@ +import torch + + +# Randomly Generated Data +def get_data(batch_size, seq_len, vocab_size): + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device()) + attention_mask = torch.ones_like(input_ids) + return input_ids, attention_mask + + +def get_tflops(model_numel, batch_size, seq_len, step_time): + return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) diff --git a/examples/language/gpt/README.md b/examples/language/gpt/README.md new file mode 100644 index 0000000000000000000000000000000000000000..47d24a4d69cb6c2edc4041e5f7501cb2e1e3362c --- /dev/null +++ b/examples/language/gpt/README.md @@ -0,0 +1,80 @@ +# Train GPT with Colossal-AI + +This example shows how to use [Colossal-AI](https://github.com/hpcaitech/ColossalAI) to run huggingface GPT training in distributed manners. + +## GPT + +We use the [GPT-2](https://huggingface.co/gpt2) model from huggingface transformers. The key learning goal of GPT-2 is to use unsupervised pre-training models to do supervised tasks.GPT-2 has an amazing performance in text generation, and the generated text exceeds people's expectations in terms of contextual coherence and emotional expression. + +## Requirements + +Before you can launch training, you need to install the following requirements. + +### Install PyTorch + +```bash +#conda +conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.3 -c pytorch +#pip +pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113 +``` + +### [Install Colossal-AI](https://github.com/hpcaitech/ColossalAI#installation) + + +### Install requirements + +```bash +pip install -r requirements.txt +``` + +This is just an example that we download PyTorch=1.12.0, CUDA=11.6 and colossalai. You can download another version of PyTorch and its corresponding ColossalAI version. Just make sure that the version of ColossalAI is at least 0.1.10, PyTorch is at least 1.8.1 and transformers is at least 4.231. +If you want to test ZeRO1 and ZeRO2 in Colossal-AI, you need to ensure Colossal-AI>=0.1.12. + +## Dataset + +For simplicity, the input data is randomly generated here. + +## Training +We provide two stable solutions. +One utilizes the Gemini to implement hybrid parallel strategies of Gemini, DDP/ZeRO, and Tensor Parallelism for a huggingface GPT model. +The other one use [Titans](https://github.com/hpcaitech/Titans), a distributed executed model zoo maintained by ColossalAI,to implement the hybrid parallel strategies of TP + ZeRO + PP. + +We recommend using Gemini to quickly run your model in a distributed manner. +It doesn't require significant changes to the model structures, therefore you can apply it on a new model easily. +And use Titans as an advanced weapon to pursue a more extreme performance. +Titans has included the some typical models, such as Vit and GPT. +However, it requires some efforts to start if facing a new model structure. + +### GeminiDPP/ZeRO + Tensor Parallelism +```bash +bash run_gemini.sh +``` + +The `train_gpt_demo.py` provides three distributed plans (except ones already provided by PyTorch), you can choose the plan you want in `run_gemini.sh`. The CAI_Gemini leverages Tensor Parallel and Gemini + ZeRO DDP. For their differences, you may check out the answer to issue [here](https://github.com/hpcaitech/ColossalAI/issues/2590#issuecomment-1418766581). + +- ZeRO1 (CAI_ZeRO1) +- ZeRO2 (CAI_ZeRO2) +- Gemini + ZeRO DDP (CAI_Gemini) +- Pytorch DDP (Pytorch_DDP) +- Pytorch ZeRO (Pytorch_ZeRO) + +### Titans (Tensor Parallelism) + ZeRO + Pipeline Parallelism + +Titans provides a customized GPT model, which uses distributed operators as building blocks. +In [./titans/README.md], we provide a hybrid parallelism of ZeRO, TP and PP. +You can switch parallel strategies using a config file. + +## Performance + +Testbed: a cluster of 8xA100 (80GB) and 1xAMD EPYC 7543 32-Core Processor (512 GB). GPUs are connected via PCI-e. +ColossalAI version 0.1.13. + +[benchmark results on google doc](https://docs.google.com/spreadsheets/d/15A2j3RwyHh-UobAPv_hJgT4W_d7CnlPm5Fp4yEzH5K4/edit#gid=0) + +[benchmark results on Tencent doc (for china)](https://docs.qq.com/sheet/DUVpqeVdxS3RKRldk?tab=BB08J2) + +### Experimental Features + +#### [Pipeline Parallel](./experiments/pipeline_parallel/) +#### [Auto Parallel](./experiments/auto_parallel_with_gpt/) diff --git a/examples/language/gpt/experiments/auto_offload/README.md b/examples/language/gpt/experiments/auto_offload/README.md new file mode 100644 index 0000000000000000000000000000000000000000..535aa76541cc2ce97c6a1b3ba59177efcae3bda7 --- /dev/null +++ b/examples/language/gpt/experiments/auto_offload/README.md @@ -0,0 +1,37 @@ +# Auto-Offload Demo with GPT2 + +## Requirements + +Before you can launch training, you need to install the following requirements. + +### Install PyTorch + +```bash +#conda +conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.3 -c pytorch +#pip +pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113 +``` + +### Install [Colossal-AI v0.2.0](https://colossalai.org/download/) From Official Website + +```bash +pip install colossalai==0.2.0+torch1.12cu11.3 -f https://release.colossalai.org +``` + +### Install transformers + +```bash +pip install transformers +``` + +## Dataset + +For simplicity, the input data is randomly generated here. + +## Training + +```bash +#Run the auto offload on GPT with default setting and a dummy dataset. +bash run.sh +``` diff --git a/examples/language/gpt/experiments/auto_offload/model_zoo.py b/examples/language/gpt/experiments/auto_offload/model_zoo.py new file mode 100644 index 0000000000000000000000000000000000000000..35e44608f8108e88ea3d33b8a1f35417cffea9ce --- /dev/null +++ b/examples/language/gpt/experiments/auto_offload/model_zoo.py @@ -0,0 +1,65 @@ +import torch +import torch.nn as nn +from transformers import GPT2Config, GPT2LMHeadModel + +class GPTLMModel(nn.Module): + + def __init__(self, + hidden_size=768, + num_layers=12, + num_attention_heads=12, + max_seq_len=1024, + vocab_size=50257): + super().__init__() + self.model = GPT2LMHeadModel( + GPT2Config(n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size)) + + def forward(self, input_ids, attention_mask): + # Only return lm_logits + return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=True)[0] + + +class GPTLMLoss(nn.Module): + + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, logits, labels): + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + +def get_gpt2_components(model_type: str, batch_size: int): + vocab_size = 1024 + seq_len = 8 + + def gpt2_model_builder(): + if model_type == "gpt2_medium": + return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16) + elif model_type == "gpt2_xl": + return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32) + elif model_type == "gpt2_10b": + return GPTLMModel(hidden_size=4096, num_layers=50, num_attention_heads=16) + elif model_type == "gpt2_14b": + return GPTLMModel(hidden_size=4096, num_layers=70, num_attention_heads=16) + elif model_type == "gpt2_20b": + return GPTLMModel(hidden_size=8192, num_layers=25, num_attention_heads=16) + elif model_type == "gpt2_24b": + return GPTLMModel(hidden_size=8192, num_layers=30, num_attention_heads=16) + else: + raise TypeError(f"model_builder {model_type}") + + def gpt2_data_gen(device="cuda"): + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) + attention_mask = torch.ones_like(input_ids, device=device) + kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) + return kwargs + + return gpt2_model_builder, gpt2_data_gen \ No newline at end of file diff --git a/examples/language/gpt/experiments/auto_offload/requirements.txt b/examples/language/gpt/experiments/auto_offload/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..3ebde8d460aad354648666ab18c8413a213047b3 --- /dev/null +++ b/examples/language/gpt/experiments/auto_offload/requirements.txt @@ -0,0 +1,2 @@ +colossalai >= 0.1.12 +torch >= 1.8.1 \ No newline at end of file diff --git a/examples/language/gpt/experiments/auto_offload/run.sh b/examples/language/gpt/experiments/auto_offload/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..6a272ec442abbb6d4b83e9a6c97c781dad8840db --- /dev/null +++ b/examples/language/gpt/experiments/auto_offload/run.sh @@ -0,0 +1,8 @@ +export BATCH_SIZE=${BATCH_SIZE:-64} +export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"} +export MEMORY_BUDGET=${MEMORY_BUDGET:-16} +export SOLVER_TYPE=${SOLVER_TYPE:-"asyn"} + +mkdir -p offload_logs + +python train_gpt_offload.py --model_type=${MODEL_TYPE} --memory_budget=${MEMORY_BUDGET} --solver_type=${SOLVER_TYPE} --batch_size=${BATCH_SIZE} 2>&1 | tee ./offload_logs/${MODEL_TYPE}_bs_${BATCH_SIZE}_st_${SOLVER_TYPE}.log diff --git a/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py b/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py new file mode 100644 index 0000000000000000000000000000000000000000..89415c23f93c6dd96f5256235a1050aa7a8e5f13 --- /dev/null +++ b/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py @@ -0,0 +1,97 @@ +import argparse +import time + +import pytest +import torch +from model_zoo import GPTLMLoss, get_gpt2_components +from torch.utils._pytree import tree_map + +import colossalai +from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer +from colossalai.auto_parallel.offload.mem_optimize import memory_optimize +from colossalai.auto_parallel.offload.solver import NOT_NVML +from colossalai.fx.profiler import parameter_size +from colossalai.nn.optimizer import HybridAdam +from colossalai.testing import spawn +from colossalai.utils import get_current_device + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--model_type', type=str, default="gpt2_medium") + parser.add_argument('--batch_size', type=int, default=64) + parser.add_argument('--solver_type', type=str, default='asyn') + parser.add_argument('--memory_budget', type=float, default=16) + return parser.parse_args() + + +@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed') +def train_gpt(args): + memory_budget = args.memory_budget * 1024 * 1024 * 1024 + solver_type = args.solver_type + model_type = args.model_type + batch_size = args.batch_size + + # build model + model_builder, data_gen = get_gpt2_components(model_type=model_type, batch_size=batch_size) + label = torch.randint(low=0, high=128, size=( + 64, + 8, + ), device=get_current_device()) + criterion = GPTLMLoss() + + start_time = time.time() + model = model_builder() + model.train() + param_size = parameter_size(model) / 1024**2 / 2 + init_time = time.time() - start_time + print(f"init_param_size={param_size:.3f} MB | init_model_time={init_time:.3f} s") + + data_args = data_gen(device="cpu") + wrap_fn = lambda x: x.to(dtype=torch.half) if isinstance(x, torch.Tensor) and torch.is_floating_point(x) else x + data_args = tree_map(wrap_fn, data_args) + start_time = time.time() + model = memory_optimize(model, data_args, memory_budget, solver_type) + solver_time = time.time() - start_time + print(f"solver_time={solver_time:.3f} s") + + hybrid_optimizer = HybridAdam(model.model.parameters(), lr=1e-3) + optim = AMPOptimizer(hybrid_optimizer, model) + + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + + time_list = [] + data_args = data_gen(device="cuda") + data_args = tree_map(wrap_fn, data_args) + for step in range(10): + optim.zero_grad() + torch.cuda.synchronize() + start_time = time.time() + loss = criterion(model(**data_args), label) + optim.backward(loss) + torch.cuda.synchronize() + time_list.append(time.time() - start_time) + optim.step() + + torch.cuda.synchronize() + + exec_time = sum(sorted(time_list)[:5]) / 5 + runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024**2 + runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024**2 + print(f'solver_type: {solver_type} | model_type: {model_type}') + print(f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB ' + f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|') + print(time_list) + + +def run(rank, world_size, port, args): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + train_gpt(args) + + +if __name__ == '__main__': + args = parse_args() + spawn(run, 1, args=args) diff --git a/examples/language/gpt/experiments/auto_parallel/README.md b/examples/language/gpt/experiments/auto_parallel/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1c8b1c35109fca737a6269c54d03194e24422524 --- /dev/null +++ b/examples/language/gpt/experiments/auto_parallel/README.md @@ -0,0 +1,44 @@ +# Auto-Parallelism with GPT2 + +## Requirements + +Before you can launch training, you need to install the following requirements. + +### Install PyTorch + +```bash +#conda +conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.3 -c pytorch +#pip +pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113 +``` + +### Install [Colossal-AI v0.2.0](https://colossalai.org/download/) From Official Website + +```bash +pip install colossalai==0.2.0+torch1.12cu11.3 -f https://release.colossalai.org +``` + +### Install transformers + +```bash +pip install transformers +``` + +### Install pulp and coin-or-cbc + +```bash +pip install pulp +conda install -c conda-forge coin-or-cbc +``` + +## Dataset + +For simplicity, the input data is randomly generated here. + +## Training + +```bash +#Run the auto parallel resnet example with 4 GPUs with a dummy dataset. +colossalai run --nproc_per_node 4 auto_parallel_with_gpt.py +``` diff --git a/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py b/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py new file mode 100644 index 0000000000000000000000000000000000000000..e331fc8fcf10639366fc809e0558958c419a487a --- /dev/null +++ b/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py @@ -0,0 +1,98 @@ +from functools import partial +from time import time + +import psutil +import torch +import transformers +from gpt_modules import GPT2LMHeadModel, GPTLMLoss + +from colossalai.auto_parallel.tensor_shard.initialize import autoparallelize +from colossalai.core import global_context as gpc +from colossalai.initialize import launch_from_torch +from colossalai.logging import disable_existing_loggers, get_dist_logger + +BATCH_SIZE = 16 +SEQ_LENGTH = 1024 +HIDDEN_DIM = 4096 +NUM_HEADS = 16 +NUM_LAYERS = 4 +VOCAB_SIZE = 50257 +NUM_STEPS = 10 +FP16 = True + + +def get_cpu_mem(): + return psutil.Process().memory_info().rss / 1024**2 + + +def get_gpu_mem(): + return torch.cuda.memory_allocated() / 1024**2 + + +def get_mem_info(prefix=''): + return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB' + + +def get_tflops(model_numel, batch_size, seq_len, step_time): + # Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu + return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) / 8 + + +# Randomly Generated Data +def get_data(batch_size, seq_len, vocab_size): + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device()) + attention_mask = torch.ones_like(input_ids) + return input_ids, attention_mask + + +def main(): + disable_existing_loggers() + launch_from_torch(config={}) + logger = get_dist_logger() + config = transformers.GPT2Config(n_position=SEQ_LENGTH, n_layer=NUM_LAYERS, n_head=NUM_HEADS, n_embd=HIDDEN_DIM) + if FP16: + model = GPT2LMHeadModel(config=config).half().to('cuda') + else: + model = GPT2LMHeadModel(config=config).to('cuda') + global_numel = sum([p.numel() for p in model.parameters()]) + + meta_input_sample = { + 'input_ids': torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to('meta'), + 'attention_mask': torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to('meta'), + } + + gm, solution = autoparallelize(model, meta_input_sample, return_solution=True) + + # print solution on rank 0 + if gpc.get_global_rank() == 0: + for node_strategy in solution: + print(node_strategy) + + # build criterion + criterion = GPTLMLoss() + + optimizer = torch.optim.Adam(gm.parameters(), lr=0.01) + logger.info(get_mem_info(prefix='After init model, '), ranks=[0]) + get_tflops_func = partial(get_tflops, global_numel, BATCH_SIZE, SEQ_LENGTH) + torch.cuda.synchronize() + model.train() + + for n in range(10): + # we just use randomly generated data here + input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LENGTH, VOCAB_SIZE) + optimizer.zero_grad() + start = time() + outputs = gm(input_ids, attn_mask) + loss = criterion(outputs, input_ids) + loss.backward() + optimizer.step() + torch.cuda.synchronize() + step_time = time() - start + logger.info( + f'[{n+1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}', + ranks=[0]) + torch.cuda.synchronize() + + +if __name__ == '__main__': + main() diff --git a/examples/language/gpt/experiments/auto_parallel/gpt_modules.py b/examples/language/gpt/experiments/auto_parallel/gpt_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..95feaec38c26794760bf01c34d45fe06f61b401a --- /dev/null +++ b/examples/language/gpt/experiments/auto_parallel/gpt_modules.py @@ -0,0 +1,253 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +from transformers.activations import ACT2FN +from transformers.models.gpt2.modeling_gpt2 import BaseModelOutputWithPastAndCrossAttentions, GPT2PreTrainedModel +from transformers.pytorch_utils import Conv1D + + +class GPT2MLP(nn.Module): + + def __init__(self, intermediate_size, config): + super().__init__() + embed_dim = config.hidden_size + self.c_fc = Conv1D(intermediate_size, embed_dim) + self.c_proj = Conv1D(embed_dim, intermediate_size) + self.act = ACT2FN[config.activation_function] + self.dropout = nn.Dropout(config.resid_pdrop) + + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + return hidden_states + + +# The reason Why we don't import GPT2Attention from transformers directly is that: +# 1. The tracer will not work correctly when we feed meta_args and concrete_args at same time, +# so we have to build the customized GPT2Attention class and remove the conditional branch manually. +# 2. The order of split and view op has been changed in the customized GPT2Attention class, the new +# order is same as megatron-lm gpt model. +class GPT2Attention(nn.Module): + + def __init__(self, config, layer_idx=None): + super().__init__() + + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), + dtype=torch.uint8)).view(1, 1, max_positions, max_positions), + ) + self.register_buffer("masked_bias", torch.tensor(-1e4)) + + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.split_size = self.embed_dim + self.scale_attn_weights = config.scale_attn_weights + + # Layer-wise attention scaling, reordering, and upcasting + self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx + self.layer_idx = layer_idx + + self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) + self.c_proj = Conv1D(self.embed_dim, self.embed_dim) + + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + + self.pruned_heads = set() + + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + + if self.scale_attn_weights: + attn_weights = attn_weights / (value.size(-1)**0.5) + + # Layer-wise attention scaling + if self.scale_attn_by_inverse_layer_idx: + attn_weights = attn_weights / float(self.layer_idx + 1) + + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length].to(torch.bool) + attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = attn_weights.type(value.dtype) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def _split_heads(self, tensor, num_heads, attn_head_size): + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(new_shape) + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + + def _merge_heads(self, tensor, num_heads, attn_head_size): + tensor = tensor.permute(0, 2, 1, 3).contiguous() + new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) + return tensor.view(new_shape) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + + qkv = self.c_attn(hidden_states) + query, key, value = self._split_heads(qkv, self.num_heads, 3 * self.head_dim).split(self.head_dim, dim=3) + present = (key, value) + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = self.c_proj(attn_output) + return attn_output + + +class GPT2Block(nn.Module): + + def __init__(self, config, layer_idx=None): + super().__init__() + hidden_size = config.hidden_size + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = GPT2Attention(config, layer_idx=layer_idx) + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.mlp = GPT2MLP(inner_dim, config) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_outputs = self.attn( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + ) + # residual connection + hidden_states = attn_outputs + residual + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + + return hidden_states + + +class GPT2Model(GPT2PreTrainedModel): + + def __init__(self, config): + super().__init__(config) + + self.embed_dim = config.hidden_size + + self.wte = nn.Embedding(config.vocab_size, self.embed_dim) + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + + self.drop = nn.Dropout(config.embd_pdrop) + self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + + device = input_ids.device + + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + # GPT2Attention mask. + attention_mask = attention_mask.view(batch_size, -1) + attention_mask = attention_mask[:, None, None, :] + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * -10000.0 + + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + + hidden_states = inputs_embeds + position_embeds + + output_shape = input_shape + (hidden_states.size(-1),) + + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + outputs = block(hidden_states, attention_mask=attention_mask, head_mask=head_mask[i]) + hidden_states = outputs + + hidden_states = self.ln_f(hidden_states) + hidden_states = hidden_states.view(output_shape) + + return hidden_states + + +class GPT2LMHeadModel(GPT2PreTrainedModel): + + def __init__(self, config): + super().__init__(config) + self.transformer = GPT2Model(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + ): + transformer_outputs = self.transformer( + input_ids=input_ids, + attention_mask=attention_mask, + ) + lm_logits = self.lm_head(transformer_outputs) + + return lm_logits + + +class GPTLMLoss(nn.Module): + + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, logits, labels): + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) diff --git a/examples/language/gpt/experiments/auto_parallel/requirements.txt b/examples/language/gpt/experiments/auto_parallel/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..1b2561f098d55c6f2f336591eff762932086f9de --- /dev/null +++ b/examples/language/gpt/experiments/auto_parallel/requirements.txt @@ -0,0 +1,4 @@ +colossalai >= 0.1.12 +torch >= 1.8.1 +transformers >= 4.23.1 +PuLP >= 2.7.0 diff --git a/examples/language/gpt/experiments/auto_parallel/saved_solution/solution_12_layers.pt b/examples/language/gpt/experiments/auto_parallel/saved_solution/solution_12_layers.pt new file mode 100644 index 0000000000000000000000000000000000000000..7b8cd7edd11e6d1f605e0e9f992b6a13676ecd10 Binary files /dev/null and b/examples/language/gpt/experiments/auto_parallel/saved_solution/solution_12_layers.pt differ diff --git a/examples/language/gpt/experiments/auto_parallel/saved_solution/solution_1_layers.pt b/examples/language/gpt/experiments/auto_parallel/saved_solution/solution_1_layers.pt new file mode 100644 index 0000000000000000000000000000000000000000..9b431a45baba43b9581fb5cf3d4bf39a2aaea5d6 Binary files /dev/null and b/examples/language/gpt/experiments/auto_parallel/saved_solution/solution_1_layers.pt differ diff --git a/examples/language/gpt/experiments/auto_parallel/saved_solution/solution_4_layers.pt b/examples/language/gpt/experiments/auto_parallel/saved_solution/solution_4_layers.pt new file mode 100644 index 0000000000000000000000000000000000000000..79a448c1b06f1db8731d2d45f988ff0b57810b04 Binary files /dev/null and b/examples/language/gpt/experiments/auto_parallel/saved_solution/solution_4_layers.pt differ diff --git a/examples/language/gpt/experiments/pipeline_parallel/README.md b/examples/language/gpt/experiments/pipeline_parallel/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5af994a006651e42bf708d2ed067f1df8c64c5eb --- /dev/null +++ b/examples/language/gpt/experiments/pipeline_parallel/README.md @@ -0,0 +1,38 @@ +# Pipeline Parallelism Demo with GPT2 + +## Requirements + +Before you can launch training, you need to install the following requirements. + +### Install PyTorch + +```bash +#conda +conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.3 -c pytorch +#pip +pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113 +``` + +### Install [Colossal-AI v0.2.0](https://colossalai.org/download/) From Official Website + +```bash +pip install colossalai==0.2.0+torch1.12cu11.3 -f https://release.colossalai.org +``` + +### Install transformers + +```bash +pip install transformers +``` + +## Dataset + +For simplicity, the input data is randomly generated here. + +## Training + +```bash +#Run the Pipeline Parallel on GPT with default setting and a dummy dataset. +#You can change the GPU number or microbatch number in the run.sh . +bash run.sh +``` diff --git a/examples/language/gpt/experiments/pipeline_parallel/model_zoo.py b/examples/language/gpt/experiments/pipeline_parallel/model_zoo.py new file mode 100644 index 0000000000000000000000000000000000000000..c31b3fa6d1035a0238c1d56fc678351c0fa2068d --- /dev/null +++ b/examples/language/gpt/experiments/pipeline_parallel/model_zoo.py @@ -0,0 +1,73 @@ +from torch import nn +from transformers import GPT2Config, GPT2LMHeadModel + + +## Define the Model and Loss Based on Huggingface transformers GPT2LMHeadModel +class GPTLMModel(nn.Module): + + def __init__(self, + hidden_size=768, + num_layers=12, + num_attention_heads=12, + max_seq_len=1024, + vocab_size=50257, + checkpoint=False): + super().__init__() + self.checkpoint = checkpoint + self.config = GPT2Config(n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size) + self.model = GPT2LMHeadModel(self.config) + if checkpoint: + self.model.gradient_checkpointing_enable() + + def forward(self, input_ids, attention_mask): + # Only return lm_logits + return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0] + + +def gpt2_medium(checkpoint=False): + return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint) + + +def gpt2_xl(checkpoint=True): + return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32, checkpoint=checkpoint) + + +def gpt2_10b(checkpoint=True): + return GPTLMModel(hidden_size=4096, num_layers=50, num_attention_heads=16, checkpoint=checkpoint) + + +def gpt2_14b(checkpoint=True): + return GPTLMModel(hidden_size=4096, num_layers=70, num_attention_heads=16, checkpoint=checkpoint) + + +def gpt2_20b(checkpoint=True): + return GPTLMModel(hidden_size=8192, num_layers=25, num_attention_heads=16, checkpoint=checkpoint) + + +def gpt2_24b(checkpoint=True): + return GPTLMModel(hidden_size=8192, num_layers=30, num_attention_heads=16, checkpoint=checkpoint) + + +def model_builder(model_size: str) -> callable: + if model_size == "gpt2_medium": + return gpt2_medium + elif model_size == "gpt2_xl": + return gpt2_xl + elif model_size == "gpt2_10b": + return gpt2_10b + elif model_size == "gpt2_14b": + return gpt2_14b + elif model_size == "gpt2_20b": + return gpt2_20b + elif model_size == "gpt2_24b": + return gpt2_24b + else: + raise TypeError(f"model_builder {model_size}") + + +__all__ = ['model_builder'] diff --git a/examples/language/gpt/experiments/pipeline_parallel/requirements.txt b/examples/language/gpt/experiments/pipeline_parallel/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..137a69e80498223cd7581a62e2e27320b77682a0 --- /dev/null +++ b/examples/language/gpt/experiments/pipeline_parallel/requirements.txt @@ -0,0 +1,2 @@ +colossalai >= 0.1.12 +torch >= 1.8.1 diff --git a/examples/language/gpt/experiments/pipeline_parallel/run.sh b/examples/language/gpt/experiments/pipeline_parallel/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..235cefcbc6df6b24d261bb88431236c9a3e492b3 --- /dev/null +++ b/examples/language/gpt/experiments/pipeline_parallel/run.sh @@ -0,0 +1,7 @@ +export GPUNUM=${GPUNUM:-4} +export BATCH_SIZE=${BATCH_SIZE:-16} +export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"} +export NUM_MICROBATCH=${NUM_MICROBATCH:-8} + +mkdir -p pp_logs +python train_gpt_pp.py --device="cuda" --model_type=${MODEL_TYPE} --num_microbatches=${NUM_MICROBATCH} --world_size=${GPUNUM} --batch_size=${BATCH_SIZE} 2>&1 | tee ./pp_logs/${MODEL_TYPE}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_nm_${NUM_MICROBATCH}.log diff --git a/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py b/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py new file mode 100644 index 0000000000000000000000000000000000000000..ad69888b8cc80ba7cd1ea52bd853ba61bc6c297b --- /dev/null +++ b/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py @@ -0,0 +1,180 @@ +import argparse +import time +from functools import partial + +import torch +from model_zoo import model_builder +from torch import nn +from tqdm import tqdm + +from colossalai.fx import ColoTracer +from colossalai.fx.passes.adding_split_node_pass import ( + avgnode_split_pass, + gpipe_dp_split_pass, + split_with_split_nodes_pass, +) +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.optimizer import HybridAdam +from colossalai.pipeline.middleware.adaptor import get_fx_topology +from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine +from colossalai.pipeline.rpc.utils import rpc_run + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--model_type', type=str, default="gpt2_medium") + parser.add_argument('--world_size', type=int, default=2) + parser.add_argument('--batch_size', type=int, default=16) + parser.add_argument('--dp_degree', type=int, default=1) + parser.add_argument('--tp_degree', type=int, default=1) + parser.add_argument('--num_microbatches', type=int, default=2) + parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda') + parser.add_argument('--master_addr', type=str, default='localhost') + parser.add_argument('--master_port', type=str, default='29011') + parser.add_argument('--num_worker_threads', type=int, default=128) + return parser.parse_args() + + +class GPTLMLoss(nn.Module): + + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, logits, labels): + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + +# Randomly Generated Data +def get_data(batch_size, seq_len, vocab_size): + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device()) + attention_mask = torch.ones_like(input_ids) + return input_ids, attention_mask + + +def get_tflops(model_numel, batch_size, seq_len, step_time): + return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) + + +# Create annotated model which is noted where to be splitted. +def get_annotated_model(model, data_kwargs, num_stages, num_microbatches): + tracer = ColoTracer() + meta_args = {k: v.to('meta') for k, v in data_kwargs.items()} + graph = tracer.trace(root=model, meta_args=meta_args) + gm = torch.fx.GraphModule(model, graph, model.__class__.__name__) + + interp_meta_args = tuple([v.to('meta') for k, v in data_kwargs.items()]) + interp = MetaInfoProp(gm) + interp.run(*interp_meta_args) + + #annotated_model = avgnode_split_pass(gm, num_stages) + annotated_model = gpipe_dp_split_pass(gm, num_stages, num_microbatches, mode='block', block_limit=0.01) + + return annotated_model + + +def create_partition_module(pp_rank: int, num_stages: int, model, data_kwargs, num_microbatches): + annotated_model = get_annotated_model(model, data_kwargs, num_stages, num_microbatches) + top_module, split_submodules = split_with_split_nodes_pass(annotated_model, merge_output=True) + topo = get_fx_topology(top_module) + for submodule in split_submodules: + if isinstance(submodule, torch.fx.GraphModule): + setattr(submodule, '_topo', topo) + return split_submodules[pp_rank + 1] + + +def partition(model, data_kwargs, num_microbatches, pp_rank: int, chunk: int, stage_num: int): + module = create_partition_module(pp_rank, stage_num, model, data_kwargs, num_microbatches) + return module + + +def run_master(args): + batch_size = args.batch_size + device = args.device + world_size = args.world_size + stage_num = world_size + num_microbatches = args.num_microbatches + model_type = args.model_type + # batch size per DP degree + SEQ_LEN = 1024 + VOCAB_SIZE = 50257 + NUM_STEPS = 10 + WARMUP_STEPS = 1 + + disable_existing_loggers() + logger = get_dist_logger() + logger.info(f"{args.model_type}, batch size {batch_size}, num stage {stage_num}, num microbatch {num_microbatches}", + ranks=[0]) + + torch.manual_seed(123) + + # build criterion + criterion = GPTLMLoss() + + # warm up pipeline fx partition + input_ids, attn_mask = get_data(batch_size, SEQ_LEN, VOCAB_SIZE) + warmup_data_kwargs = {'input_ids': input_ids, 'attention_mask': attn_mask} + + # create model + logger.info(f'start model_builder') + model = model_builder(model_type)(checkpoint=False) + logger.info(f'end model_builder') + + # set 1f1b pipeline engine + pp_engine = FillDrainPipelineEngine(partition_fn=partial(partition, model, warmup_data_kwargs, num_microbatches), + stage_num=stage_num, + num_microbatches=num_microbatches, + device=device, + chunk=1, + criterion=criterion, + metric=None, + checkpoint=False) + + partition_numels = pp_engine.remote_numels() + for rank, numel in partition_numels.items(): + logger.info(f'{rank=} numel in the partition:{numel}') + + # build optim + pp_engine.initialize_optimizer(torch.optim.Adam, lr=1e-3) + + ranks_tflops = {} + for n in range(NUM_STEPS): + # we just use randomly generated data here + input_ids, attn_mask = get_data(batch_size, SEQ_LEN, VOCAB_SIZE) + batch = {'input_ids': input_ids, 'attention_mask': attn_mask} + + start = time.time() + outputs = pp_engine.forward_backward(batch=batch, labels=input_ids, forward_only=False) + step_time = time.time() - start + + for rank, numel in partition_numels.items(): + if rank not in ranks_tflops: + ranks_tflops[rank] = [] + step_tflops = get_tflops(numel, batch_size, SEQ_LEN, step_time) + + logger.info( + f"Rank{rank} , [{n + 1}/{NUM_STEPS}] , Step time: {step_time:.3f}s, TFLOPS: {get_tflops(numel, batch_size, SEQ_LEN, step_time):.3f}", + ranks=[0], + ) + + if n >= WARMUP_STEPS: + ranks_tflops[rank].append(step_tflops) + + median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS + gpu_tflops = [] + for rank, tflops_list in ranks_tflops.items(): + tflops_list.sort() + gpu_tflops.append(tflops_list[median_index]) + logger.info(f"GPU{rank} Median TFLOPS is {tflops_list[median_index]:.3f}") + + logger.info(f"Total TFLOPS is {sum(gpu_tflops):.3f}") + logger.info(f"Avg TFLOPS per GPU is {sum(gpu_tflops) / world_size:.3f}") + + +if __name__ == '__main__': + args = parse_args() + rpc_run(args, run_master) diff --git a/examples/language/gpt/gemini/benchmark_gemini.sh b/examples/language/gpt/gemini/benchmark_gemini.sh new file mode 100644 index 0000000000000000000000000000000000000000..3a42e13645f6937c6a127b454868a5141a24eb0f --- /dev/null +++ b/examples/language/gpt/gemini/benchmark_gemini.sh @@ -0,0 +1,22 @@ +for MODEL_TYPE in "gpt2_medium"; do + for DISTPLAN in "CAI_Gemini"; do + for BATCH_SIZE in 16; do + for GPUNUM in 1 2 4 8; do + for TPDEGREE in 1 2 4 8; do + if [ ${TPDEGREE} -gt ${GPUNUM} ]; then + continue + fi + for PLACEMENT in "cpu" "auto"; do + echo "****************** Begin ***************************" + echo "+ benchmrking MODEL ${MODEL_TYPE} DISTPLAN ${DISTPLAN} GPU ${GPUNUM} BS ${BATCH_SIZE} TP ${TPDEGREE} POLICY ${PLACEMENT}" + MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE} PLACEMENT=${PLACEMENT} \ + bash ./run_gemini.sh + echo "****************** Finished ***************************" + echo "" + echo "" + done + done + done + done + done +done diff --git a/examples/language/gpt/gemini/commons/model_zoo.py b/examples/language/gpt/gemini/commons/model_zoo.py new file mode 100644 index 0000000000000000000000000000000000000000..65124d9e488403e3473c23b4cd7a31f5e2173871 --- /dev/null +++ b/examples/language/gpt/gemini/commons/model_zoo.py @@ -0,0 +1,85 @@ +from torch import nn +from transformers import GPT2Config, GPT2LMHeadModel + + +## Define the Model and Loss Based on Huggingface transformers GPT2LMHeadModel +class GPTLMModel(nn.Module): + + def __init__(self, + hidden_size=768, + num_layers=12, + num_attention_heads=12, + max_seq_len=1024, + vocab_size=50257, + checkpoint=False): + super().__init__() + self.checkpoint = checkpoint + self.config = GPT2Config(n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size) + self.model = GPT2LMHeadModel(self.config) + if checkpoint: + self.model.gradient_checkpointing_enable() + + def forward(self, input_ids, attention_mask): + # Only return lm_logits + return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0] + + +def gpt2_medium(checkpoint=False): + return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint) + + +def gpt2_xl(checkpoint=True): + return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32, checkpoint=checkpoint) + + +def gpt2_10b(checkpoint=True): + return GPTLMModel(hidden_size=4096, num_layers=50, num_attention_heads=16, checkpoint=checkpoint) + + +def gpt2_14b(checkpoint=True): + return GPTLMModel(hidden_size=4096, num_layers=70, num_attention_heads=16, checkpoint=checkpoint) + + +def gpt2_20b(checkpoint=True): + return GPTLMModel(hidden_size=8192, num_layers=25, num_attention_heads=16, checkpoint=checkpoint) + + +def gpt2_24b(checkpoint=True): + return GPTLMModel(hidden_size=8192, num_layers=30, num_attention_heads=16, checkpoint=checkpoint) + + +def gpt2_30b(checkpoint=True): + return GPTLMModel(hidden_size=8192, num_layers=37, num_attention_heads=16, checkpoint=checkpoint) + + +def gpt2_40b(checkpoint=True): + return GPTLMModel(hidden_size=8192, num_layers=50, num_attention_heads=16, checkpoint=checkpoint) + + +def model_builder(model_size: str) -> callable: + if model_size == "gpt2_medium": + return gpt2_medium + elif model_size == "gpt2_xl": + return gpt2_xl + elif model_size == "gpt2_10b": + return gpt2_10b + elif model_size == "gpt2_14b": + return gpt2_14b + elif model_size == "gpt2_20b": + return gpt2_20b + elif model_size == "gpt2_24b": + return gpt2_24b + elif model_size == "gpt2_30b": + return gpt2_30b + elif model_size == "gpt2_40b": + return gpt2_40b + else: + raise TypeError(f"model_builder {model_size}") + + +__all__ = ['model_builder'] diff --git a/examples/language/gpt/gemini/commons/utils.py b/examples/language/gpt/gemini/commons/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7bd098c1927c71f0bc77c26ec7e5ac37688b7c4f --- /dev/null +++ b/examples/language/gpt/gemini/commons/utils.py @@ -0,0 +1,41 @@ +import time +from contextlib import nullcontext + +import torch +from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler + + +class DummyProfiler: + + def __init__(self): + self.step_number = 0 + + def step(self): + self.step_number += 1 + + +# Randomly Generated Data +def get_data(batch_size, seq_len, vocab_size): + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device()) + attention_mask = torch.ones_like(input_ids) + return input_ids, attention_mask + + +def get_tflops(model_numel, batch_size, seq_len, step_time): + return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) + + +def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir): + if enable_flag: + return profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps), + on_trace_ready=tensorboard_trace_handler(save_dir), + record_shapes=True, + profile_memory=True) + else: + return nullcontext(DummyProfiler()) + + +def get_time_stamp(): + cur_time = time.strftime("%d-%H:%M", time.localtime()) + return cur_time diff --git a/examples/language/gpt/gemini/requirements.txt b/examples/language/gpt/gemini/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..137a69e80498223cd7581a62e2e27320b77682a0 --- /dev/null +++ b/examples/language/gpt/gemini/requirements.txt @@ -0,0 +1,2 @@ +colossalai >= 0.1.12 +torch >= 1.8.1 diff --git a/examples/language/gpt/gemini/run_gemini.sh b/examples/language/gpt/gemini/run_gemini.sh new file mode 100644 index 0000000000000000000000000000000000000000..ad4e9419c1bdf01bb5e75350b9cfb715908c68ec --- /dev/null +++ b/examples/language/gpt/gemini/run_gemini.sh @@ -0,0 +1,31 @@ +set -x +# distplan in ["CAI_ZeRO1", "CAI_ZeRO2", "CAI_Gemini", "Pytorch_DDP", "Pytorch_ZeRO"] +export DISTPLAN=${DISTPLAN:-"CAI_Gemini"} + +# The following options only valid when DISTPLAN="colossalai" +export GPUNUM=${GPUNUM:-1} +export TPDEGREE=${TPDEGREE:-1} +export PLACEMENT=${PLACEMENT:-"cpu"} +export USE_SHARD_INIT=${USE_SHARD_INIT:-False} +export BATCH_SIZE=${BATCH_SIZE:-16} +export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"} +export TRAIN_STEP=${TRAIN_STEP:-10} +# export PYTHONPATH=$PWD:$PYTHONPATH + +if [ ${USE_SHARD_INIT} = "True" ]; then + USE_SHARD_INIT="--shardinit" +else + USE_SHARD_INIT="" +fi + +mkdir -p gemini_logs + +torchrun --standalone --nproc_per_node=${GPUNUM} ./train_gpt_demo.py \ +--tp_degree=${TPDEGREE} \ +--model_type=${MODEL_TYPE} \ +--batch_size=${BATCH_SIZE} \ +--placement=${PLACEMENT} \ +${USE_SHARD_INIT} \ +--distplan=${DISTPLAN} \ +--train_step=${TRAIN_STEP} \ +2>&1 | tee ./gemini_logs/${MODEL_TYPE}_${DISTPLAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_tp_${TPDEGREE}_${PLACEMENT}.log diff --git a/examples/language/gpt/gemini/test_ci.sh b/examples/language/gpt/gemini/test_ci.sh new file mode 100644 index 0000000000000000000000000000000000000000..6079d5ed615bd2ea4d6d18cc168a8371219103c1 --- /dev/null +++ b/examples/language/gpt/gemini/test_ci.sh @@ -0,0 +1,35 @@ +set -x +$(cd `dirname $0`;pwd) +export TRAIN_STEP=4 + +for MODEL_TYPE in "gpt2_medium"; do + for DISTPLAN in "colossalai"; do + for BATCH_SIZE in 2; do + for GPUNUM in 1 4; do + for TPDEGREE in 1 2; do + if [ ${TPDEGREE} -gt ${GPUNUM} ]; then + continue + fi + for PLACEMENT in "cpu" "auto"; do + MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE} PLACEMENT=${PLACEMENT} \ + bash ./run_gemini.sh + done + done + done + done + done + + for DISTPLAN in "zero1" "zero2"; do + for BATCH_SIZE in 2; do + for GPUNUM in 1 4; do + for TPDEGREE in 1; do + if [ ${TPDEGREE} -gt ${GPUNUM} ]; then + continue + fi + MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE}\ + bash ./run_gemini.sh + done + done + done + done +done diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..b2a7fa36d02140e6919c2f2c3fe8c1f5596d2943 --- /dev/null +++ b/examples/language/gpt/gemini/train_gpt_demo.py @@ -0,0 +1,352 @@ +import os +from functools import partial +from time import time + +import psutil +import torch +import torch.nn as nn +from commons.model_zoo import model_builder +from commons.utils import get_data, get_profile_context, get_tflops, get_time_stamp +from packaging import version +from torch.nn.parallel import DistributedDataParallel as DDP + +import colossalai +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.optimizer import HybridAdam +from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec +from colossalai.utils import get_current_device +from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper + +CAI_VERSION = colossalai.__version__ + + +def parse_args(): + parser = colossalai.get_default_parser() + parser.add_argument( + "--distplan", + type=str, + default='CAI_Gemini', + help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].", + ) + parser.add_argument( + "--tp_degree", + type=int, + default=1, + help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.", + ) + parser.add_argument( + "--placement", + type=str, + default='cpu', + help="Placement Policy for Gemini. Valid when using colossalai as dist plan.", + ) + parser.add_argument( + "--shardinit", + action='store_true', + help= + "Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.", + ) + parser.add_argument( + "--batch_size", + type=int, + default=8, + help="batch size per DP group of training.", + ) + parser.add_argument( + "--model_type", + type=str, + default="gpt2_medium", + help="model model scale", + ) + parser.add_argument( + "--train_step", + type=int, + default=10, + help="training iterations for test", + ) + + args = parser.parse_args() + return args + + +# Parameter Sharding Strategies for Tensor Parallelism +def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup): + spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + param.set_tensor_spec(*spec) + + +def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup): + split_param_single_dim_tp1d(0, param, pg) + + +def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup): + split_param_single_dim_tp1d(-1, param, pg) + + +class GPTLMLoss(nn.Module): + + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, logits, labels): + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + +def get_cpu_mem(): + return psutil.Process().memory_info().rss / 1024**2 + + +def get_gpu_mem(): + return torch.cuda.memory_allocated() / 1024**2 + + +def get_mem_info(prefix=''): + return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB' + + +def get_model_size(model: nn.Module): + total_numel = 0 + for module in model.modules(): + for p in module.parameters(recurse=False): + total_numel += p.numel() + return total_numel + + +def model_size_formatter(numel: int) -> str: + GB_SIZE = 10**9 + MB_SIZE = 10**6 + KB_SIZE = 10**3 + if numel >= GB_SIZE: + return f'{numel / GB_SIZE:.1f}B' + elif numel >= MB_SIZE: + return f'{numel / MB_SIZE:.1f}M' + elif numel >= KB_SIZE: + return f'{numel / KB_SIZE:.1f}K' + else: + return str(numel) + + +def set_cpu_maximum_parallelism(): + conf_str = torch.__config__.parallel_info() + inter_str = conf_str.split("hardware_concurrency() : ")[1] + max_concurrency = inter_str.split('\n')[0] + os.environ["OMP_NUM_THREADS"] = max_concurrency + print(f"environmental variable OMP_NUM_THREADS is set to {max_concurrency}.") + + +# Tensor Parallel +def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup): + """tensor_parallelize + Sharding the Model Parameters. + + Args: + model (torch.nn.Module): a torch module to be sharded + """ + for mn, module in model.named_modules(): + for pn, param in module.named_parameters(recurse=False): + # NOTE() a param maybe shared by two modules + if hasattr(param, 'visited'): + continue + + # if shard init, then convert param to replica and use the dp-only ProcessGroup + param: ColoParameter = param + param.set_dist_spec(ReplicaSpec()) + param.set_process_group(pg) + + # shard it w.r.t tp pattern + if 'mlp.c_fc' in mn: + if 'weight' in pn or 'bias' in pn: + split_param_col_tp1d(param, pg) # colmn slice + # keep the shape of the output from c_fc + param.compute_spec.set_output_replicate(False) + else: + param.set_dist_spec(ReplicaSpec()) + elif 'mlp.c_proj' in mn: + if 'weight' in pn: + split_param_row_tp1d(param, pg) # row slice + else: + param.set_dist_spec(ReplicaSpec()) + elif 'wte' in mn or 'wpe' in mn: + split_param_col_tp1d(param, pg) # colmn slice + elif 'c_attn' in mn or 'c_proj' in mn: + split_param_col_tp1d(param, pg) # colmn slice + else: + param.set_dist_spec(ReplicaSpec()) + param.visited = True + + +def main(): + # version check + # this example is supposed to work for versions greater than 0.2.0 + assert version.parse(CAI_VERSION) >= version.parse("0.2.0") + + set_cpu_maximum_parallelism() + args = parse_args() + + # if args.distplan not in ["colossalai", "torch_ddp", "torch_zero", "zero1", "zero2"]: + if args.distplan not in ["CAI_ZeRO1", "CAI_ZeRO2", "CAI_Gemini", "Pytorch_DDP", "Pytorch_ZeRO"]: + raise TypeError(f"{args.distplan} is error") + + # batch size per DP degree + BATCH_SIZE = args.batch_size + SEQ_LEN = 1024 + VOCAB_SIZE = 50257 + + NUM_STEPS = args.train_step + + WARMUP_STEPS = 1 + assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps" + assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median" + PROF_FLAG = False # The flag of profiling, False by default + + disable_existing_loggers() + colossalai.launch_from_torch(config={}) + + logger = get_dist_logger() + logger.info(f"{args.model_type}, {args.distplan}, batch size {BATCH_SIZE}", ranks=[0]) + + # build criterion + criterion = GPTLMLoss() + + torch.manual_seed(123) + if args.distplan.startswith("CAI"): + # all param must use the same process group. + world_size = torch.distributed.get_world_size() + shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None + default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None + + if args.shardinit and args.distplan != "CAI_Gemini": + raise RuntimeError("You can only use shardinit with CAI_Gemini") + + # build GPT model + with ColoInitContext(device=get_current_device(), + dtype=torch.half, + default_dist_spec=default_dist_spec, + default_pg=shard_pg): + model = model_builder(args.model_type)(checkpoint=True) + + tp_pg = ProcessGroup(tp_degree=args.tp_degree) + # Tensor Parallelism (TP) + # You should notice that v0.1.10 is not compatible with TP degree > 1 + if args.tp_degree > 1: + tensor_parallelize(model, tp_pg) + + # asign running configurations + gemini_config = None + if args.distplan.startswith("CAI_ZeRO"): + optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True) + elif args.distplan == "CAI_Gemini": + gemini_config = dict(strict_ddp_mode=args.tp_degree == 1, + device=get_current_device(), + placement_policy=args.placement, + pin_memory=True, + hidden_dim=model.config.n_embd, + search_range_mb=128) + optim_config = dict(gpu_margin_mem_ratio=0.) + else: + raise RuntimeError + + # build a highly optimized gpu/cpu optimizer + optimizer = HybridAdam(model.parameters(), lr=1e-3) + + if args.distplan == "CAI_ZeRO1": + zero_stage = 1 + elif args.distplan == "CAI_ZeRO2": + zero_stage = 2 + elif args.distplan == "CAI_Gemini": + zero_stage = 3 + else: + raise RuntimeError + + # wrap your model and optimizer + model = zero_model_wrapper(model, zero_stage, gemini_config) + optimizer = zero_optim_wrapper(model, optimizer, optim_config=optim_config) + + logger.info(get_mem_info(prefix='After init optim, '), ranks=[0]) + elif args.distplan.startswith("Pytorch"): + assert args.tp_degree == 1, "The degree of TP should be 1 for DDP examples." + model = model_builder(args.model_type)(checkpoint=True).cuda() + model = DDP(model) + if args.distplan.endswith("DDP"): + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + elif args.distplan.endswith("ZeRO"): + from torch.distributed.optim import ZeroRedundancyOptimizer + optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=1e-3) + else: + raise RuntimeError + + # model is shared after TP + numel = get_model_size(model) + logger.info(f"the size of testing model size is {model_size_formatter(numel)}.") + logger.info(get_mem_info(prefix='After init model, '), ranks=[0]) + + # Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu + # = (batch_per_DP_group * dp_degree) * (numel * tp_degree) * seq_len * 8 / (tp_degree * dp_degree) + # = batch_per_DP_group * numel * seq_len * 8 + get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN) + + torch.cuda.synchronize() + model.train() + tflops_list = [] + + def train_step(): + # we just use randomly generated data here + input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE) + optimizer.zero_grad() + + start = time() + outputs = model(input_ids, attn_mask) + loss = criterion(outputs, input_ids) + torch.cuda.synchronize() + fwd_end = time() + fwd_time = fwd_end - start + logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Forward '), ranks=[0]) + + if args.distplan.startswith("CAI"): + optimizer.backward(loss) + elif args.distplan.startswith("Pytorch"): + loss.backward() + else: + raise RuntimeError + + torch.cuda.synchronize() + bwd_end = time() + bwd_time = bwd_end - fwd_end + logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Backward '), ranks=[0]) + + optimizer.step() + torch.cuda.synchronize() + optim_time = time() - bwd_end + step_time = time() - start + logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Optimizer step '), ranks=[0]) + + step_tflops = get_tflops_func(step_time) + logger.info( + f"[{n + 1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}, FWD time: {fwd_time:.3f}s, BWD time: {bwd_time:.3f}s, OPTIM time: {optim_time:.3f}s", + ranks=[0], + ) + if n >= WARMUP_STEPS: + tflops_list.append(step_tflops) + + demo_profiler = get_profile_context(PROF_FLAG, + WARMUP_STEPS, + NUM_STEPS - WARMUP_STEPS, + save_dir=f"profile/{get_time_stamp()}-demo") + + with demo_profiler as prof: + for n in range(NUM_STEPS): + train_step() + prof.step() + + tflops_list.sort() + median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS + logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}") + torch.cuda.synchronize() + + +if __name__ == '__main__': + main() diff --git a/examples/language/gpt/requirements.txt b/examples/language/gpt/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..ef58bb76bfc86d9eda33a76ee7687b6204d65f9c --- /dev/null +++ b/examples/language/gpt/requirements.txt @@ -0,0 +1,2 @@ +transformers >= 4.23 +colossalai diff --git a/examples/language/gpt/test_ci.sh b/examples/language/gpt/test_ci.sh new file mode 100644 index 0000000000000000000000000000000000000000..d67c17229e711ba0cafb0837c260f813ce595537 --- /dev/null +++ b/examples/language/gpt/test_ci.sh @@ -0,0 +1,2 @@ +set -x +cd gemini && bash test_ci.sh diff --git a/examples/language/gpt/titans/LICENSE b/examples/language/gpt/titans/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/examples/language/gpt/titans/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/examples/language/gpt/titans/README.md b/examples/language/gpt/titans/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e954f35fae0d2b60c3a7aedceabb7849c9f1c22d --- /dev/null +++ b/examples/language/gpt/titans/README.md @@ -0,0 +1,48 @@ +# Run GPT With Colossal-AI + +## How to Prepare Webtext Dataset + +You can download the preprocessed sample dataset for this demo via our [Google Drive sharing link](https://drive.google.com/file/d/1QKI6k-e2gJ7XgS8yIpgPPiMmwiBP_BPE/view?usp=sharing). + + +You can also avoid dataset preparation by using `--use_dummy_dataset` during running. + +## Run this Demo + +Use the following commands to install prerequisites. + +```bash +# assuming using cuda 11.3 +pip install -r requirements.txt +``` + +Use the following commands to execute training. + +```Bash +#!/usr/bin/env sh +# if you want to use real dataset, then remove --use_dummy_dataset +# export DATA=/path/to/small-gpt-dataset.json' + +# run on a single node +colossalai run --nproc_per_node= train_gpt.py --config configs/ --from_torch --use_dummy_dataset + +# run on multiple nodes +colossalai run --nproc_per_node= \ + --master_addr \ + --master_port \ + --hosts \ + train_gpt.py \ + --config configs/ \ + --from_torch \ + --use_dummy_dataset + +# run on multiple nodes with slurm +srun python \ + train_gpt.py \ + --config configs/ \ + --host \ + --use_dummy_dataset + +``` + +You can set the `` to any file in the `configs` folder. To simply get it running, you can start with `gpt_small_zero3_pp1d.py` on a single node first. You can view the explanations in the config file regarding how to change the parallel setting. diff --git a/examples/language/gpt/titans/configs/gpt2_small_zero3_pp1d.py b/examples/language/gpt/titans/configs/gpt2_small_zero3_pp1d.py new file mode 100644 index 0000000000000000000000000000000000000000..7bf53303948a641a5e74504e9979a336699f5aaf --- /dev/null +++ b/examples/language/gpt/titans/configs/gpt2_small_zero3_pp1d.py @@ -0,0 +1,31 @@ +from model import GPT2_small_pipeline_hybrid + +from colossalai.nn.optimizer import HybridAdam +from colossalai.zero.shard_utils import TensorShardStrategy + +BATCH_SIZE = 8 +NUM_EPOCHS = 10 +SEQ_LEN = 1024 +NUM_MICRO_BATCHES = 4 +HIDDEN_SIZE = 768 +TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LEN, HIDDEN_SIZE) + +# if you do no want zero, just comment out this dictionary +zero = dict(model_config=dict(tensor_placement_policy='cuda', shard_strategy=TensorShardStrategy()), + optimizer_config=dict(initial_scale=2**5)) + +optimizer = dict( + type=HybridAdam, + lr=0.000015, + weight_decay=1e-2, +) + +model = dict(type=GPT2_small_pipeline_hybrid, checkpoint=True, num_chunks=1) + +# pipeline parallel: modify integer value for the number of pipeline stages +# tensor parallel: modify size to set the tensor parallel size, usually the number of GPUs per node +# for the current model implementation, mode can only be 1D or None +parallel = dict( + pipeline=1, + tensor=dict(size=2, mode='1d'), +) diff --git a/examples/language/gpt/titans/configs/gpt3_zero3_pp1d.py b/examples/language/gpt/titans/configs/gpt3_zero3_pp1d.py new file mode 100644 index 0000000000000000000000000000000000000000..9f9816b3004f8108dec385da6eff94499ce730e4 --- /dev/null +++ b/examples/language/gpt/titans/configs/gpt3_zero3_pp1d.py @@ -0,0 +1,31 @@ +from model import GPT3_pipeline_hybrid + +from colossalai.nn.optimizer import HybridAdam +from colossalai.zero.shard_utils import TensorShardStrategy + +BATCH_SIZE = 192 +NUM_EPOCHS = 60 +SEQ_LEN = 2048 +NUM_MICRO_BATCHES = 192 +HIDDEN_SIZE = 12288 +TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LEN, HIDDEN_SIZE) + +# if you do no want zero, just comment out this dictionary +zero = dict(model_config=dict(tensor_placement_policy='cuda', shard_strategy=TensorShardStrategy()), + optimizer_config=dict(initial_scale=2**16)) + +optimizer = dict( + type=HybridAdam, + lr=0.00015, + weight_decay=1e-2, +) + +model = dict(type=GPT3_pipeline_hybrid, checkpoint=True, num_chunks=1) + +# pipeline parallel: modify integer value for the number of pipeline stages +# tensor parallel: modify size to set the tensor parallel size, usually the number of GPUs per node +# for the current model implementation, mode can only be 1D or None +parallel = dict( + pipeline=1, + tensor=dict(size=2, mode='1d'), # for the current model implementation, mode can only be 1D or None +) diff --git a/examples/language/gpt/titans/dataset/webtext.py b/examples/language/gpt/titans/dataset/webtext.py new file mode 100644 index 0000000000000000000000000000000000000000..64f5944a97f927a818b092613179fb064466a730 --- /dev/null +++ b/examples/language/gpt/titans/dataset/webtext.py @@ -0,0 +1,43 @@ +import json +import os +from typing import Optional + +import torch +from torch.utils.data import Dataset +from transformers import GPT2Tokenizer + +from colossalai.registry import DATASETS + + +@DATASETS.register_module +class WebtextDataset(Dataset): + + def __init__(self, path: Optional[str] = None, seq_len=1024) -> None: + super().__init__() + if path is not None: + root = os.path.dirname(path) + encoded_data_cache_path = os.path.join(root, f'gpt_webtext_{seq_len}.pt') + if os.path.isfile(encoded_data_cache_path): + seq_len_, data, attention_mask = torch.load(encoded_data_cache_path) + if seq_len_ == seq_len: + self.data = data + self.attention_mask = attention_mask + return + raw_data = [] + with open(path) as f: + for line in f.readlines(): + raw_data.append(json.loads(line)['text']) + tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + tokenizer.pad_token = tokenizer.unk_token + encoded_data = tokenizer(raw_data, padding=True, truncation=True, max_length=seq_len, return_tensors='pt') + self.data = encoded_data['input_ids'] + self.attention_mask = encoded_data['attention_mask'] + else: + self.data = torch.randint(0, 50257, (10240, seq_len)) + self.attention_mask = torch.ones_like(self.data) + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + return {'input_ids': self.data[index], 'attention_mask': self.attention_mask[index]}, self.data[index] diff --git a/examples/language/gpt/titans/model/__init__.py b/examples/language/gpt/titans/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eec48ef893fbc401db067701b645949f17008c40 --- /dev/null +++ b/examples/language/gpt/titans/model/__init__.py @@ -0,0 +1,3 @@ +from .embed import vocab_parallel_cross_entropy +from .gpt1d import * +from .pipeline_gpt1d import * diff --git a/examples/language/gpt/titans/model/embed.py b/examples/language/gpt/titans/model/embed.py new file mode 100644 index 0000000000000000000000000000000000000000..6369b9f8c5a136b534a316718cf195888e6d57cc --- /dev/null +++ b/examples/language/gpt/titans/model/embed.py @@ -0,0 +1,599 @@ +import torch +import torch.nn.init as init +from torch import Tensor +from torch import distributed as dist +from torch import nn as nn +from torch.nn import functional as F +from torch.nn.parameter import Parameter + +from colossalai.context import ParallelMode, seed +from colossalai.core import global_context as gpc +from colossalai.nn.layer.base_layer import ParallelLayer +from colossalai.nn.layer.parallel_1d._utils import gather_forward_split_backward, reduce_grad, reduce_input +from colossalai.nn.layer.parallel_1d.layers import Linear1D_Row +from colossalai.nn.layer.utils import divide +from colossalai.registry import LAYERS, LOSSES, MODELS +from colossalai.utils import get_current_device + + +class VocabParallelEmbedding(torch.nn.Module): + """Language model embeddings. + + Arguments: + hidden_size: hidden size + vocab_size: vocabulary size + max_sequence_length: maximum size of sequence. This + is used for positional embedding + embedding_dropout_prob: dropout probability for embeddings + init_method: weight initialization method + num_tokentypes: size of the token-type embeddings. 0 value + will ignore this embedding + """ + + def __init__(self, + hidden_size, + vocab_size, + max_sequence_length, + embedding_dropout_prob, + num_tokentypes=0, + dtype=torch.float): + super(VocabParallelEmbedding, self).__init__() + + self.hidden_size = hidden_size + self.num_tokentypes = num_tokentypes + + # Word embeddings (parallel). + self.word_embeddings = VocabParallelEmbedding1D(vocab_size, self.hidden_size, dtype=dtype) + self._word_embeddings_key = 'word_embeddings' + + # Position embedding (serial). + self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size, dtype=dtype) + self._position_embeddings_key = 'position_embeddings' + # Initialize the position embeddings. + # self.init_method(self.position_embeddings.weight) + + # Token type embedding. + # Add this as an optional field that can be added through + # method call so we can load a pretrain model without + # token types and add them as needed. + self._tokentype_embeddings_key = 'tokentype_embeddings' + if self.num_tokentypes > 0: + self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size, dtype=dtype) + # Initialize the token-type embeddings. + # self.init_method(self.tokentype_embeddings.weight) + else: + self.tokentype_embeddings = None + + # Embeddings dropout + self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) + + def zero_parameters(self): + """Zero out all parameters in embedding.""" + self.word_embeddings.weight.data.fill_(0) + self.word_embeddings.weight.shared = True + self.position_embeddings.weight.data.fill_(0) + self.position_embeddings.weight.shared = True + if self.num_tokentypes > 0: + self.tokentype_embeddings.weight.data.fill_(0) + self.tokentype_embeddings.weight.shared = True + + def add_tokentype_embeddings(self, num_tokentypes): + """Add token-type embedding. This function is provided so we can add + token-type embeddings in case the pretrained model does not have it. + This allows us to load the model normally and then add this embedding. + """ + if self.tokentype_embeddings is not None: + raise Exception('tokentype embeddings is already initialized') + if torch.distributed.get_rank() == 0: + print('adding embedding for {} tokentypes'.format(num_tokentypes), flush=True) + self.num_tokentypes = num_tokentypes + self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size) + # Initialize the token-type embeddings. + # self.init_method(self.tokentype_embeddings.weight) + + def forward(self, input_ids, position_ids=None, tokentype_ids=None): + # Embeddings. + if input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + words_embeddings = self.word_embeddings(input_ids) + + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) + if position_ids is None: + position_ids = torch.arange(0, input_shape[-1] + 0, dtype=torch.long, device=get_current_device()) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + position_embeddings = self.position_embeddings(position_ids) + + embeddings = words_embeddings + position_embeddings + + # Dropout. + with seed(ParallelMode.TENSOR): + embeddings = self.embedding_dropout(embeddings) + return embeddings + + def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): + """For easy load.""" + + state_dict_ = {} + state_dict_[self._word_embeddings_key] \ + = self.word_embeddings.state_dict(destination, prefix, keep_vars) + state_dict_[self._position_embeddings_key] \ + = self.position_embeddings.state_dict( + destination, prefix, keep_vars) + if self.num_tokentypes > 0: + state_dict_[self._tokentype_embeddings_key] \ + = self.tokentype_embeddings.state_dict( + destination, prefix, keep_vars) + + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + + # Word embedding. + if self._word_embeddings_key in state_dict: + state_dict_ = state_dict[self._word_embeddings_key] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if 'word_embeddings' in key: + state_dict_[key.split('word_embeddings.')[1]] \ + = state_dict[key] + self.word_embeddings.load_state_dict(state_dict_, strict=strict) + + # Position embedding. + if self._position_embeddings_key in state_dict: + state_dict_ = state_dict[self._position_embeddings_key] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if 'position_embeddings' in key: + state_dict_[key.split('position_embeddings.')[1]] \ + = state_dict[key] + self.position_embeddings.load_state_dict(state_dict_, strict=strict) + + # Tokentype embedding. + if self.num_tokentypes > 0: + state_dict_ = {} + if self._tokentype_embeddings_key in state_dict: + state_dict_ = state_dict[self._tokentype_embeddings_key] + else: + # for backward compatibility. + for key in state_dict.keys(): + if 'tokentype_embeddings' in key: + state_dict_[key.split('tokentype_embeddings.')[1]] \ + = state_dict[key] + if len(state_dict_.keys()) > 0: + self.tokentype_embeddings.load_state_dict(state_dict_, strict=strict) + else: + print('***WARNING*** expected tokentype embeddings in the ' + 'checkpoint but could not find it', + flush=True) + + +class VocabParallelEmbedding1D(torch.nn.Module): + """Embedding parallelized in the vocabulary dimension. + + This is mainly adapted from torch.nn.Embedding and all the default + values are kept. + Arguments: + num_embeddings: vocabulary size. + embedding_dim: size of hidden state. + init_method: method to initialize weights. + """ + + def __init__(self, num_embeddings, embedding_dim, dtype=None, init_method=None): + super(VocabParallelEmbedding1D, self).__init__() + # Keep the input dimensions. + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + # Set the details for compatibility. + self.padding_idx = None + self.max_norm = None + self.norm_type = 2. + self.scale_grad_by_freq = False + self.sparse = False + self._weight = None + self.tensor_model_parallel_size = gpc.tensor_parallel_size + # Divide the weight matrix along the vocabulary dimension. + self.vocab_start_index, self.vocab_end_index = \ + VocabUtility.vocab_range_from_global_vocab_size( + self.num_embeddings, gpc.get_local_rank(ParallelMode.PARALLEL_1D), + self.tensor_model_parallel_size) + self.num_embeddings_per_partition = self.vocab_end_index - \ + self.vocab_start_index + + # Allocate weights and initialize. + factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + self.weight = Parameter(torch.empty(self.num_embeddings_per_partition, self.embedding_dim, **factory_kwargs)) + init.uniform_(self.weight, -1, 1) + + def forward(self, input_): + if self.tensor_model_parallel_size > 1: + # Build the mask. + input_mask = (input_ < self.vocab_start_index) | \ + (input_ >= self.vocab_end_index) + # Mask the input. + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + else: + masked_input = input_ + # Get the embeddings. + output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, self.max_norm, self.norm_type, + self.scale_grad_by_freq, self.sparse) + # Mask the output embedding. + if self.tensor_model_parallel_size > 1: + output_parallel[input_mask, :] = 0.0 + # Reduce across all the model parallel GPUs. + output = output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) + return output + + +@LOSSES.register_module +class vocab_parallel_cross_entropy(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, vocab_parallel_logits, target): + """Helper function for the cross entropy.""" + vocab_parallel_logits = vocab_parallel_logits[..., :-1, :].contiguous() + target = target[..., 1:].contiguous() + return _VocabParallelCrossEntropy.apply(vocab_parallel_logits.view(-1, vocab_parallel_logits.size(-1)), + target.view(-1)) + + +class _VocabParallelCrossEntropy(torch.autograd.Function): + + @staticmethod + def forward(ctx, vocab_parallel_logits, target): + + # Maximum value along vocab dimension across all GPUs. + logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] + torch.distributed.all_reduce(logits_max, + op=torch.distributed.ReduceOp.MAX, + group=gpc.get_group(ParallelMode.PARALLEL_1D)) + # Subtract the maximum value. + vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) + + # Get the partition's vocab indices + get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size + partition_vocab_size = vocab_parallel_logits.size()[-1] + rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + world_size = gpc.tensor_parallel_size + vocab_start_index, vocab_end_index = get_vocab_range(partition_vocab_size, rank, world_size) + + # Create a mask of valid vocab ids (1 means it needs to be masked). + target_mask = (target < vocab_start_index) | (target >= vocab_end_index) + masked_target = target.clone() - vocab_start_index + masked_target[target_mask] = 0 + + # Get predicted-logits = logits[target]. + # For Simplicity, we convert logits to a 2-D tensor with size + # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. + logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) + masked_target_1d = masked_target.view(-1) + arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) + predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] + predicted_logits_1d = predicted_logits_1d.clone().contiguous() + predicted_logits = predicted_logits_1d.view_as(target) + predicted_logits[target_mask] = 0.0 + # All reduce is needed to get the chunks from other GPUs. + torch.distributed.all_reduce(predicted_logits, + op=torch.distributed.ReduceOp.SUM, + group=gpc.get_group(ParallelMode.PARALLEL_1D)) + + # Sum of exponential of logits along vocab dimension across all GPUs. + exp_logits = vocab_parallel_logits + torch.exp(vocab_parallel_logits, out=exp_logits) + sum_exp_logits = exp_logits.sum(dim=-1) + torch.distributed.all_reduce(sum_exp_logits, + op=torch.distributed.ReduceOp.SUM, + group=gpc.get_group(ParallelMode.PARALLEL_1D)) + + # Loss = log(sum(exp(logits))) - predicted-logit. + loss = torch.log(sum_exp_logits) - predicted_logits + loss = loss.mean() + # Store softmax, target-mask and masked-target for backward pass. + exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) + return loss + + @staticmethod + def backward(ctx, grad_output): + + # Retreive tensors from the forward path. + softmax, target_mask, masked_target_1d = ctx.saved_tensors + + # All the inputs have softmax as their gradient. + grad_input = softmax + # For simplicity, work with the 2D gradient. + partition_vocab_size = softmax.size()[-1] + grad_2d = grad_input.view(-1, partition_vocab_size) + + # Add the gradient from matching classes. + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) + grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float()) + + # Finally elementwise multiplication with the output gradients. + grad_input.mul_(grad_output.unsqueeze(dim=-1)) + + return grad_input, None + + +class VocabUtility: + """Split the vocabulary into `world_size` chunks amd return the + first and last index of the vocabulary belonging to the `rank` + partition: Note that indices in [fist, last)""" + + @staticmethod + def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size): + index_f = rank * per_partition_vocab_size + index_l = index_f + per_partition_vocab_size + return index_f, index_l + + @staticmethod + def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size): + per_partition_vocab_size = divide(global_vocab_size, world_size) + return VocabUtility.vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size) + + +class VocabParallelGPTLMHead1D(ParallelLayer): + """ + Language model head that shares the same parameters with the embedding matrix. + """ + + def __init__(self, embed=None, vocab_size=None, dtype=None, embed_dim=None): + super().__init__() + if embed is not None: + self.head = embed + else: + self.head = VocabParallelEmbedding1D(vocab_size, embed_dim, dtype=dtype) + + def forward(self, x: Tensor) -> Tensor: + x = reduce_grad(x, ParallelMode.PARALLEL_1D) + x = F.linear(x, self.head.weight) + return x + + +################################### + + +class HiddenParallelEmbedding(torch.nn.Module): + """Language model embeddings. + + Arguments: + hidden_size: hidden size + vocab_size: vocabulary size + max_sequence_length: maximum size of sequence. This + is used for positional embedding + embedding_dropout_prob: dropout probability for embeddings + init_method: weight initialization method + num_tokentypes: size of the token-type embeddings. 0 value + will ignore this embedding + """ + + def __init__( + self, + hidden_size, + vocab_size, + max_sequence_length, + embedding_dropout_prob, + dtype=torch.float, + padding_idx: int = 0, + num_tokentypes=0, + ): + super(HiddenParallelEmbedding, self).__init__() + + self.hidden_size = hidden_size + self.num_tokentypes = num_tokentypes + + # Word embeddings (parallel). + self.word_embeddings = HiddenParallelEmbedding1D(vocab_size, hidden_size, dtype, padding_idx) + self._word_embeddings_key = 'word_embeddings' + + # Position embedding (serial). + self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size) + self._position_embeddings_key = 'position_embeddings' + # Initialize the position embeddings. + # self.init_method(self.position_embeddings.weight) + + # Token type embedding. + # Add this as an optional field that can be added through + # method call so we can load a pretrain model without + # token types and add them as needed. + self._tokentype_embeddings_key = 'tokentype_embeddings' + if self.num_tokentypes > 0: + self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size) + # Initialize the token-type embeddings. + # self.init_method(self.tokentype_embeddings.weight) + else: + self.tokentype_embeddings = None + + # Embeddings dropout + self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) + + def zero_parameters(self): + """Zero out all parameters in embedding.""" + self.word_embeddings.weight.data.fill_(0) + self.word_embeddings.weight.shared = True + self.position_embeddings.weight.data.fill_(0) + self.position_embeddings.weight.shared = True + if self.num_tokentypes > 0: + self.tokentype_embeddings.weight.data.fill_(0) + self.tokentype_embeddings.weight.shared = True + + def add_tokentype_embeddings(self, num_tokentypes): + """Add token-type embedding. This function is provided so we can add + token-type embeddings in case the pretrained model does not have it. + This allows us to load the model normally and then add this embedding. + """ + if self.tokentype_embeddings is not None: + raise Exception('tokentype embeddings is already initialized') + if torch.distributed.get_rank() == 0: + print('adding embedding for {} tokentypes'.format(num_tokentypes), flush=True) + self.num_tokentypes = num_tokentypes + self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size) + # Initialize the token-type embeddings. + # self.init_method(self.tokentype_embeddings.weight) + + def forward(self, input_ids, position_ids=None, tokentype_ids=None): + if input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + words_embeddings = self.word_embeddings(input_ids) + + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) + if position_ids is None: + position_ids = torch.arange(0, input_shape[-1] + 0, dtype=torch.long, device=get_current_device()) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + position_embeddings = self.position_embeddings(position_ids) + + embeddings = words_embeddings + position_embeddings + + # Dropout. + with seed(ParallelMode.TENSOR): + embeddings = self.embedding_dropout(embeddings) + return embeddings + + def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): + """For easy load.""" + + state_dict_ = {} + state_dict_[self._word_embeddings_key] \ + = self.word_embeddings.state_dict(destination, prefix, keep_vars) + state_dict_[self._position_embeddings_key] \ + = self.position_embeddings.state_dict( + destination, prefix, keep_vars) + if self.num_tokentypes > 0: + state_dict_[self._tokentype_embeddings_key] \ + = self.tokentype_embeddings.state_dict( + destination, prefix, keep_vars) + + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + + # Word embedding. + if self._word_embeddings_key in state_dict: + state_dict_ = state_dict[self._word_embeddings_key] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if 'word_embeddings' in key: + state_dict_[key.split('word_embeddings.')[1]] \ + = state_dict[key] + self.word_embeddings.load_state_dict(state_dict_, strict=strict) + + # Position embedding. + if self._position_embeddings_key in state_dict: + state_dict_ = state_dict[self._position_embeddings_key] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if 'position_embeddings' in key: + state_dict_[key.split('position_embeddings.')[1]] \ + = state_dict[key] + self.position_embeddings.load_state_dict(state_dict_, strict=strict) + + # Tokentype embedding. + if self.num_tokentypes > 0: + state_dict_ = {} + if self._tokentype_embeddings_key in state_dict: + state_dict_ = state_dict[self._tokentype_embeddings_key] + else: + # for backward compatibility. + for key in state_dict.keys(): + if 'tokentype_embeddings' in key: + state_dict_[key.split('tokentype_embeddings.')[1]] \ + = state_dict[key] + if len(state_dict_.keys()) > 0: + self.tokentype_embeddings.load_state_dict(state_dict_, strict=strict) + else: + print('***WARNING*** expected tokentype embeddings in the ' + 'checkpoint but could not find it', + flush=True) + + +class HiddenParallelEmbedding1D(torch.nn.Module): + """Embedding parallelized in the vocabulary dimension. + + This is mainly adapted from torch.nn.Embedding and all the default + values are kept. + Arguments: + num_embeddings: vocabulary size. + embedding_dim: size of hidden state. + init_method: method to initialize weights. + """ + + def __init__(self, num_embeddings, embedding_dim, dtype=torch.float, padding_idx: int = None, init_method=None): + super(HiddenParallelEmbedding1D, self).__init__() + # Keep the input dimensions. + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + embed_dim_per_partition = divide(embedding_dim, gpc.tensor_parallel_size) + # Set the details for compatibility. + self.padding_idx = padding_idx + self.max_norm = None + self.norm_type = 2. + self.scale_grad_by_freq = False + self.sparse = False + self._weight = None + + # Allocate weights and initialize. + factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + self.weight = Parameter(torch.empty(num_embeddings, embed_dim_per_partition, **factory_kwargs)) + init.uniform_(self.weight, -1, 1) + + def forward(self, input_): + + # Get the embeddings. + output_parallel = F.embedding(input_, self.weight, self.padding_idx, self.max_norm, self.norm_type, + self.scale_grad_by_freq, self.sparse) + + # Reduce across all the model parallel GPUs. + output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) + return output + + +@LAYERS.register_module +class HiddenParallelGPTLMHead1D(ParallelLayer): + """ + Language model head that shares the same parameters with the embedding matrix. + """ + + def __init__( + self, + embed=None, + embed_dim=None, + vocab_size=None, + dtype=None, + ): + super().__init__() + if embed is not None: + self.head = embed + self.synced_embed = True + else: + # self.embedding = HiddenParallelEmbedding1D(vocab_size, hidden_size, dtype, padding_idx) + # (hidden_size/q, vocab_size) + self.synced_embed = False + self.head = Linear1D_Row(in_features=embed_dim, + out_features=vocab_size, + bias=False, + dtype=dtype, + parallel_input=False) + + def forward(self, x: Tensor) -> Tensor: + if self.synced_embed: + x = F.linear(x, self.head.weight) + else: + x = self.head(x) + + return x diff --git a/examples/language/gpt/titans/model/gpt1d.py b/examples/language/gpt/titans/model/gpt1d.py new file mode 100644 index 0000000000000000000000000000000000000000..2edd03606b7da3c7ba5418ec2ace2f1841576670 --- /dev/null +++ b/examples/language/gpt/titans/model/gpt1d.py @@ -0,0 +1,349 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import math + +import torch +from torch import Tensor +from torch import nn as nn + +from colossalai import kernel +from colossalai import nn as col_nn +from colossalai.core import global_context as gpc +from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType +from colossalai.nn.layer import Linear1D_Col, Linear1D_Row +from colossalai.nn.layer.base_layer import ParallelLayer +from colossalai.nn.layer.utils import ACT2FN, divide +from colossalai.utils import checkpoint +from colossalai.utils.activation_checkpoint import checkpoint + +__all__ = [ + 'GPTMLP1D', 'GPTSelfAttention1D', 'GPTTransformerLayer1D', 'FusedGPTSelfAttention1D', 'FusedGPTTransformerLayer1D' +] + + +class GPTMLP1D(ParallelLayer): + + def __init__( + self, + in_features: int, + mlp_ratio: int, + act_func: str = 'gelu', + dropout_prob: float = 0., + dtype=None, + checkpoint: bool = False, + skip_bias_add: bool = False, + ): + super().__init__() + + self.in_features = in_features + self.mlp_ratio = mlp_ratio + self.checkpoint = checkpoint + self.skip_bias_add = skip_bias_add + + self.act = ACT2FN[act_func] + skip_dense_1_add_bias = False + + # Project to mlp_ratio * h. + self.dense_1 = Linear1D_Col( + self.in_features, + int(self.mlp_ratio * self.in_features), + dtype=dtype, + gather_output=False, + skip_bias_add=skip_dense_1_add_bias, + ) + + # Project back to h. + self.dense_2 = Linear1D_Row( + int(self.mlp_ratio * self.in_features), + self.in_features, + dtype=dtype, + parallel_input=True, + ) + + self.dropout = col_nn.Dropout(dropout_prob) + + def _forward(self, hidden_states: Tensor) -> Tensor: + intermediate_output = self.dense_1(hidden_states) + intermediate_output = self.act(intermediate_output) + + output = self.dense_2(intermediate_output) + output = self.dropout(output) + return output + + def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor: + return checkpoint(self._forward, False, hidden_states) + + def forward(self, hidden_states: Tensor) -> Tensor: + if self.checkpoint: + return self._checkpoint_forward(hidden_states) + else: + return self._forward(hidden_states) + + +class GenericGPTSelfAttention1D(ParallelLayer): + + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + attention_dropout_prob: float, + hidden_dropout_prob: float, + dtype=None, + checkpoint: bool = False, + max_position_embeddings=1024, + ): + super().__init__() + self.hidden_size = hidden_size + self.attention_head_size = divide(hidden_size, num_attention_heads) + self.num_attention_heads_per_partition = divide(num_attention_heads, gpc.tensor_parallel_size) + self.hidden_size_per_partition = divide(hidden_size, gpc.tensor_parallel_size) + self.checkpoint = checkpoint + self.query_key_value = Linear1D_Col( + hidden_size, + 3 * hidden_size, + dtype=dtype, + ) + self.attention_dropout = col_nn.Dropout(attention_dropout_prob) + self.dense = Linear1D_Row( + hidden_size, + hidden_size, + dtype=dtype, + parallel_input=True, + ) + self.dropout = col_nn.Dropout(hidden_dropout_prob) + + def softmax_forward(self, attention_scores, attention_mask, query_layer, key_layer): + raise NotImplementedError + + def _forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor: + query_key_value = self.query_key_value(hidden_states) + new_qkv_shape = query_key_value.shape[:-1] + \ + (self.num_attention_heads_per_partition, 3 * self.attention_head_size) + query_key_value = query_key_value.view(new_qkv_shape) + query_key_value = query_key_value.permute((0, 2, 1, 3)) + query_layer, key_layer, value_layer = torch.chunk(query_key_value, 3, dim=-1) + + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = self.softmax_forward(attention_scores, attention_mask, query_layer, key_layer) + + attention_scores = attention_scores.type(value_layer.dtype) + + attention_probs = self.attention_dropout(attention_scores) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.transpose(1, 2) + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.reshape(new_context_layer_shape) + output = self.dense(context_layer) + output = self.dropout(output) + + return output + + def _checkpoint_forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor: + return checkpoint(self._forward, False, hidden_states, attention_mask) + + def forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor: + if self.checkpoint: + return self._checkpoint_forward(hidden_states, attention_mask) + else: + return self._forward(hidden_states, attention_mask) + + +class GPTSelfAttention1D(GenericGPTSelfAttention1D): + + def __init__(self, + hidden_size: int, + num_attention_heads: int, + attention_dropout_prob: float, + hidden_dropout_prob: float, + dtype=None, + checkpoint: bool = False, + max_position_embeddings=1024): + super().__init__(hidden_size, + num_attention_heads, + attention_dropout_prob, + hidden_dropout_prob, + dtype=dtype, + checkpoint=checkpoint, + max_position_embeddings=max_position_embeddings) + self.softmax = nn.Softmax(dim=-1) + max_positions = max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), + dtype=torch.uint8)).view(1, 1, max_positions, max_positions), + ) + self.register_buffer("masked_bias", torch.tensor(-1e4)) + + def softmax_forward(self, attention_scores, attention_mask, query_layer, key_layer): + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + # causal mask + query_length, key_length = query_layer.size(-2), key_layer.size(-2) + causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length].bool() + attention_scores = torch.where(causal_mask, attention_scores, self.masked_bias.to(attention_scores)) + if attention_mask is not None: + # Apply the attention mask + attention_scores = attention_scores + attention_mask + attention_scores = self.softmax(attention_scores) + return attention_scores + + +class FusedGPTSelfAttention1D(GenericGPTSelfAttention1D): + + def __init__(self, + hidden_size: int, + num_attention_heads: int, + attention_dropout_prob: float, + hidden_dropout_prob: float, + dtype=None, + checkpoint: bool = False, + max_position_embeddings=1024): + super().__init__(hidden_size, + num_attention_heads, + attention_dropout_prob, + hidden_dropout_prob, + dtype=dtype, + checkpoint=checkpoint, + max_position_embeddings=max_position_embeddings) + self.softmax = kernel.FusedScaleMaskSoftmax(input_in_fp16=True, + input_in_bf16=False, + attn_mask_type=AttnMaskType.causal, + scaled_masked_softmax_fusion=True, + mask_func=None, + softmax_in_fp32=True, + scale=math.sqrt(self.attention_head_size)) + + def softmax_forward(self, attention_scores, attention_mask, query_layer, key_layer): + return self.softmax(attention_scores, attention_mask) + + +class GenericGPTTransformerLayer1D(ParallelLayer): + + def __init__(self, + hidden_size: int, + num_attention_heads: int, + act_func: str = 'gelu', + mlp_ratio: float = 4.0, + attention_dropout_prob: float = 0., + hidden_dropout_prob: float = 0., + dtype=None, + checkpoint: bool = False, + max_position_embeddings: int = 1024, + layer_norm_epsilon: float = 1e-5, + apply_post_layer_norm: bool = False, + attention=None, + layer_norm=None): + super().__init__() + self.checkpoint = checkpoint + self.dtype = dtype + self.norm1 = layer_norm(hidden_size, eps=layer_norm_epsilon) + self.apply_post_layer_norm = apply_post_layer_norm + self.attention = attention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + attention_dropout_prob=attention_dropout_prob, + hidden_dropout_prob=hidden_dropout_prob, + dtype=dtype, + max_position_embeddings=max_position_embeddings, + checkpoint=False, + ) + + self.norm2 = layer_norm(hidden_size, eps=layer_norm_epsilon) + self.mlp = GPTMLP1D( + in_features=hidden_size, + dropout_prob=hidden_dropout_prob, + act_func=act_func, + mlp_ratio=mlp_ratio, + dtype=dtype, + checkpoint=False, + ) + + def _forward(self, hidden_states, attention_mask) -> Tensor: + if not self.apply_post_layer_norm: + residual = hidden_states + hidden_states = self.norm1(hidden_states) + if self.apply_post_layer_norm: + residual = hidden_states + attention_output = self.attention(hidden_states, attention_mask) + hidden_states = residual + attention_output + + if not self.apply_post_layer_norm: + residual = hidden_states + hidden_states = self.norm2(hidden_states) + if self.apply_post_layer_norm: + residual = hidden_states + feed_forward_hidden_states = self.mlp(hidden_states) + hidden_states = residual + feed_forward_hidden_states + + output = (hidden_states, attention_mask) + return output + + def forward(self, hidden_states, attention_mask): + if self.checkpoint: + return checkpoint(self._forward, False, hidden_states, attention_mask) + else: + return self._forward(hidden_states, attention_mask) + + +class GPTTransformerLayer1D(GenericGPTTransformerLayer1D): + + def __init__(self, + hidden_size: int, + num_attention_heads: int, + act_func: str = 'gelu', + mlp_ratio: float = 4, + attention_dropout_prob: float = 0, + hidden_dropout_prob: float = 0, + dtype=None, + checkpoint: bool = False, + max_position_embeddings: int = 1024, + layer_norm_epsilon: float = 0.00001, + apply_post_layer_norm: bool = False): + attention = GPTSelfAttention1D + layer_norm = nn.LayerNorm + super().__init__(hidden_size, + num_attention_heads, + act_func=act_func, + mlp_ratio=mlp_ratio, + attention_dropout_prob=attention_dropout_prob, + hidden_dropout_prob=hidden_dropout_prob, + dtype=dtype, + checkpoint=checkpoint, + max_position_embeddings=max_position_embeddings, + layer_norm_epsilon=layer_norm_epsilon, + apply_post_layer_norm=apply_post_layer_norm, + attention=attention, + layer_norm=layer_norm) + + +class FusedGPTTransformerLayer1D(GenericGPTTransformerLayer1D): + + def __init__(self, + hidden_size: int, + num_attention_heads: int, + act_func: str = 'gelu', + mlp_ratio: float = 4, + attention_dropout_prob: float = 0, + hidden_dropout_prob: float = 0, + dtype=None, + checkpoint: bool = False, + max_position_embeddings: int = 1024, + layer_norm_epsilon: float = 0.00001, + apply_post_layer_norm: bool = False): + attention = FusedGPTSelfAttention1D + layer_norm = kernel.LayerNorm + super().__init__(hidden_size, + num_attention_heads, + act_func=act_func, + mlp_ratio=mlp_ratio, + attention_dropout_prob=attention_dropout_prob, + hidden_dropout_prob=hidden_dropout_prob, + dtype=dtype, + checkpoint=checkpoint, + max_position_embeddings=max_position_embeddings, + layer_norm_epsilon=layer_norm_epsilon, + apply_post_layer_norm=apply_post_layer_norm, + attention=attention, + layer_norm=layer_norm) diff --git a/examples/language/gpt/titans/model/pipeline_gpt1d.py b/examples/language/gpt/titans/model/pipeline_gpt1d.py new file mode 100644 index 0000000000000000000000000000000000000000..30180285bc70fefaec5610a4b2d284a946a25d2c --- /dev/null +++ b/examples/language/gpt/titans/model/pipeline_gpt1d.py @@ -0,0 +1,322 @@ +import inspect + +# import model_zoo.gpt.gpt as col_gpt +import titans.model.gpt.gpt as col_gpt +import torch +import torch.nn as nn + +from colossalai import kernel +from colossalai import nn as col_nn +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper +from colossalai.pipeline.utils import partition_uniform + +from .embed import HiddenParallelEmbedding, HiddenParallelGPTLMHead1D, VocabParallelEmbedding, VocabParallelGPTLMHead1D +from .gpt1d import FusedGPTTransformerLayer1D, GPTTransformerLayer1D + +__all__ = [ + 'GPT2_small_pipeline_1D', + 'GPT2_exlarge_pipeline_1D', + 'GPT3_pipeline_1D', + 'GPT2_exlarge_pipeline_hybrid', + 'GPT2_small_pipeline_hybrid', + 'GPT3_pipeline_hybrid', +] + + +class GenericPipelineGPT(nn.Module): + + def __init__(self, embedding=None, blocks=None, norm=None, head=None) -> None: + super().__init__() + self.embedding = embedding + self.blocks = blocks + self.norm = norm + self.head = head + assert blocks is not None + if norm is not None or head is not None: + assert norm is not None and head is not None + + def forward(self, hidden_states=None, input_ids=None, attention_mask=None): + if self.embedding is not None: + hidden_states = self.embedding(input_ids=input_ids) + batch_size = hidden_states.shape[0] + attention_mask = attention_mask.view(batch_size, -1) + attention_mask = attention_mask[:, None, None, :] + attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * -10000.0 + for block in self.blocks: + hidden_states, attention_mask = block(hidden_states, attention_mask) + if self.norm is not None: + hidden_states = self.head(self.norm(hidden_states)) + return hidden_states + + +class PipelineGPT1D(GenericPipelineGPT): + + def __init__(self, + num_layers: int = 12, + hidden_size: int = 768, + num_attention_heads: int = 12, + vocab_size: int = 50304, + embed_drop_rate: float = 0., + act_func: str = 'gelu', + mlp_ratio: int = 4.0, + attn_drop_rate: float = 0., + drop_rate: float = 0., + dtype: torch.dtype = torch.float, + checkpoint: bool = False, + max_position_embeddings: int = 1024, + layer_norm_epsilon: float = 1e-5, + apply_post_layer_norm: bool = False, + first: bool = False, + last: bool = False, + embed_split_hidden=False): + embedding = None + norm = None + head = None + embed_cls = VocabParallelEmbedding + head_cls = VocabParallelGPTLMHead1D + if embed_split_hidden: + embed_cls = HiddenParallelEmbedding + head_cls = HiddenParallelGPTLMHead1D + if first: + embedding = embed_cls(hidden_size, vocab_size, max_position_embeddings, embed_drop_rate, dtype=dtype) + blocks = nn.ModuleList([ + GPTTransformerLayer1D(hidden_size, + num_attention_heads, + act_func=act_func, + mlp_ratio=mlp_ratio, + attention_dropout_prob=attn_drop_rate, + hidden_dropout_prob=drop_rate, + dtype=dtype, + checkpoint=checkpoint, + max_position_embeddings=max_position_embeddings, + layer_norm_epsilon=layer_norm_epsilon, + apply_post_layer_norm=apply_post_layer_norm) for _ in range(num_layers) + ]) + if last: + norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) + head = head_cls(vocab_size=vocab_size, embed_dim=hidden_size, dtype=dtype) + super().__init__(embedding=embedding, blocks=blocks, norm=norm, head=head) + + +class FusedPipelineGPT1D(GenericPipelineGPT): + + def __init__(self, + num_layers: int = 12, + hidden_size: int = 768, + num_attention_heads: int = 12, + vocab_size: int = 50304, + embed_drop_rate: float = 0., + act_func: str = 'gelu', + mlp_ratio: int = 4.0, + attn_drop_rate: float = 0., + drop_rate: float = 0., + dtype: torch.dtype = torch.float, + checkpoint: bool = False, + max_position_embeddings: int = 1024, + layer_norm_epsilon: float = 1e-5, + apply_post_layer_norm: bool = False, + first: bool = False, + last: bool = False, + embed_split_hidden=False): + embedding = None + norm = None + head = None + embed_cls = VocabParallelEmbedding + head_cls = VocabParallelGPTLMHead1D + if embed_split_hidden: + embed_cls = HiddenParallelEmbedding + head_cls = HiddenParallelGPTLMHead1D + if first: + embedding = embed_cls(hidden_size, vocab_size, max_position_embeddings, embed_drop_rate, dtype=dtype) + blocks = nn.ModuleList([ + FusedGPTTransformerLayer1D(hidden_size, + num_attention_heads, + act_func=act_func, + mlp_ratio=mlp_ratio, + attention_dropout_prob=attn_drop_rate, + hidden_dropout_prob=drop_rate, + dtype=dtype, + checkpoint=checkpoint, + max_position_embeddings=max_position_embeddings, + layer_norm_epsilon=layer_norm_epsilon, + apply_post_layer_norm=apply_post_layer_norm) for _ in range(num_layers) + ]) + if last: + norm = kernel.LayerNorm(hidden_size, eps=layer_norm_epsilon) + head = head_cls(vocab_size=vocab_size, embed_dim=hidden_size, dtype=dtype) + super().__init__(embedding=embedding, blocks=blocks, norm=norm, head=head) + + def forward(self, hidden_states=None, input_ids=None, attention_mask=None): + if self.embedding is not None: + hidden_states = self.embedding(input_ids=input_ids) + attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility + for block in self.blocks: + hidden_states, attention_mask = block(hidden_states, attention_mask) + if self.norm is not None: + hidden_states = self.head(self.norm(hidden_states)) + return hidden_states + + +class PipelineGPTHybrid(GenericPipelineGPT): + + def __init__(self, + num_layers: int = 12, + hidden_size: int = 768, + num_attention_heads: int = 12, + vocab_size: int = 50304, + embed_drop_rate: float = 0., + act_func: str = 'gelu', + mlp_ratio: int = 4, + attn_drop_rate: float = 0., + drop_rate: float = 0., + dtype: torch.dtype = torch.float, + checkpoint: bool = False, + max_position_embeddings: int = 1024, + layer_norm_epsilon: float = 1e-5, + apply_post_layer_norm: bool = False, + first: bool = False, + last: bool = False, + embed_split_hidden=False): + embedding = None + norm = None + head = None + if first: + embedding = col_gpt.GPTEmbedding(hidden_size, + vocab_size, + max_position_embeddings, + dropout=embed_drop_rate, + dtype=dtype) + blocks = nn.ModuleList([ + col_gpt.GPTBlock(hidden_size, + num_attention_heads, + mlp_ratio=mlp_ratio, + attention_dropout=attn_drop_rate, + dropout=drop_rate, + dtype=dtype, + checkpoint=checkpoint, + activation=nn.functional.gelu) for _ in range(num_layers) + ]) + if last: + norm = col_nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) + # head = col_gpt.GPTLMHead(vocab_size=vocab_size, + # hidden_size=hidden_size, + # dtype=dtype, + # bias=False) + head = col_nn.Classifier(hidden_size, vocab_size, dtype=dtype, bias=False) + super().__init__(embedding=embedding, blocks=blocks, norm=norm, head=head) + + +def _filter_kwargs(func, kwargs): + sig = inspect.signature(func) + return {k: v for k, v in kwargs.items() if k in sig.parameters} + + +def _build_generic_gpt_pipeline_1d(module_cls, num_layers, num_chunks, device=torch.device('cuda'), **kwargs): + logger = get_dist_logger() + + if gpc.is_initialized(ParallelMode.PIPELINE): + pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + else: + pipeline_size = 1 + pipeline_rank = 0 + rank = gpc.get_global_rank() + + if pipeline_size > 1: + wrapper = PipelineSharedModuleWrapper([0, pipeline_size - 1]) + else: + wrapper = None + parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank] + models = [] + for start, end in parts: + kwargs['num_layers'] = end - start + kwargs['first'] = start == 0 + kwargs['last'] = end == num_layers + logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers') + chunk = module_cls(**_filter_kwargs(module_cls.__init__, kwargs)).to(device) + + if wrapper is not None: + if start == 0: + wrapper.register_module(chunk.embedding.word_embeddings) + elif end == num_layers: + wrapper.register_module(chunk.head) + models.append(chunk) + if len(models) == 1: + model = models[0] + else: + model = nn.ModuleList(models) + + numel = 0 + for _, param in model.named_parameters(recurse=True): + numel += param.numel() + logger.info(f'Rank{rank}/{pipeline_rank} model size = {numel * 2 / 1e9} GB') + return model + + +def _build_gpt_pipeline_1d(num_layers, num_chunks, device=torch.device('cuda'), fused=False, **kwargs): + model = FusedPipelineGPT1D if fused else PipelineGPT1D + return _build_generic_gpt_pipeline_1d(model, num_layers, num_chunks, device, **kwargs) + + +def _build_gpt_pipeline_hybrid(num_layers, num_chunks, device=torch.device('cuda'), **kwargs): + return _build_generic_gpt_pipeline_1d(PipelineGPTHybrid, num_layers, num_chunks, device, **kwargs) + + +def GPT2_small_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, fused=False): + cfg = dict(hidden_size=768, + num_attention_heads=12, + checkpoint=checkpoint, + dtype=dtype, + embed_split_hidden=embed_split_hidden) + return _build_gpt_pipeline_1d(12, num_chunks, fused=fused, **cfg) + + +def GPT2_exlarge_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, fused=False): + cfg = dict(hidden_size=1600, + num_attention_heads=32, + checkpoint=checkpoint, + dtype=dtype, + embed_split_hidden=embed_split_hidden) + return _build_gpt_pipeline_1d(48, num_chunks, fused=fused, **cfg) + + +def GPT3_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, fused=False): + cfg = dict(hidden_size=12288, + num_attention_heads=96, + checkpoint=checkpoint, + max_position_embeddings=2048, + dtype=dtype, + embed_split_hidden=embed_split_hidden) + return _build_gpt_pipeline_1d(96, num_chunks, fused=fused, **cfg) + + +def GPT2_exlarge_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False): + cfg = dict(hidden_size=1600, + num_attention_heads=32, + checkpoint=checkpoint, + dtype=dtype, + embed_split_hidden=embed_split_hidden) + return _build_gpt_pipeline_hybrid(48, num_chunks, **cfg) + + +def GPT2_small_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False): + cfg = dict(hidden_size=768, + num_attention_heads=12, + checkpoint=checkpoint, + dtype=dtype, + embed_split_hidden=embed_split_hidden) + return _build_gpt_pipeline_hybrid(12, num_chunks, **cfg) + + +def GPT3_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False): + cfg = dict(hidden_size=12288, + num_attention_heads=96, + checkpoint=checkpoint, + max_position_embeddings=2048, + dtype=dtype, + embed_split_hidden=embed_split_hidden) + return _build_gpt_pipeline_hybrid(96, num_chunks, **cfg) diff --git a/examples/language/gpt/titans/requirements.txt b/examples/language/gpt/titans/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..64ff7a4abcd84a5979c199591845597129535c4e --- /dev/null +++ b/examples/language/gpt/titans/requirements.txt @@ -0,0 +1,4 @@ +torch==1.12.1 +titans==0.0.7 +colossalai==0.2.0+torch1.12cu11.3 +-f https://release.colossalai.org diff --git a/examples/language/gpt/titans/run.sh b/examples/language/gpt/titans/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..a1a7fc737db0ac8820b17313f8d7e5b40123d3d8 --- /dev/null +++ b/examples/language/gpt/titans/run.sh @@ -0,0 +1,3 @@ +export DATA=/data/scratch/gpt_data/small-gpt-dataset.json +DUMMY_DATA=--use_dummy_dataset +colossalai run --nproc_per_node=2 train_gpt.py --config ./configs/gpt2_small_zero3_pp1d.py --from_torch $DUMMY_DATA diff --git a/examples/language/gpt/titans/test_ci.sh b/examples/language/gpt/titans/test_ci.sh new file mode 100644 index 0000000000000000000000000000000000000000..7cb24c1a4082c685a6d549f7a06574b1971aee39 --- /dev/null +++ b/examples/language/gpt/titans/test_ci.sh @@ -0,0 +1 @@ +colossalai run --nproc_per_node=4 train_gpt.py --config ./configs/gpt2_small_zero3_pp1d.py --from_torch --use_dummy_dataset diff --git a/examples/language/gpt/titans/train_gpt.py b/examples/language/gpt/titans/train_gpt.py new file mode 100644 index 0000000000000000000000000000000000000000..66225d6c80447a147c70f2937ca2d1ad9e9aa5c4 --- /dev/null +++ b/examples/language/gpt/titans/train_gpt.py @@ -0,0 +1,113 @@ +import contextlib +import os + +import torch +import torch.nn as nn +from dataset.webtext import WebtextDataset +from titans.model.gpt import GPTLMLoss + +import colossalai +import colossalai.utils as utils +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn import LinearWarmupLR +from colossalai.trainer import Trainer, hooks +from colossalai.utils import colo_set_process_memory_fraction, is_using_pp +from colossalai.utils.timer import MultiTimer +from colossalai.zero.init_ctx import ZeroInitContext + + +def calc_local_model_size(model: torch.nn.Module): + numel_per_device = 0 + for p in model.parameters(): + numel_per_device += p.numel() + return numel_per_device + + +VOCAB_SIZE = 50257 + + +def main(): + parser = colossalai.get_default_parser() + parser.add_argument('--from_torch', default=False, action='store_true') + parser.add_argument('--use_dummy_dataset', default=False, action='store_true') + args = parser.parse_args() + disable_existing_loggers() + if args.from_torch: + colossalai.launch_from_torch(config=args.config) + else: + colossalai.launch_from_slurm(config=args.config, host=args.host, port=29500, seed=42) + logger = get_dist_logger() + + data_path = None if args.use_dummy_dataset else os.environ['DATA'] + logger.info(f'Build data loader from path {data_path}', ranks=[0]) + + train_ds = WebtextDataset(path=data_path, seq_len=gpc.config.SEQ_LEN) + train_dataloader = utils.get_dataloader(train_ds, + seed=42, + batch_size=gpc.config.BATCH_SIZE, + pin_memory=True, + shuffle=True, + drop_last=True) + + logger.info('Build model', ranks=[0]) + use_pipeline = is_using_pp() + use_interleaved = hasattr(gpc.config.model, 'num_chunks') + use_zero3 = hasattr(gpc.config, 'zero') + ctx = contextlib.nullcontext() + if use_zero3: + ctx = ZeroInitContext(target_device=torch.cuda.current_device(), + shard_strategy=gpc.config.zero.model_config.shard_strategy, + shard_param=True) + with ctx: + model = gpc.config.model.pop('type')(**gpc.config.model) + if use_pipeline and use_interleaved and not isinstance(model, nn.ModuleList): + model = nn.ModuleList([model]) + + if use_zero3: + numel = ctx.model_numel_tensor.item() + else: + numel = calc_local_model_size(model) + + tflop = numel * gpc.config.BATCH_SIZE * gpc.config.SEQ_LEN \ + * gpc.get_world_size(ParallelMode.MODEL) * gpc.get_world_size(ParallelMode.DATA) * 8 / (1024 ** 4) + + criterion = getattr(gpc.config, 'loss_fn', None) + if criterion is not None: + criterion = criterion.type() + else: + criterion = GPTLMLoss() + logger.info('Build optimizer', ranks=[0]) + optimizer = gpc.config.optimizer.pop('type')(model.parameters(), **gpc.config.optimizer) + lr_scheduler = LinearWarmupLR(optimizer, total_steps=gpc.config.NUM_EPOCHS, warmup_steps=5) + engine, train_dataloader, _, lr_scheduler = colossalai.initialize(model, + optimizer, + criterion, + train_dataloader=train_dataloader, + lr_scheduler=lr_scheduler) + global_batch_size = gpc.config.BATCH_SIZE * \ + gpc.get_world_size(ParallelMode.DATA) * getattr(gpc.config, "gradient_accumulation", 1) + logger.info(f'Init done, global batch size = {global_batch_size}', ranks=[0]) + timier = MultiTimer() + trainer = Trainer(engine=engine, logger=logger, timer=timier) + hook_list = [ + hooks.LossHook(), + hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True), + hooks.LogMetricByEpochHook(logger), + hooks.ThroughputHook(ignored_steps=10, tflop_per_step=tflop), + hooks.LogMetricByStepHook(), + hooks.LogMemoryByEpochHook(logger), + # hooks.LogMemoryByEpochHook(logger), + # hooks.LogTimingByEpochHook(timer, logger), + ] + trainer.fit(train_dataloader=train_dataloader, + epochs=gpc.config.NUM_EPOCHS, + test_interval=1, + hooks=hook_list, + display_progress=True, + return_output_label=False) + + +if __name__ == '__main__': + main() diff --git a/examples/language/opt/README.md b/examples/language/opt/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c2fd254571c7fb7b9df3de8539fae32826a70e3f --- /dev/null +++ b/examples/language/opt/README.md @@ -0,0 +1,33 @@ + + +## OPT +Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model, which stimulates AI programmers to perform various downstream tasks and application deployments. + +The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Casual Language Modelling at low cost. + +We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before +the tokenization). This training script is adapted from the [HuggingFace Language Modelling examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling). + +## Our Modifications +We adapt the OPT training code to ColossalAI by leveraging Gemini and ZeRO DDP. + +## Quick Start +You can launch training by using the following bash script + +```bash +bash ./run_gemini.sh +``` diff --git a/examples/language/opt/benchmark.sh b/examples/language/opt/benchmark.sh new file mode 100644 index 0000000000000000000000000000000000000000..0d04b5e9b33cad5bef7037a4a9f700290c0c338a --- /dev/null +++ b/examples/language/opt/benchmark.sh @@ -0,0 +1,21 @@ +export BS=16 +export MEMCAP=0 +export MODEL="6.7b" +export GPUNUM=1 + +for MODEL in "6.7b" "13b" "1.3b" +do +for GPUNUM in 8 1 +do +for BS in 16 24 32 8 +do +for MEMCAP in 0 40 +do +pkill -9 torchrun +pkill -9 python + +env BS=$BS MEM_CAP=$MEMCAP MODEL=$MODEL GPUNUM=$GPUNUM bash ./run_gemini.sh +done +done +done +done diff --git a/examples/language/opt/requirements.txt b/examples/language/opt/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..137a69e80498223cd7581a62e2e27320b77682a0 --- /dev/null +++ b/examples/language/opt/requirements.txt @@ -0,0 +1,2 @@ +colossalai >= 0.1.12 +torch >= 1.8.1 diff --git a/examples/language/opt/run_gemini.sh b/examples/language/opt/run_gemini.sh new file mode 100644 index 0000000000000000000000000000000000000000..73f231292a132ec0a2837efe86dd6f5bc3eb81ba --- /dev/null +++ b/examples/language/opt/run_gemini.sh @@ -0,0 +1,28 @@ +set -x +export BS=${BS:-16} +export MEMCAP=${MEMCAP:-0} +# Acceptable values include `125m`, `350m`, `1.3b`, `2.7b`, `6.7b`, `13b`, `30b`, `66b`. For `175b` +export MODEL=${MODEL:-"125m"} +export GPUNUM=${GPUNUM:-1} +export USE_SHARD_INIT=${USE_SHARD_INIT:-"false"} + +# make directory for logs +mkdir -p ./logs + +if [ ${USE_SHARD_INIT} = "true" ]; then + USE_SHARD_INIT="--shardinit" +else + USE_SHARD_INIT="" +fi + +export MODLE_PATH="facebook/opt-${MODEL}" + +# HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 +torchrun \ + --nproc_per_node ${GPUNUM} \ + --master_port 19198 \ + train_gemini_opt.py \ + --mem_cap ${MEMCAP} \ + --model_name_or_path ${MODLE_PATH} \ + ${USE_SHARD_INIT} \ + --batch_size ${BS} 2>&1 | tee ./logs/colo_${MODEL}_bs_${BS}_cap_${MEMCAP}_gpu_${GPUNUM}.log diff --git a/examples/language/opt/test_ci.sh b/examples/language/opt/test_ci.sh new file mode 100644 index 0000000000000000000000000000000000000000..317f602cda3c5a63b2d6130fff8a363cc8f613e2 --- /dev/null +++ b/examples/language/opt/test_ci.sh @@ -0,0 +1,4 @@ +for GPUNUM in 2 1 +do +env BS=2 MODEL="125m" GPUNUM=$GPUNUM bash ./run_gemini.sh +done diff --git a/examples/language/opt/train_gemini_opt.py b/examples/language/opt/train_gemini_opt.py new file mode 100755 index 0000000000000000000000000000000000000000..3614b689de26fbf8fb89a7c554bd9de4802da582 --- /dev/null +++ b/examples/language/opt/train_gemini_opt.py @@ -0,0 +1,233 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) +on a text file or a dataset without using HuggingFace Trainer. + +Here is the full list of checkpoints on the hub that can be fine-tuned by this script: +https://huggingface.co/models?filter=text-generation +""" +# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. + +import time +from functools import partial + +import datasets +import torch +import torch.distributed as dist +import transformers +from transformers import CONFIG_MAPPING, MODEL_MAPPING, AutoConfig, OPTForCausalLM +from transformers.utils.versions import require_version + +import colossalai +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.tensor import ProcessGroup, ShardSpec +from colossalai.utils import get_current_device +from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, GeminiDDP + + +def get_data(batch_size, seq_len, vocab_size): + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device()) + attention_mask = torch.ones_like(input_ids) + return input_ids, attention_mask + + +require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") + +MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +def get_time_stamp(): + torch.cuda.synchronize() + return time.time() + + +def get_tflops(model_numel, batch_size, seq_len, step_time): + return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) + + +def parse_args(): + parser = colossalai.get_default_parser() + parser.add_argument( + "--model_name_or_path", + type=str, + help="Path to pretrained model or model identifier from huggingface.co/models.", + required=True, + ) + parser.add_argument( + "--config_name", + type=str, + default=None, + help="Pretrained config name or path if not the same as model_name", + ) + parser.add_argument( + "--batch_size", + type=int, + default=8, + help="Batch size (per dp group) for the training dataloader.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") + parser.add_argument( + "--max_train_steps", + type=int, + default=20, + help="Total number of training steps to perform.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--model_type", + type=str, + default=None, + help="Model type to use if training from scratch.", + choices=MODEL_TYPES, + ) + parser.add_argument( + "--shardinit", + action="store_true", + help="Initialize the model with tensor parallel", + ) + parser.add_argument("--mem_cap", type=int, default=0, help="use mem cap") + parser.add_argument("--init_in_cpu", action='store_true', default=False, help="init training model in cpu") + args = parser.parse_args() + + return args + + +def colo_memory_cap(size_in_GB): + from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device + cuda_capacity = colo_device_memory_capacity(get_current_device()) + if size_in_GB * (1024**3) < cuda_capacity: + colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity) + print("Using {} GB of GPU memory".format(size_in_GB)) + + +def main(): + args = parse_args() + disable_existing_loggers() + colossalai.launch_from_torch({}) + logger = get_dist_logger() + is_main_process = dist.get_rank() == 0 + + if is_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + if args.mem_cap > 0: + colo_memory_cap(args.mem_cap) + + # If passed along, set the training seed now. + if args.seed is not None: + torch.mannul_seed(args.seed) + logger.info(f"Rank {dist.get_rank()}: random seed is set to {args.seed}") + + # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Load pretrained model + # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + if args.config_name: + config = AutoConfig.from_pretrained(args.config_name) + elif args.model_name_or_path: + config = AutoConfig.from_pretrained(args.model_name_or_path) + else: + config = CONFIG_MAPPING[args.model_type]() + logger.warning("You are instantiating a new config instance from scratch.") + logger.info("Model config has been created", ranks=[0]) + + if args.init_in_cpu: + init_dev = torch.device('cpu') + else: + init_dev = get_current_device() + + # shard init parameters + if args.shardinit: + logger.info("Sharding initialization !", ranks=[0]) + else: + logger.info("Skipping sharding initialization", ranks=[0]) + + world_size = torch.distributed.get_world_size() + shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None + default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None + + # build model + if args.model_name_or_path is None: + logger.info("Train a new model from scratch", ranks=[0]) + with ColoInitContext(device=init_dev, + dtype=torch.half, + default_dist_spec=default_dist_spec, + default_pg=shard_pg): + model = OPTForCausalLM(config) + else: + logger.info("Finetune a pre-trained model", ranks=[0]) + with ColoInitContext(device=init_dev, + dtype=torch.half, + default_dist_spec=default_dist_spec, + default_pg=shard_pg): + model = OPTForCausalLM.from_pretrained(args.model_name_or_path, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + local_files_only=False) + + # enable gradient checkpointing + model.gradient_checkpointing_enable() + + numel = sum([p.numel() for p in model.parameters()]) + PLACEMENT_POLICY = 'cpu' + model = GeminiDDP(model, + device=get_current_device(), + placement_policy=PLACEMENT_POLICY, + pin_memory=True, + strict_ddp_mode=args.shardinit) + optimizer = GeminiAdamOptimizer(model, lr=args.learning_rate, initial_scale=2**14, gpu_margin_mem_ratio=0.0) + + SEQ_LEN = 1024 + VOCAB_SIZE = 50257 + + get_tflops_func = partial(get_tflops, numel, args.batch_size, SEQ_LEN) + + model.train() + for step in range(args.max_train_steps): + st_time = time.time() + input_ids, attn_mask = get_data(args.batch_size, SEQ_LEN, VOCAB_SIZE) + + outputs = model(input_ids=input_ids, attention_mask=attn_mask, labels=input_ids, use_cache=False) + loss = outputs['loss'] + optimizer.backward(loss) + + optimizer.step() + optimizer.zero_grad() + torch.cuda.synchronize() + step_time = time.time() - st_time + step_tflops = get_tflops_func(step_time) + + logger.info("step {} finished, Tflops {}".format(step, step_tflops), ranks=[0]) + + logger.info("Training finished", ranks=[0]) + + +if __name__ == "__main__": + main() diff --git a/examples/language/palm/README.md b/examples/language/palm/README.md new file mode 100644 index 0000000000000000000000000000000000000000..486bf240f89c1abbe442512660ee9f1aec6c1b9d --- /dev/null +++ b/examples/language/palm/README.md @@ -0,0 +1,64 @@ + + +## PaLM - Pytorch + +Implementation of the specific Transformer architecture from PaLM - Scaling Language Modeling with Pathways, in less than 200 lines of code. + +This model is pretty much SOTA on everything language. + +It obviously will not scale, but it is just for educational purposes. To elucidate the public how simple it all really is. + +## Install +```bash +$ pip install PaLM-pytorch +``` + +## Usage + +```python +import torch +from palm_pytorch import PaLM + +palm = PaLM( + num_tokens = 20000, + dim = 512, + depth = 12, + heads = 8, + dim_head = 64, +) + +tokens = torch.randint(0, 20000, (1, 2048)) +logits = palm(tokens) # (1, 2048, 20000) +``` + +The PaLM 540B in the paper would be + +```python +palm = PaLM( + num_tokens = 256000, + dim = 18432, + depth = 118, + heads = 48, + dim_head = 256 +) +``` + +## Test on Enwik8 + +```bash +$ python train.py +``` + +## Todo + +- [ ] offer a Triton optimized version of PaLM, bringing in https://github.com/lucidrains/triton-transformer + +## Citations + +```bibtex +@article{chowdhery2022PaLM, + title = {PaLM: Scaling Language Modeling with Pathways}, + author = {Chowdhery, Aakanksha et al}, + year = {2022} +} +``` diff --git a/examples/language/palm/data/README.md b/examples/language/palm/data/README.md new file mode 100644 index 0000000000000000000000000000000000000000..56433b4dc6cb24ff245282e6e75985ad720eae6c --- /dev/null +++ b/examples/language/palm/data/README.md @@ -0,0 +1,3 @@ +# Data source + +The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/ diff --git a/examples/language/palm/palm_pytorch/__init__.py b/examples/language/palm/palm_pytorch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dab49645a643a5a0680d9c8e6483458ce4b387a8 --- /dev/null +++ b/examples/language/palm/palm_pytorch/__init__.py @@ -0,0 +1 @@ +from palm_pytorch.palm_pytorch import PaLM diff --git a/examples/language/palm/palm_pytorch/autoregressive_wrapper.py b/examples/language/palm/palm_pytorch/autoregressive_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..dc4f3d856fecaa09581cd6771bc5c9460c55865b --- /dev/null +++ b/examples/language/palm/palm_pytorch/autoregressive_wrapper.py @@ -0,0 +1,77 @@ +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import nn + +# helper function + + +def exists(val): + return val is not None + + +def eval_decorator(fn): + + def inner(model, *args, **kwargs): + was_training = model.training + model.eval() + out = fn(model, *args, **kwargs) + model.train(was_training) + return out + + return inner + + +# top k filtering + + +def top_k(logits, thres=0.9): + k = int((1 - thres) * logits.shape[-1]) + val, ind = torch.topk(logits, k) + probs = torch.full_like(logits, float("-inf")) + probs.scatter_(1, ind, val) + return probs + + +class AutoregressiveWrapper(nn.Module): + + def __init__(self, net, max_seq_len=2048, pad_value=0): + super().__init__() + self.max_seq_len = max_seq_len + self.pad_value = pad_value + self.net = net + + @torch.no_grad() + @eval_decorator + def generate(self, start_tokens, seq_len, eos_token=None, temperature=1.0, filter_thres=0.9, **kwargs): + b, t, device = *start_tokens.shape, start_tokens.device + + out = start_tokens + + for _ in range(seq_len): + logits = self.net(out, **kwargs)[:, -1, :] + + filtered_logits = top_k(logits, thres=filter_thres) + probs = F.softmax(filtered_logits / temperature, dim=-1) + + sample = torch.multinomial(probs, 1) + + out = torch.cat((out, sample), dim=-1) + + if exists(eos_token): + is_eos_token = out == eos_token + + if is_eos_token.any(dim=-1).all(): + # mask out everything after the eos tokens + shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1)) + mask = shifted_is_eos_tokens.float().cumsum(dim=-1) >= 1 + out = out.masked_fill(mask, self.pad_value) + break + + out = out[:, t:] + return out + + def forward(self, x, **kwargs): + x_inp, x_labels = x[:, :-1], x[:, 1:] + logits = self.net(x_inp, **kwargs) + return F.cross_entropy(rearrange(logits, "b c n -> b n c"), x_labels) diff --git a/examples/language/palm/palm_pytorch/palm_pytorch.py b/examples/language/palm/palm_pytorch/palm_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..c37974711e11b7046e474f27fbc138b246e63d9f --- /dev/null +++ b/examples/language/palm/palm_pytorch/palm_pytorch.py @@ -0,0 +1,207 @@ +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import einsum, matmul, nn + +# normalization +# they use layernorm without bias, something that pytorch does not offer + + +class LayerNorm(nn.Module): + + def __init__(self, dim, eps=1e-5): + super().__init__() + self.eps = eps + self.gamma = nn.Parameter(torch.ones(dim)) + self.register_buffer("beta", torch.zeros(dim)) + + def forward(self, x): + return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) + + +# parallel with residual +# discovered by Wang et al + EleutherAI from GPT-J fame + + +class ParallelResidual(nn.Module): + + def __init__(self, *fns): + super().__init__() + self.fns = nn.ModuleList(fns) + + def forward(self, x): + return x + sum([fn(x) for fn in self.fns]) + + +# rotary positional embedding +# https://arxiv.org/abs/2104.09864 + + +class RotaryEmbedding(nn.Module): + + def __init__(self, dim): + super().__init__() + inv_freq = 1.0 / (10000**(torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + def forward(self, max_seq_len, *, device): + seq = torch.arange(max_seq_len, device=device) + #freqs = einsum("i , j -> i j", seq.type_as(self.inv_freq), self.inv_freq) + #freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq) + i, j = len(seq.type_as(self.inv_freq)), len(self.inv_freq) + freqs = matmul(seq.type_as(self.inv_freq).reshape(i, 1), self.inv_freq.reshape(1, j)) + return torch.cat((freqs, freqs), dim=-1) + + +def rotate_half(x): + x = rearrange(x, "... (j d) -> ... j d", j=2) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(pos, t): + return (t * pos.cos()) + (rotate_half(t) * pos.sin()) + + +# feedforward +# classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU +# https://arxiv.org/abs/2002.05202 + + +class SwiGLU(nn.Module): + + def forward(self, x): + x, gate = x.chunk(2, dim=-1) + return F.silu(gate) * x + + +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + LayerNorm(dim), + nn.Linear(dim, inner_dim * 2, bias=False), + SwiGLU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +# attention +class Attention(nn.Module): + + def __init__(self, dim, dim_head=64, heads=8): + super().__init__() + inner_dim = dim_head * heads + self.norm = LayerNorm(dim) + self.heads = heads + self.scale = dim_head**-0.5 + self.rotary_emb = RotaryEmbedding(dim_head) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, dim_head * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + # for caching causal mask and rotary embeddings + + self.register_buffer("mask", None, persistent=False) + self.register_buffer("pos_emb", None, persistent=False) + + def get_mask(self, n, device): + if self.mask is not None and self.mask.shape[-1] >= n: + return self.mask[:n, :n] + + mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1) + self.register_buffer("mask", mask, persistent=False) + return mask + + def get_rotary_embedding(self, n, device): + if self.pos_emb is not None and self.pos_emb.shape[-2] >= n: + return self.pos_emb[:n] + + pos_emb = self.rotary_emb(n, device=device) + self.register_buffer("position", pos_emb, persistent=False) + return pos_emb + + def forward(self, x): + """ + einstein notation + b - batch + h - heads + n, i, j - sequence length (base sequence length, source, target) + d - feature dimension + """ + + n, device, h = x.shape[1], x.device, self.heads + + # pre layernorm + + x = self.norm(x) + + # queries, keys, values + + q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1)) + + # split heads + # they use multi-query single-key-value attention, yet another Noam Shazeer paper + # they found no performance loss past a certain scale, and more efficient decoding obviously + # https://arxiv.org/abs/1911.02150 + + q = rearrange(q, "b n (h d) -> b h n d", h=h) + + # rotary embeddings + + positions = self.get_rotary_embedding(n, device) + q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k)) + + # scale + + q = q * self.scale + + b, h, i, d, j = q.size(0), q.size(1), q.size(2), q.size(3), k.size(1) + + # similarity + + #sim = einsum("b h i d, b j d -> b h i j", q, k) + sim = matmul(q.reshape(b, h * i, d), k.transpose(1, 2)) + sim = sim.reshape(b, h, i, j) + + # causal mask + + causal_mask = self.get_mask(n, device) + sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) + + # attention + + sim = sim - sim.amax(dim=-1, keepdim=True).detach() + attn = sim.softmax(dim=-1) + + b_, h_, i_, j_, d_ = attn.size(0), attn.size(1), attn.size(2), attn.size(3), v.size(2) + + # aggregate values + + #out = einsum("b h i j, b j d -> b h i d", attn, v) + out = matmul(attn.reshape(b_, h_ * i_, j_), v) + out = out.reshape(b_, h_, i_, d_) + + # merge heads + + out = rearrange(out, "b h n d -> b n (h d)") + return self.to_out(out) + + +# transformer + + +def PaLM(*, dim, num_tokens, depth, dim_head=64, heads=8, ff_mult=4): + net = nn.Sequential( + nn.Embedding(num_tokens, dim), *[ + ParallelResidual( + Attention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ) for _ in range(depth) + ], LayerNorm(dim), nn.Linear(dim, num_tokens, bias=False)) + + # they used embedding weight tied projection out to logits, not common, but works + net[-1].weight = net[0].weight + + nn.init.normal_(net[0].weight, std=0.02) + return net diff --git a/examples/language/palm/requirements.txt b/examples/language/palm/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..137a69e80498223cd7581a62e2e27320b77682a0 --- /dev/null +++ b/examples/language/palm/requirements.txt @@ -0,0 +1,2 @@ +colossalai >= 0.1.12 +torch >= 1.8.1 diff --git a/examples/language/palm/run.sh b/examples/language/palm/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..7a533509e009cfb7fe8d1c88d63b0457a2c21a65 --- /dev/null +++ b/examples/language/palm/run.sh @@ -0,0 +1,11 @@ +# distplan in ["colossalai", "pytorch"] +export DISTPAN="colossalai" + +# The following options only valid when DISTPAN="colossalai" +export TPDEGREE=1 +export GPUNUM=1 +export PLACEMENT='cpu' +export USE_SHARD_INIT=False +export BATCH_SIZE=4 + +env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --master_port 29501 train.py --tp_degree=${TPDEGREE} --batch_size=${BATCH_SIZE} --placement ${PLACEMENT} --shardinit ${USE_SHARD_INIT} --distplan ${DISTPAN} 2>&1 | tee run.log diff --git a/examples/language/palm/test_ci.sh b/examples/language/palm/test_ci.sh new file mode 100644 index 0000000000000000000000000000000000000000..f21095578077eda89c08eeeebba0013327e26ea9 --- /dev/null +++ b/examples/language/palm/test_ci.sh @@ -0,0 +1,9 @@ +$(cd `dirname $0`;pwd) + +for BATCH_SIZE in 2 +do +for GPUNUM in 1 4 +do +env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --master_port 29501 train.py --dummy_data=True --batch_size=${BATCH_SIZE} 2>&1 | tee run.log +done +done diff --git a/examples/language/palm/train.py b/examples/language/palm/train.py new file mode 100644 index 0000000000000000000000000000000000000000..7923e4fc855d17ef8fbd7411c0d3c3c3bde51d67 --- /dev/null +++ b/examples/language/palm/train.py @@ -0,0 +1,308 @@ +import gzip +import random +from functools import partial +from time import time + +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +import tqdm +from packaging import version +from palm_pytorch import PaLM +from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper +from torch.utils.data import DataLoader, Dataset + +import colossalai +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec +from colossalai.utils import MultiTimer, get_current_device +from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, ZeroDDP + +# constants + +NUM_BATCHES = int(10) +WARMUP_BATCHES = 1 +GRADIENT_ACCUMULATE_EVERY = 1 +LEARNING_RATE = 2e-4 +VALIDATE_EVERY = 100 +GENERATE_EVERY = 500 +GENERATE_LENGTH = 512 +SEQ_LEN = 1024 + + +def parse_args(): + parser = colossalai.get_default_parser() + parser.add_argument( + "--distplan", + type=str, + default='colossalai', + help="The distributed plan [colossalai, pytorch].", + ) + parser.add_argument( + "--tp_degree", + type=int, + default=1, + help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.", + ) + parser.add_argument( + "--placement", + type=str, + default='cpu', + help="Placement Policy for Gemini. Valid when using colossalai as dist plan.", + ) + parser.add_argument( + "--shardinit", + type=bool, + default=False, + help= + "Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.", + ) + parser.add_argument( + "--batch_size", + type=int, + default=8, + help="batch size per DP group of training.", + ) + parser.add_argument( + "--dummy_data", + type=bool, + default=False, + help="use dummy dataset.", + ) + args = parser.parse_args() + return args + + +# helpers +def cycle(loader): + while True: + for data in loader: + yield data + + +def decode_token(token): + return str(chr(max(32, token))) + + +def get_tflops(model_numel, batch_size, seq_len, step_time): + return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) + + +def decode_tokens(tokens): + return "".join(list(map(decode_token, tokens))) + + +def get_model_size(model: nn.Module): + total_numel = 0 + for module in model.modules(): + for p in module.parameters(recurse=False): + total_numel += p.numel() + return total_numel + + +# Gemini + ZeRO DDP +def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"): + cai_version = colossalai.__version__ + if version.parse(cai_version) > version.parse("0.1.10"): + from colossalai.nn.parallel import GeminiDDP + model = GeminiDDP(model, + device=get_current_device(), + placement_policy=placememt_policy, + pin_memory=True, + search_range_mb=32) + elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"): + from colossalai.gemini import ChunkManager, GeminiManager + chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32) + gemini_manager = GeminiManager(placememt_policy, chunk_manager) + chunk_manager = ChunkManager(chunk_size, + pg, + enable_distributed_storage=True, + init_device=GeminiManager.get_default_device(placememt_policy)) + model = ZeroDDP(model, gemini_manager) + else: + raise NotImplemented(f"CAI version {cai_version} is not supported") + return model + + +# Parameter Sharding Strategies for Tensor Parallelism +def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup): + spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + param.set_tensor_spec(*spec) + + +def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup): + split_param_single_dim_tp1d(0, param, pg) + + +def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup): + split_param_single_dim_tp1d(-1, param, pg) + + +# Tensor Parallel +def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup): + """tensor_parallelize + Sharding the Model Parameters. + Args: + model (torch.nn.Module): a torch module to be sharded + """ + for mn, module in model.named_modules(): + for pn, param in module.named_parameters(recurse=False): + if hasattr(param, 'visited'): + continue + param.set_dist_spec(ReplicaSpec()) + if 'net.0' in mn: + split_param_col_tp1d(param, pg) # colmn slice + elif 'to_q' in mn: + split_param_col_tp1d(param, pg) # colmn slice + elif 'to_kv' in mn: + split_param_row_tp1d(param, pg) # row slice + elif 'to_out' in mn: + split_param_row_tp1d(param, pg) # row slice + elif '1.1' in mn: + split_param_col_tp1d(param, pg) # colmn slice + elif '1.2' in mn: + split_param_row_tp1d(param, pg) # row slice + else: + param.set_dist_spec(ReplicaSpec()) + param.visited = True + + +args = parse_args() +if args.distplan not in ["colossalai", "pytorch"]: + raise TypeError(f"{args.distplan} is error") +disable_existing_loggers() +colossalai.launch_from_torch(config={}) +logger = get_dist_logger() + + +def generate_dataset(dummy_data: bool = False): + if not dummy_data: + with gzip.open("./data/enwik8.gz") as file: + X = np.fromstring(file.read(int(95e6)), dtype=np.uint8) + trX, vaX = np.split(X, [int(90e6)]) + data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX) + # print(f"data_train {data_train.shape} {data_train.dtype} {max(data_train)} {min(data_train)}") + # print(f"data_val {data_val.shape} {data_val.dtype} {max(data_val)} {min(data_val)}") + return data_train, data_val + else: + return torch.randint(0, 100, (90000000,)), torch.randint(0, 100, (5000000,)) + + +data_train, data_val = generate_dataset(args.dummy_data) + +print("generate dataset ready!") + + +class TextSamplerDataset(Dataset): + + def __init__(self, data, seq_len): + super().__init__() + self.data = data + self.seq_len = seq_len + + def __getitem__(self, index): + rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,)) + full_seq = self.data[rand_start:rand_start + self.seq_len + 1].long() + return full_seq.cuda() + + def __len__(self): + return self.data.size(0) // self.seq_len + + +train_dataset = TextSamplerDataset(data_train, SEQ_LEN) +val_dataset = TextSamplerDataset(data_val, SEQ_LEN) +train_loader = cycle(DataLoader(train_dataset, batch_size=args.batch_size)) +val_loader = cycle(DataLoader(val_dataset, batch_size=args.batch_size)) + +if args.distplan == "colossalai": + # instantiate GPT-like decoder model + + default_pg = ProcessGroup(tp_degree=args.tp_degree) + default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None + ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg) + + with ctx: + model = PaLM(num_tokens=50304, dim=4096, depth=64) + model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN) + + pg = default_pg + tensor_parallelize(model, pg) + model = gemini_zero_dpp(model, pg, args.placement) + + # optimizer + + #optimizer = GeminiAdamOptimizer(model, lr=1e-7, initial_scale=2**5) + optimizer = GeminiAdamOptimizer(model, lr=LEARNING_RATE, initial_scale=2**5) +else: + model = PaLM(num_tokens=256, dim=512, depth=8) + model = AutoregressiveWrapper(model, max_seq_len=2048) + model.cuda() + optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) + +# model is shared after TP +numel = get_model_size(model) +get_tflops_func = partial(get_tflops, numel, args.batch_size, SEQ_LEN) + +# training +model.train() +tflops_list = [] +for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"): + + if args.distplan == "colossalai": + optimizer.zero_grad() + start = time() + loss = model(next(train_loader)) + fwd_end = time() + fwd_time = fwd_end - start + # loss.backward() + optimizer.backward(loss) + bwd_end = time() + bwd_time = bwd_end - fwd_end + + # print(f"training loss: {loss.item()}") + torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) + # optim.step() + # optim.zero_grad() + optimizer.step() + optim_time = time() - bwd_end + step_time = time() - start + + step_tflops = get_tflops_func(step_time) + logger.info( + f"[{i + 1}/{NUM_BATCHES}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}, FWD time: {fwd_time:.3f}s, BWD time: {bwd_time:.3f}s, OPTIM time: {optim_time:.3f}s", + ranks=[0], + ) + if i >= WARMUP_BATCHES: + tflops_list.append(step_tflops) + + else: + for __ in range(GRADIENT_ACCUMULATE_EVERY): + loss = model(next(train_loader)) + loss.backward() + + print(f"training loss: {loss.item()}") + torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) + optim.step() + optim.zero_grad() + +tflops_list.sort() +median_index = ((NUM_BATCHES - WARMUP_BATCHES) >> 1) + WARMUP_BATCHES +logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}") + +# TODO +# if i % VALIDATE_EVERY == 0: +# model.eval() +# with torch.no_grad(): +# loss = model(next(val_loader)) +# print(f"validation loss: {loss.item()}") + + # if i % GENERATE_EVERY == 0: + # model.eval() + # inp = random.choice(val_dataset)[:-1] + # prime = decode_tokens(inp) + # print(f"%s \n\n %s", (prime, "*" * 100)) + + # sample = model.generate(inp[None, ...], GENERATE_LENGTH) + # output_str = decode_tokens(sample[0]) + # print(output_str) diff --git a/examples/tutorial/.gitignore b/examples/tutorial/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..f873b6a4abbdae50cab83899b5b4a7554a5266e3 --- /dev/null +++ b/examples/tutorial/.gitignore @@ -0,0 +1 @@ +./data/ diff --git a/examples/tutorial/README.md b/examples/tutorial/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f4843331fd54f795938c07677c2be60e55497950 --- /dev/null +++ b/examples/tutorial/README.md @@ -0,0 +1,54 @@ +# Colossal-AI Tutorial Hands-on + +> This path is an abbreviated tutorial prepared for specific activities and may not be maintained in real time. For use of Colossal-AI, please refer to other [examples](https://github.com/hpcaitech/ColossalAI/tree/main/examples) and [documents](https://www.colossalai.org/). + +## Introduction + +Welcome to the [Colossal-AI](https://github.com/hpcaitech/ColossalAI) tutorial, which has been accepted as official tutorials by top conference [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), [PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/), etc. + + +[Colossal-AI](https://github.com/hpcaitech/ColossalAI), a unified deep learning system for the big model era, integrates +many advanced technologies such as multi-dimensional tensor parallelism, sequence parallelism, heterogeneous memory management, +large-scale optimization, adaptive task scheduling, etc. By using Colossal-AI, we could help users to efficiently and +quickly deploy large AI model training and inference, reducing large AI model training budgets and scaling down the labor cost of learning and deployment. + +### 🚀 Quick Links + +[**Colossal-AI**](https://github.com/hpcaitech/ColossalAI) | +[**Paper**](https://arxiv.org/abs/2110.14883) | +[**Documentation**](https://www.colossalai.org/) | +[**Issue**](https://github.com/hpcaitech/ColossalAI/issues/new/choose) | +[**Slack**](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w) + +## Table of Content + + - Multi-dimensional Parallelism [[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/tutorial/hybrid_parallel) [[video]](https://www.youtube.com/watch?v=OwUQKdA2Icc) + - Sequence Parallelism [[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/tutorial/sequence_parallel) [[video]](https://www.youtube.com/watch?v=HLLVKb7Cszs) + - Large Batch Training Optimization [[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/tutorial/large_batch_optimizer) [[video]](https://www.youtube.com/watch?v=9Un0ktxJZbI) + - Automatic Parallelism [[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/tutorial/auto_parallel) [[video]](https://www.youtube.com/watch?v=_-2jlyidxqE) + - Fine-tuning and Inference for OPT [[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/tutorial/opt) [[video]](https://www.youtube.com/watch?v=jbEFNVzl67Y) + - Optimized AlphaFold [[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/tutorial/fastfold) [[video]](https://www.youtube.com/watch?v=-zP13LfJP7w) + - Optimized Stable Diffusion [[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion) [[video]](https://www.youtube.com/watch?v=8KHeUjjc-XQ) + + +## Discussion + +Discussion about the [Colossal-AI](https://github.com/hpcaitech/ColossalAI) project is always welcomed! We would love to exchange ideas with the community to better help this project grow. +If you think there is a need to discuss anything, you may jump to our [Slack](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w). + +If you encounter any problem while running these tutorials, you may want to raise an [issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose) in this repository. + +## 🛠️ Setup environment +[[video]](https://www.youtube.com/watch?v=dpMYj974ZIc) You should use `conda` to create a virtual environment, we recommend **python 3.8**, e.g. `conda create -n colossal python=3.8`. This installation commands are for CUDA 11.3, if you have a different version of CUDA, please download PyTorch and Colossal-AI accordingly. +You can refer to the [Installation](https://github.com/hpcaitech/ColossalAI#installation) to set up your environment. + +You can run `colossalai check -i` to verify if you have correctly set up your environment 🕹️. +![](https://raw.githubusercontent.com/hpcaitech/public_assets/main/examples/tutorial/colossalai%20check%20-i.png) + +If you encounter messages like `please install with cuda_ext`, do let me know as it could be a problem of the distribution wheel. 😥 + +Then clone the Colossal-AI repository from GitHub. +```bash +git clone https://github.com/hpcaitech/ColossalAI.git +cd ColossalAI/examples/tutorial +``` diff --git a/examples/tutorial/auto_parallel/README.md b/examples/tutorial/auto_parallel/README.md new file mode 100644 index 0000000000000000000000000000000000000000..6a12e0dd5a4839d617ad2e1aca8468ee3393436e --- /dev/null +++ b/examples/tutorial/auto_parallel/README.md @@ -0,0 +1,68 @@ +# Auto-Parallelism + +## Table of contents + +- [Auto-Parallelism](#auto-parallelism) + - [Table of contents](#table-of-contents) + - [📚 Overview](#-overview) + - [🚀 Quick Start](#-quick-start) + - [Setup](#setup) + - [Auto-Parallel Tutorial](#auto-parallel-tutorial) + - [Auto-Checkpoint Tutorial](#auto-checkpoint-tutorial) + + +## 📚 Overview + +This tutorial folder contains a simple demo to run auto-parallelism with ResNet. Meanwhile, this diretory also contains demo scripts to run automatic activation checkpointing, but both features are still experimental for now and no guarantee that they will work for your version of Colossal-AI. + +## 🚀 Quick Start + +### Setup + +1. Create a conda environment + +```bash +conda create -n auto python=3.8 +conda activate auto +``` + +2. Install `requirements` and `coin-or-cbc` for the solver. + +```bash +pip install -r requirements.txt +conda install -c conda-forge coin-or-cbc +``` + + +### Auto-Parallel Tutorial + +Run the auto parallel resnet example with 4 GPUs with synthetic dataset. + +```bash +colossalai run --nproc_per_node 4 auto_parallel_with_resnet.py +``` + +You should expect to the log like this. This log shows the edge cost on the computation graph as well as the sharding strategy for an operation. For example, `layer1_0_conv1 S01R = S01R X RR` means that the first dimension (batch) of the input and output is sharded while the weight is not sharded (S means sharded, R means replicated), simply equivalent to data parallel training. +![](https://raw.githubusercontent.com/hpcaitech/public_assets/main/examples/tutorial/auto-parallel%20demo.png) + +**Note: This experimental feature has been tested on torch 1.12.1 and transformer 4.22.2. If you are using other versions, you may need to modify the code to make it work.** + +### Auto-Checkpoint Tutorial + +We prepare two bechmarks for you to test the performance of auto checkpoint + +The first test `auto_ckpt_solver_test.py` will show you the ability of solver to search checkpoint strategy that could fit in the given budget (test on GPT2 Medium and ResNet 50). It will output the benchmark summary and data visualization of peak memory vs. budget memory and relative step time vs. peak memory. + +The second test `auto_ckpt_batchsize_test.py` will show you the advantage of fitting larger batchsize training into limited GPU memory with the help of our activation checkpoint solver (test on ResNet152). It will output the benchmark summary. + +The usage of the above two test +```bash +# run auto_ckpt_solver_test.py on gpt2 medium +python auto_ckpt_solver_test.py --model gpt2 + +# run auto_ckpt_solver_test.py on resnet50 +python auto_ckpt_solver_test.py --model resnet50 + +# tun auto_ckpt_batchsize_test.py +python auto_ckpt_batchsize_test.py +``` diff --git a/examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py b/examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5a68aae18041946150150e9b63b04f0b9e387e40 --- /dev/null +++ b/examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py @@ -0,0 +1,52 @@ +from copy import deepcopy +from functools import partial + +import torch +import torchvision.models as tm +from bench_utils import bench, data_gen_resnet + +import colossalai +from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor +from colossalai.fx import metainfo_trace, symbolic_trace +from colossalai.testing import spawn + + +def _benchmark(rank, world_size, port): + """Auto activation checkpoint batchsize benchmark + This benchmark test the through put of Resnet152 with our activation solver given the memory budget of 95% of + maximum GPU memory, and with the batch size of [512, 1024, 2048], you could see that using auto activation + checkpoint with optimality guarantee, we might be able to find better batch size for the model, as larger batch + size means that we are able to use larger portion of GPU FLOPS, while recomputation scheduling with our solver + only result in minor performance drop. So at last we might be able to find better training batch size for our + model (combine with large batch training optimizer such as LAMB). + """ + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = tm.resnet152() + gm = symbolic_trace(model) + raw_graph = deepcopy(gm.graph) + peak_mems, through_puts, batch_sizes = [], [], [512, 1024, 2048] + for batch_size in batch_sizes: + batch_size = int(batch_size) + gm = metainfo_trace(gm, torch.empty(batch_size, 3, 224, 224, device='meta')) + solver = CheckpointSolverRotor(gm.graph, free_memory=torch.cuda.mem_get_info()[0] * 0.95) + gm.graph = solver.solve() + peak_mem, step_time = bench(gm, + torch.nn.CrossEntropyLoss(), + partial(data_gen_resnet, batch_size=batch_size, shape=(3, 224, 224)), + num_steps=5) + peak_mems.append(peak_mem) + through_puts.append(batch_size / step_time * 1.0e3) + gm.graph = deepcopy(raw_graph) + + # print results + print("===============benchmark summary================") + for batch_size, peak_mem, through_put in zip(batch_sizes, peak_mems, through_puts): + print(f'batch_size: {int(batch_size)}, peak memory: {peak_mem:.3f} MB, through put: {through_put:.3f} images/s') + + +def auto_activation_checkpoint_batchsize_benchmark(): + spawn(_benchmark, 1) + + +if __name__ == "__main__": + auto_activation_checkpoint_batchsize_benchmark() diff --git a/examples/tutorial/auto_parallel/auto_ckpt_solver_test.py b/examples/tutorial/auto_parallel/auto_ckpt_solver_test.py new file mode 100644 index 0000000000000000000000000000000000000000..aa5c47294a8279a34e912e1cb1fb0aec4f7dcdff --- /dev/null +++ b/examples/tutorial/auto_parallel/auto_ckpt_solver_test.py @@ -0,0 +1,87 @@ +import time +from argparse import ArgumentParser +from functools import partial + +import matplotlib.pyplot as plt +import torch +import torchvision.models as tm +from bench_utils import GPTLMLoss, bench_rotor, data_gen_gpt2, data_gen_resnet, gpt2_medium + +import colossalai +from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor +from colossalai.fx import metainfo_trace, symbolic_trace +from colossalai.testing import spawn + + +def _benchmark(rank, world_size, port, args): + """ + Auto activation checkpoint solver benchmark, we provide benchmark on two models: gpt2_medium and resnet50. + The benchmark will sample in a range of memory budget for each model and output the benchmark summary and + data visualization of peak memory vs. budget memory and relative step time vs. peak memory. + """ + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + if args.model == 'resnet50': + model = tm.resnet50() + data_gen = partial(data_gen_resnet, batch_size=128, shape=(3, 224, 224)) + gm = symbolic_trace(model) + gm = metainfo_trace(gm, torch.empty(128, 3, 224, 224, device='meta')) + loss = torch.nn.CrossEntropyLoss() + else: + model = gpt2_medium() + data_gen = partial(data_gen_gpt2, batch_size=8, seq_len=1024, vocab_size=50257) + data, mask = data_gen(device='meta')[0] + gm = symbolic_trace(model, meta_args={'input_ids': data, 'attention_mask': mask}) + gm = metainfo_trace(gm, data, mask) + loss = GPTLMLoss() + + free_memory = 11000 * 1024**2 if args.model == 'resnet50' else 56000 * 1024**2 + start_factor = 4 if args.model == 'resnet50' else 10 + + # trace and benchmark + budgets, peak_hist, step_hist = bench_rotor(gm, + loss, + data_gen, + num_steps=5, + sample_points=15, + free_memory=free_memory, + start_factor=start_factor) + + # print summary + print("==============benchmark summary==============") + for budget, peak, step in zip(budgets, peak_hist, step_hist): + print(f'memory budget: {budget:.3f} MB, peak memory: {peak:.3f} MB, step time: {step:.3f} MS') + + # plot valid results + fig, axs = plt.subplots(1, 2, figsize=(16, 8)) + valid_idx = step_hist.index(next(step for step in step_hist if step != float("inf"))) + + # plot peak memory vs. budget memory + axs[0].plot(budgets[valid_idx:], peak_hist[valid_idx:]) + axs[0].plot([budgets[valid_idx], budgets[-1]], [budgets[valid_idx], budgets[-1]], linestyle='--') + axs[0].set_xlabel("Budget Memory (MB)") + axs[0].set_ylabel("Peak Memory (MB)") + axs[0].set_title("Peak Memory vs. Budget Memory") + + # plot relative step time vs. budget memory + axs[1].plot(peak_hist[valid_idx:], [step_time / step_hist[-1] for step_time in step_hist[valid_idx:]]) + axs[1].plot([peak_hist[valid_idx], peak_hist[-1]], [1.0, 1.0], linestyle='--') + axs[1].set_xlabel("Peak Memory (MB)") + axs[1].set_ylabel("Relative Step Time") + axs[1].set_title("Step Time vs. Peak Memory") + axs[1].set_ylim(0.8, 1.5) + + # save plot + fig.savefig(f"{args.model}_benchmark.png") + + +def auto_activation_checkpoint_benchmark(args): + world_size = 1 + spawn(_benchmark, world_size, args=args) + + +if __name__ == "__main__": + parser = ArgumentParser("Auto Activation Checkpoint Solver Benchmark") + parser.add_argument("--model", type=str, default='gpt2', choices=['gpt2', 'resnet50']) + args = parser.parse_args() + + auto_activation_checkpoint_benchmark(args) diff --git a/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py b/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..a6a9ad0a312cba1d8771b4e868fcb0d0e92507ee --- /dev/null +++ b/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py @@ -0,0 +1,95 @@ +import torch +from torchvision.models import resnet50 +from tqdm import tqdm + +import colossalai +from colossalai.auto_parallel.tensor_shard.initialize import initialize_model +from colossalai.core import global_context as gpc +from colossalai.device.device_mesh import DeviceMesh +from colossalai.logging import get_dist_logger +from colossalai.nn.lr_scheduler import CosineAnnealingLR + + +def synthesize_data(): + img = torch.rand(gpc.config.BATCH_SIZE, 3, 32, 32) + label = torch.randint(low=0, high=10, size=(gpc.config.BATCH_SIZE,)) + return img, label + + +def main(): + colossalai.launch_from_torch(config='./config.py') + + logger = get_dist_logger() + + # trace the model with meta data + model = resnet50(num_classes=10).cuda() + + input_sample = {'x': torch.rand([gpc.config.BATCH_SIZE * torch.distributed.get_world_size(), 3, 32, 32]).to('meta')} + device_mesh = DeviceMesh(physical_mesh_id=torch.tensor([0, 1, 2, 3]), mesh_shape=[2, 2], init_process_group=True) + model, solution = initialize_model(model, input_sample, device_mesh=device_mesh, return_solution=True) + + if gpc.get_global_rank() == 0: + for node_strategy in solution: + print(node_strategy) + # build criterion + criterion = torch.nn.CrossEntropyLoss() + + # optimizer + optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) + + # lr_scheduler + lr_scheduler = CosineAnnealingLR(optimizer, total_steps=gpc.config.NUM_EPOCHS) + + for epoch in range(gpc.config.NUM_EPOCHS): + model.train() + + # if we use synthetic data + # we assume it only has 10 steps per epoch + num_steps = range(10) + progress = tqdm(num_steps) + + for _ in progress: + # generate fake data + img, label = synthesize_data() + + img = img.cuda() + label = label.cuda() + optimizer.zero_grad() + output = model(img) + train_loss = criterion(output, label) + train_loss.backward(train_loss) + torch.cuda.synchronize() + optimizer.step() + lr_scheduler.step() + + # run evaluation + model.eval() + correct = 0 + total = 0 + + # if we use synthetic data + # we assume it only has 10 steps for evaluation + num_steps = range(10) + progress = tqdm(num_steps) + + for _ in progress: + # generate fake data + img, label = synthesize_data() + + img = img.cuda() + label = label.cuda() + + with torch.no_grad(): + output = model(img) + test_loss = criterion(output, label) + pred = torch.argmax(output, dim=-1) + correct += torch.sum(pred == label) + total += img.size(0) + + logger.info( + f"Epoch {epoch} - train loss: {train_loss:.5}, test loss: {test_loss:.5}, acc: {correct / total:.5}, lr: {lr_scheduler.get_last_lr()[0]:.5g}", + ranks=[0]) + + +if __name__ == '__main__': + main() diff --git a/examples/tutorial/auto_parallel/bench_utils.py b/examples/tutorial/auto_parallel/bench_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..69859f885ae601efcec049b4771df831ed0da009 --- /dev/null +++ b/examples/tutorial/auto_parallel/bench_utils.py @@ -0,0 +1,170 @@ +import time +from copy import deepcopy +from functools import partial +from typing import Callable, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torchvision.models as tm +from transformers import GPT2Config, GPT2LMHeadModel + +from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor +from colossalai.fx import metainfo_trace + + +def bench(gm: torch.fx.GraphModule, + criterion: torch.nn.Module, + data_gen: Callable, + num_steps: int = 5) -> Tuple[int, int]: + """Benchmarking a given graph module + Args: + gm (torch.fx.GraphModule): The graph module to benchmark. + criterion (torch.nn.Module): Loss function. + data_gen (Callable): Data generator. + num_steps (int, optional): Number of test steps. Defaults to 5. + Returns: + Tuple[int, int]: peak memory in MB and step time in MS. + """ + gm.train() + gm.cuda() + step_time = float('inf') + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + cached = torch.cuda.max_memory_allocated(device="cuda") + try: + for _ in range(num_steps): + args, label = data_gen() + output, loss = None, None + + torch.cuda.synchronize(device="cuda") + start = time.time() + output = gm(*args) + loss = criterion(output, label) + loss.backward() + torch.cuda.synchronize(device="cuda") + step_time = min(step_time, time.time() - start) + + for child in gm.children(): + for param in child.parameters(): + param.grad = None + del args, label, output, loss + except: + del args, label, output, loss + gm.to("cpu") + torch.cuda.empty_cache() + peak_mem = (torch.cuda.max_memory_allocated(device="cuda") - cached) / 1024**2 + return peak_mem, step_time * 1.0e3 + + +def bench_rotor(gm: torch.fx.GraphModule, + criterion: torch.nn.Module, + data_gen: Callable, + num_steps: int = 5, + sample_points: int = 20, + free_memory: int = torch.cuda.mem_get_info()[0], + start_factor: int = 4) -> Tuple[np.array, list, list]: + """Auto Checkpoint Rotor Algorithm benchmarking + Benchmarks the Auto Checkpoint Rotor Algorithm for a given graph module and data. + Args: + gm (torch.fx.GraphModule): The graph module to benchmark. + criterion (torch.nn.Module): Loss function. + data_gen (Callable): Data generator. + num_steps (int, optional): Number of test steps. Defaults to 5. + sample_points (int, optional): Number of sample points. Defaults to 20. + free_memory (int, optional): Max memory budget in Byte. Defaults to torch.cuda.mem_get_info()[0]. + start_factor (int, optional): Start memory budget factor for benchmark, the start memory budget + will be free_memory / start_factor. Defaults to 4. + Returns: + Tuple[np.array, list, list]: return budgets vector (MB), peak memory vector (MB), step time vector (MS). + """ + peak_hist, step_hist = [], [] + raw_graph = deepcopy(gm.graph) + for budget in np.linspace(free_memory // start_factor, free_memory, sample_points): + gm = metainfo_trace(gm, *data_gen()[0]) + solver = CheckpointSolverRotor(gm.graph, free_memory=budget) + try: + gm.graph = solver.solve(verbose=False) + peak_memory, step_time = bench(gm, criterion, data_gen, num_steps=num_steps) + except: + peak_memory, step_time = budget / 1024**2, float('inf') + peak_hist.append(peak_memory) + step_hist.append(step_time) + gm.graph = deepcopy(raw_graph) + return np.linspace(free_memory // start_factor, free_memory, sample_points) / 1024**2, peak_hist, step_hist + + +class GPTLMModel(nn.Module): + """ + GPT Model + """ + + def __init__(self, + hidden_size=768, + num_layers=12, + num_attention_heads=12, + max_seq_len=1024, + vocab_size=50257, + checkpoint=False): + super().__init__() + self.checkpoint = checkpoint + self.model = GPT2LMHeadModel( + GPT2Config(n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size)) + if checkpoint: + self.model.gradient_checkpointing_enable() + + def forward(self, input_ids, attention_mask): + # Only return lm_logits + return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0] + + +class GPTLMLoss(nn.Module): + """ + GPT Loss + """ + + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, logits, labels): + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + +def gpt2_medium(checkpoint=False): + return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint) + + +def gpt2_xl(checkpoint=False): + return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32, checkpoint=checkpoint) + + +def gpt2_6b(checkpoint=False): + return GPTLMModel(hidden_size=4096, num_layers=30, num_attention_heads=16, checkpoint=checkpoint) + + +def data_gen_gpt2(batch_size, seq_len, vocab_size, device='cuda:0'): + """ + Generate random data for gpt2 benchmarking + """ + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) + attention_mask = torch.ones_like(input_ids, device=device) + return (input_ids, attention_mask), attention_mask + + +def data_gen_resnet(batch_size, shape, device='cuda:0'): + """ + Generate random data for resnet benchmarking + """ + data = torch.empty(batch_size, *shape, device=device) + label = torch.empty(batch_size, dtype=torch.long, device=device).random_(1000) + return (data,), label diff --git a/examples/tutorial/auto_parallel/config.py b/examples/tutorial/auto_parallel/config.py new file mode 100644 index 0000000000000000000000000000000000000000..52e0abcef6980ba9896dce314e2b34ab97c64cc1 --- /dev/null +++ b/examples/tutorial/auto_parallel/config.py @@ -0,0 +1,2 @@ +BATCH_SIZE = 32 +NUM_EPOCHS = 2 diff --git a/examples/tutorial/auto_parallel/requirements.txt b/examples/tutorial/auto_parallel/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..cc61362ba6f9f61a98a21b6d0e61640b1030c8a7 --- /dev/null +++ b/examples/tutorial/auto_parallel/requirements.txt @@ -0,0 +1,7 @@ +torch==1.12.1 +colossalai +titans +pulp +datasets +matplotlib +transformers==4.22.1 diff --git a/examples/tutorial/auto_parallel/setup.py b/examples/tutorial/auto_parallel/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..6e6cff32ed23ece8a78a066a66433e9ffdd04210 --- /dev/null +++ b/examples/tutorial/auto_parallel/setup.py @@ -0,0 +1,13 @@ +from setuptools import find_packages, setup + +setup( + name='auto_parallel', + version='0.0.1', + description='', + packages=find_packages(), + install_requires=[ + 'torch', + 'numpy', + 'tqdm', + ], +) diff --git a/examples/tutorial/auto_parallel/test_ci.sh b/examples/tutorial/auto_parallel/test_ci.sh new file mode 100644 index 0000000000000000000000000000000000000000..bf6275b673ff7f559708bae2b5bc85dbac23c3ae --- /dev/null +++ b/examples/tutorial/auto_parallel/test_ci.sh @@ -0,0 +1,6 @@ +#!/bin/bash +set -euxo pipefail + +pip install -r requirements.txt +conda install -c conda-forge coin-or-cbc +colossalai run --nproc_per_node 4 auto_parallel_with_resnet.py diff --git a/examples/tutorial/download_cifar10.py b/examples/tutorial/download_cifar10.py new file mode 100644 index 0000000000000000000000000000000000000000..5c6b6988ade531f9a6e77955803b7c2dbd88ca9a --- /dev/null +++ b/examples/tutorial/download_cifar10.py @@ -0,0 +1,13 @@ +import os + +from torchvision.datasets import CIFAR10 + + +def main(): + dir_path = os.path.dirname(os.path.realpath(__file__)) + data_root = os.path.join(dir_path, 'data') + dataset = CIFAR10(root=data_root, download=True) + + +if __name__ == '__main__': + main() diff --git a/examples/tutorial/fastfold/FastFold b/examples/tutorial/fastfold/FastFold new file mode 160000 index 0000000000000000000000000000000000000000..05681304651b1b29d7d887db169045ea3dd28fce --- /dev/null +++ b/examples/tutorial/fastfold/FastFold @@ -0,0 +1 @@ +Subproject commit 05681304651b1b29d7d887db169045ea3dd28fce diff --git a/examples/tutorial/fastfold/README.md b/examples/tutorial/fastfold/README.md new file mode 100644 index 0000000000000000000000000000000000000000..434d033b9792bfc9939f111725a7d069bb124694 --- /dev/null +++ b/examples/tutorial/fastfold/README.md @@ -0,0 +1,49 @@ +# FastFold Inference + +## Table of contents + +- [FastFold Inference](#fastfold-inference) + - [Table of contents](#table-of-contents) + - [📚 Overview](#-overview) + - [🚀 Quick Start](#-quick-start) + - [🔍 Dive into FastFold](#-dive-into-fastfold) + +## 📚 Overview + +This example lets you to try out the inference of [FastFold](https://github.com/hpcaitech/FastFold). + +## 🚀 Quick Start + +1. Install FastFold + +We highly recommend you to install FastFold with conda. +``` +git clone https://github.com/hpcaitech/FastFold +cd FastFold +conda env create --name=fastfold -f environment.yml +conda activate fastfold +python setup.py install +``` + +2. Download datasets. + +It may take ~900GB space to keep datasets. +``` +./scripts/download_all_data.sh data/ +``` + +3. Run the inference scripts. + +``` +bash inference.sh +``` +You can find predictions under the `outputs` dir. + +## 🔍 Dive into FastFold + +There are another features of [FastFold](https://github.com/hpcaitech/FastFold), such as: ++ more excellent kernel based on triton ++ much faster data processing based on ray ++ training supported + +More detailed information can be seen [here](https://github.com/hpcaitech/FastFold/). diff --git a/examples/tutorial/hybrid_parallel/README.md b/examples/tutorial/hybrid_parallel/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1b5e54f928d4713811722bfb99c7a485f1287a52 --- /dev/null +++ b/examples/tutorial/hybrid_parallel/README.md @@ -0,0 +1,40 @@ +# Multi-dimensional Parallelism with Colossal-AI + +## Table of contents + +- [Overview](#-overview) +- [Quick Start](#-quick-start) + +## 📚 Overview + +This example lets you to quickly try out the hybrid parallelism provided by Colossal-AI. +You can change the parameters below to try out different settings in the `config.py`. + +```python +# parallel setting +TENSOR_PARALLEL_SIZE = 2 +TENSOR_PARALLEL_MODE = '1d' + +parallel = dict( + pipeline=2, + tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE), +) +``` + +## 🚀 Quick Start + +1. Install PyTorch + +2. Install the dependencies. + +```bash +pip install -r requirements.txt +``` + +3. Run the training scripts with synthetic data. + +```bash +colossalai run --nproc_per_node 4 train.py --config config.py +``` + +4. Modify the config file to play with different types of tensor parallelism, for example, change tensor parallel size to be 4 and mode to be 2d and run on 8 GPUs. diff --git a/examples/tutorial/hybrid_parallel/config.py b/examples/tutorial/hybrid_parallel/config.py new file mode 100644 index 0000000000000000000000000000000000000000..fe9abf2f1955fc0c9e15bf7d8669c5d05c36ce76 --- /dev/null +++ b/examples/tutorial/hybrid_parallel/config.py @@ -0,0 +1,36 @@ +from colossalai.amp import AMP_TYPE + +# hyperparameters +# BATCH_SIZE is as per GPU +# global batch size = BATCH_SIZE x data parallel size +BATCH_SIZE = 4 +LEARNING_RATE = 3e-3 +WEIGHT_DECAY = 0.3 +NUM_EPOCHS = 2 +WARMUP_EPOCHS = 1 + +# model config +IMG_SIZE = 224 +PATCH_SIZE = 16 +HIDDEN_SIZE = 128 +DEPTH = 4 +NUM_HEADS = 4 +MLP_RATIO = 2 +NUM_CLASSES = 10 +CHECKPOINT = False +SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token + +# parallel setting +TENSOR_PARALLEL_SIZE = 2 +TENSOR_PARALLEL_MODE = '1d' + +parallel = dict( + pipeline=2, + tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE), +) + +fp16 = dict(mode=AMP_TYPE.NAIVE) +clip_grad_norm = 1.0 + +# pipeline config +NUM_MICRO_BATCHES = parallel['pipeline'] diff --git a/examples/tutorial/hybrid_parallel/requirements.txt b/examples/tutorial/hybrid_parallel/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..99b7ecfe162e2ab29762258553273b02130a19e4 --- /dev/null +++ b/examples/tutorial/hybrid_parallel/requirements.txt @@ -0,0 +1,3 @@ +torch +colossalai +titans diff --git a/examples/tutorial/hybrid_parallel/test_ci.sh b/examples/tutorial/hybrid_parallel/test_ci.sh new file mode 100644 index 0000000000000000000000000000000000000000..e0dbef354e2d85721a3deda62b2c392dad09bb1d --- /dev/null +++ b/examples/tutorial/hybrid_parallel/test_ci.sh @@ -0,0 +1,5 @@ +#!/bin/bash +set -euxo pipefail + +pip install -r requirements.txt +colossalai run --nproc_per_node 4 train.py --config config.py diff --git a/examples/tutorial/hybrid_parallel/train.py b/examples/tutorial/hybrid_parallel/train.py new file mode 100644 index 0000000000000000000000000000000000000000..4953d5350f31ac222330fbcbffc156e3187fd80b --- /dev/null +++ b/examples/tutorial/hybrid_parallel/train.py @@ -0,0 +1,137 @@ +import os + +import torch +from titans.model.vit.vit import _create_vit_model +from tqdm import tqdm + +import colossalai +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.nn import CrossEntropyLoss +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.pipeline.pipelinable import PipelinableContext +from colossalai.utils import is_using_pp + + +class DummyDataloader(): + + def __init__(self, length, batch_size): + self.length = length + self.batch_size = batch_size + + def generate(self): + data = torch.rand(self.batch_size, 3, 224, 224) + label = torch.randint(low=0, high=10, size=(self.batch_size,)) + return data, label + + def __iter__(self): + self.step = 0 + return self + + def __next__(self): + if self.step < self.length: + self.step += 1 + return self.generate() + else: + raise StopIteration + + def __len__(self): + return self.length + + +def main(): + # launch from torch + parser = colossalai.get_default_parser() + args = parser.parse_args() + colossalai.launch_from_torch(config=args.config) + + # get logger + logger = get_dist_logger() + logger.info("initialized distributed environment", ranks=[0]) + + if hasattr(gpc.config, 'LOG_PATH'): + if gpc.get_global_rank() == 0: + log_path = gpc.config.LOG_PATH + if not os.path.exists(log_path): + os.mkdir(log_path) + logger.log_to_file(log_path) + + use_pipeline = is_using_pp() + + # create model + model_kwargs = dict(img_size=gpc.config.IMG_SIZE, + patch_size=gpc.config.PATCH_SIZE, + hidden_size=gpc.config.HIDDEN_SIZE, + depth=gpc.config.DEPTH, + num_heads=gpc.config.NUM_HEADS, + mlp_ratio=gpc.config.MLP_RATIO, + num_classes=10, + init_method='jax', + checkpoint=gpc.config.CHECKPOINT) + + if use_pipeline: + pipelinable = PipelinableContext() + with pipelinable: + model = _create_vit_model(**model_kwargs) + pipelinable.to_layer_list() + pipelinable.policy = "uniform" + model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE)) + else: + model = _create_vit_model(**model_kwargs) + + # count number of parameters + total_numel = 0 + for p in model.parameters(): + total_numel += p.numel() + if not gpc.is_initialized(ParallelMode.PIPELINE): + pipeline_stage = 0 + else: + pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE) + logger.info(f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}") + + # use synthetic dataset + # we train for 10 steps and eval for 5 steps per epoch + train_dataloader = DummyDataloader(length=10, batch_size=gpc.config.BATCH_SIZE) + test_dataloader = DummyDataloader(length=5, batch_size=gpc.config.BATCH_SIZE) + + # create loss function + criterion = CrossEntropyLoss(label_smoothing=0.1) + + # create optimizer + optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY) + + # create lr scheduler + lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, + total_steps=gpc.config.NUM_EPOCHS, + warmup_steps=gpc.config.WARMUP_EPOCHS) + + # initialize + engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model, + optimizer=optimizer, + criterion=criterion, + train_dataloader=train_dataloader, + test_dataloader=test_dataloader) + + logger.info("Engine is built", ranks=[0]) + + for epoch in range(gpc.config.NUM_EPOCHS): + # training + engine.train() + data_iter = iter(train_dataloader) + + if gpc.get_global_rank() == 0: + description = 'Epoch {} / {}'.format(epoch, gpc.config.NUM_EPOCHS) + progress = tqdm(range(len(train_dataloader)), desc=description) + else: + progress = range(len(train_dataloader)) + for _ in progress: + engine.zero_grad() + engine.execute_schedule(data_iter, return_output_label=False) + engine.step() + lr_scheduler.step() + gpc.destroy() + + +if __name__ == '__main__': + main() diff --git a/examples/tutorial/large_batch_optimizer/README.md b/examples/tutorial/large_batch_optimizer/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1a17c2d8740f16db283ea494a0e26993796c453e --- /dev/null +++ b/examples/tutorial/large_batch_optimizer/README.md @@ -0,0 +1,37 @@ +# Large Batch Training Optimization + +## Table of contents + +- [Large Batch Training Optimization](#large-batch-training-optimization) + - [Table of contents](#table-of-contents) + - [📚 Overview](#-overview) + - [🚀 Quick Start](#-quick-start) + +## 📚 Overview + +This example lets you to quickly try out the large batch training optimization provided by Colossal-AI. We use synthetic dataset to go through the process, thus, you don't need to prepare any dataset. You can try out the `Lamb` and `Lars` optimizers from Colossal-AI with the following code. + +```python +from colossalai.nn.optimizer import Lamb, Lars +``` + +## 🚀 Quick Start + +1. Install PyTorch + +2. Install the dependencies. + +```bash +pip install -r requirements.txt +``` + +3. Run the training scripts with synthetic data. + +```bash +# run on 4 GPUs +# run with lars +colossalai run --nproc_per_node 4 train.py --config config.py --optimizer lars + +# run with lamb +colossalai run --nproc_per_node 4 train.py --config config.py --optimizer lamb +``` diff --git a/examples/tutorial/large_batch_optimizer/config.py b/examples/tutorial/large_batch_optimizer/config.py new file mode 100644 index 0000000000000000000000000000000000000000..2efa0ffd0556c6d245dea8e86ffc19cf5ab68cc3 --- /dev/null +++ b/examples/tutorial/large_batch_optimizer/config.py @@ -0,0 +1,16 @@ +from colossalai.amp import AMP_TYPE + +# hyperparameters +# BATCH_SIZE is as per GPU +# global batch size = BATCH_SIZE x data parallel size +BATCH_SIZE = 512 +LEARNING_RATE = 3e-3 +WEIGHT_DECAY = 0.3 +NUM_EPOCHS = 2 +WARMUP_EPOCHS = 1 + +# model config +NUM_CLASSES = 10 + +fp16 = dict(mode=AMP_TYPE.NAIVE) +clip_grad_norm = 1.0 diff --git a/examples/tutorial/large_batch_optimizer/requirements.txt b/examples/tutorial/large_batch_optimizer/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..c013287751bfc2a27805b4549107ca40e14549a5 --- /dev/null +++ b/examples/tutorial/large_batch_optimizer/requirements.txt @@ -0,0 +1,3 @@ +colossalai +torch +titans diff --git a/examples/tutorial/large_batch_optimizer/test_ci.sh b/examples/tutorial/large_batch_optimizer/test_ci.sh new file mode 100644 index 0000000000000000000000000000000000000000..89f426c542b18f61225ed86eefe15038cba41cfe --- /dev/null +++ b/examples/tutorial/large_batch_optimizer/test_ci.sh @@ -0,0 +1,8 @@ +#!/bin/bash +set -euxo pipefail + +pip install -r requirements.txt + +# run test +colossalai run --nproc_per_node 4 --master_port 29500 train.py --config config.py --optimizer lars +colossalai run --nproc_per_node 4 --master_port 29501 train.py --config config.py --optimizer lamb diff --git a/examples/tutorial/large_batch_optimizer/train.py b/examples/tutorial/large_batch_optimizer/train.py new file mode 100644 index 0000000000000000000000000000000000000000..35e54582f49443ce994e77dbd3e1722ae9a7bd01 --- /dev/null +++ b/examples/tutorial/large_batch_optimizer/train.py @@ -0,0 +1,104 @@ +import torch +import torch.nn as nn +from torchvision.models import resnet18 +from tqdm import tqdm + +import colossalai +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.nn.optimizer import Lamb, Lars + + +class DummyDataloader(): + + def __init__(self, length, batch_size): + self.length = length + self.batch_size = batch_size + + def generate(self): + data = torch.rand(self.batch_size, 3, 224, 224) + label = torch.randint(low=0, high=10, size=(self.batch_size,)) + return data, label + + def __iter__(self): + self.step = 0 + return self + + def __next__(self): + if self.step < self.length: + self.step += 1 + return self.generate() + else: + raise StopIteration + + def __len__(self): + return self.length + + +def main(): + # initialize distributed setting + parser = colossalai.get_default_parser() + parser.add_argument('--optimizer', + choices=['lars', 'lamb'], + help="Choose your large-batch optimizer", + required=True) + args = parser.parse_args() + + # launch from torch + colossalai.launch_from_torch(config=args.config) + + # get logger + logger = get_dist_logger() + logger.info("initialized distributed environment", ranks=[0]) + + # create synthetic dataloaders + train_dataloader = DummyDataloader(length=10, batch_size=gpc.config.BATCH_SIZE) + test_dataloader = DummyDataloader(length=5, batch_size=gpc.config.BATCH_SIZE) + + # build model + model = resnet18(num_classes=gpc.config.NUM_CLASSES) + + # create loss function + criterion = nn.CrossEntropyLoss() + + # create optimizer + if args.optimizer == "lars": + optim_cls = Lars + elif args.optimizer == "lamb": + optim_cls = Lamb + optimizer = optim_cls(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY) + + # create lr scheduler + lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, + total_steps=gpc.config.NUM_EPOCHS, + warmup_steps=gpc.config.WARMUP_EPOCHS) + + # initialize + engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model, + optimizer=optimizer, + criterion=criterion, + train_dataloader=train_dataloader, + test_dataloader=test_dataloader) + + logger.info("Engine is built", ranks=[0]) + + for epoch in range(gpc.config.NUM_EPOCHS): + # training + engine.train() + data_iter = iter(train_dataloader) + + if gpc.get_global_rank() == 0: + description = 'Epoch {} / {}'.format(epoch, gpc.config.NUM_EPOCHS) + progress = tqdm(range(len(train_dataloader)), desc=description) + else: + progress = range(len(train_dataloader)) + for _ in progress: + engine.zero_grad() + engine.execute_schedule(data_iter, return_output_label=False) + engine.step() + lr_scheduler.step() + + +if __name__ == '__main__': + main() diff --git a/examples/tutorial/new_api/README.md b/examples/tutorial/new_api/README.md new file mode 100644 index 0000000000000000000000000000000000000000..cec88f41caf1d0ab6c2355a955ec0660a1fa344a --- /dev/null +++ b/examples/tutorial/new_api/README.md @@ -0,0 +1,5 @@ +# New API Features + +**The New API is not officially released yet.** + +This folder contains some of the demonstrations of the new API. The new API is still under intensive development and will be released soon. diff --git a/examples/tutorial/new_api/test_ci.sh b/examples/tutorial/new_api/test_ci.sh new file mode 100644 index 0000000000000000000000000000000000000000..8b4475e9f1473acd9456e050cb338395f2fbb60e --- /dev/null +++ b/examples/tutorial/new_api/test_ci.sh @@ -0,0 +1,2 @@ +#!/usr/bin/env +echo "The CI integration will be completed when the API is stable" diff --git a/examples/tutorial/new_api/torch_ddp/.gitignore b/examples/tutorial/new_api/torch_ddp/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..a79cf5236c088af99b3891cb3dc536aaaded808c --- /dev/null +++ b/examples/tutorial/new_api/torch_ddp/.gitignore @@ -0,0 +1,4 @@ +data +checkpoint +ckpt-fp16 +ckpt-fp32 diff --git a/examples/tutorial/new_api/torch_ddp/README.md b/examples/tutorial/new_api/torch_ddp/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e120bacb0c84269565c6185efb3666a87680d35e --- /dev/null +++ b/examples/tutorial/new_api/torch_ddp/README.md @@ -0,0 +1,44 @@ +# Distributed Data Parallel + +## 🚀 Quick Start + +This example provides a training script and an evaluation script. The training script provides an example of training ResNet on CIFAR10 dataset from scratch. + +- Training Arguments + - `-r`, `--resume`: resume from checkpoint file path + - `-c`, `--checkpoint`: the folder to save checkpoints + - `-i`, `--interval`: epoch interval to save checkpoints + - `-f`, `--fp16`: use fp16 + +- Eval Arguments + - `-e`, `--epoch`: select the epoch to evaluate + - `-c`, `--checkpoint`: the folder where checkpoints are found + + +### Train + +```bash +# train with torch DDP with fp32 +colossalai run --nproc_per_node 2 train.py -c ./ckpt-fp32 + +# train with torch DDP with mixed precision training +colossalai run --nproc_per_node 2 train.py -c ./ckpt-fp16 --fp16 +``` + +### Eval + +```bash +# evaluate fp32 training +python eval.py -c ./ckpt-fp32 -e 80 + +# evaluate fp16 mixed precision training +python eval.py -c ./ckpt-fp16 -e 80 +``` + +Expected accuracy performance will be: + +| Model | Single-GPU Baseline FP32 | Booster DDP with FP32 | Booster DDP with FP16 | +| --------- | ------------------------ | --------------------- | --------------------- | +| ResNet-18 | 85.85% | 85.03% | 85.12% | + +**Note: the baseline is adapted from the [script](https://pytorch-tutorial.readthedocs.io/en/latest/tutorial/chapter03_intermediate/3_2_2_cnn_resnet_cifar10/) to use `torchvision.models.resnet18`** diff --git a/examples/tutorial/new_api/torch_ddp/eval.py b/examples/tutorial/new_api/torch_ddp/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..657708ec3ff26a699a7abe12a3c3330a96194a23 --- /dev/null +++ b/examples/tutorial/new_api/torch_ddp/eval.py @@ -0,0 +1,48 @@ +import argparse + +import torch +import torch.nn as nn +import torchvision +import torchvision.transforms as transforms + +# ============================== +# Parse Arguments +# ============================== +parser = argparse.ArgumentParser() +parser.add_argument('-e', '--epoch', type=int, default=80, help="resume from the epoch's checkpoint") +parser.add_argument('-c', '--checkpoint', type=str, default='./checkpoint', help="checkpoint directory") +args = parser.parse_args() + +# ============================== +# Prepare Test Dataset +# ============================== +# CIFAR-10 dataset +test_dataset = torchvision.datasets.CIFAR10(root='./data/', train=False, transform=transforms.ToTensor()) + +# Data loader +test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=128, shuffle=False) + +# ============================== +# Load Model +# ============================== +model = torchvision.models.resnet18(num_classes=10).cuda() +state_dict = torch.load(f'{args.checkpoint}/model_{args.epoch}.pth') +model.load_state_dict(state_dict) + +# ============================== +# Run Evaluation +# ============================== +model.eval() + +with torch.no_grad(): + correct = 0 + total = 0 + for images, labels in test_loader: + images = images.cuda() + labels = labels.cuda() + outputs = model(images) + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + + print('Accuracy of the model on the test images: {} %'.format(100 * correct / total)) diff --git a/examples/tutorial/new_api/torch_ddp/train.py b/examples/tutorial/new_api/torch_ddp/train.py new file mode 100644 index 0000000000000000000000000000000000000000..4741c3151cbbbbb210cd2d5a7b4417625925ae28 --- /dev/null +++ b/examples/tutorial/new_api/torch_ddp/train.py @@ -0,0 +1,128 @@ +import argparse +from pathlib import Path + +import torch +import torch.nn as nn +import torchvision +import torchvision.transforms as transforms +from torch.optim.lr_scheduler import MultiStepLR + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import TorchDDPPlugin +from colossalai.cluster import DistCoordinator + +# ============================== +# Parse Arguments +# ============================== +parser = argparse.ArgumentParser() +parser.add_argument('-r', '--resume', type=int, default=-1, help="resume from the epoch's checkpoint") +parser.add_argument('-c', '--checkpoint', type=str, default='./checkpoint', help="checkpoint directory") +parser.add_argument('-i', '--interval', type=int, default=5, help="interval of saving checkpoint") +parser.add_argument('-f', '--fp16', action='store_true', help="use fp16") +args = parser.parse_args() + +# ============================== +# Prepare Checkpoint Directory +# ============================== +Path(args.checkpoint).mkdir(parents=True, exist_ok=True) + +# ============================== +# Prepare Hyperparameters +# ============================== +NUM_EPOCHS = 80 +LEARNING_RATE = 1e-3 +START_EPOCH = args.resume if args.resume >= 0 else 0 + +# ============================== +# Launch Distributed Environment +# ============================== +colossalai.launch_from_torch(config={}) +coordinator = DistCoordinator() + +# update the learning rate with linear scaling +# old_gpu_num / old_lr = new_gpu_num / new_lr +LEARNING_RATE *= coordinator.world_size + +# ============================== +# Prepare Booster +# ============================== +plugin = TorchDDPPlugin() +if args.fp16: + booster = Booster(mixed_precision='fp16', plugin=plugin) +else: + booster = Booster(plugin=plugin) + +# ============================== +# Prepare Train Dataset +# ============================== +transform = transforms.Compose( + [transforms.Pad(4), + transforms.RandomHorizontalFlip(), + transforms.RandomCrop(32), + transforms.ToTensor()]) + +# CIFAR-10 dataset +with coordinator.priority_execution(): + train_dataset = torchvision.datasets.CIFAR10(root='./data/', train=True, transform=transform, download=True) + +# ==================================== +# Prepare model, optimizer, criterion +# ==================================== +# resent50 +model = torchvision.models.resnet18(num_classes=10).cuda() + +# Loss and optimizer +criterion = nn.CrossEntropyLoss() +optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) + +# lr scheduler +lr_scheduler = MultiStepLR(optimizer, milestones=[20, 40, 60, 80], gamma=1 / 3) + +# prepare dataloader with torch ddp plugin +train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=100, shuffle=True) + +# ============================== +# Resume from checkpoint +# ============================== +if args.resume >= 0: + booster.load_model(model, f'{args.checkpoint}/model_{args.resume}.pth') + booster.load_optimizer(optimizer, f'{args.checkpoint}/optimizer_{args.resume}.pth') + booster.load_lr_scheduler(lr_scheduler, f'{args.checkpoint}/lr_scheduler_{args.resume}.pth') + +# ============================== +# Boost with ColossalAI +# ============================== +model, optimizer, criterion, train_dataloader, lr_scheduler = booster.boost(model, optimizer, criterion, + train_dataloader, lr_scheduler) + +# ============================== +# Train model +# ============================== +total_step = len(train_dataloader) + +for epoch in range(START_EPOCH, NUM_EPOCHS): + for i, (images, labels) in enumerate(train_dataloader): + images = images.cuda() + labels = labels.cuda() + + # Forward pass + outputs = model(images) + loss = criterion(outputs, labels) + + # Backward and optimize + optimizer.zero_grad() + booster.backward(loss, optimizer) + optimizer.step() + + if (i + 1) % 100 == 0: + print("Epoch [{}/{}], Step [{}/{}] Loss: {:.4f}".format(epoch + 1, NUM_EPOCHS, i + 1, total_step, + loss.item())) + + lr_scheduler.step() + + # save checkpoint every 5 epoch + if (epoch + 1) % args.interval == 0: + booster.save_model(model, f'{args.checkpoint}/model_{epoch + 1}.pth') + booster.save_optimizer(optimizer, f'{args.checkpoint}/optimizer_{epoch + 1}.pth') + booster.save_lr_scheduler(lr_scheduler, f'{args.checkpoint}/lr_scheduler_{epoch + 1}.pth') diff --git a/examples/tutorial/opt/inference/README.md b/examples/tutorial/opt/inference/README.md new file mode 100644 index 0000000000000000000000000000000000000000..20ad4a23fdeb9a719ae9f8b9875e6404371f98f9 --- /dev/null +++ b/examples/tutorial/opt/inference/README.md @@ -0,0 +1,88 @@ +# Overview + +This is an example showing how to run OPT generation. The OPT model is implemented using ColossalAI. + +It supports tensor parallelism, batching and caching. + +## 🚀Quick Start +1. Run inference with OPT 125M +```bash +docker hpcaitech/tutorial:opt-inference +docker run -it --rm --gpus all --ipc host -p 7070:7070 hpcaitech/tutorial:opt-inference +``` +2. Start the http server inside the docker container with tensor parallel size 2 +```bash +python opt_fastapi.py opt-125m --tp 2 --checkpoint /data/opt-125m +``` + +# How to run + +Run OPT-125M: +```shell +python opt_fastapi.py opt-125m +``` + +It will launch a HTTP server on `0.0.0.0:7070` by default and you can customize host and port. You can open `localhost:7070/docs` in your browser to see the openapi docs. + +## Configure + +### Configure model +```shell +python opt_fastapi.py +``` +Available models: opt-125m, opt-6.7b, opt-30b, opt-175b. + +### Configure tensor parallelism +```shell +python opt_fastapi.py --tp +``` +The `` can be an integer in `[1, #GPUs]`. Default `1`. + +### Configure checkpoint +```shell +python opt_fastapi.py --checkpoint +``` +The `` can be a file path or a directory path. If it's a directory path, all files under the directory will be loaded. + +### Configure queue +```shell +python opt_fastapi.py --queue_size +``` +The `` can be an integer in `[0, MAXINT]`. If it's `0`, the request queue size is infinite. If it's a positive integer, when the request queue is full, incoming requests will be dropped (the HTTP status code of response will be 406). + +### Configure batching +```shell +python opt_fastapi.py --max_batch_size +``` +The `` can be an integer in `[1, MAXINT]`. The engine will make batch whose size is less or equal to this value. + +Note that the batch size is not always equal to ``, as some consecutive requests may not be batched. + +### Configure caching +```shell +python opt_fastapi.py --cache_size --cache_list_size +``` +This will cache `` unique requests. And for each unique request, it cache `` different results. A random result will be returned if the cache is hit. + +The `` can be an integer in `[0, MAXINT]`. If it's `0`, cache won't be applied. The `` can be an integer in `[1, MAXINT]`. + +### Other configurations +```shell +python opt_fastapi.py -h +``` + +# How to benchmark +```shell +cd benchmark +locust +``` + +Then open the web interface link which is on your console. + +# Pre-process pre-trained weights + +## OPT-66B +See [script/processing_ckpt_66b.py](./script/processing_ckpt_66b.py). + +## OPT-175B +See [script/process-opt-175b](./script/process-opt-175b/). diff --git a/examples/tutorial/opt/inference/batch.py b/examples/tutorial/opt/inference/batch.py new file mode 100644 index 0000000000000000000000000000000000000000..1a0876ca833890fd2bfac3b4d9f342a05c67928f --- /dev/null +++ b/examples/tutorial/opt/inference/batch.py @@ -0,0 +1,59 @@ +import torch +from typing import List, Deque, Tuple, Hashable, Any +from energonai import BatchManager, SubmitEntry, TaskEntry + + +class BatchManagerForGeneration(BatchManager): + def __init__(self, max_batch_size: int = 1, pad_token_id: int = 0) -> None: + super().__init__() + self.max_batch_size = max_batch_size + self.pad_token_id = pad_token_id + + def _left_padding(self, batch_inputs): + max_len = max(len(inputs['input_ids']) for inputs in batch_inputs) + outputs = {'input_ids': [], 'attention_mask': []} + for inputs in batch_inputs: + input_ids, attention_mask = inputs['input_ids'], inputs['attention_mask'] + padding_len = max_len - len(input_ids) + input_ids = [self.pad_token_id] * padding_len + input_ids + attention_mask = [0] * padding_len + attention_mask + outputs['input_ids'].append(input_ids) + outputs['attention_mask'].append(attention_mask) + for k in outputs: + outputs[k] = torch.tensor(outputs[k]) + return outputs, max_len + + @staticmethod + def _make_batch_key(entry: SubmitEntry) -> tuple: + data = entry.data + return (data['top_k'], data['top_p'], data['temperature']) + + def make_batch(self, q: Deque[SubmitEntry]) -> Tuple[TaskEntry, dict]: + entry = q.popleft() + uids = [entry.uid] + batch = [entry.data] + while len(batch) < self.max_batch_size: + if len(q) == 0: + break + if self._make_batch_key(entry) != self._make_batch_key(q[0]): + break + if q[0].data['max_tokens'] > entry.data['max_tokens']: + break + e = q.popleft() + batch.append(e.data) + uids.append(e.uid) + inputs, max_len = self._left_padding(batch) + trunc_lens = [] + for data in batch: + trunc_lens.append(max_len + data['max_tokens']) + inputs['top_k'] = entry.data['top_k'] + inputs['top_p'] = entry.data['top_p'] + inputs['temperature'] = entry.data['temperature'] + inputs['max_tokens'] = max_len + entry.data['max_tokens'] + return TaskEntry(tuple(uids), inputs), {'trunc_lens': trunc_lens} + + def split_batch(self, task_entry: TaskEntry, trunc_lens: List[int] = []) -> List[Tuple[Hashable, Any]]: + retval = [] + for uid, output, trunc_len in zip(task_entry.uids, task_entry.batch, trunc_lens): + retval.append((uid, output[:trunc_len])) + return retval diff --git a/examples/tutorial/opt/inference/benchmark/locustfile.py b/examples/tutorial/opt/inference/benchmark/locustfile.py new file mode 100644 index 0000000000000000000000000000000000000000..4d829e5d83bf73c45b20d05bfa549b5ee3b869ba --- /dev/null +++ b/examples/tutorial/opt/inference/benchmark/locustfile.py @@ -0,0 +1,15 @@ +from locust import HttpUser, task +from json import JSONDecodeError + + +class GenerationUser(HttpUser): + @task + def generate(self): + prompt = 'Question: What is the longest river on the earth? Answer:' + for i in range(4, 9): + data = {'max_tokens': 2**i, 'prompt': prompt} + with self.client.post('/generation', json=data, catch_response=True) as response: + if response.status_code in (200, 406): + response.success() + else: + response.failure('Response wrong') diff --git a/examples/tutorial/opt/inference/cache.py b/examples/tutorial/opt/inference/cache.py new file mode 100644 index 0000000000000000000000000000000000000000..30febc44fbb3d7afb703c641e582562f90c6b8d0 --- /dev/null +++ b/examples/tutorial/opt/inference/cache.py @@ -0,0 +1,64 @@ +from collections import OrderedDict +from threading import Lock +from contextlib import contextmanager +from typing import List, Any, Hashable, Dict + + +class MissCacheError(Exception): + pass + + +class ListCache: + def __init__(self, cache_size: int, list_size: int, fixed_keys: List[Hashable] = []) -> None: + """Cache a list of values. The fixed keys won't be removed. For other keys, LRU is applied. + When the value list is not full, a cache miss occurs. Otherwise, a cache hit occurs. Redundant values will be removed. + + Args: + cache_size (int): Max size for LRU cache. + list_size (int): Value list size. + fixed_keys (List[Hashable], optional): The keys which won't be removed. Defaults to []. + """ + self.cache_size = cache_size + self.list_size = list_size + self.cache: OrderedDict[Hashable, List[Any]] = OrderedDict() + self.fixed_cache: Dict[Hashable, List[Any]] = {} + for key in fixed_keys: + self.fixed_cache[key] = [] + self._lock = Lock() + + def get(self, key: Hashable) -> List[Any]: + with self.lock(): + if key in self.fixed_cache: + l = self.fixed_cache[key] + if len(l) >= self.list_size: + return l + elif key in self.cache: + self.cache.move_to_end(key) + l = self.cache[key] + if len(l) >= self.list_size: + return l + raise MissCacheError() + + def add(self, key: Hashable, value: Any) -> None: + with self.lock(): + if key in self.fixed_cache: + l = self.fixed_cache[key] + if len(l) < self.list_size and value not in l: + l.append(value) + elif key in self.cache: + self.cache.move_to_end(key) + l = self.cache[key] + if len(l) < self.list_size and value not in l: + l.append(value) + else: + if len(self.cache) >= self.cache_size: + self.cache.popitem(last=False) + self.cache[key] = [value] + + @contextmanager + def lock(self): + try: + self._lock.acquire() + yield + finally: + self._lock.release() diff --git a/examples/tutorial/opt/inference/opt_fastapi.py b/examples/tutorial/opt/inference/opt_fastapi.py new file mode 100644 index 0000000000000000000000000000000000000000..cbfc2a22e7c0c98070d940f3a42cfd66ce839028 --- /dev/null +++ b/examples/tutorial/opt/inference/opt_fastapi.py @@ -0,0 +1,123 @@ +import argparse +import logging +import random +from typing import Optional + +import uvicorn +from energonai import QueueFullError, launch_engine +from energonai.model import opt_6B, opt_30B, opt_125M, opt_175B +from fastapi import FastAPI, HTTPException, Request +from pydantic import BaseModel, Field +from transformers import GPT2Tokenizer + +from batch import BatchManagerForGeneration +from cache import ListCache, MissCacheError + + +class GenerationTaskReq(BaseModel): + max_tokens: int = Field(gt=0, le=256, example=64) + prompt: str = Field( + min_length=1, example='Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:') + top_k: Optional[int] = Field(default=None, gt=0, example=50) + top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5) + temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7) + + +app = FastAPI() + + +@app.post('/generation') +async def generate(data: GenerationTaskReq, request: Request): + logger.info(f'{request.client.host}:{request.client.port} - "{request.method} {request.url.path}" - {data}') + key = (data.prompt, data.max_tokens) + try: + if cache is None: + raise MissCacheError() + outputs = cache.get(key) + output = random.choice(outputs) + logger.info('Cache hit') + except MissCacheError: + inputs = tokenizer(data.prompt, truncation=True, max_length=512) + inputs['max_tokens'] = data.max_tokens + inputs['top_k'] = data.top_k + inputs['top_p'] = data.top_p + inputs['temperature'] = data.temperature + try: + uid = id(data) + engine.submit(uid, inputs) + output = await engine.wait(uid) + output = tokenizer.decode(output, skip_special_tokens=True) + if cache is not None: + cache.add(key, output) + except QueueFullError as e: + raise HTTPException(status_code=406, detail=e.args[0]) + + return {'text': output} + + +@app.on_event("shutdown") +async def shutdown(*_): + engine.shutdown() + server.should_exit = True + server.force_exit = True + await server.shutdown() + + +def get_model_fn(model_name: str): + model_map = { + 'opt-125m': opt_125M, + 'opt-6.7b': opt_6B, + 'opt-30b': opt_30B, + 'opt-175b': opt_175B + } + return model_map[model_name] + + +def print_args(args: argparse.Namespace): + print('\n==> Args:') + for k, v in args.__dict__.items(): + print(f'{k} = {v}') + + +FIXED_CACHE_KEYS = [ + ('Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:', 64), + ('A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.', 64), + ("English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", 64) +] + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('model', choices=['opt-125m', 'opt-6.7b', 'opt-30b', 'opt-175b']) + parser.add_argument('--tp', type=int, default=1) + parser.add_argument('--master_host', default='localhost') + parser.add_argument('--master_port', type=int, default=19990) + parser.add_argument('--rpc_port', type=int, default=19980) + parser.add_argument('--max_batch_size', type=int, default=8) + parser.add_argument('--pipe_size', type=int, default=1) + parser.add_argument('--queue_size', type=int, default=0) + parser.add_argument('--http_host', default='0.0.0.0') + parser.add_argument('--http_port', type=int, default=7070) + parser.add_argument('--checkpoint', default=None) + parser.add_argument('--cache_size', type=int, default=0) + parser.add_argument('--cache_list_size', type=int, default=1) + args = parser.parse_args() + print_args(args) + model_kwargs = {} + if args.checkpoint is not None: + model_kwargs['checkpoint'] = args.checkpoint + + logger = logging.getLogger(__name__) + tokenizer = GPT2Tokenizer.from_pretrained('facebook/opt-30b') + if args.cache_size > 0: + cache = ListCache(args.cache_size, args.cache_list_size, fixed_keys=FIXED_CACHE_KEYS) + else: + cache = None + engine = launch_engine(args.tp, 1, args.master_host, args.master_port, args.rpc_port, get_model_fn(args.model), + batch_manager=BatchManagerForGeneration(max_batch_size=args.max_batch_size, + pad_token_id=tokenizer.pad_token_id), + pipe_size=args.pipe_size, + queue_size=args.queue_size, + **model_kwargs) + config = uvicorn.Config(app, host=args.http_host, port=args.http_port) + server = uvicorn.Server(config=config) + server.run() diff --git a/examples/tutorial/opt/inference/opt_server.py b/examples/tutorial/opt/inference/opt_server.py new file mode 100644 index 0000000000000000000000000000000000000000..8dab82622c59c313aee836a4509a27fd48ccad1c --- /dev/null +++ b/examples/tutorial/opt/inference/opt_server.py @@ -0,0 +1,122 @@ +import logging +import argparse +import random +from torch import Tensor +from pydantic import BaseModel, Field +from typing import Optional +from energonai.model import opt_125M, opt_30B, opt_175B, opt_6B +from transformers import GPT2Tokenizer +from energonai import launch_engine, QueueFullError +from sanic import Sanic +from sanic.request import Request +from sanic.response import json +from sanic_ext import validate, openapi +from batch import BatchManagerForGeneration +from cache import ListCache, MissCacheError + + +class GenerationTaskReq(BaseModel): + max_tokens: int = Field(gt=0, le=256, example=64) + prompt: str = Field( + min_length=1, example='Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:') + top_k: Optional[int] = Field(default=None, gt=0, example=50) + top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5) + temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7) + + +app = Sanic('opt') + + +@app.post('/generation') +@openapi.body(GenerationTaskReq) +@validate(json=GenerationTaskReq) +async def generate(request: Request, body: GenerationTaskReq): + logger.info(f'{request.ip}:{request.port} - "{request.method} {request.path}" - {body}') + key = (body.prompt, body.max_tokens) + try: + if cache is None: + raise MissCacheError() + outputs = cache.get(key) + output = random.choice(outputs) + logger.info('Cache hit') + except MissCacheError: + inputs = tokenizer(body.prompt, truncation=True, max_length=512) + inputs['max_tokens'] = body.max_tokens + inputs['top_k'] = body.top_k + inputs['top_p'] = body.top_p + inputs['temperature'] = body.temperature + try: + uid = id(body) + engine.submit(uid, inputs) + output = await engine.wait(uid) + assert isinstance(output, Tensor) + output = tokenizer.decode(output, skip_special_tokens=True) + if cache is not None: + cache.add(key, output) + except QueueFullError as e: + return json({'detail': e.args[0]}, status=406) + + return json({'text': output}) + + +@app.after_server_stop +def shutdown(*_): + engine.shutdown() + + +def get_model_fn(model_name: str): + model_map = { + 'opt-125m': opt_125M, + 'opt-6.7b': opt_6B, + 'opt-30b': opt_30B, + 'opt-175b': opt_175B + } + return model_map[model_name] + + +def print_args(args: argparse.Namespace): + print('\n==> Args:') + for k, v in args.__dict__.items(): + print(f'{k} = {v}') + + +FIXED_CACHE_KEYS = [ + ('Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:', 64), + ('A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.', 64), + ("English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", 64) +] + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('model', choices=['opt-125m', 'opt-6.7b', 'opt-30b', 'opt-175b']) + parser.add_argument('--tp', type=int, default=1) + parser.add_argument('--master_host', default='localhost') + parser.add_argument('--master_port', type=int, default=19990) + parser.add_argument('--rpc_port', type=int, default=19980) + parser.add_argument('--max_batch_size', type=int, default=8) + parser.add_argument('--pipe_size', type=int, default=1) + parser.add_argument('--queue_size', type=int, default=0) + parser.add_argument('--http_host', default='0.0.0.0') + parser.add_argument('--http_port', type=int, default=7070) + parser.add_argument('--checkpoint', default=None) + parser.add_argument('--cache_size', type=int, default=0) + parser.add_argument('--cache_list_size', type=int, default=1) + args = parser.parse_args() + print_args(args) + model_kwargs = {} + if args.checkpoint is not None: + model_kwargs['checkpoint'] = args.checkpoint + + logger = logging.getLogger(__name__) + tokenizer = GPT2Tokenizer.from_pretrained('facebook/opt-30b') + if args.cache_size > 0: + cache = ListCache(args.cache_size, args.cache_list_size, fixed_keys=FIXED_CACHE_KEYS) + else: + cache = None + engine = launch_engine(args.tp, 1, args.master_host, args.master_port, args.rpc_port, get_model_fn(args.model), + batch_manager=BatchManagerForGeneration(max_batch_size=args.max_batch_size, + pad_token_id=tokenizer.pad_token_id), + pipe_size=args.pipe_size, + queue_size=args.queue_size, + **model_kwargs) + app.run(args.http_host, args.http_port) diff --git a/examples/tutorial/opt/inference/requirements.txt b/examples/tutorial/opt/inference/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..966dff4746f24591a88987b02845ce1d1edb7243 --- /dev/null +++ b/examples/tutorial/opt/inference/requirements.txt @@ -0,0 +1,10 @@ +fastapi==0.85.1 +locust==2.11.0 +pydantic==1.10.2 +sanic==22.9.0 +sanic_ext==22.9.0 +torch>=1.10.0 +transformers==4.23.1 +uvicorn==0.19.0 +colossalai +git+https://github.com/hpcaitech/EnergonAI@main diff --git a/examples/tutorial/opt/inference/script/process-opt-175b/README.md b/examples/tutorial/opt/inference/script/process-opt-175b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..bc3cba72df33c3242ed35625e3467aaefb4849e1 --- /dev/null +++ b/examples/tutorial/opt/inference/script/process-opt-175b/README.md @@ -0,0 +1,46 @@ +# Process OPT-175B weights + +You should download the pre-trained weights following the [doc](https://github.com/facebookresearch/metaseq/tree/main/projects/OPT) before reading this. + +First, install `metaseq` and `git clone https://github.com/facebookresearch/metaseq.git`. + +Then, `cd metaseq`. + +To consolidate checkpoints to eliminate FSDP: + +```shell +bash metaseq/scripts/reshard_mp_launch_no_slurm.sh /checkpoint_last / 8 1 +``` + +You will get 8 files in ``, and you should have the following checksums: +``` +7e71cb65c4be784aa0b2889ac6039ee8 reshard-model_part-0-shard0.pt +c8123da04f2c25a9026ea3224d5d5022 reshard-model_part-1-shard0.pt +45e5d10896382e5bc4a7064fcafd2b1e reshard-model_part-2-shard0.pt +abb7296c4d2fc17420b84ca74fc3ce64 reshard-model_part-3-shard0.pt +05dcc7ac6046f4d3f90b3d1068e6da15 reshard-model_part-4-shard0.pt +d24dd334019060ce1ee7e625fcf6b4bd reshard-model_part-5-shard0.pt +fb1615ce0bbe89cc717f3e5079ee2655 reshard-model_part-6-shard0.pt +2f3124432d2dbc6aebfca06be4b791c2 reshard-model_part-7-shard0.pt +``` + +Copy `flat-meta.json` to ``. + +Then cd to this dir, and we unflatten parameters. + +```shell +bash unflat.sh / / +``` + +Finally, you will get 8 files in `` with following checksums: +``` +6169c59d014be95553c89ec01b8abb62 reshard-model_part-0.pt +58868105da3d74a528a548fdb3a8cff6 reshard-model_part-1.pt +69b255dc5a49d0eba9e4b60432cda90b reshard-model_part-2.pt +002c052461ff9ffb0cdac3d5906f41f2 reshard-model_part-3.pt +6d57f72909320d511ffd5f1c668b2beb reshard-model_part-4.pt +93c8c4041cdc0c7907cc7afcf15cec2a reshard-model_part-5.pt +5d63b8750d827a1aa7c8ae5b02a3a2ca reshard-model_part-6.pt +f888bd41e009096804fe9a4b48c7ffe8 reshard-model_part-7.pt +``` + diff --git a/examples/tutorial/opt/inference/script/process-opt-175b/convert_ckpt.py b/examples/tutorial/opt/inference/script/process-opt-175b/convert_ckpt.py new file mode 100644 index 0000000000000000000000000000000000000000..a17ddd4fa1735fcd1a114f1fd3c87870c257d6bd --- /dev/null +++ b/examples/tutorial/opt/inference/script/process-opt-175b/convert_ckpt.py @@ -0,0 +1,55 @@ +import argparse +import json +import os +import re +from collections import defaultdict + +import numpy as np +import torch + + +def load_json(path: str): + with open(path) as f: + return json.load(f) + + +def parse_shape_info(flat_dir: str): + data = load_json(os.path.join(flat_dir, 'shape.json')) + flat_info = defaultdict(lambda: defaultdict(list)) + for k, shape in data.items(): + matched = re.match(r'decoder.layers.\d+', k) + if matched is None: + flat_key = 'flat_param_0' + else: + flat_key = f'{matched[0]}.flat_param_0' + flat_info[flat_key]['names'].append(k) + flat_info[flat_key]['shapes'].append(shape) + flat_info[flat_key]['numels'].append(int(np.prod(shape))) + return flat_info + + +def convert(flat_dir: str, output_dir: str, part: int): + flat_path = os.path.join(flat_dir, f'reshard-model_part-{part}-shard0.pt') + output_path = os.path.join(output_dir, f'reshard-model_part-{part}.pt') + flat_meta = load_json(os.path.join(flat_dir, 'flat-meta.json')) + flat_sd = torch.load(flat_path) + print(f'Loaded flat state dict from {flat_path}') + output_sd = {} + for flat_key, param_meta in flat_meta.items(): + flat_param = flat_sd['model'][flat_key] + assert sum(param_meta['numels']) == flat_param.numel( + ), f'flat {flat_key} {flat_param.numel()} vs {sum(param_meta["numels"])}' + for name, shape, param in zip(param_meta['names'], param_meta['shapes'], flat_param.split(param_meta['numels'])): + output_sd[name] = param.view(shape) + + torch.save(output_sd, output_path) + print(f'Saved unflat state dict to {output_path}') + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('flat_dir') + parser.add_argument('output_dir') + parser.add_argument('part', type=int) + args = parser.parse_args() + convert(args.flat_dir, args.output_dir, args.part) diff --git a/examples/tutorial/opt/inference/script/process-opt-175b/flat-meta.json b/examples/tutorial/opt/inference/script/process-opt-175b/flat-meta.json new file mode 100644 index 0000000000000000000000000000000000000000..59d285565cfdb8d68ce0e5dc91aab2046c136e9c --- /dev/null +++ b/examples/tutorial/opt/inference/script/process-opt-175b/flat-meta.json @@ -0,0 +1 @@ +{"flat_param_0": {"names": ["decoder.embed_tokens.weight", "decoder.embed_positions.weight", "decoder.layer_norm.weight", "decoder.layer_norm.bias"], "shapes": [[6284, 12288], [2050, 12288], [12288], [12288]], "numels": [77217792, 25190400, 12288, 12288]}, "decoder.layers.0.flat_param_0": {"names": ["decoder.layers.0.self_attn.qkv_proj.weight", "decoder.layers.0.self_attn.qkv_proj.bias", "decoder.layers.0.self_attn.out_proj.weight", "decoder.layers.0.self_attn.out_proj.bias", "decoder.layers.0.self_attn_layer_norm.weight", "decoder.layers.0.self_attn_layer_norm.bias", "decoder.layers.0.fc1.weight", "decoder.layers.0.fc1.bias", "decoder.layers.0.fc2.weight", "decoder.layers.0.fc2.bias", "decoder.layers.0.final_layer_norm.weight", "decoder.layers.0.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.1.flat_param_0": {"names": ["decoder.layers.1.self_attn.qkv_proj.weight", "decoder.layers.1.self_attn.qkv_proj.bias", "decoder.layers.1.self_attn.out_proj.weight", "decoder.layers.1.self_attn.out_proj.bias", "decoder.layers.1.self_attn_layer_norm.weight", "decoder.layers.1.self_attn_layer_norm.bias", "decoder.layers.1.fc1.weight", "decoder.layers.1.fc1.bias", "decoder.layers.1.fc2.weight", "decoder.layers.1.fc2.bias", "decoder.layers.1.final_layer_norm.weight", "decoder.layers.1.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.2.flat_param_0": {"names": ["decoder.layers.2.self_attn.qkv_proj.weight", "decoder.layers.2.self_attn.qkv_proj.bias", "decoder.layers.2.self_attn.out_proj.weight", "decoder.layers.2.self_attn.out_proj.bias", "decoder.layers.2.self_attn_layer_norm.weight", "decoder.layers.2.self_attn_layer_norm.bias", "decoder.layers.2.fc1.weight", "decoder.layers.2.fc1.bias", "decoder.layers.2.fc2.weight", "decoder.layers.2.fc2.bias", "decoder.layers.2.final_layer_norm.weight", "decoder.layers.2.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.3.flat_param_0": {"names": ["decoder.layers.3.self_attn.qkv_proj.weight", "decoder.layers.3.self_attn.qkv_proj.bias", "decoder.layers.3.self_attn.out_proj.weight", "decoder.layers.3.self_attn.out_proj.bias", "decoder.layers.3.self_attn_layer_norm.weight", "decoder.layers.3.self_attn_layer_norm.bias", "decoder.layers.3.fc1.weight", "decoder.layers.3.fc1.bias", "decoder.layers.3.fc2.weight", "decoder.layers.3.fc2.bias", "decoder.layers.3.final_layer_norm.weight", "decoder.layers.3.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.4.flat_param_0": {"names": ["decoder.layers.4.self_attn.qkv_proj.weight", "decoder.layers.4.self_attn.qkv_proj.bias", "decoder.layers.4.self_attn.out_proj.weight", "decoder.layers.4.self_attn.out_proj.bias", "decoder.layers.4.self_attn_layer_norm.weight", "decoder.layers.4.self_attn_layer_norm.bias", "decoder.layers.4.fc1.weight", "decoder.layers.4.fc1.bias", "decoder.layers.4.fc2.weight", "decoder.layers.4.fc2.bias", "decoder.layers.4.final_layer_norm.weight", "decoder.layers.4.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.5.flat_param_0": {"names": ["decoder.layers.5.self_attn.qkv_proj.weight", "decoder.layers.5.self_attn.qkv_proj.bias", "decoder.layers.5.self_attn.out_proj.weight", "decoder.layers.5.self_attn.out_proj.bias", "decoder.layers.5.self_attn_layer_norm.weight", "decoder.layers.5.self_attn_layer_norm.bias", "decoder.layers.5.fc1.weight", "decoder.layers.5.fc1.bias", "decoder.layers.5.fc2.weight", "decoder.layers.5.fc2.bias", "decoder.layers.5.final_layer_norm.weight", "decoder.layers.5.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.6.flat_param_0": {"names": ["decoder.layers.6.self_attn.qkv_proj.weight", "decoder.layers.6.self_attn.qkv_proj.bias", "decoder.layers.6.self_attn.out_proj.weight", "decoder.layers.6.self_attn.out_proj.bias", "decoder.layers.6.self_attn_layer_norm.weight", "decoder.layers.6.self_attn_layer_norm.bias", "decoder.layers.6.fc1.weight", "decoder.layers.6.fc1.bias", "decoder.layers.6.fc2.weight", "decoder.layers.6.fc2.bias", "decoder.layers.6.final_layer_norm.weight", "decoder.layers.6.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.7.flat_param_0": {"names": ["decoder.layers.7.self_attn.qkv_proj.weight", "decoder.layers.7.self_attn.qkv_proj.bias", "decoder.layers.7.self_attn.out_proj.weight", "decoder.layers.7.self_attn.out_proj.bias", "decoder.layers.7.self_attn_layer_norm.weight", "decoder.layers.7.self_attn_layer_norm.bias", "decoder.layers.7.fc1.weight", "decoder.layers.7.fc1.bias", "decoder.layers.7.fc2.weight", "decoder.layers.7.fc2.bias", "decoder.layers.7.final_layer_norm.weight", "decoder.layers.7.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.8.flat_param_0": {"names": ["decoder.layers.8.self_attn.qkv_proj.weight", "decoder.layers.8.self_attn.qkv_proj.bias", "decoder.layers.8.self_attn.out_proj.weight", "decoder.layers.8.self_attn.out_proj.bias", "decoder.layers.8.self_attn_layer_norm.weight", "decoder.layers.8.self_attn_layer_norm.bias", "decoder.layers.8.fc1.weight", "decoder.layers.8.fc1.bias", "decoder.layers.8.fc2.weight", "decoder.layers.8.fc2.bias", "decoder.layers.8.final_layer_norm.weight", "decoder.layers.8.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.9.flat_param_0": {"names": ["decoder.layers.9.self_attn.qkv_proj.weight", "decoder.layers.9.self_attn.qkv_proj.bias", "decoder.layers.9.self_attn.out_proj.weight", "decoder.layers.9.self_attn.out_proj.bias", "decoder.layers.9.self_attn_layer_norm.weight", "decoder.layers.9.self_attn_layer_norm.bias", "decoder.layers.9.fc1.weight", "decoder.layers.9.fc1.bias", "decoder.layers.9.fc2.weight", "decoder.layers.9.fc2.bias", "decoder.layers.9.final_layer_norm.weight", "decoder.layers.9.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.10.flat_param_0": {"names": ["decoder.layers.10.self_attn.qkv_proj.weight", "decoder.layers.10.self_attn.qkv_proj.bias", "decoder.layers.10.self_attn.out_proj.weight", "decoder.layers.10.self_attn.out_proj.bias", "decoder.layers.10.self_attn_layer_norm.weight", "decoder.layers.10.self_attn_layer_norm.bias", "decoder.layers.10.fc1.weight", "decoder.layers.10.fc1.bias", "decoder.layers.10.fc2.weight", "decoder.layers.10.fc2.bias", "decoder.layers.10.final_layer_norm.weight", "decoder.layers.10.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.11.flat_param_0": {"names": ["decoder.layers.11.self_attn.qkv_proj.weight", "decoder.layers.11.self_attn.qkv_proj.bias", "decoder.layers.11.self_attn.out_proj.weight", "decoder.layers.11.self_attn.out_proj.bias", "decoder.layers.11.self_attn_layer_norm.weight", "decoder.layers.11.self_attn_layer_norm.bias", "decoder.layers.11.fc1.weight", "decoder.layers.11.fc1.bias", "decoder.layers.11.fc2.weight", "decoder.layers.11.fc2.bias", "decoder.layers.11.final_layer_norm.weight", "decoder.layers.11.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.12.flat_param_0": {"names": ["decoder.layers.12.self_attn.qkv_proj.weight", "decoder.layers.12.self_attn.qkv_proj.bias", "decoder.layers.12.self_attn.out_proj.weight", "decoder.layers.12.self_attn.out_proj.bias", "decoder.layers.12.self_attn_layer_norm.weight", "decoder.layers.12.self_attn_layer_norm.bias", "decoder.layers.12.fc1.weight", "decoder.layers.12.fc1.bias", "decoder.layers.12.fc2.weight", "decoder.layers.12.fc2.bias", "decoder.layers.12.final_layer_norm.weight", "decoder.layers.12.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.13.flat_param_0": {"names": ["decoder.layers.13.self_attn.qkv_proj.weight", "decoder.layers.13.self_attn.qkv_proj.bias", "decoder.layers.13.self_attn.out_proj.weight", "decoder.layers.13.self_attn.out_proj.bias", "decoder.layers.13.self_attn_layer_norm.weight", "decoder.layers.13.self_attn_layer_norm.bias", "decoder.layers.13.fc1.weight", "decoder.layers.13.fc1.bias", "decoder.layers.13.fc2.weight", "decoder.layers.13.fc2.bias", "decoder.layers.13.final_layer_norm.weight", "decoder.layers.13.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.14.flat_param_0": {"names": ["decoder.layers.14.self_attn.qkv_proj.weight", "decoder.layers.14.self_attn.qkv_proj.bias", "decoder.layers.14.self_attn.out_proj.weight", "decoder.layers.14.self_attn.out_proj.bias", "decoder.layers.14.self_attn_layer_norm.weight", "decoder.layers.14.self_attn_layer_norm.bias", "decoder.layers.14.fc1.weight", "decoder.layers.14.fc1.bias", "decoder.layers.14.fc2.weight", "decoder.layers.14.fc2.bias", "decoder.layers.14.final_layer_norm.weight", "decoder.layers.14.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.15.flat_param_0": {"names": ["decoder.layers.15.self_attn.qkv_proj.weight", "decoder.layers.15.self_attn.qkv_proj.bias", "decoder.layers.15.self_attn.out_proj.weight", "decoder.layers.15.self_attn.out_proj.bias", "decoder.layers.15.self_attn_layer_norm.weight", "decoder.layers.15.self_attn_layer_norm.bias", "decoder.layers.15.fc1.weight", "decoder.layers.15.fc1.bias", "decoder.layers.15.fc2.weight", "decoder.layers.15.fc2.bias", "decoder.layers.15.final_layer_norm.weight", "decoder.layers.15.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.16.flat_param_0": {"names": ["decoder.layers.16.self_attn.qkv_proj.weight", "decoder.layers.16.self_attn.qkv_proj.bias", "decoder.layers.16.self_attn.out_proj.weight", "decoder.layers.16.self_attn.out_proj.bias", "decoder.layers.16.self_attn_layer_norm.weight", "decoder.layers.16.self_attn_layer_norm.bias", "decoder.layers.16.fc1.weight", "decoder.layers.16.fc1.bias", "decoder.layers.16.fc2.weight", "decoder.layers.16.fc2.bias", "decoder.layers.16.final_layer_norm.weight", "decoder.layers.16.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.17.flat_param_0": {"names": ["decoder.layers.17.self_attn.qkv_proj.weight", "decoder.layers.17.self_attn.qkv_proj.bias", "decoder.layers.17.self_attn.out_proj.weight", "decoder.layers.17.self_attn.out_proj.bias", "decoder.layers.17.self_attn_layer_norm.weight", "decoder.layers.17.self_attn_layer_norm.bias", "decoder.layers.17.fc1.weight", "decoder.layers.17.fc1.bias", "decoder.layers.17.fc2.weight", "decoder.layers.17.fc2.bias", "decoder.layers.17.final_layer_norm.weight", "decoder.layers.17.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.18.flat_param_0": {"names": ["decoder.layers.18.self_attn.qkv_proj.weight", "decoder.layers.18.self_attn.qkv_proj.bias", "decoder.layers.18.self_attn.out_proj.weight", "decoder.layers.18.self_attn.out_proj.bias", "decoder.layers.18.self_attn_layer_norm.weight", "decoder.layers.18.self_attn_layer_norm.bias", "decoder.layers.18.fc1.weight", "decoder.layers.18.fc1.bias", "decoder.layers.18.fc2.weight", "decoder.layers.18.fc2.bias", "decoder.layers.18.final_layer_norm.weight", "decoder.layers.18.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.19.flat_param_0": {"names": ["decoder.layers.19.self_attn.qkv_proj.weight", "decoder.layers.19.self_attn.qkv_proj.bias", "decoder.layers.19.self_attn.out_proj.weight", "decoder.layers.19.self_attn.out_proj.bias", "decoder.layers.19.self_attn_layer_norm.weight", "decoder.layers.19.self_attn_layer_norm.bias", "decoder.layers.19.fc1.weight", "decoder.layers.19.fc1.bias", "decoder.layers.19.fc2.weight", "decoder.layers.19.fc2.bias", "decoder.layers.19.final_layer_norm.weight", "decoder.layers.19.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.20.flat_param_0": {"names": ["decoder.layers.20.self_attn.qkv_proj.weight", "decoder.layers.20.self_attn.qkv_proj.bias", "decoder.layers.20.self_attn.out_proj.weight", "decoder.layers.20.self_attn.out_proj.bias", "decoder.layers.20.self_attn_layer_norm.weight", "decoder.layers.20.self_attn_layer_norm.bias", "decoder.layers.20.fc1.weight", "decoder.layers.20.fc1.bias", "decoder.layers.20.fc2.weight", "decoder.layers.20.fc2.bias", "decoder.layers.20.final_layer_norm.weight", "decoder.layers.20.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.21.flat_param_0": {"names": ["decoder.layers.21.self_attn.qkv_proj.weight", "decoder.layers.21.self_attn.qkv_proj.bias", "decoder.layers.21.self_attn.out_proj.weight", "decoder.layers.21.self_attn.out_proj.bias", "decoder.layers.21.self_attn_layer_norm.weight", "decoder.layers.21.self_attn_layer_norm.bias", "decoder.layers.21.fc1.weight", "decoder.layers.21.fc1.bias", "decoder.layers.21.fc2.weight", "decoder.layers.21.fc2.bias", "decoder.layers.21.final_layer_norm.weight", "decoder.layers.21.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.22.flat_param_0": {"names": ["decoder.layers.22.self_attn.qkv_proj.weight", "decoder.layers.22.self_attn.qkv_proj.bias", "decoder.layers.22.self_attn.out_proj.weight", "decoder.layers.22.self_attn.out_proj.bias", "decoder.layers.22.self_attn_layer_norm.weight", "decoder.layers.22.self_attn_layer_norm.bias", "decoder.layers.22.fc1.weight", "decoder.layers.22.fc1.bias", "decoder.layers.22.fc2.weight", "decoder.layers.22.fc2.bias", "decoder.layers.22.final_layer_norm.weight", "decoder.layers.22.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.23.flat_param_0": {"names": ["decoder.layers.23.self_attn.qkv_proj.weight", "decoder.layers.23.self_attn.qkv_proj.bias", "decoder.layers.23.self_attn.out_proj.weight", "decoder.layers.23.self_attn.out_proj.bias", "decoder.layers.23.self_attn_layer_norm.weight", "decoder.layers.23.self_attn_layer_norm.bias", "decoder.layers.23.fc1.weight", "decoder.layers.23.fc1.bias", "decoder.layers.23.fc2.weight", "decoder.layers.23.fc2.bias", "decoder.layers.23.final_layer_norm.weight", "decoder.layers.23.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.24.flat_param_0": {"names": ["decoder.layers.24.self_attn.qkv_proj.weight", "decoder.layers.24.self_attn.qkv_proj.bias", "decoder.layers.24.self_attn.out_proj.weight", "decoder.layers.24.self_attn.out_proj.bias", "decoder.layers.24.self_attn_layer_norm.weight", "decoder.layers.24.self_attn_layer_norm.bias", "decoder.layers.24.fc1.weight", "decoder.layers.24.fc1.bias", "decoder.layers.24.fc2.weight", "decoder.layers.24.fc2.bias", "decoder.layers.24.final_layer_norm.weight", "decoder.layers.24.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.25.flat_param_0": {"names": ["decoder.layers.25.self_attn.qkv_proj.weight", "decoder.layers.25.self_attn.qkv_proj.bias", "decoder.layers.25.self_attn.out_proj.weight", "decoder.layers.25.self_attn.out_proj.bias", "decoder.layers.25.self_attn_layer_norm.weight", "decoder.layers.25.self_attn_layer_norm.bias", "decoder.layers.25.fc1.weight", "decoder.layers.25.fc1.bias", "decoder.layers.25.fc2.weight", "decoder.layers.25.fc2.bias", "decoder.layers.25.final_layer_norm.weight", "decoder.layers.25.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.26.flat_param_0": {"names": ["decoder.layers.26.self_attn.qkv_proj.weight", "decoder.layers.26.self_attn.qkv_proj.bias", "decoder.layers.26.self_attn.out_proj.weight", "decoder.layers.26.self_attn.out_proj.bias", "decoder.layers.26.self_attn_layer_norm.weight", "decoder.layers.26.self_attn_layer_norm.bias", "decoder.layers.26.fc1.weight", "decoder.layers.26.fc1.bias", "decoder.layers.26.fc2.weight", "decoder.layers.26.fc2.bias", "decoder.layers.26.final_layer_norm.weight", "decoder.layers.26.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.27.flat_param_0": {"names": ["decoder.layers.27.self_attn.qkv_proj.weight", "decoder.layers.27.self_attn.qkv_proj.bias", "decoder.layers.27.self_attn.out_proj.weight", "decoder.layers.27.self_attn.out_proj.bias", "decoder.layers.27.self_attn_layer_norm.weight", "decoder.layers.27.self_attn_layer_norm.bias", "decoder.layers.27.fc1.weight", "decoder.layers.27.fc1.bias", "decoder.layers.27.fc2.weight", "decoder.layers.27.fc2.bias", "decoder.layers.27.final_layer_norm.weight", "decoder.layers.27.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.28.flat_param_0": {"names": ["decoder.layers.28.self_attn.qkv_proj.weight", "decoder.layers.28.self_attn.qkv_proj.bias", "decoder.layers.28.self_attn.out_proj.weight", "decoder.layers.28.self_attn.out_proj.bias", "decoder.layers.28.self_attn_layer_norm.weight", "decoder.layers.28.self_attn_layer_norm.bias", "decoder.layers.28.fc1.weight", "decoder.layers.28.fc1.bias", "decoder.layers.28.fc2.weight", "decoder.layers.28.fc2.bias", "decoder.layers.28.final_layer_norm.weight", "decoder.layers.28.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.29.flat_param_0": {"names": ["decoder.layers.29.self_attn.qkv_proj.weight", "decoder.layers.29.self_attn.qkv_proj.bias", "decoder.layers.29.self_attn.out_proj.weight", "decoder.layers.29.self_attn.out_proj.bias", "decoder.layers.29.self_attn_layer_norm.weight", "decoder.layers.29.self_attn_layer_norm.bias", "decoder.layers.29.fc1.weight", "decoder.layers.29.fc1.bias", "decoder.layers.29.fc2.weight", "decoder.layers.29.fc2.bias", "decoder.layers.29.final_layer_norm.weight", "decoder.layers.29.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.30.flat_param_0": {"names": ["decoder.layers.30.self_attn.qkv_proj.weight", "decoder.layers.30.self_attn.qkv_proj.bias", "decoder.layers.30.self_attn.out_proj.weight", "decoder.layers.30.self_attn.out_proj.bias", "decoder.layers.30.self_attn_layer_norm.weight", "decoder.layers.30.self_attn_layer_norm.bias", "decoder.layers.30.fc1.weight", "decoder.layers.30.fc1.bias", "decoder.layers.30.fc2.weight", "decoder.layers.30.fc2.bias", "decoder.layers.30.final_layer_norm.weight", "decoder.layers.30.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.31.flat_param_0": {"names": ["decoder.layers.31.self_attn.qkv_proj.weight", "decoder.layers.31.self_attn.qkv_proj.bias", "decoder.layers.31.self_attn.out_proj.weight", "decoder.layers.31.self_attn.out_proj.bias", "decoder.layers.31.self_attn_layer_norm.weight", "decoder.layers.31.self_attn_layer_norm.bias", "decoder.layers.31.fc1.weight", "decoder.layers.31.fc1.bias", "decoder.layers.31.fc2.weight", "decoder.layers.31.fc2.bias", "decoder.layers.31.final_layer_norm.weight", "decoder.layers.31.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.32.flat_param_0": {"names": ["decoder.layers.32.self_attn.qkv_proj.weight", "decoder.layers.32.self_attn.qkv_proj.bias", "decoder.layers.32.self_attn.out_proj.weight", "decoder.layers.32.self_attn.out_proj.bias", "decoder.layers.32.self_attn_layer_norm.weight", "decoder.layers.32.self_attn_layer_norm.bias", "decoder.layers.32.fc1.weight", "decoder.layers.32.fc1.bias", "decoder.layers.32.fc2.weight", "decoder.layers.32.fc2.bias", "decoder.layers.32.final_layer_norm.weight", "decoder.layers.32.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.33.flat_param_0": {"names": ["decoder.layers.33.self_attn.qkv_proj.weight", "decoder.layers.33.self_attn.qkv_proj.bias", "decoder.layers.33.self_attn.out_proj.weight", "decoder.layers.33.self_attn.out_proj.bias", "decoder.layers.33.self_attn_layer_norm.weight", "decoder.layers.33.self_attn_layer_norm.bias", "decoder.layers.33.fc1.weight", "decoder.layers.33.fc1.bias", "decoder.layers.33.fc2.weight", "decoder.layers.33.fc2.bias", "decoder.layers.33.final_layer_norm.weight", "decoder.layers.33.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.34.flat_param_0": {"names": ["decoder.layers.34.self_attn.qkv_proj.weight", "decoder.layers.34.self_attn.qkv_proj.bias", "decoder.layers.34.self_attn.out_proj.weight", "decoder.layers.34.self_attn.out_proj.bias", "decoder.layers.34.self_attn_layer_norm.weight", "decoder.layers.34.self_attn_layer_norm.bias", "decoder.layers.34.fc1.weight", "decoder.layers.34.fc1.bias", "decoder.layers.34.fc2.weight", "decoder.layers.34.fc2.bias", "decoder.layers.34.final_layer_norm.weight", "decoder.layers.34.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.35.flat_param_0": {"names": ["decoder.layers.35.self_attn.qkv_proj.weight", "decoder.layers.35.self_attn.qkv_proj.bias", "decoder.layers.35.self_attn.out_proj.weight", "decoder.layers.35.self_attn.out_proj.bias", "decoder.layers.35.self_attn_layer_norm.weight", "decoder.layers.35.self_attn_layer_norm.bias", "decoder.layers.35.fc1.weight", "decoder.layers.35.fc1.bias", "decoder.layers.35.fc2.weight", "decoder.layers.35.fc2.bias", "decoder.layers.35.final_layer_norm.weight", "decoder.layers.35.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.36.flat_param_0": {"names": ["decoder.layers.36.self_attn.qkv_proj.weight", "decoder.layers.36.self_attn.qkv_proj.bias", "decoder.layers.36.self_attn.out_proj.weight", "decoder.layers.36.self_attn.out_proj.bias", "decoder.layers.36.self_attn_layer_norm.weight", "decoder.layers.36.self_attn_layer_norm.bias", "decoder.layers.36.fc1.weight", "decoder.layers.36.fc1.bias", "decoder.layers.36.fc2.weight", "decoder.layers.36.fc2.bias", "decoder.layers.36.final_layer_norm.weight", "decoder.layers.36.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.37.flat_param_0": {"names": ["decoder.layers.37.self_attn.qkv_proj.weight", "decoder.layers.37.self_attn.qkv_proj.bias", "decoder.layers.37.self_attn.out_proj.weight", "decoder.layers.37.self_attn.out_proj.bias", "decoder.layers.37.self_attn_layer_norm.weight", "decoder.layers.37.self_attn_layer_norm.bias", "decoder.layers.37.fc1.weight", "decoder.layers.37.fc1.bias", "decoder.layers.37.fc2.weight", "decoder.layers.37.fc2.bias", "decoder.layers.37.final_layer_norm.weight", "decoder.layers.37.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.38.flat_param_0": {"names": ["decoder.layers.38.self_attn.qkv_proj.weight", "decoder.layers.38.self_attn.qkv_proj.bias", "decoder.layers.38.self_attn.out_proj.weight", "decoder.layers.38.self_attn.out_proj.bias", "decoder.layers.38.self_attn_layer_norm.weight", "decoder.layers.38.self_attn_layer_norm.bias", "decoder.layers.38.fc1.weight", "decoder.layers.38.fc1.bias", "decoder.layers.38.fc2.weight", "decoder.layers.38.fc2.bias", "decoder.layers.38.final_layer_norm.weight", "decoder.layers.38.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.39.flat_param_0": {"names": ["decoder.layers.39.self_attn.qkv_proj.weight", "decoder.layers.39.self_attn.qkv_proj.bias", "decoder.layers.39.self_attn.out_proj.weight", "decoder.layers.39.self_attn.out_proj.bias", "decoder.layers.39.self_attn_layer_norm.weight", "decoder.layers.39.self_attn_layer_norm.bias", "decoder.layers.39.fc1.weight", "decoder.layers.39.fc1.bias", "decoder.layers.39.fc2.weight", "decoder.layers.39.fc2.bias", "decoder.layers.39.final_layer_norm.weight", "decoder.layers.39.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.40.flat_param_0": {"names": ["decoder.layers.40.self_attn.qkv_proj.weight", "decoder.layers.40.self_attn.qkv_proj.bias", "decoder.layers.40.self_attn.out_proj.weight", "decoder.layers.40.self_attn.out_proj.bias", "decoder.layers.40.self_attn_layer_norm.weight", "decoder.layers.40.self_attn_layer_norm.bias", "decoder.layers.40.fc1.weight", "decoder.layers.40.fc1.bias", "decoder.layers.40.fc2.weight", "decoder.layers.40.fc2.bias", "decoder.layers.40.final_layer_norm.weight", "decoder.layers.40.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.41.flat_param_0": {"names": ["decoder.layers.41.self_attn.qkv_proj.weight", "decoder.layers.41.self_attn.qkv_proj.bias", "decoder.layers.41.self_attn.out_proj.weight", "decoder.layers.41.self_attn.out_proj.bias", "decoder.layers.41.self_attn_layer_norm.weight", "decoder.layers.41.self_attn_layer_norm.bias", "decoder.layers.41.fc1.weight", "decoder.layers.41.fc1.bias", "decoder.layers.41.fc2.weight", "decoder.layers.41.fc2.bias", "decoder.layers.41.final_layer_norm.weight", "decoder.layers.41.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.42.flat_param_0": {"names": ["decoder.layers.42.self_attn.qkv_proj.weight", "decoder.layers.42.self_attn.qkv_proj.bias", "decoder.layers.42.self_attn.out_proj.weight", "decoder.layers.42.self_attn.out_proj.bias", "decoder.layers.42.self_attn_layer_norm.weight", "decoder.layers.42.self_attn_layer_norm.bias", "decoder.layers.42.fc1.weight", "decoder.layers.42.fc1.bias", "decoder.layers.42.fc2.weight", "decoder.layers.42.fc2.bias", "decoder.layers.42.final_layer_norm.weight", "decoder.layers.42.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.43.flat_param_0": {"names": ["decoder.layers.43.self_attn.qkv_proj.weight", "decoder.layers.43.self_attn.qkv_proj.bias", "decoder.layers.43.self_attn.out_proj.weight", "decoder.layers.43.self_attn.out_proj.bias", "decoder.layers.43.self_attn_layer_norm.weight", "decoder.layers.43.self_attn_layer_norm.bias", "decoder.layers.43.fc1.weight", "decoder.layers.43.fc1.bias", "decoder.layers.43.fc2.weight", "decoder.layers.43.fc2.bias", "decoder.layers.43.final_layer_norm.weight", "decoder.layers.43.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.44.flat_param_0": {"names": ["decoder.layers.44.self_attn.qkv_proj.weight", "decoder.layers.44.self_attn.qkv_proj.bias", "decoder.layers.44.self_attn.out_proj.weight", "decoder.layers.44.self_attn.out_proj.bias", "decoder.layers.44.self_attn_layer_norm.weight", "decoder.layers.44.self_attn_layer_norm.bias", "decoder.layers.44.fc1.weight", "decoder.layers.44.fc1.bias", "decoder.layers.44.fc2.weight", "decoder.layers.44.fc2.bias", "decoder.layers.44.final_layer_norm.weight", "decoder.layers.44.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.45.flat_param_0": {"names": ["decoder.layers.45.self_attn.qkv_proj.weight", "decoder.layers.45.self_attn.qkv_proj.bias", "decoder.layers.45.self_attn.out_proj.weight", "decoder.layers.45.self_attn.out_proj.bias", "decoder.layers.45.self_attn_layer_norm.weight", "decoder.layers.45.self_attn_layer_norm.bias", "decoder.layers.45.fc1.weight", "decoder.layers.45.fc1.bias", "decoder.layers.45.fc2.weight", "decoder.layers.45.fc2.bias", "decoder.layers.45.final_layer_norm.weight", "decoder.layers.45.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.46.flat_param_0": {"names": ["decoder.layers.46.self_attn.qkv_proj.weight", "decoder.layers.46.self_attn.qkv_proj.bias", "decoder.layers.46.self_attn.out_proj.weight", "decoder.layers.46.self_attn.out_proj.bias", "decoder.layers.46.self_attn_layer_norm.weight", "decoder.layers.46.self_attn_layer_norm.bias", "decoder.layers.46.fc1.weight", "decoder.layers.46.fc1.bias", "decoder.layers.46.fc2.weight", "decoder.layers.46.fc2.bias", "decoder.layers.46.final_layer_norm.weight", "decoder.layers.46.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.47.flat_param_0": {"names": ["decoder.layers.47.self_attn.qkv_proj.weight", "decoder.layers.47.self_attn.qkv_proj.bias", "decoder.layers.47.self_attn.out_proj.weight", "decoder.layers.47.self_attn.out_proj.bias", "decoder.layers.47.self_attn_layer_norm.weight", "decoder.layers.47.self_attn_layer_norm.bias", "decoder.layers.47.fc1.weight", "decoder.layers.47.fc1.bias", "decoder.layers.47.fc2.weight", "decoder.layers.47.fc2.bias", "decoder.layers.47.final_layer_norm.weight", "decoder.layers.47.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.48.flat_param_0": {"names": ["decoder.layers.48.self_attn.qkv_proj.weight", "decoder.layers.48.self_attn.qkv_proj.bias", "decoder.layers.48.self_attn.out_proj.weight", "decoder.layers.48.self_attn.out_proj.bias", "decoder.layers.48.self_attn_layer_norm.weight", "decoder.layers.48.self_attn_layer_norm.bias", "decoder.layers.48.fc1.weight", "decoder.layers.48.fc1.bias", "decoder.layers.48.fc2.weight", "decoder.layers.48.fc2.bias", "decoder.layers.48.final_layer_norm.weight", "decoder.layers.48.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.49.flat_param_0": {"names": ["decoder.layers.49.self_attn.qkv_proj.weight", "decoder.layers.49.self_attn.qkv_proj.bias", "decoder.layers.49.self_attn.out_proj.weight", "decoder.layers.49.self_attn.out_proj.bias", "decoder.layers.49.self_attn_layer_norm.weight", "decoder.layers.49.self_attn_layer_norm.bias", "decoder.layers.49.fc1.weight", "decoder.layers.49.fc1.bias", "decoder.layers.49.fc2.weight", "decoder.layers.49.fc2.bias", "decoder.layers.49.final_layer_norm.weight", "decoder.layers.49.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.50.flat_param_0": {"names": ["decoder.layers.50.self_attn.qkv_proj.weight", "decoder.layers.50.self_attn.qkv_proj.bias", "decoder.layers.50.self_attn.out_proj.weight", "decoder.layers.50.self_attn.out_proj.bias", "decoder.layers.50.self_attn_layer_norm.weight", "decoder.layers.50.self_attn_layer_norm.bias", "decoder.layers.50.fc1.weight", "decoder.layers.50.fc1.bias", "decoder.layers.50.fc2.weight", "decoder.layers.50.fc2.bias", "decoder.layers.50.final_layer_norm.weight", "decoder.layers.50.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.51.flat_param_0": {"names": ["decoder.layers.51.self_attn.qkv_proj.weight", "decoder.layers.51.self_attn.qkv_proj.bias", "decoder.layers.51.self_attn.out_proj.weight", "decoder.layers.51.self_attn.out_proj.bias", "decoder.layers.51.self_attn_layer_norm.weight", "decoder.layers.51.self_attn_layer_norm.bias", "decoder.layers.51.fc1.weight", "decoder.layers.51.fc1.bias", "decoder.layers.51.fc2.weight", "decoder.layers.51.fc2.bias", "decoder.layers.51.final_layer_norm.weight", "decoder.layers.51.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.52.flat_param_0": {"names": ["decoder.layers.52.self_attn.qkv_proj.weight", "decoder.layers.52.self_attn.qkv_proj.bias", "decoder.layers.52.self_attn.out_proj.weight", "decoder.layers.52.self_attn.out_proj.bias", "decoder.layers.52.self_attn_layer_norm.weight", "decoder.layers.52.self_attn_layer_norm.bias", "decoder.layers.52.fc1.weight", "decoder.layers.52.fc1.bias", "decoder.layers.52.fc2.weight", "decoder.layers.52.fc2.bias", "decoder.layers.52.final_layer_norm.weight", "decoder.layers.52.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.53.flat_param_0": {"names": ["decoder.layers.53.self_attn.qkv_proj.weight", "decoder.layers.53.self_attn.qkv_proj.bias", "decoder.layers.53.self_attn.out_proj.weight", "decoder.layers.53.self_attn.out_proj.bias", "decoder.layers.53.self_attn_layer_norm.weight", "decoder.layers.53.self_attn_layer_norm.bias", "decoder.layers.53.fc1.weight", "decoder.layers.53.fc1.bias", "decoder.layers.53.fc2.weight", "decoder.layers.53.fc2.bias", "decoder.layers.53.final_layer_norm.weight", "decoder.layers.53.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.54.flat_param_0": {"names": ["decoder.layers.54.self_attn.qkv_proj.weight", "decoder.layers.54.self_attn.qkv_proj.bias", "decoder.layers.54.self_attn.out_proj.weight", "decoder.layers.54.self_attn.out_proj.bias", "decoder.layers.54.self_attn_layer_norm.weight", "decoder.layers.54.self_attn_layer_norm.bias", "decoder.layers.54.fc1.weight", "decoder.layers.54.fc1.bias", "decoder.layers.54.fc2.weight", "decoder.layers.54.fc2.bias", "decoder.layers.54.final_layer_norm.weight", "decoder.layers.54.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.55.flat_param_0": {"names": ["decoder.layers.55.self_attn.qkv_proj.weight", "decoder.layers.55.self_attn.qkv_proj.bias", "decoder.layers.55.self_attn.out_proj.weight", "decoder.layers.55.self_attn.out_proj.bias", "decoder.layers.55.self_attn_layer_norm.weight", "decoder.layers.55.self_attn_layer_norm.bias", "decoder.layers.55.fc1.weight", "decoder.layers.55.fc1.bias", "decoder.layers.55.fc2.weight", "decoder.layers.55.fc2.bias", "decoder.layers.55.final_layer_norm.weight", "decoder.layers.55.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.56.flat_param_0": {"names": ["decoder.layers.56.self_attn.qkv_proj.weight", "decoder.layers.56.self_attn.qkv_proj.bias", "decoder.layers.56.self_attn.out_proj.weight", "decoder.layers.56.self_attn.out_proj.bias", "decoder.layers.56.self_attn_layer_norm.weight", "decoder.layers.56.self_attn_layer_norm.bias", "decoder.layers.56.fc1.weight", "decoder.layers.56.fc1.bias", "decoder.layers.56.fc2.weight", "decoder.layers.56.fc2.bias", "decoder.layers.56.final_layer_norm.weight", "decoder.layers.56.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.57.flat_param_0": {"names": ["decoder.layers.57.self_attn.qkv_proj.weight", "decoder.layers.57.self_attn.qkv_proj.bias", "decoder.layers.57.self_attn.out_proj.weight", "decoder.layers.57.self_attn.out_proj.bias", "decoder.layers.57.self_attn_layer_norm.weight", "decoder.layers.57.self_attn_layer_norm.bias", "decoder.layers.57.fc1.weight", "decoder.layers.57.fc1.bias", "decoder.layers.57.fc2.weight", "decoder.layers.57.fc2.bias", "decoder.layers.57.final_layer_norm.weight", "decoder.layers.57.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.58.flat_param_0": {"names": ["decoder.layers.58.self_attn.qkv_proj.weight", "decoder.layers.58.self_attn.qkv_proj.bias", "decoder.layers.58.self_attn.out_proj.weight", "decoder.layers.58.self_attn.out_proj.bias", "decoder.layers.58.self_attn_layer_norm.weight", "decoder.layers.58.self_attn_layer_norm.bias", "decoder.layers.58.fc1.weight", "decoder.layers.58.fc1.bias", "decoder.layers.58.fc2.weight", "decoder.layers.58.fc2.bias", "decoder.layers.58.final_layer_norm.weight", "decoder.layers.58.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.59.flat_param_0": {"names": ["decoder.layers.59.self_attn.qkv_proj.weight", "decoder.layers.59.self_attn.qkv_proj.bias", "decoder.layers.59.self_attn.out_proj.weight", "decoder.layers.59.self_attn.out_proj.bias", "decoder.layers.59.self_attn_layer_norm.weight", "decoder.layers.59.self_attn_layer_norm.bias", "decoder.layers.59.fc1.weight", "decoder.layers.59.fc1.bias", "decoder.layers.59.fc2.weight", "decoder.layers.59.fc2.bias", "decoder.layers.59.final_layer_norm.weight", "decoder.layers.59.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.60.flat_param_0": {"names": ["decoder.layers.60.self_attn.qkv_proj.weight", "decoder.layers.60.self_attn.qkv_proj.bias", "decoder.layers.60.self_attn.out_proj.weight", "decoder.layers.60.self_attn.out_proj.bias", "decoder.layers.60.self_attn_layer_norm.weight", "decoder.layers.60.self_attn_layer_norm.bias", "decoder.layers.60.fc1.weight", "decoder.layers.60.fc1.bias", "decoder.layers.60.fc2.weight", "decoder.layers.60.fc2.bias", "decoder.layers.60.final_layer_norm.weight", "decoder.layers.60.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.61.flat_param_0": {"names": ["decoder.layers.61.self_attn.qkv_proj.weight", "decoder.layers.61.self_attn.qkv_proj.bias", "decoder.layers.61.self_attn.out_proj.weight", "decoder.layers.61.self_attn.out_proj.bias", "decoder.layers.61.self_attn_layer_norm.weight", "decoder.layers.61.self_attn_layer_norm.bias", "decoder.layers.61.fc1.weight", "decoder.layers.61.fc1.bias", "decoder.layers.61.fc2.weight", "decoder.layers.61.fc2.bias", "decoder.layers.61.final_layer_norm.weight", "decoder.layers.61.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.62.flat_param_0": {"names": ["decoder.layers.62.self_attn.qkv_proj.weight", "decoder.layers.62.self_attn.qkv_proj.bias", "decoder.layers.62.self_attn.out_proj.weight", "decoder.layers.62.self_attn.out_proj.bias", "decoder.layers.62.self_attn_layer_norm.weight", "decoder.layers.62.self_attn_layer_norm.bias", "decoder.layers.62.fc1.weight", "decoder.layers.62.fc1.bias", "decoder.layers.62.fc2.weight", "decoder.layers.62.fc2.bias", "decoder.layers.62.final_layer_norm.weight", "decoder.layers.62.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.63.flat_param_0": {"names": ["decoder.layers.63.self_attn.qkv_proj.weight", "decoder.layers.63.self_attn.qkv_proj.bias", "decoder.layers.63.self_attn.out_proj.weight", "decoder.layers.63.self_attn.out_proj.bias", "decoder.layers.63.self_attn_layer_norm.weight", "decoder.layers.63.self_attn_layer_norm.bias", "decoder.layers.63.fc1.weight", "decoder.layers.63.fc1.bias", "decoder.layers.63.fc2.weight", "decoder.layers.63.fc2.bias", "decoder.layers.63.final_layer_norm.weight", "decoder.layers.63.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.64.flat_param_0": {"names": ["decoder.layers.64.self_attn.qkv_proj.weight", "decoder.layers.64.self_attn.qkv_proj.bias", "decoder.layers.64.self_attn.out_proj.weight", "decoder.layers.64.self_attn.out_proj.bias", "decoder.layers.64.self_attn_layer_norm.weight", "decoder.layers.64.self_attn_layer_norm.bias", "decoder.layers.64.fc1.weight", "decoder.layers.64.fc1.bias", "decoder.layers.64.fc2.weight", "decoder.layers.64.fc2.bias", "decoder.layers.64.final_layer_norm.weight", "decoder.layers.64.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.65.flat_param_0": {"names": ["decoder.layers.65.self_attn.qkv_proj.weight", "decoder.layers.65.self_attn.qkv_proj.bias", "decoder.layers.65.self_attn.out_proj.weight", "decoder.layers.65.self_attn.out_proj.bias", "decoder.layers.65.self_attn_layer_norm.weight", "decoder.layers.65.self_attn_layer_norm.bias", "decoder.layers.65.fc1.weight", "decoder.layers.65.fc1.bias", "decoder.layers.65.fc2.weight", "decoder.layers.65.fc2.bias", "decoder.layers.65.final_layer_norm.weight", "decoder.layers.65.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.66.flat_param_0": {"names": ["decoder.layers.66.self_attn.qkv_proj.weight", "decoder.layers.66.self_attn.qkv_proj.bias", "decoder.layers.66.self_attn.out_proj.weight", "decoder.layers.66.self_attn.out_proj.bias", "decoder.layers.66.self_attn_layer_norm.weight", "decoder.layers.66.self_attn_layer_norm.bias", "decoder.layers.66.fc1.weight", "decoder.layers.66.fc1.bias", "decoder.layers.66.fc2.weight", "decoder.layers.66.fc2.bias", "decoder.layers.66.final_layer_norm.weight", "decoder.layers.66.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.67.flat_param_0": {"names": ["decoder.layers.67.self_attn.qkv_proj.weight", "decoder.layers.67.self_attn.qkv_proj.bias", "decoder.layers.67.self_attn.out_proj.weight", "decoder.layers.67.self_attn.out_proj.bias", "decoder.layers.67.self_attn_layer_norm.weight", "decoder.layers.67.self_attn_layer_norm.bias", "decoder.layers.67.fc1.weight", "decoder.layers.67.fc1.bias", "decoder.layers.67.fc2.weight", "decoder.layers.67.fc2.bias", "decoder.layers.67.final_layer_norm.weight", "decoder.layers.67.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.68.flat_param_0": {"names": ["decoder.layers.68.self_attn.qkv_proj.weight", "decoder.layers.68.self_attn.qkv_proj.bias", "decoder.layers.68.self_attn.out_proj.weight", "decoder.layers.68.self_attn.out_proj.bias", "decoder.layers.68.self_attn_layer_norm.weight", "decoder.layers.68.self_attn_layer_norm.bias", "decoder.layers.68.fc1.weight", "decoder.layers.68.fc1.bias", "decoder.layers.68.fc2.weight", "decoder.layers.68.fc2.bias", "decoder.layers.68.final_layer_norm.weight", "decoder.layers.68.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.69.flat_param_0": {"names": ["decoder.layers.69.self_attn.qkv_proj.weight", "decoder.layers.69.self_attn.qkv_proj.bias", "decoder.layers.69.self_attn.out_proj.weight", "decoder.layers.69.self_attn.out_proj.bias", "decoder.layers.69.self_attn_layer_norm.weight", "decoder.layers.69.self_attn_layer_norm.bias", "decoder.layers.69.fc1.weight", "decoder.layers.69.fc1.bias", "decoder.layers.69.fc2.weight", "decoder.layers.69.fc2.bias", "decoder.layers.69.final_layer_norm.weight", "decoder.layers.69.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.70.flat_param_0": {"names": ["decoder.layers.70.self_attn.qkv_proj.weight", "decoder.layers.70.self_attn.qkv_proj.bias", "decoder.layers.70.self_attn.out_proj.weight", "decoder.layers.70.self_attn.out_proj.bias", "decoder.layers.70.self_attn_layer_norm.weight", "decoder.layers.70.self_attn_layer_norm.bias", "decoder.layers.70.fc1.weight", "decoder.layers.70.fc1.bias", "decoder.layers.70.fc2.weight", "decoder.layers.70.fc2.bias", "decoder.layers.70.final_layer_norm.weight", "decoder.layers.70.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.71.flat_param_0": {"names": ["decoder.layers.71.self_attn.qkv_proj.weight", "decoder.layers.71.self_attn.qkv_proj.bias", "decoder.layers.71.self_attn.out_proj.weight", "decoder.layers.71.self_attn.out_proj.bias", "decoder.layers.71.self_attn_layer_norm.weight", "decoder.layers.71.self_attn_layer_norm.bias", "decoder.layers.71.fc1.weight", "decoder.layers.71.fc1.bias", "decoder.layers.71.fc2.weight", "decoder.layers.71.fc2.bias", "decoder.layers.71.final_layer_norm.weight", "decoder.layers.71.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.72.flat_param_0": {"names": ["decoder.layers.72.self_attn.qkv_proj.weight", "decoder.layers.72.self_attn.qkv_proj.bias", "decoder.layers.72.self_attn.out_proj.weight", "decoder.layers.72.self_attn.out_proj.bias", "decoder.layers.72.self_attn_layer_norm.weight", "decoder.layers.72.self_attn_layer_norm.bias", "decoder.layers.72.fc1.weight", "decoder.layers.72.fc1.bias", "decoder.layers.72.fc2.weight", "decoder.layers.72.fc2.bias", "decoder.layers.72.final_layer_norm.weight", "decoder.layers.72.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.73.flat_param_0": {"names": ["decoder.layers.73.self_attn.qkv_proj.weight", "decoder.layers.73.self_attn.qkv_proj.bias", "decoder.layers.73.self_attn.out_proj.weight", "decoder.layers.73.self_attn.out_proj.bias", "decoder.layers.73.self_attn_layer_norm.weight", "decoder.layers.73.self_attn_layer_norm.bias", "decoder.layers.73.fc1.weight", "decoder.layers.73.fc1.bias", "decoder.layers.73.fc2.weight", "decoder.layers.73.fc2.bias", "decoder.layers.73.final_layer_norm.weight", "decoder.layers.73.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.74.flat_param_0": {"names": ["decoder.layers.74.self_attn.qkv_proj.weight", "decoder.layers.74.self_attn.qkv_proj.bias", "decoder.layers.74.self_attn.out_proj.weight", "decoder.layers.74.self_attn.out_proj.bias", "decoder.layers.74.self_attn_layer_norm.weight", "decoder.layers.74.self_attn_layer_norm.bias", "decoder.layers.74.fc1.weight", "decoder.layers.74.fc1.bias", "decoder.layers.74.fc2.weight", "decoder.layers.74.fc2.bias", "decoder.layers.74.final_layer_norm.weight", "decoder.layers.74.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.75.flat_param_0": {"names": ["decoder.layers.75.self_attn.qkv_proj.weight", "decoder.layers.75.self_attn.qkv_proj.bias", "decoder.layers.75.self_attn.out_proj.weight", "decoder.layers.75.self_attn.out_proj.bias", "decoder.layers.75.self_attn_layer_norm.weight", "decoder.layers.75.self_attn_layer_norm.bias", "decoder.layers.75.fc1.weight", "decoder.layers.75.fc1.bias", "decoder.layers.75.fc2.weight", "decoder.layers.75.fc2.bias", "decoder.layers.75.final_layer_norm.weight", "decoder.layers.75.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.76.flat_param_0": {"names": ["decoder.layers.76.self_attn.qkv_proj.weight", "decoder.layers.76.self_attn.qkv_proj.bias", "decoder.layers.76.self_attn.out_proj.weight", "decoder.layers.76.self_attn.out_proj.bias", "decoder.layers.76.self_attn_layer_norm.weight", "decoder.layers.76.self_attn_layer_norm.bias", "decoder.layers.76.fc1.weight", "decoder.layers.76.fc1.bias", "decoder.layers.76.fc2.weight", "decoder.layers.76.fc2.bias", "decoder.layers.76.final_layer_norm.weight", "decoder.layers.76.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.77.flat_param_0": {"names": ["decoder.layers.77.self_attn.qkv_proj.weight", "decoder.layers.77.self_attn.qkv_proj.bias", "decoder.layers.77.self_attn.out_proj.weight", "decoder.layers.77.self_attn.out_proj.bias", "decoder.layers.77.self_attn_layer_norm.weight", "decoder.layers.77.self_attn_layer_norm.bias", "decoder.layers.77.fc1.weight", "decoder.layers.77.fc1.bias", "decoder.layers.77.fc2.weight", "decoder.layers.77.fc2.bias", "decoder.layers.77.final_layer_norm.weight", "decoder.layers.77.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.78.flat_param_0": {"names": ["decoder.layers.78.self_attn.qkv_proj.weight", "decoder.layers.78.self_attn.qkv_proj.bias", "decoder.layers.78.self_attn.out_proj.weight", "decoder.layers.78.self_attn.out_proj.bias", "decoder.layers.78.self_attn_layer_norm.weight", "decoder.layers.78.self_attn_layer_norm.bias", "decoder.layers.78.fc1.weight", "decoder.layers.78.fc1.bias", "decoder.layers.78.fc2.weight", "decoder.layers.78.fc2.bias", "decoder.layers.78.final_layer_norm.weight", "decoder.layers.78.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.79.flat_param_0": {"names": ["decoder.layers.79.self_attn.qkv_proj.weight", "decoder.layers.79.self_attn.qkv_proj.bias", "decoder.layers.79.self_attn.out_proj.weight", "decoder.layers.79.self_attn.out_proj.bias", "decoder.layers.79.self_attn_layer_norm.weight", "decoder.layers.79.self_attn_layer_norm.bias", "decoder.layers.79.fc1.weight", "decoder.layers.79.fc1.bias", "decoder.layers.79.fc2.weight", "decoder.layers.79.fc2.bias", "decoder.layers.79.final_layer_norm.weight", "decoder.layers.79.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.80.flat_param_0": {"names": ["decoder.layers.80.self_attn.qkv_proj.weight", "decoder.layers.80.self_attn.qkv_proj.bias", "decoder.layers.80.self_attn.out_proj.weight", "decoder.layers.80.self_attn.out_proj.bias", "decoder.layers.80.self_attn_layer_norm.weight", "decoder.layers.80.self_attn_layer_norm.bias", "decoder.layers.80.fc1.weight", "decoder.layers.80.fc1.bias", "decoder.layers.80.fc2.weight", "decoder.layers.80.fc2.bias", "decoder.layers.80.final_layer_norm.weight", "decoder.layers.80.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.81.flat_param_0": {"names": ["decoder.layers.81.self_attn.qkv_proj.weight", "decoder.layers.81.self_attn.qkv_proj.bias", "decoder.layers.81.self_attn.out_proj.weight", "decoder.layers.81.self_attn.out_proj.bias", "decoder.layers.81.self_attn_layer_norm.weight", "decoder.layers.81.self_attn_layer_norm.bias", "decoder.layers.81.fc1.weight", "decoder.layers.81.fc1.bias", "decoder.layers.81.fc2.weight", "decoder.layers.81.fc2.bias", "decoder.layers.81.final_layer_norm.weight", "decoder.layers.81.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.82.flat_param_0": {"names": ["decoder.layers.82.self_attn.qkv_proj.weight", "decoder.layers.82.self_attn.qkv_proj.bias", "decoder.layers.82.self_attn.out_proj.weight", "decoder.layers.82.self_attn.out_proj.bias", "decoder.layers.82.self_attn_layer_norm.weight", "decoder.layers.82.self_attn_layer_norm.bias", "decoder.layers.82.fc1.weight", "decoder.layers.82.fc1.bias", "decoder.layers.82.fc2.weight", "decoder.layers.82.fc2.bias", "decoder.layers.82.final_layer_norm.weight", "decoder.layers.82.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.83.flat_param_0": {"names": ["decoder.layers.83.self_attn.qkv_proj.weight", "decoder.layers.83.self_attn.qkv_proj.bias", "decoder.layers.83.self_attn.out_proj.weight", "decoder.layers.83.self_attn.out_proj.bias", "decoder.layers.83.self_attn_layer_norm.weight", "decoder.layers.83.self_attn_layer_norm.bias", "decoder.layers.83.fc1.weight", "decoder.layers.83.fc1.bias", "decoder.layers.83.fc2.weight", "decoder.layers.83.fc2.bias", "decoder.layers.83.final_layer_norm.weight", "decoder.layers.83.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.84.flat_param_0": {"names": ["decoder.layers.84.self_attn.qkv_proj.weight", "decoder.layers.84.self_attn.qkv_proj.bias", "decoder.layers.84.self_attn.out_proj.weight", "decoder.layers.84.self_attn.out_proj.bias", "decoder.layers.84.self_attn_layer_norm.weight", "decoder.layers.84.self_attn_layer_norm.bias", "decoder.layers.84.fc1.weight", "decoder.layers.84.fc1.bias", "decoder.layers.84.fc2.weight", "decoder.layers.84.fc2.bias", "decoder.layers.84.final_layer_norm.weight", "decoder.layers.84.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.85.flat_param_0": {"names": ["decoder.layers.85.self_attn.qkv_proj.weight", "decoder.layers.85.self_attn.qkv_proj.bias", "decoder.layers.85.self_attn.out_proj.weight", "decoder.layers.85.self_attn.out_proj.bias", "decoder.layers.85.self_attn_layer_norm.weight", "decoder.layers.85.self_attn_layer_norm.bias", "decoder.layers.85.fc1.weight", "decoder.layers.85.fc1.bias", "decoder.layers.85.fc2.weight", "decoder.layers.85.fc2.bias", "decoder.layers.85.final_layer_norm.weight", "decoder.layers.85.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.86.flat_param_0": {"names": ["decoder.layers.86.self_attn.qkv_proj.weight", "decoder.layers.86.self_attn.qkv_proj.bias", "decoder.layers.86.self_attn.out_proj.weight", "decoder.layers.86.self_attn.out_proj.bias", "decoder.layers.86.self_attn_layer_norm.weight", "decoder.layers.86.self_attn_layer_norm.bias", "decoder.layers.86.fc1.weight", "decoder.layers.86.fc1.bias", "decoder.layers.86.fc2.weight", "decoder.layers.86.fc2.bias", "decoder.layers.86.final_layer_norm.weight", "decoder.layers.86.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.87.flat_param_0": {"names": ["decoder.layers.87.self_attn.qkv_proj.weight", "decoder.layers.87.self_attn.qkv_proj.bias", "decoder.layers.87.self_attn.out_proj.weight", "decoder.layers.87.self_attn.out_proj.bias", "decoder.layers.87.self_attn_layer_norm.weight", "decoder.layers.87.self_attn_layer_norm.bias", "decoder.layers.87.fc1.weight", "decoder.layers.87.fc1.bias", "decoder.layers.87.fc2.weight", "decoder.layers.87.fc2.bias", "decoder.layers.87.final_layer_norm.weight", "decoder.layers.87.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.88.flat_param_0": {"names": ["decoder.layers.88.self_attn.qkv_proj.weight", "decoder.layers.88.self_attn.qkv_proj.bias", "decoder.layers.88.self_attn.out_proj.weight", "decoder.layers.88.self_attn.out_proj.bias", "decoder.layers.88.self_attn_layer_norm.weight", "decoder.layers.88.self_attn_layer_norm.bias", "decoder.layers.88.fc1.weight", "decoder.layers.88.fc1.bias", "decoder.layers.88.fc2.weight", "decoder.layers.88.fc2.bias", "decoder.layers.88.final_layer_norm.weight", "decoder.layers.88.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.89.flat_param_0": {"names": ["decoder.layers.89.self_attn.qkv_proj.weight", "decoder.layers.89.self_attn.qkv_proj.bias", "decoder.layers.89.self_attn.out_proj.weight", "decoder.layers.89.self_attn.out_proj.bias", "decoder.layers.89.self_attn_layer_norm.weight", "decoder.layers.89.self_attn_layer_norm.bias", "decoder.layers.89.fc1.weight", "decoder.layers.89.fc1.bias", "decoder.layers.89.fc2.weight", "decoder.layers.89.fc2.bias", "decoder.layers.89.final_layer_norm.weight", "decoder.layers.89.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.90.flat_param_0": {"names": ["decoder.layers.90.self_attn.qkv_proj.weight", "decoder.layers.90.self_attn.qkv_proj.bias", "decoder.layers.90.self_attn.out_proj.weight", "decoder.layers.90.self_attn.out_proj.bias", "decoder.layers.90.self_attn_layer_norm.weight", "decoder.layers.90.self_attn_layer_norm.bias", "decoder.layers.90.fc1.weight", "decoder.layers.90.fc1.bias", "decoder.layers.90.fc2.weight", "decoder.layers.90.fc2.bias", "decoder.layers.90.final_layer_norm.weight", "decoder.layers.90.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.91.flat_param_0": {"names": ["decoder.layers.91.self_attn.qkv_proj.weight", "decoder.layers.91.self_attn.qkv_proj.bias", "decoder.layers.91.self_attn.out_proj.weight", "decoder.layers.91.self_attn.out_proj.bias", "decoder.layers.91.self_attn_layer_norm.weight", "decoder.layers.91.self_attn_layer_norm.bias", "decoder.layers.91.fc1.weight", "decoder.layers.91.fc1.bias", "decoder.layers.91.fc2.weight", "decoder.layers.91.fc2.bias", "decoder.layers.91.final_layer_norm.weight", "decoder.layers.91.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.92.flat_param_0": {"names": ["decoder.layers.92.self_attn.qkv_proj.weight", "decoder.layers.92.self_attn.qkv_proj.bias", "decoder.layers.92.self_attn.out_proj.weight", "decoder.layers.92.self_attn.out_proj.bias", "decoder.layers.92.self_attn_layer_norm.weight", "decoder.layers.92.self_attn_layer_norm.bias", "decoder.layers.92.fc1.weight", "decoder.layers.92.fc1.bias", "decoder.layers.92.fc2.weight", "decoder.layers.92.fc2.bias", "decoder.layers.92.final_layer_norm.weight", "decoder.layers.92.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.93.flat_param_0": {"names": ["decoder.layers.93.self_attn.qkv_proj.weight", "decoder.layers.93.self_attn.qkv_proj.bias", "decoder.layers.93.self_attn.out_proj.weight", "decoder.layers.93.self_attn.out_proj.bias", "decoder.layers.93.self_attn_layer_norm.weight", "decoder.layers.93.self_attn_layer_norm.bias", "decoder.layers.93.fc1.weight", "decoder.layers.93.fc1.bias", "decoder.layers.93.fc2.weight", "decoder.layers.93.fc2.bias", "decoder.layers.93.final_layer_norm.weight", "decoder.layers.93.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.94.flat_param_0": {"names": ["decoder.layers.94.self_attn.qkv_proj.weight", "decoder.layers.94.self_attn.qkv_proj.bias", "decoder.layers.94.self_attn.out_proj.weight", "decoder.layers.94.self_attn.out_proj.bias", "decoder.layers.94.self_attn_layer_norm.weight", "decoder.layers.94.self_attn_layer_norm.bias", "decoder.layers.94.fc1.weight", "decoder.layers.94.fc1.bias", "decoder.layers.94.fc2.weight", "decoder.layers.94.fc2.bias", "decoder.layers.94.final_layer_norm.weight", "decoder.layers.94.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.95.flat_param_0": {"names": ["decoder.layers.95.self_attn.qkv_proj.weight", "decoder.layers.95.self_attn.qkv_proj.bias", "decoder.layers.95.self_attn.out_proj.weight", "decoder.layers.95.self_attn.out_proj.bias", "decoder.layers.95.self_attn_layer_norm.weight", "decoder.layers.95.self_attn_layer_norm.bias", "decoder.layers.95.fc1.weight", "decoder.layers.95.fc1.bias", "decoder.layers.95.fc2.weight", "decoder.layers.95.fc2.bias", "decoder.layers.95.final_layer_norm.weight", "decoder.layers.95.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}} \ No newline at end of file diff --git a/examples/tutorial/opt/inference/script/process-opt-175b/unflat.sh b/examples/tutorial/opt/inference/script/process-opt-175b/unflat.sh new file mode 100644 index 0000000000000000000000000000000000000000..cc5c190e24e82cc1e0b6d6177b0fe691a5942385 --- /dev/null +++ b/examples/tutorial/opt/inference/script/process-opt-175b/unflat.sh @@ -0,0 +1,7 @@ +#!/usr/bin/env sh + +for i in $(seq 0 7); do + python convert_ckpt.py $1 $2 ${i} & +done + +wait $(jobs -p) diff --git a/examples/tutorial/opt/inference/script/processing_ckpt_66b.py b/examples/tutorial/opt/inference/script/processing_ckpt_66b.py new file mode 100644 index 0000000000000000000000000000000000000000..0494647d7bcce31a6ba1e3decf0943043730c2fa --- /dev/null +++ b/examples/tutorial/opt/inference/script/processing_ckpt_66b.py @@ -0,0 +1,55 @@ +import os +import torch +from multiprocessing import Pool + +# download pytorch model ckpt in https://huggingface.co/facebook/opt-66b/tree/main +# you can use whether wget or git lfs + +path = "/path/to/your/ckpt" +new_path = "/path/to/the/processed/ckpt/" + +assert os.path.isdir(path) +files = [] +for filename in os.listdir(path): + filepath = os.path.join(path, filename) + if os.path.isfile(filepath): + files.append(filepath) + +with Pool(14) as pool: + ckpts = pool.map(torch.load, files) + +restored = {} +for ckpt in ckpts: + for k,v in ckpt.items(): + if(k[0] == 'm'): + k = k[6:] + if(k == "lm_head.weight"): + k = "head.dense.weight" + if(k == "decoder.final_layer_norm.weight"): + k = "decoder.layer_norm.weight" + if(k == "decoder.final_layer_norm.bias"): + k = "decoder.layer_norm.bias" + restored[k] = v +restored["decoder.version"] = "0.0" + + +split_num = len(restored.keys()) // 60 +count = 0 +file_count = 1 +tmp = {} +for k,v in restored.items(): + print(k) + tmp[k] = v + count = count + 1 + if(count == split_num): + filename = str(file_count) + "-restored.pt" + torch.save(tmp, os.path.join(new_path, filename)) + file_count = file_count + 1 + count = 0 + tmp = {} + +filename = str(file_count) + "-restored.pt" +torch.save(tmp, os.path.join(new_path, filename)) + + + diff --git a/examples/tutorial/opt/opt/README.md b/examples/tutorial/opt/opt/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a01209cbda0e58f931289e794ac0f79b9e7f4b47 --- /dev/null +++ b/examples/tutorial/opt/opt/README.md @@ -0,0 +1,76 @@ + +# Train OPT model with Colossal-AI + + +## OPT +Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model, which stimulates AI programmers to perform various downstream tasks and application deployments. + +The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Casual Language Modelling at low cost. + +We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before +the tokenization). This training script is adapted from the [HuggingFace Language Modelling examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling). + +## Our Modifications +We adapt the OPT training code to ColossalAI by leveraging Gemini and ZeRO DDP. + +## 🚀Quick Start for Tutorial +1. Install the dependency +```bash +pip install datasets accelerate +``` +2. Run finetuning with synthetic datasets with one GPU +```bash +bash ./run_clm_synthetic.sh +``` +3. Run finetuning with 4 GPUs +```bash +bash ./run_clm_synthetic.sh 16 0 125m 4 +``` + +## Quick Start for Practical Use +You can launch training by using the following bash script + +```bash +bash ./run_clm.sh +``` + +- batch-size-per-gpu: number of samples fed to each GPU, default is 16 +- mem-cap: limit memory usage within a value in GB, default is 0 (no limit) +- model: the size of the OPT model, default is `6.7b`. Acceptable values include `125m`, `350m`, `1.3b`, `2.7b`, `6.7`, `13b`, `30b`, `66b`. For `175b`, you can request +the pretrained weights from [OPT weight downloading page](https://github.com/facebookresearch/metaseq/tree/main/projects/OPT). +- gpu-num: the number of GPUs to use, default is 1. + +It uses `wikitext` dataset. + +To use synthetic dataset: + +```bash +bash ./run_clm_synthetic.sh +``` + +## Remarkable Performance +On a single GPU, Colossal-AI’s automatic strategy provides remarkable performance gains from the ZeRO Offloading strategy by Microsoft DeepSpeed. +Users can experience up to a 40% speedup, at a variety of model scales. However, when using a traditional deep learning training framework like PyTorch, a single GPU can no longer support the training of models at such a scale. + +

+ +

+ +Adopting the distributed training strategy with 8 GPUs is as simple as adding a `-nprocs 8` to the training command of Colossal-AI! + +More details about behind the scenes can be found on the corresponding [blog](https://medium.com/@yangyou_berkeley/colossal-ai-seamlessly-accelerates-large-models-at-low-costs-with-hugging-face-4d1a887e500d), +and a detailed tutorial will be added in [Documentation](https://www.colossalai.org/docs/get_started/installation) very soon. diff --git a/examples/tutorial/opt/opt/benchmark.sh b/examples/tutorial/opt/opt/benchmark.sh new file mode 100644 index 0000000000000000000000000000000000000000..f02f7629ad16f42fc04f131c26869477fc59aa90 --- /dev/null +++ b/examples/tutorial/opt/opt/benchmark.sh @@ -0,0 +1,21 @@ +export BS=16 +export MEMCAP=0 +export MODEL="6.7b" +export GPUNUM=1 + +for MODEL in "6.7b" "13b" "1.3b" +do +for GPUNUM in 8 1 +do +for BS in 16 24 32 8 +do +for MEMCAP in 0 40 +do +pkill -9 torchrun +pkill -9 python + +bash ./run_clm.sh $BS $MEMCAP $MODEL $GPUNUM +done +done +done +done diff --git a/examples/tutorial/opt/opt/colossalai_zero.py b/examples/tutorial/opt/opt/colossalai_zero.py new file mode 100644 index 0000000000000000000000000000000000000000..7c2c152450c5563a19a80c3e14b8463fa8423c9a --- /dev/null +++ b/examples/tutorial/opt/opt/colossalai_zero.py @@ -0,0 +1,10 @@ +try: + from colossalai.zero.shard_utils import TensorShardStrategy +except ImportError: + # colossalai > 0.2.8 + from colossalai.zero.legacy import TensorShardStrategy + +zero = dict(model_config=dict(shard_strategy=TensorShardStrategy(), + tensor_placement_policy="auto", + reuse_fp16_shard=True), + optimizer_config=dict(gpu_margin_mem_ratio=0.8, initial_scale=16384)) diff --git a/examples/tutorial/opt/opt/context.py b/examples/tutorial/opt/opt/context.py new file mode 100644 index 0000000000000000000000000000000000000000..95f0abf1d8c92ed5766e5f0fa2c70618be7827c5 --- /dev/null +++ b/examples/tutorial/opt/opt/context.py @@ -0,0 +1,32 @@ +import torch.distributed as dist + +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc + + +class barrier_context(): + """ + This context manager is used to allow one process to execute while blocking all + other processes in the same process group. This is often useful when downloading is required + as we only want to download in one process to prevent file corruption. + Args: + executor_rank (int): the process rank to execute without blocking, all other processes will be blocked + parallel_mode (ParallelMode): the parallel mode corresponding to a process group + Usage: + with barrier_context(): + dataset = CIFAR10(root='./data', download=True) + """ + + def __init__(self, executor_rank: int = 0, parallel_mode: ParallelMode = ParallelMode.GLOBAL): + # the class name is lowercase by convention + current_rank = gpc.get_local_rank(parallel_mode=parallel_mode) + self.should_block = current_rank != executor_rank + self.group = gpc.get_group(parallel_mode=parallel_mode) + + def __enter__(self): + if self.should_block: + dist.barrier(group=self.group) + + def __exit__(self, exc_type, exc_value, exc_traceback): + if not self.should_block: + dist.barrier(group=self.group) diff --git a/examples/tutorial/opt/opt/requirements.txt b/examples/tutorial/opt/opt/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..d0ed2c717aee74916ee54a29ebcf69a6d41ca2f5 --- /dev/null +++ b/examples/tutorial/opt/opt/requirements.txt @@ -0,0 +1,7 @@ +colossalai +torch >= 1.8.1 +datasets >= 1.8.0 +sentencepiece != 0.1.92 +protobuf +accelerate == 0.13.2 +transformers diff --git a/examples/tutorial/opt/opt/run_clm.py b/examples/tutorial/opt/opt/run_clm.py new file mode 100755 index 0000000000000000000000000000000000000000..fdc86adab66578156809231ede5c16064e57ad71 --- /dev/null +++ b/examples/tutorial/opt/opt/run_clm.py @@ -0,0 +1,638 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) +on a text file or a dataset without using HuggingFace Trainer. + +Here is the full list of checkpoints on the hub that can be fine-tuned by this script: +https://huggingface.co/models?filter=text-generation +""" +# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. + +import math +import os +import time +from itertools import chain + +import datasets +import torch +import torch.distributed as dist +import transformers +from accelerate.utils import set_seed +from context import barrier_context +from datasets import load_dataset +from packaging import version +from torch.utils.data import DataLoader +from tqdm.auto import tqdm +from transformers import ( + CONFIG_MAPPING, + MODEL_MAPPING, + AutoConfig, + AutoTokenizer, + GPT2Tokenizer, + OPTForCausalLM, + SchedulerType, + default_data_collator, + get_scheduler, +) +from transformers.utils.versions import require_version + +import colossalai +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.optimizer import HybridAdam +from colossalai.tensor import ProcessGroup +from colossalai.utils import get_current_device, get_dataloader +from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer + +require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") + +MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +def get_time_stamp(): + torch.cuda.synchronize() + return time.time() + + +def parse_args(): + parser = colossalai.get_default_parser() + parser.add_argument("-s", "--synthetic", action="store_true") + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help="The name of the dataset to use (via the datasets library).", + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The configuration name of the dataset to use (via the datasets library).", + ) + parser.add_argument("--train_file", + type=str, + default=None, + help="A csv or a json file containing the training data.") + parser.add_argument("--validation_file", + type=str, + default=None, + help="A csv or a json file containing the validation data.") + parser.add_argument( + "--validation_split_percentage", + default=5, + help="The percentage of the train set used as validation set in case there's no validation split", + ) + parser.add_argument( + "--model_name_or_path", + type=str, + help="Path to pretrained model or model identifier from huggingface.co/models.", + required=True, + ) + parser.add_argument( + "--config_name", + type=str, + default=None, + help="Pretrained config name or path if not the same as model_name", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--use_slow_tokenizer", + action="store_true", + help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).", + ) + parser.add_argument( + "--per_device_train_batch_size", + type=int, + default=8, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument( + "--per_device_eval_batch_size", + type=int, + default=8, + help="Batch size (per device) for the evaluation dataloader.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") + parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.") + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--lr_scheduler_type", + type=SchedulerType, + default="linear", + help="The scheduler type to use.", + choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], + ) + parser.add_argument("--num_warmup_steps", + type=int, + default=0, + help="Number of steps for the warmup in the lr scheduler.") + parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--model_type", + type=str, + default=None, + help="Model type to use if training from scratch.", + choices=MODEL_TYPES, + ) + parser.add_argument( + "--block_size", + type=int, + default=None, + help=("Optional input sequence length after tokenization. The training dataset will be truncated in block of" + " this size for training. Default to the model max input length for single sentence inputs (take into" + " account special tokens)."), + ) + parser.add_argument( + "--preprocessing_num_workers", + type=int, + default=None, + help="The number of processes to use for the preprocessing.", + ) + parser.add_argument("--overwrite_cache", + type=bool, + default=False, + help="Overwrite the cached training and evaluation sets") + parser.add_argument("--no_keep_linebreaks", + action="store_true", + help="Do not keep line breaks when using TXT files.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_model_id", + type=str, + help="The name of the repository to keep in sync with the local `output_dir`.") + parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--checkpointing_steps", + type=str, + default=None, + help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.", + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help="If the training should continue from a checkpoint folder.", + ) + parser.add_argument( + "--with_tracking", + action="store_true", + help="Whether to enable experiment trackers for logging.", + ) + parser.add_argument( + "--report_to", + type=str, + default="all", + help=('The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' + ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.' + "Only applicable when `--with_tracking` is passed."), + ) + + parser.add_argument("--mem_cap", type=int, default=0, help="use mem cap") + parser.add_argument("--init_in_cpu", action='store_true', default=False, help="init training model in cpu") + args = parser.parse_args() + + # Sanity checks + if not args.synthetic: + if args.dataset_name is None and args.train_file is None and args.validation_file is None: + raise ValueError("Need either a dataset name or a training/validation file.") + else: + if args.train_file is not None: + extension = args.train_file.split(".")[-1] + assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, json or txt file." + if args.validation_file is not None: + extension = args.validation_file.split(".")[-1] + assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file." + + if args.push_to_hub: + assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed." + + return args + + +def colo_memory_cap(size_in_GB): + from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device + cuda_capacity = colo_device_memory_capacity(get_current_device()) + if size_in_GB * (1024**3) < cuda_capacity: + colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity) + print("Using {} GB of GPU memory".format(size_in_GB)) + + +class DummyDataloader: + + def __init__(self, length, batch_size, seq_len, vocab_size): + self.length = length + self.batch_size = batch_size + self.seq_len = seq_len + self.vocab_size = vocab_size + + def generate(self): + input_ids = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len), device=get_current_device()) + attention_mask = torch.ones_like(input_ids) + return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": input_ids} + + def __iter__(self): + self.step = 0 + return self + + def __next__(self): + if self.step < self.length: + self.step += 1 + return self.generate() + else: + raise StopIteration + + def __len__(self): + return self.length + + +def main(): + args = parse_args() + disable_existing_loggers() + colossalai.launch_from_torch(config=dict()) + logger = get_dist_logger() + is_main_process = dist.get_rank() == 0 + + if is_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + if args.mem_cap > 0: + colo_memory_cap(args.mem_cap) + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + logger.info(f"Rank {dist.get_rank()}: random seed is set to {args.seed}") + + # Handle the repository creation + with barrier_context(): + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) + # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ + # (the dataset will be downloaded automatically from the datasets Hub). + # + # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called + # 'text' is found. You can easily tweak this behavior (see below). + # + # In distributed training, the load_dataset function guarantee that only one local process can concurrently + # download the dataset. + logger.info("Start preparing dataset", ranks=[0]) + if not args.synthetic: + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name) + if "validation" not in raw_datasets.keys(): + raw_datasets["validation"] = load_dataset( + args.dataset_name, + args.dataset_config_name, + split=f"train[:{args.validation_split_percentage}%]", + ) + raw_datasets["train"] = load_dataset( + args.dataset_name, + args.dataset_config_name, + split=f"train[{args.validation_split_percentage}%:]", + ) + else: + data_files = {} + dataset_args = {} + if args.train_file is not None: + data_files["train"] = args.train_file + if args.validation_file is not None: + data_files["validation"] = args.validation_file + extension = args.train_file.split(".")[-1] + if extension == "txt": + extension = "text" + dataset_args["keep_linebreaks"] = not args.no_keep_linebreaks + raw_datasets = load_dataset(extension, data_files=data_files, **dataset_args) + # If no validation data is there, validation_split_percentage will be used to divide the dataset. + if "validation" not in raw_datasets.keys(): + raw_datasets["validation"] = load_dataset( + extension, + data_files=data_files, + split=f"train[:{args.validation_split_percentage}%]", + **dataset_args, + ) + raw_datasets["train"] = load_dataset( + extension, + data_files=data_files, + split=f"train[{args.validation_split_percentage}%:]", + **dataset_args, + ) + logger.info("Dataset is prepared", ranks=[0]) + + # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Load pretrained model and tokenizer + # + # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + if args.config_name: + config = AutoConfig.from_pretrained(args.config_name) + elif args.model_name_or_path: + config = AutoConfig.from_pretrained(args.model_name_or_path) + else: + config = CONFIG_MAPPING[args.model_type]() + logger.warning("You are instantiating a new config instance from scratch.") + logger.info("Model config has been created", ranks=[0]) + + if args.model_name_or_path == 'facebook/opt-13b': + tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path) + else: + print(f'load model from {args.model_name_or_path}') + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer) + logger.info(f"{tokenizer.__class__.__name__} has been created", ranks=[0]) + + if args.init_in_cpu: + init_dev = torch.device('cpu') + else: + init_dev = get_current_device() + + # build model + if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b': + # currently, there has a bug in pretrained opt-13b + # we can not import it until huggingface fix it + logger.info("Train a new model from scratch", ranks=[0]) + with ColoInitContext(device=init_dev): + model = OPTForCausalLM(config) + else: + logger.info("Finetune a pre-trained model", ranks=[0]) + with ColoInitContext(device=init_dev): + model = OPTForCausalLM.from_pretrained(args.model_name_or_path, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + local_files_only=False) + + # enable graident checkpointing + model.gradient_checkpointing_enable() + + PLACEMENT_POLICY = 'auto' + cai_version = colossalai.__version__ + logger.info(f'using Colossal-AI version {cai_version}') + if version.parse(cai_version) > version.parse("0.1.10"): + try: + from colossalai.nn.parallel import GeminiDDP + except ImportError: + # this works for unreleased main branch, and this may be released on 0.2.9 + from colossalai.zero import GeminiDDP + model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True) + elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"): + from colossalai.gemini import ChunkManager, GeminiManager + pg = ProcessGroup() + chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32) + chunk_manager = ChunkManager(chunk_size, + pg, + enable_distributed_storage=True, + init_device=GeminiManager.get_default_device(PLACEMENT_POLICY)) + gemini_manager = GeminiManager(PLACEMENT_POLICY, chunk_manager) + model = ZeroDDP(model, gemini_manager) + + logger.info(f'{model.__class__.__name__} has been created', ranks=[0]) + + if not args.synthetic: + # Preprocessing the datasets. + # First we tokenize all the texts. + column_names = raw_datasets["train"].column_names + text_column_name = "text" if "text" in column_names else column_names[0] + + def tokenize_function(examples): + return tokenizer(examples[text_column_name]) + + with barrier_context(executor_rank=0, parallel_mode=ParallelMode.DATA): + tokenized_datasets = raw_datasets.map( + tokenize_function, + batched=True, + num_proc=args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not args.overwrite_cache, + desc="Running tokenizer on dataset", + ) + + if args.block_size is None: + block_size = tokenizer.model_max_length + if block_size > 1024: + logger.warning( + f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " + "Picking 1024 instead. You can change that default value by passing --block_size xxx.") + block_size = 1024 + else: + if args.block_size > tokenizer.model_max_length: + logger.warning(f"The block_size passed ({args.block_size}) is larger than the maximum length for the model" + f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}.") + block_size = min(args.block_size, tokenizer.model_max_length) + + # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. + def group_texts(examples): + # Concatenate all texts. + concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} + total_length = len(concatenated_examples[list(examples.keys())[0]]) + # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can + # customize this part to your needs. + if total_length >= block_size: + total_length = (total_length // block_size) * block_size + # Split by chunks of max_len. + result = { + k: [t[i:i + block_size] for i in range(0, total_length, block_size) + ] for k, t in concatenated_examples.items() + } + result["labels"] = result["input_ids"].copy() + return result + + if not args.synthetic: + # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder + # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower + # to preprocess. + # + # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: + # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map + + with barrier_context(executor_rank=0, parallel_mode=ParallelMode.DATA): + lm_datasets = tokenized_datasets.map( + group_texts, + batched=True, + num_proc=args.preprocessing_num_workers, + load_from_cache_file=not args.overwrite_cache, + desc=f"Grouping texts in chunks of {block_size}", + ) + + train_dataset = lm_datasets["train"] + eval_dataset = lm_datasets["validation"] + + # Log a few random samples from the training set: + # for index in random.sample(range(len(train_dataset)), 3): + # logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") + + # DataLoaders creation: + train_dataloader = get_dataloader(train_dataset, + shuffle=True, + add_sampler=True, + collate_fn=default_data_collator, + batch_size=args.per_device_train_batch_size) + eval_dataloader = DataLoader(eval_dataset, + collate_fn=default_data_collator, + batch_size=args.per_device_eval_batch_size) + else: + train_dataloader = DummyDataloader(30, args.per_device_train_batch_size, config.max_position_embeddings, + config.vocab_size) + eval_dataloader = DummyDataloader(10, args.per_device_train_batch_size, config.max_position_embeddings, + config.vocab_size) + logger.info("Dataloaders have been created", ranks=[0]) + + # Optimizer + # Split weights in two groups, one with weight decay and the other not. + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": args.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + + optimizer = HybridAdam(optimizer_grouped_parameters, lr=args.learning_rate) + optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**14) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + name=args.lr_scheduler_type, + optimizer=optimizer, + num_warmup_steps=args.num_warmup_steps, + num_training_steps=args.max_train_steps, + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # Train! + total_batch_size = args.per_device_train_batch_size * gpc.get_world_size(ParallelMode.DATA) + num_train_samples = len(train_dataset) if not args.synthetic else 30 * total_batch_size + num_eval_samples = len(eval_dataset) if not args.synthetic else 10 * total_batch_size + + logger.info("***** Running training *****", ranks=[0]) + logger.info(f" Num examples = {num_train_samples}", ranks=[0]) + logger.info(f" Num Epochs = {args.num_train_epochs}", ranks=[0]) + logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}", ranks=[0]) + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}", ranks=[0]) + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}", ranks=[0]) + logger.info(f" Total optimization steps = {args.max_train_steps}", ranks=[0]) + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(args.max_train_steps), disable=not is_main_process) + completed_steps = 0 + starting_epoch = 0 + global_step = 0 + + for epoch in range(starting_epoch, args.num_train_epochs): + + if completed_steps >= args.max_train_steps: + break + + model.train() + for step, batch in enumerate(train_dataloader): + batch = {k: v.cuda() for k, v in batch.items()} + outputs = model(use_cache=False, **batch) + loss = outputs['loss'] + optimizer.backward(loss) + + if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + progress_bar.update(1) + completed_steps += 1 + + global_step += 1 + logger.info("Global step {} finished".format(global_step + 1), ranks=[0]) + + if completed_steps >= args.max_train_steps: + break + + model.eval() + losses = [] + for step, batch in enumerate(eval_dataloader): + with torch.no_grad(): + batch = {k: v.cuda() for k, v in batch.items()} + outputs = model(**batch) + + loss = outputs['loss'].unsqueeze(0) + losses.append(loss) + + losses = torch.cat(losses) + losses = losses[:num_eval_samples] + try: + eval_loss = torch.mean(losses) + perplexity = math.exp(eval_loss) + except OverflowError: + perplexity = float("inf") + + logger.info(f"Epoch {epoch}: perplexity: {perplexity} eval_loss: {eval_loss}", ranks=[0]) + + if args.output_dir is not None: + model_state = model.state_dict() + if is_main_process: + torch.save(model_state, args.output_dir + '/epoch_{}_model.pth'.format(completed_steps)) + dist.barrier() + # load_state = torch.load(args.output_dir + '/epoch_{}_model.pth'.format(completed_steps)) + # model.load_state_dict(load_state, strict=False) + + logger.info("Training finished", ranks=[0]) + + +if __name__ == "__main__": + main() diff --git a/examples/tutorial/opt/opt/run_clm.sh b/examples/tutorial/opt/opt/run_clm.sh new file mode 100644 index 0000000000000000000000000000000000000000..858d3325a7b4f9f55592ac4a7d62836f2a0a0501 --- /dev/null +++ b/examples/tutorial/opt/opt/run_clm.sh @@ -0,0 +1,22 @@ +set -x +export BS=${1:-16} +export MEMCAP=${2:-0} +export MODEL=${3:-"125m"} +export GPUNUM=${4:-1} + +# make directory for logs +mkdir -p ./logs + +export MODLE_PATH="facebook/opt-${MODEL}" + +# HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 +torchrun \ + --nproc_per_node ${GPUNUM} \ + --master_port 19198 \ + run_clm.py \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --output_dir $PWD \ + --mem_cap ${MEMCAP} \ + --model_name_or_path ${MODLE_PATH} \ + --per_device_train_batch_size ${BS} 2>&1 | tee ./logs/colo_${MODEL}_bs_${BS}_cap_${MEMCAP}_gpu_${GPUNUM}.log diff --git a/examples/tutorial/opt/opt/run_clm_synthetic.sh b/examples/tutorial/opt/opt/run_clm_synthetic.sh new file mode 100644 index 0000000000000000000000000000000000000000..80435f16ce2d24b634bdb16017092c5a32ea599b --- /dev/null +++ b/examples/tutorial/opt/opt/run_clm_synthetic.sh @@ -0,0 +1,21 @@ +set -x +export BS=${1:-16} +export MEMCAP=${2:-0} +export MODEL=${3:-"125m"} +export GPUNUM=${4:-1} + +# make directory for logs +mkdir -p ./logs + +export MODLE_PATH="facebook/opt-${MODEL}" + +# HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 +torchrun \ + --nproc_per_node ${GPUNUM} \ + --master_port 19198 \ + run_clm.py \ + -s \ + --output_dir $PWD \ + --mem_cap ${MEMCAP} \ + --model_name_or_path ${MODLE_PATH} \ + --per_device_train_batch_size ${BS} 2>&1 | tee ./logs/colo_${MODEL}_bs_${BS}_cap_${MEMCAP}_gpu_${GPUNUM}.log diff --git a/examples/tutorial/opt/opt/test_ci.sh b/examples/tutorial/opt/opt/test_ci.sh new file mode 100755 index 0000000000000000000000000000000000000000..e505da1364de04c76f60177824466539af1af416 --- /dev/null +++ b/examples/tutorial/opt/opt/test_ci.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +set -xue + +pip install -r requirements.txt + +BS=8 +MEMCAP=0 +GPUNUM=2 +MODLE="facebook/opt-125m" + +torchrun \ + --nproc_per_node ${GPUNUM} \ + --master_port 19198 \ + run_clm.py \ + -s \ + --output_dir $PWD \ + --mem_cap ${MEMCAP} \ + --model_name_or_path ${MODLE} \ + --per_device_train_batch_size ${BS} \ + --num_train_epochs 1 diff --git a/examples/tutorial/opt/test_ci.sh b/examples/tutorial/opt/test_ci.sh new file mode 100755 index 0000000000000000000000000000000000000000..8341bb10510fbfaf2573fb886f7b5b3b0dedabb4 --- /dev/null +++ b/examples/tutorial/opt/test_ci.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +cd opt && bash test_ci.sh diff --git a/examples/tutorial/requirements.txt b/examples/tutorial/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..137a69e80498223cd7581a62e2e27320b77682a0 --- /dev/null +++ b/examples/tutorial/requirements.txt @@ -0,0 +1,2 @@ +colossalai >= 0.1.12 +torch >= 1.8.1 diff --git a/examples/tutorial/sequence_parallel/README.md b/examples/tutorial/sequence_parallel/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1b7c60e22861d2d519bab5debf032ef1f8477b47 --- /dev/null +++ b/examples/tutorial/sequence_parallel/README.md @@ -0,0 +1,68 @@ +# Sequence Parallelism + +## Table of contents + +- [Sequence Parallelism](#sequence-parallelism) + - [Table of contents](#table-of-contents) + - [📚 Overview](#-overview) + - [🚀 Quick Start](#-quick-start) + - [🏎 How to Train with Sequence Parallelism](#-how-to-train-with-sequence-parallelism) + - [Step 1. Configure your parameters](#step-1-configure-your-parameters) + - [Step 2. Invoke parallel training](#step-2-invoke-parallel-training) + +## 📚 Overview + +In this tutorial, we implemented BERT with sequence parallelism. Sequence parallelism splits the input tensor and intermediate +activation along the sequence dimension. This method can achieve better memory efficiency and allows us to train with larger batch size and longer sequence length. + +Paper: [Sequence Parallelism: Long Sequence Training from System Perspective](https://arxiv.org/abs/2105.13120) + +## 🚀 Quick Start + +1. Install PyTorch + +2. Install the dependencies. + +```bash +pip install -r requirements.txt +``` + +3. Run with the following command + +```bash +export PYTHONPATH=$PWD + +# run with synthetic dataset +colossalai run --nproc_per_node 4 train.py +``` + +> The default config is sequence parallel size = 2, pipeline size = 1, let’s change pipeline size to be 2 and try it again. + + +## 🏎 How to Train with Sequence Parallelism + +We provided `train.py` for you to execute training. Before invoking the script, there are several +steps to perform. + +### Step 1. Configure your parameters + +In the `config.py` provided, a set of parameters are defined including training scheme, model, etc. +You can also modify the ColossalAI setting. For example, if you wish to parallelize over the +sequence dimension on 8 GPUs. You can change `size=4` to `size=8`. If you wish to use pipeline parallelism, you can set `pipeline=`. + +### Step 2. Invoke parallel training + +Lastly, you can start training with sequence parallelism. How you invoke `train.py` depends on your +machine setting. + +- If you are using a single machine with multiple GPUs, PyTorch launch utility can easily let you + start your script. A sample command is like below: + + ```bash + colossalai run --nproc_per_node --master_addr localhost --master_port 29500 train.py + ``` + +- If you are using multiple machines with multiple GPUs, we suggest that you refer to `colossalai + launch_from_slurm` or `colossalai.launch_from_openmpi` as it is easier to use SLURM and OpenMPI + to start multiple processes over multiple nodes. If you have your own launcher, you can fall back + to the default `colossalai.launch` function. diff --git a/examples/tutorial/sequence_parallel/config.py b/examples/tutorial/sequence_parallel/config.py new file mode 100644 index 0000000000000000000000000000000000000000..6edf9cc2c7e5a22d836839bdcc40959ac75b0296 --- /dev/null +++ b/examples/tutorial/sequence_parallel/config.py @@ -0,0 +1,35 @@ +from colossalai.amp import AMP_TYPE + +# hyper-parameters +TRAIN_ITERS = 10 +DECAY_ITERS = 4 +WARMUP_FRACTION = 0.01 +GLOBAL_BATCH_SIZE = 32 # dp world size * sentences per GPU +EVAL_ITERS = 10 +EVAL_INTERVAL = 10 +LR = 0.0001 +MIN_LR = 1e-05 +WEIGHT_DECAY = 0.01 +SEQ_LENGTH = 128 + +# BERT config +DEPTH = 4 +NUM_ATTENTION_HEADS = 4 +HIDDEN_SIZE = 128 + +# model config +ADD_BINARY_HEAD = False + +# random seed +SEED = 1234 + +# pipeline config +# only enabled when pipeline > 1 +NUM_MICRO_BATCHES = 4 + +# colossalai config +parallel = dict(pipeline=1, tensor=dict(size=2, mode='sequence')) + +fp16 = dict(mode=AMP_TYPE.NAIVE, verbose=True) + +gradient_handler = [dict(type='SequenceParallelGradientHandler')] diff --git a/examples/tutorial/sequence_parallel/data/__init__.py b/examples/tutorial/sequence_parallel/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1ef2d999389fe001b01342e66942c69455327efb --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/__init__.py @@ -0,0 +1,102 @@ +from colossalai.context.parallel_context import ParallelContext +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.context import ParallelMode +from .datasets.data_samplers import build_pretraining_data_loader +from .datasets.builder import build_train_valid_test_datasets +import torch + + +def cyclic_iter(iter): + while True: + for x in iter: + yield x + + +def build_train_valid_test_data_iterators(train_iters, + global_batch_size, + eval_interval, + eval_iters, + dataloader_type='single', + **kwargs + ): + (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None) + + logger = get_dist_logger() + logger.info('> building train, validation, and test datasets ...', ranks=[0]) + + # Backward compatibility, assume fixed batch size. + # if iteration > 0 and consumed_train_samples == 0: + # assert train_samples is None, \ + # 'only backward compatibility support for iteration-based training' + # consumed_train_samples = iteration * global_batch_size + # if iteration > 0 and consumed_valid_samples == 0: + # if train_samples is None: + # consumed_valid_samples = (iteration // eval_interval) * \ + # eval_iters * global_batch_size + + # Data loader only on rank 0 of each model parallel group. + if not gpc.is_initialized(ParallelMode.TENSOR) or gpc.get_local_rank(ParallelMode.TENSOR) == 0: + + # Number of train/valid/test samples. + train_samples = train_iters * global_batch_size + eval_iters_ = (train_iters // eval_interval + 1) * eval_iters + test_iters = eval_iters + train_val_test_num_samples = [train_samples, + eval_iters_ * global_batch_size, + test_iters * global_batch_size] + logger.info(' > datasets target sizes (minimum size):') + logger.info(' train: {}'.format(train_val_test_num_samples[0]), ranks=[0]) + logger.info(' validation: {}'.format(train_val_test_num_samples[1]), ranks=[0]) + logger.info(' test: {}'.format(train_val_test_num_samples[2]), ranks=[0]) + + # Build the datasets. + train_ds, valid_ds, test_ds = build_train_valid_test_datasets( + train_valid_test_num_samples=train_val_test_num_samples, **kwargs) + + # Build dataloaders. + dp_size = gpc.get_world_size(ParallelMode.DATA) + train_dataloader = build_pretraining_data_loader( + train_ds, consumed_samples=0, micro_batch_size=global_batch_size//dp_size) + valid_dataloader = build_pretraining_data_loader( + valid_ds, consumed_samples=0, micro_batch_size=global_batch_size//dp_size) + test_dataloader = build_pretraining_data_loader(test_ds, 0, micro_batch_size=global_batch_size//dp_size) + + # Flags to know if we need to do training/validation/testing. + do_train = train_dataloader is not None and train_iters > 0 + do_valid = valid_dataloader is not None and eval_iters > 0 + do_test = test_dataloader is not None and eval_iters > 0 + # Need to broadcast num_tokens and num_type_tokens. + flags = torch.cuda.LongTensor( + [int(do_train), int(do_valid), int(do_test)]) + else: + flags = torch.cuda.LongTensor([0, 0, 0]) + + # Broadcast num tokens. + torch.distributed.broadcast(flags, + gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], + group=gpc.get_group(ParallelMode.TENSOR)) + + # Build iterators. + dl_type = dataloader_type + assert dl_type in ['single', 'cyclic'] + + if train_dataloader is not None: + train_data_iterator = iter(train_dataloader) if dl_type == 'single' \ + else iter(cyclic_iter(train_dataloader)) + else: + train_data_iterator = None + + if valid_dataloader is not None: + valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \ + else iter(cyclic_iter(valid_dataloader)) + else: + valid_data_iterator = None + + if test_dataloader is not None: + test_data_iterator = iter(test_dataloader) if dl_type == 'single' \ + else iter(cyclic_iter(test_dataloader)) + else: + test_data_iterator = None + + return train_data_iterator, valid_data_iterator, test_data_iterator diff --git a/examples/tutorial/sequence_parallel/data/bert_helper.py b/examples/tutorial/sequence_parallel/data/bert_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..d092db3e7dd8d545253e3a36c6203ace3d0eec9d --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/bert_helper.py @@ -0,0 +1,165 @@ +from colossalai.core import global_context as gpc +from colossalai.context import ParallelMode +import torch + +_MAX_DATA_DIM = 5 + + +def _build_key_size_numel_dictionaries(keys, data): + """Build the size on rank 0 and broadcast.""" + max_dim = _MAX_DATA_DIM + sizes = [0 for _ in range(max_dim) for _ in keys] + + # Pack the sizes on rank zero. + if not gpc.is_initialized(ParallelMode.TENSOR) or gpc.get_local_rank(ParallelMode.TENSOR) == 0: + offset = 0 + for key in keys: + assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM' + size = data[key].size() + for i, s in enumerate(size): + sizes[i + offset] = s + offset += max_dim + + # Move to GPU and broadcast. + sizes_cuda = torch.cuda.LongTensor(sizes) + torch.distributed.broadcast(sizes_cuda, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], + group=gpc.get_group(ParallelMode.TENSOR)) + + # Move back to cpu and unpack. + sizes_cpu = sizes_cuda.cpu() + key_size = {} + key_numel = {} + total_numel = 0 + offset = 0 + for key in keys: + i = 0 + size = [] + numel = 1 + while sizes_cpu[offset + i] > 0: + this_size = sizes_cpu[offset + i] + size.append(this_size) + numel *= this_size + i += 1 + key_size[key] = size + key_numel[key] = numel + total_numel += numel + offset += max_dim + + return key_size, key_numel, total_numel + + +def broadcast_data(keys, data, datatype): + """Broadcast data from rank zero of each model parallel group to the + members of the same model parallel group. + + Arguments: + keys: list of keys in the data dictionary to be broadcasted + data: data dictionary of string keys and cpu tensor values. + datatype: torch data type of all tensors in data associated + with keys. + """ + # Build (key, size) and (key, number of elements) dictionaries along + # with the total number of elements on all ranks. + key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys, + data) + + # Pack on rank zero. + if not gpc.is_initialized(ParallelMode.TENSOR) or gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # Check that all keys have the same data type. + # Flatten the data associated with the keys + flatten_data = torch.cat( + [data[key].contiguous().view(-1) for key in keys], dim=0).cuda() + else: + flatten_data = torch.empty(total_numel, + device=torch.cuda.current_device(), + dtype=datatype) + + # Broadcast + torch.distributed.broadcast(flatten_data, + gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], + group=gpc.get_group(ParallelMode.TENSOR)) + + # Unpack + output = {} + offset = 0 + for key in keys: + size = key_size[key] + numel = key_numel[key] + output[key] = flatten_data.narrow(0, offset, numel).view(size) + offset += numel + + return output + + +def get_batch(data_iterator): + """Build the batch.""" + + # Items and their type. + keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask'] + datatype = torch.int64 + + # Broadcast data. + if data_iterator is not None: + data = next(data_iterator) + else: + data = None + data_b = broadcast_data(keys, data, datatype) + + # Unpack. + tokens = data_b['text'].long() + types = data_b['types'].long() + sentence_order = data_b['is_random'].long() + loss_mask = data_b['loss_mask'].float() + lm_labels = data_b['labels'].long() + padding_mask = data_b['padding_mask'].long() + + return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask + + +def get_batch_for_sequence_parallel(data_iterator): + """Build the batch.""" + + # Items and their type. + keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask'] + datatype = torch.int64 + + # Broadcast data. + if data_iterator is not None: + data = next(data_iterator) + else: + data = None + + # unpack + data_b = broadcast_data(keys, data, datatype) + + # # get tensor parallel local rank + global_rank = torch.distributed.get_rank() + local_world_size = 1 if not gpc.is_initialized(ParallelMode.TENSOR) else gpc.get_world_size(ParallelMode.TENSOR) + local_rank = global_rank % local_world_size + seq_length = data_b['text'].size(1) + sub_seq_length = seq_length // local_world_size + sub_seq_start = local_rank * sub_seq_length + sub_seq_end = (local_rank+1) * sub_seq_length + # + # # Unpack. + tokens = data_b['text'][:, sub_seq_start:sub_seq_end].long() + types = data_b['types'][:, sub_seq_start:sub_seq_end].long() + sentence_order = data_b['is_random'].long() + loss_mask = data_b['loss_mask'][:, sub_seq_start:sub_seq_end].float() + lm_labels = data_b['labels'][:, sub_seq_start:sub_seq_end].long() + padding_mask = data_b['padding_mask'].long() + + return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask + + +class SequenceParallelDataIterator: + + def __init__(self, data_iter): + self.data_iter = data_iter + + + def __iter__(self): + return self.data_iter + + def __next__(self): + return get_batch_for_sequence_parallel(self.data_iter) \ No newline at end of file diff --git a/examples/tutorial/sequence_parallel/data/datasets/Makefile b/examples/tutorial/sequence_parallel/data/datasets/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..8f9db7686696fbea6c94b998db4b40ef426c748d --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/datasets/Makefile @@ -0,0 +1,9 @@ +CXXFLAGS += -O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color +CPPFLAGS += $(shell python3 -m pybind11 --includes) +LIBNAME = helpers +LIBEXT = $(shell python3-config --extension-suffix) + +default: $(LIBNAME)$(LIBEXT) + +%$(LIBEXT): %.cpp + $(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@ diff --git a/examples/tutorial/sequence_parallel/data/datasets/__init__.py b/examples/tutorial/sequence_parallel/data/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cd5f898c6bdf89c6cf0243af102d04f6efed86b8 --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/datasets/__init__.py @@ -0,0 +1 @@ +from . import indexed_dataset diff --git a/examples/tutorial/sequence_parallel/data/datasets/bert_dataset.py b/examples/tutorial/sequence_parallel/data/datasets/bert_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d6388bd9f8e427575f345c20f38aa276a97be049 --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/datasets/bert_dataset.py @@ -0,0 +1,236 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BERT Style dataset.""" + +import os +import time + +import numpy as np +import torch +from torch.utils.data import Dataset + +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger + +from ..tokenizer import get_tokenizer +from .dataset_utils import ( + create_masked_lm_predictions, + create_tokens_and_tokentypes, + get_a_and_b_segments, + pad_and_convert_to_numpy, + truncate_segments, +) + +try: + from . import helpers +except: + print("helper is not built, ignore this message if you are using synthetic data.") + + +class BertDataset(Dataset): + + def __init__(self, name, indexed_dataset, data_prefix, num_epochs, max_num_samples, masked_lm_prob, max_seq_length, + short_seq_prob, seed, binary_head): + + # Params to store. + self.name = name + self.seed = seed + self.masked_lm_prob = masked_lm_prob + self.max_seq_length = max_seq_length + self.binary_head = binary_head + + # Dataset. + self.indexed_dataset = indexed_dataset + + # Build the samples mapping. + self.samples_mapping = get_samples_mapping_( + self.indexed_dataset, + data_prefix, + num_epochs, + max_num_samples, + self.max_seq_length - 3, # account for added tokens, + short_seq_prob, + self.seed, + self.name, + self.binary_head) + + # Vocab stuff. + tokenizer = get_tokenizer() + self.vocab_id_list = list(tokenizer.inv_vocab.keys()) + self.vocab_id_to_token_dict = tokenizer.inv_vocab + self.cls_id = tokenizer.cls + self.sep_id = tokenizer.sep + self.mask_id = tokenizer.mask + self.pad_id = tokenizer.pad + + def __len__(self): + return self.samples_mapping.shape[0] + + def __getitem__(self, idx): + start_idx, end_idx, seq_length = self.samples_mapping[idx] + sample = [self.indexed_dataset[i] for i in range(start_idx, end_idx)] + # Note that this rng state should be numpy and not python since + # python randint is inclusive whereas the numpy one is exclusive. + # We % 2**32 since numpy requires the seed to be between 0 and 2**32 - 1 + np_rng = np.random.RandomState(seed=((self.seed + idx) % 2**32)) + return build_training_sample( + sample, + seq_length, + self.max_seq_length, # needed for padding + self.vocab_id_list, + self.vocab_id_to_token_dict, + self.cls_id, + self.sep_id, + self.mask_id, + self.pad_id, + self.masked_lm_prob, + np_rng, + self.binary_head) + + +def get_samples_mapping_(indexed_dataset, data_prefix, num_epochs, max_num_samples, max_seq_length, short_seq_prob, + seed, name, binary_head): + logger = get_dist_logger() + if not num_epochs: + if not max_num_samples: + raise ValueError("Need to specify either max_num_samples " + "or num_epochs") + num_epochs = np.iinfo(np.int32).max - 1 + if not max_num_samples: + max_num_samples = np.iinfo(np.int64).max - 1 + + # Filename of the index mapping + indexmap_filename = data_prefix + indexmap_filename += '_{}_indexmap'.format(name) + if num_epochs != (np.iinfo(np.int32).max - 1): + indexmap_filename += '_{}ep'.format(num_epochs) + if max_num_samples != (np.iinfo(np.int64).max - 1): + indexmap_filename += '_{}mns'.format(max_num_samples) + indexmap_filename += '_{}msl'.format(max_seq_length) + indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob) + indexmap_filename += '_{}s'.format(seed) + indexmap_filename += '.npy' + + # Build the indexed mapping if not exist. + if torch.distributed.get_rank() == 0 and \ + not os.path.isfile(indexmap_filename): + print(' > WARNING: could not find index map file {}, building ' + 'the indices on rank 0 ...'.format(indexmap_filename)) + + # Make sure the types match the helpers input types. + assert indexed_dataset.doc_idx.dtype == np.int64 + assert indexed_dataset.sizes.dtype == np.int32 + + # Build samples mapping + verbose = torch.distributed.get_rank() == 0 + start_time = time.time() + logger.info('\n > building samples index mapping for {} ...'.format(name), ranks=[0]) + # First compile and then import. + samples_mapping = helpers.build_mapping(indexed_dataset.doc_idx, indexed_dataset.sizes, num_epochs, + max_num_samples, max_seq_length, short_seq_prob, seed, verbose, + 2 if binary_head else 1) + logger.info('\n > done building samples index maping', ranks=[0]) + np.save(indexmap_filename, samples_mapping, allow_pickle=True) + logger.info('\n > saved the index mapping in {}'.format(indexmap_filename), ranks=[0]) + # Make sure all the ranks have built the mapping + logger.info('\n > elapsed time to build and save samples mapping ' + '(seconds): {:4f}'.format(time.time() - start_time), + ranks=[0]) + # This should be a barrier but nccl barrier assumes + # device_index=rank which is not the case for model + # parallel case + counts = torch.cuda.LongTensor([1]) + torch.distributed.all_reduce(counts, group=gpc.get_group(ParallelMode.DATA)) + if gpc.is_initialized(ParallelMode.PIPELINE): + torch.distributed.all_reduce(counts, group=gpc.get_group(ParallelMode.PIPELINE)) + assert counts[0].item() == (torch.distributed.get_world_size() // + torch.distributed.get_world_size(group=gpc.get_group(ParallelMode.SEQUENCE))) + + # Load indexed dataset. + start_time = time.time() + samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r') + logger.info('\n > loading indexed mapping from {}'.format(indexmap_filename) + + '\n loaded indexed file in {:3.3f} seconds'.format(time.time() - start_time) + + '\n total number of samples: {}'.format(samples_mapping.shape[0]), + ranks=[0]) + + return samples_mapping + + +def build_training_sample(sample, target_seq_length, max_seq_length, vocab_id_list, vocab_id_to_token_dict, cls_id, + sep_id, mask_id, pad_id, masked_lm_prob, np_rng, binary_head): + """Build training sample. + + Arguments: + sample: A list of sentences in which each sentence is a list token ids. + target_seq_length: Desired sequence length. + max_seq_length: Maximum length of the sequence. All values are padded to + this length. + vocab_id_list: List of vocabulary ids. Used to pick a random id. + vocab_id_to_token_dict: A dictionary from vocab ids to text tokens. + cls_id: Start of example id. + sep_id: Separator id. + mask_id: Mask token id. + pad_id: Padding token id. + masked_lm_prob: Probability to mask tokens. + np_rng: Random number genenrator. Note that this rng state should be + numpy and not python since python randint is inclusive for + the opper bound whereas the numpy one is exclusive. + """ + + if binary_head: + # We assume that we have at least two sentences in the sample + assert len(sample) > 1 + assert target_seq_length <= max_seq_length + + # Divide sample into two segments (A and B). + if binary_head: + tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample, np_rng) + else: + tokens_a = [] + for j in range(len(sample)): + tokens_a.extend(sample[j]) + tokens_b = [] + is_next_random = False + + # Truncate to `target_sequence_length`. + max_num_tokens = target_seq_length + truncated = truncate_segments(tokens_a, tokens_b, len(tokens_a), len(tokens_b), max_num_tokens, np_rng) + + # Build tokens and toketypes. + tokens, tokentypes = create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id) + + # Masking. + max_predictions_per_seq = masked_lm_prob * max_num_tokens + (tokens, masked_positions, masked_labels, + _) = create_masked_lm_predictions(tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob, cls_id, sep_id, + mask_id, max_predictions_per_seq, np_rng) + + # Padding. + tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \ + = pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, + masked_labels, pad_id, max_seq_length) + + train_sample = { + 'text': tokens_np, + 'types': tokentypes_np, + 'labels': labels_np, + 'is_random': int(is_next_random), + 'loss_mask': loss_mask_np, + 'padding_mask': padding_mask_np, + 'truncated': int(truncated) + } + return train_sample diff --git a/examples/tutorial/sequence_parallel/data/datasets/blendable_dataset.py b/examples/tutorial/sequence_parallel/data/datasets/blendable_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..6a06c869d8c808af92ed3a6993a57cca9ca78a8b --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/datasets/blendable_dataset.py @@ -0,0 +1,62 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Blendable dataset.""" + +import time + +import numpy as np +import torch + + +class BlendableDataset(torch.utils.data.Dataset): + + def __init__(self, datasets, weights): + + self.datasets = datasets + num_datasets = len(datasets) + assert num_datasets == len(weights) + + self.size = 0 + for dataset in self.datasets: + self.size += len(dataset) + + # Normalize weights. + weights = np.array(weights, dtype=np.float64) + sum_weights = np.sum(weights) + assert sum_weights > 0.0 + weights /= sum_weights + + # Build indices. + start_time = time.time() + assert num_datasets < 255 + self.dataset_index = np.zeros(self.size, dtype=np.uint8) + self.dataset_sample_index = np.zeros(self.size, dtype=np.int64) + + from . import helpers + helpers.build_blending_indices(self.dataset_index, + self.dataset_sample_index, + weights, num_datasets, self.size, + torch.distributed.get_rank() == 0) + print('> elapsed time for building blendable dataset indices: ' + '{:.2f} (sec)'.format(time.time() - start_time)) + + def __len__(self): + return self.size + + def __getitem__(self, idx): + dataset_idx = self.dataset_index[idx] + sample_idx = self.dataset_sample_index[idx] + return self.datasets[dataset_idx][sample_idx] diff --git a/examples/tutorial/sequence_parallel/data/datasets/builder.py b/examples/tutorial/sequence_parallel/data/datasets/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..6106f833b4628a0763cae82d0a7b6073f5c0548d --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/datasets/builder.py @@ -0,0 +1,152 @@ +from .blendable_dataset import BlendableDataset +from .dataset_utils import get_datasets_weights_and_num_samples, get_indexed_dataset_, get_train_valid_test_split_ +from .bert_dataset import BertDataset +from colossalai.logging import get_dist_logger + +DSET_TYPE_BERT = 'standard_bert' +DSET_TYPE_ICT = 'ict' +DSET_TYPE_T5 = 't5' + +DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5] + + +def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, + train_valid_test_num_samples, + max_seq_length, masked_lm_prob, + short_seq_prob, seed, skip_warmup, + binary_head, + dataset_type='standard_bert'): + + if dataset_type not in DSET_TYPES: + raise ValueError("Invalid dataset_type: ", dataset_type) + + # Indexed dataset. + indexed_dataset = get_indexed_dataset_(data_prefix, + data_impl, + skip_warmup) + + # Get start and end indices of train/valid/train into doc-idx + # Note that doc-idx is designed to be num-docs + 1 so we can + # easily iterate over it. + total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1 + splits = get_train_valid_test_split_(splits_string, total_num_of_documents) + + logger = get_dist_logger() + + # Print stats about the splits. + logger.info('\n > dataset split:', ranks=[0]) + + def print_split_stats(name, index): + start_index = indexed_dataset.doc_idx[splits[index]] + end_index = indexed_dataset.doc_idx[splits[index + 1]] + logger.info('\n {}:'.format(name) + + '\n document indices in [{}, {}) total of {} documents'.format( + splits[index], splits[index + 1], + splits[index + 1] - splits[index]) + + '\n sentence indices in [{}, {}) total of {} sentences'.format( + start_index, end_index, + end_index - start_index), + ranks=[0]) + print_split_stats('train', 0) + print_split_stats('validation', 1) + print_split_stats('test', 2) + + def build_dataset(index, name): + dataset = None + if splits[index + 1] > splits[index]: + # Get the pointer to the original doc-idx so we can set it later. + doc_idx_ptr = indexed_dataset.get_doc_idx() + # Slice the doc-idx + start_index = splits[index] + # Add +1 so we can index into the dataset to get the upper bound. + end_index = splits[index + 1] + 1 + # New doc_idx view. + indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index]) + # Build the dataset accordingly. + kwargs = dict( + name=name, + data_prefix=data_prefix, + num_epochs=None, + max_num_samples=train_valid_test_num_samples[index], + max_seq_length=max_seq_length, + seed=seed, + ) + + if dataset_type != DSET_TYPE_BERT: + raise NotImplementedError("Only BERT dataset is supported") + else: + dataset = BertDataset( + indexed_dataset=indexed_dataset, + masked_lm_prob=masked_lm_prob, + short_seq_prob=short_seq_prob, + binary_head=binary_head, + **kwargs + ) + + # Set the original pointer so dataset remains the main dataset. + indexed_dataset.set_doc_idx(doc_idx_ptr) + # Checks. + assert indexed_dataset.doc_idx[0] == 0 + assert indexed_dataset.doc_idx.shape[0] == \ + (total_num_of_documents + 1) + return dataset + + train_dataset = build_dataset(0, 'train') + valid_dataset = build_dataset(1, 'valid') + test_dataset = build_dataset(2, 'test') + + return (train_dataset, valid_dataset, test_dataset) + + +def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, + train_valid_test_num_samples, + max_seq_length, masked_lm_prob, + short_seq_prob, seed, skip_warmup, + binary_head, + dataset_type='standard_bert'): + + if len(data_prefix) == 1: + return _build_train_valid_test_datasets(data_prefix[0], + data_impl, splits_string, + train_valid_test_num_samples, + max_seq_length, masked_lm_prob, + short_seq_prob, seed, + skip_warmup, + binary_head, + dataset_type=dataset_type) + # Blending dataset. + # Parse the values. + output = get_datasets_weights_and_num_samples(data_prefix, + train_valid_test_num_samples) + prefixes, weights, datasets_train_valid_test_num_samples = output + + # Build individual datasets. + train_datasets = [] + valid_datasets = [] + test_datasets = [] + for i in range(len(prefixes)): + train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( + prefixes[i], data_impl, splits_string, + datasets_train_valid_test_num_samples[i], + max_seq_length, masked_lm_prob, short_seq_prob, + seed, skip_warmup, binary_head, dataset_type=dataset_type) + if train_ds: + train_datasets.append(train_ds) + if valid_ds: + valid_datasets.append(valid_ds) + if test_ds: + test_datasets.append(test_ds) + + # Blend. + blending_train_dataset = None + if train_datasets: + blending_train_dataset = BlendableDataset(train_datasets, weights) + blending_valid_dataset = None + if valid_datasets: + blending_valid_dataset = BlendableDataset(valid_datasets, weights) + blending_test_dataset = None + if test_datasets: + blending_test_dataset = BlendableDataset(test_datasets, weights) + + return (blending_train_dataset, blending_valid_dataset, + blending_test_dataset) diff --git a/examples/tutorial/sequence_parallel/data/datasets/data_samplers.py b/examples/tutorial/sequence_parallel/data/datasets/data_samplers.py new file mode 100644 index 0000000000000000000000000000000000000000..cf547ad9755815dfc0e9de449d147fe928a82948 --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/datasets/data_samplers.py @@ -0,0 +1,153 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Dataloaders.""" + +import torch +import random +from colossalai.core import global_context as gpc +from colossalai.context import ParallelMode + + +def build_pretraining_data_loader(dataset, consumed_samples, micro_batch_size, dataloader_type='single', num_workers=0): + """Build dataloader given an input dataset.""" + + if dataset is None: + return None + + # Megatron sampler + if dataloader_type == 'single': + batch_sampler = MegatronPretrainingSampler(total_samples=len(dataset), + consumed_samples=consumed_samples, + micro_batch_size=micro_batch_size, + data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA), + data_parallel_size=gpc.get_world_size(ParallelMode.DATA)) + elif dataloader_type == 'cyclic': + batch_sampler = MegatronPretrainingRandomSampler(total_samples=len(dataset), + consumed_samples=consumed_samples, + micro_batch_size=micro_batch_size, + data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA), + data_parallel_size=gpc.get_world_size(ParallelMode.DATA)) + else: + raise Exception('{} dataloader type is not supported.'.format(dataloader_type)) + + # Torch dataloader. + return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True) + + +class MegatronPretrainingSampler: + + def __init__(self, + total_samples, + consumed_samples, + micro_batch_size, + data_parallel_rank, + data_parallel_size, + drop_last=True): + # Keep a copy of input params for later use. + self.total_samples = total_samples + self.consumed_samples = consumed_samples + self.micro_batch_size = micro_batch_size + self.data_parallel_rank = data_parallel_rank + self.micro_batch_times_data_parallel_size = \ + self.micro_batch_size * data_parallel_size + self.drop_last = drop_last + + # Sanity checks. + assert self.total_samples > 0, \ + 'no sample to consume: {}'.format(self.total_samples) + assert self.consumed_samples < self.total_samples, \ + 'no samples left to consume: {}, {}'.format(self.consumed_samples, + self.total_samples) + assert self.micro_batch_size > 0 + assert data_parallel_size > 0 + assert self.data_parallel_rank < data_parallel_size, \ + 'data_parallel_rank should be smaller than data size: {}, ' \ + '{}'.format(self.data_parallel_rank, data_parallel_size) + + def __len__(self): + return self.total_samples + + def get_start_end_idx(self): + start_idx = self.data_parallel_rank * self.micro_batch_size + end_idx = start_idx + self.micro_batch_size + return start_idx, end_idx + + def __iter__(self): + batch = [] + # Last batch will be dropped if drop_last is not set False + for idx in range(self.consumed_samples, self.total_samples): + batch.append(idx) + if len(batch) == self.micro_batch_times_data_parallel_size: + start_idx, end_idx = self.get_start_end_idx() + yield batch[start_idx:end_idx] + batch = [] + + # Check the last partial batch and see drop_last is set + if len(batch) > 0 and not self.drop_last: + start_idx, end_idx = self.get_start_end_idx() + yield batch[start_idx:end_idx] + + +class MegatronPretrainingRandomSampler: + + def __init__(self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size): + # Keep a copy of input params for later use. + self.total_samples = total_samples + self.consumed_samples = consumed_samples + self.micro_batch_size = micro_batch_size + self.data_parallel_rank = data_parallel_rank + self.data_parallel_size = data_parallel_size + self.micro_batch_times_data_parallel_size = \ + self.micro_batch_size * data_parallel_size + self.last_batch_size = \ + self.total_samples % self.micro_batch_times_data_parallel_size + + # Sanity checks. + assert self.total_samples > 0, \ + 'no sample to consume: {}'.format(self.total_samples) + assert self.micro_batch_size > 0 + assert data_parallel_size > 0 + assert self.data_parallel_rank < data_parallel_size, \ + 'data_parallel_rank should be smaller than data size: {}, ' \ + '{}'.format(self.data_parallel_rank, data_parallel_size) + + def __len__(self): + return self.total_samples + + def __iter__(self): + active_total_samples = self.total_samples - self.last_batch_size + self.epoch = self.consumed_samples // active_total_samples + current_epoch_samples = self.consumed_samples % active_total_samples + assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0 + + # data sharding and random sampling + bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \ + * self.micro_batch_size + bucket_offset = current_epoch_samples // self.data_parallel_size + start_idx = self.data_parallel_rank * bucket_size + + g = torch.Generator() + g.manual_seed(self.epoch) + random_idx = torch.randperm(bucket_size, generator=g).tolist() + idx_range = [start_idx + x for x in random_idx[bucket_offset:]] + + batch = [] + # Last batch if not complete will be dropped. + for idx in idx_range: + batch.append(idx) + if len(batch) == self.micro_batch_size: + self.consumed_samples += self.micro_batch_times_data_parallel_size + yield batch + batch = [] diff --git a/examples/tutorial/sequence_parallel/data/datasets/dataset_utils.py b/examples/tutorial/sequence_parallel/data/datasets/dataset_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cf4e4763fc107ca9da3f063037bf08f4efc67cdd --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/datasets/dataset_utils.py @@ -0,0 +1,592 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors, and NVIDIA. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Most of the code here has been copied from: +# https://github.com/google-research/albert/blob/master/create_pretraining_data.py +# with some modifications. + +import math +import time +import collections +from colossalai.logging import get_dist_logger +import numpy as np +from .blendable_dataset import BlendableDataset +from .indexed_dataset import make_dataset as make_indexed_dataset + +DSET_TYPE_STD = 'standard_bert' +DSET_TYPE_ICT = 'ict' + +DSET_TYPES = [DSET_TYPE_ICT, DSET_TYPE_STD] + + +def get_datasets_weights_and_num_samples(data_prefix, + train_valid_test_num_samples): + + # The data prefix should be in the format of: + # weight-1, data-prefix-1, weight-2, data-prefix-2, .. + assert len(data_prefix) % 2 == 0 + num_datasets = len(data_prefix) // 2 + weights = [0]*num_datasets + prefixes = [0]*num_datasets + for i in range(num_datasets): + weights[i] = float(data_prefix[2*i]) + prefixes[i] = (data_prefix[2*i+1]).strip() + # Normalize weights + weight_sum = 0.0 + for weight in weights: + weight_sum += weight + assert weight_sum > 0.0 + weights = [weight / weight_sum for weight in weights] + + # Add 0.5% (the 1.005 factor) so in case the bleding dataset does + # not uniformly distribute the number of samples, we still have + # samples left to feed to the network. + datasets_train_valid_test_num_samples = [] + for weight in weights: + datasets_train_valid_test_num_samples.append( + [int(math.ceil(val * weight * 1.005)) + for val in train_valid_test_num_samples]) + + return prefixes, weights, datasets_train_valid_test_num_samples + + +def compile_helper(): + """Compile helper function ar runtime. Make sure this + is invoked on a single process.""" + import os + import subprocess + path = os.path.abspath(os.path.dirname(__file__)) + ret = subprocess.run(['make', '-C', path]) + if ret.returncode != 0: + print("Making C++ dataset helpers module failed, exiting.") + import sys + sys.exit(1) + + +def get_a_and_b_segments(sample, np_rng): + """Divide sample into a and b segments.""" + + # Number of sentences in the sample. + n_sentences = len(sample) + # Make sure we always have two sentences. + assert n_sentences > 1, 'make sure each sample has at least two sentences.' + + # First part: + # `a_end` is how many sentences go into the `A`. + a_end = 1 + if n_sentences >= 3: + # Note that randin in numpy is exclusive. + a_end = np_rng.randint(1, n_sentences) + tokens_a = [] + for j in range(a_end): + tokens_a.extend(sample[j]) + + # Second part: + tokens_b = [] + for j in range(a_end, n_sentences): + tokens_b.extend(sample[j]) + + # Random next: + is_next_random = False + if np_rng.random() < 0.5: + is_next_random = True + tokens_a, tokens_b = tokens_b, tokens_a + + return tokens_a, tokens_b, is_next_random + + +def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng): + """Truncates a pair of sequences to a maximum sequence length.""" + #print(len_a, len_b, max_num_tokens) + assert len_a > 0 + if len_a + len_b <= max_num_tokens: + return False + while len_a + len_b > max_num_tokens: + if len_a > len_b: + len_a -= 1 + tokens = tokens_a + else: + len_b -= 1 + tokens = tokens_b + if np_rng.random() < 0.5: + del tokens[0] + else: + tokens.pop() + return True + + +def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id): + """Merge segments A and B, add [CLS] and [SEP] and build tokentypes.""" + + tokens = [] + tokentypes = [] + # [CLS]. + tokens.append(cls_id) + tokentypes.append(0) + # Segment A. + for token in tokens_a: + tokens.append(token) + tokentypes.append(0) + # [SEP]. + tokens.append(sep_id) + tokentypes.append(0) + # Segment B. + for token in tokens_b: + tokens.append(token) + tokentypes.append(1) + if tokens_b: + # [SEP]. + tokens.append(sep_id) + tokentypes.append(1) + + return tokens, tokentypes + + +MaskedLmInstance = collections.namedtuple("MaskedLmInstance", + ["index", "label"]) + + +def is_start_piece(piece): + """Check if the current word piece is the starting piece (BERT).""" + # When a word has been split into + # WordPieces, the first token does not have any marker and any subsequence + # tokens are prefixed with ##. So whenever we see the ## token, we + # append it to the previous set of word indexes. + return not piece.startswith("##") + + +def create_masked_lm_predictions(tokens, + vocab_id_list, vocab_id_to_token_dict, + masked_lm_prob, + cls_id, sep_id, mask_id, + max_predictions_per_seq, + np_rng, + max_ngrams=3, + do_whole_word_mask=True, + favor_longer_ngram=False, + do_permutation=False): + """Creates the predictions for the masked LM objective. + Note: Tokens here are vocab ids and not text tokens.""" + + cand_indexes = [] + # Note(mingdachen): We create a list for recording if the piece is + # the starting piece of current token, where 1 means true, so that + # on-the-fly whole word masking is possible. + token_boundary = [0] * len(tokens) + + for (i, token) in enumerate(tokens): + if token == cls_id or token == sep_id: + token_boundary[i] = 1 + continue + # Whole Word Masking means that if we mask all of the wordpieces + # corresponding to an original word. + # + # Note that Whole Word Masking does *not* change the training code + # at all -- we still predict each WordPiece independently, softmaxed + # over the entire vocabulary. + if (do_whole_word_mask and len(cand_indexes) >= 1 and + not is_start_piece(vocab_id_to_token_dict[token])): + cand_indexes[-1].append(i) + else: + cand_indexes.append([i]) + if is_start_piece(vocab_id_to_token_dict[token]): + token_boundary[i] = 1 + + output_tokens = list(tokens) + + masked_lm_positions = [] + masked_lm_labels = [] + + if masked_lm_prob == 0: + return (output_tokens, masked_lm_positions, + masked_lm_labels, token_boundary) + + num_to_predict = min(max_predictions_per_seq, + max(1, int(round(len(tokens) * masked_lm_prob)))) + + # Note(mingdachen): + # By default, we set the probabilities to favor shorter ngram sequences. + ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64) + pvals = 1. / np.arange(1, max_ngrams + 1) + pvals /= pvals.sum(keepdims=True) + + if favor_longer_ngram: + pvals = pvals[::-1] + + ngram_indexes = [] + for idx in range(len(cand_indexes)): + ngram_index = [] + for n in ngrams: + ngram_index.append(cand_indexes[idx:idx + n]) + ngram_indexes.append(ngram_index) + + np_rng.shuffle(ngram_indexes) + + masked_lms = [] + covered_indexes = set() + for cand_index_set in ngram_indexes: + if len(masked_lms) >= num_to_predict: + break + if not cand_index_set: + continue + # Note(mingdachen): + # Skip current piece if they are covered in lm masking or previous ngrams. + for index_set in cand_index_set[0]: + for index in index_set: + if index in covered_indexes: + continue + + n = np_rng.choice(ngrams[:len(cand_index_set)], + p=pvals[:len(cand_index_set)] / + pvals[:len(cand_index_set)].sum(keepdims=True)) + index_set = sum(cand_index_set[n - 1], []) + n -= 1 + # Note(mingdachen): + # Repeatedly looking for a candidate that does not exceed the + # maximum number of predictions by trying shorter ngrams. + while len(masked_lms) + len(index_set) > num_to_predict: + if n == 0: + break + index_set = sum(cand_index_set[n - 1], []) + n -= 1 + # If adding a whole-word mask would exceed the maximum number of + # predictions, then just skip this candidate. + if len(masked_lms) + len(index_set) > num_to_predict: + continue + is_any_index_covered = False + for index in index_set: + if index in covered_indexes: + is_any_index_covered = True + break + if is_any_index_covered: + continue + for index in index_set: + covered_indexes.add(index) + + masked_token = None + # 80% of the time, replace with [MASK] + if np_rng.random() < 0.8: + masked_token = mask_id + else: + # 10% of the time, keep original + if np_rng.random() < 0.5: + masked_token = tokens[index] + # 10% of the time, replace with random word + else: + masked_token = vocab_id_list[np_rng.randint(0, len(vocab_id_list))] + + output_tokens[index] = masked_token + + masked_lms.append(MaskedLmInstance(index=index, label=tokens[index])) + assert len(masked_lms) <= num_to_predict + + np_rng.shuffle(ngram_indexes) + + select_indexes = set() + if do_permutation: + for cand_index_set in ngram_indexes: + if len(select_indexes) >= num_to_predict: + break + if not cand_index_set: + continue + # Note(mingdachen): + # Skip current piece if they are covered in lm masking or previous ngrams. + for index_set in cand_index_set[0]: + for index in index_set: + if index in covered_indexes or index in select_indexes: + continue + + n = np.random.choice(ngrams[:len(cand_index_set)], + p=pvals[:len(cand_index_set)] / + pvals[:len(cand_index_set)].sum(keepdims=True)) + index_set = sum(cand_index_set[n - 1], []) + n -= 1 + + while len(select_indexes) + len(index_set) > num_to_predict: + if n == 0: + break + index_set = sum(cand_index_set[n - 1], []) + n -= 1 + # If adding a whole-word mask would exceed the maximum number of + # predictions, then just skip this candidate. + if len(select_indexes) + len(index_set) > num_to_predict: + continue + is_any_index_covered = False + for index in index_set: + if index in covered_indexes or index in select_indexes: + is_any_index_covered = True + break + if is_any_index_covered: + continue + for index in index_set: + select_indexes.add(index) + assert len(select_indexes) <= num_to_predict + + select_indexes = sorted(select_indexes) + permute_indexes = list(select_indexes) + np_rng.shuffle(permute_indexes) + orig_token = list(output_tokens) + + for src_i, tgt_i in zip(select_indexes, permute_indexes): + output_tokens[src_i] = orig_token[tgt_i] + masked_lms.append(MaskedLmInstance(index=src_i, label=orig_token[src_i])) + + masked_lms = sorted(masked_lms, key=lambda x: x.index) + + for p in masked_lms: + masked_lm_positions.append(p.index) + masked_lm_labels.append(p.label) + + return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary) + + +def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, + masked_labels, pad_id, max_seq_length): + """Pad sequences and convert them to numpy.""" + + # Some checks. + num_tokens = len(tokens) + padding_length = max_seq_length - num_tokens + assert padding_length >= 0 + assert len(tokentypes) == num_tokens + assert len(masked_positions) == len(masked_labels) + + # Tokens and token types. + filler = [pad_id] * padding_length + tokens_np = np.array(tokens + filler, dtype=np.int64) + tokentypes_np = np.array(tokentypes + filler, dtype=np.int64) + + # Padding mask. + padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, + dtype=np.int64) + + # Lables and loss mask. + labels = [-1] * max_seq_length + loss_mask = [0] * max_seq_length + for i in range(len(masked_positions)): + assert masked_positions[i] < num_tokens + labels[masked_positions[i]] = masked_labels[i] + loss_mask[masked_positions[i]] = 1 + labels_np = np.array(labels, dtype=np.int64) + loss_mask_np = np.array(loss_mask, dtype=np.int64) + + return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np + + +def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, + train_valid_test_num_samples, + max_seq_length, masked_lm_prob, + short_seq_prob, seed, skip_warmup, + binary_head, + dataset_type='standard_bert'): + + if len(data_prefix) == 1: + return _build_train_valid_test_datasets(data_prefix[0], + data_impl, splits_string, + train_valid_test_num_samples, + max_seq_length, masked_lm_prob, + short_seq_prob, seed, + skip_warmup, + binary_head, + dataset_type=dataset_type) + # Blending dataset. + # Parse the values. + output = get_datasets_weights_and_num_samples(data_prefix, + train_valid_test_num_samples) + prefixes, weights, datasets_train_valid_test_num_samples = output + + # Build individual datasets. + train_datasets = [] + valid_datasets = [] + test_datasets = [] + for i in range(len(prefixes)): + train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( + prefixes[i], data_impl, splits_string, + datasets_train_valid_test_num_samples[i], + max_seq_length, masked_lm_prob, short_seq_prob, + seed, skip_warmup, binary_head, dataset_type=dataset_type) + if train_ds: + train_datasets.append(train_ds) + if valid_ds: + valid_datasets.append(valid_ds) + if test_ds: + test_datasets.append(test_ds) + + # Blend. + blending_train_dataset = None + if train_datasets: + blending_train_dataset = BlendableDataset(train_datasets, weights) + blending_valid_dataset = None + if valid_datasets: + blending_valid_dataset = BlendableDataset(valid_datasets, weights) + blending_test_dataset = None + if test_datasets: + blending_test_dataset = BlendableDataset(test_datasets, weights) + + return (blending_train_dataset, blending_valid_dataset, + blending_test_dataset) + + +def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, + train_valid_test_num_samples, + max_seq_length, masked_lm_prob, + short_seq_prob, seed, skip_warmup, + binary_head, + dataset_type='standard_bert'): + logger = get_dist_logger() + + if dataset_type not in DSET_TYPES: + raise ValueError("Invalid dataset_type: ", dataset_type) + + # Indexed dataset. + indexed_dataset = get_indexed_dataset_(data_prefix, + data_impl, + skip_warmup) + + if dataset_type == DSET_TYPE_ICT: + args = get_args() + title_dataset = get_indexed_dataset_(args.titles_data_path, + data_impl, + skip_warmup) + + # Get start and end indices of train/valid/train into doc-idx + # Note that doc-idx is designed to be num-docs + 1 so we can + # easily iterate over it. + total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1 + splits = get_train_valid_test_split_(splits_string, total_num_of_documents) + + # Print stats about the splits. + logger.info('\n > dataset split:') + + def print_split_stats(name, index): + start_index = indexed_dataset.doc_idx[splits[index]] + end_index = indexed_dataset.doc_idx[splits[index + 1]] + logger.info('\n {}:'.format(name) + + '\n document indices in [{}, {}) total of {} documents'.format( + splits[index], + splits[index + 1], + splits[index + 1] - splits[index]) + + '\n sentence indices in [{}, {}) total of {} sentences'.format( + start_index, + end_index, + end_index - start_index), + ranks=[0]) + print_split_stats('train', 0) + print_split_stats('validation', 1) + print_split_stats('test', 2) + + def build_dataset(index, name): + from .bert_dataset import BertDataset + dataset = None + if splits[index + 1] > splits[index]: + # Get the pointer to the original doc-idx so we can set it later. + doc_idx_ptr = indexed_dataset.get_doc_idx() + # Slice the doc-idx + start_index = splits[index] + # Add +1 so we can index into the dataset to get the upper bound. + end_index = splits[index + 1] + 1 + # New doc_idx view. + indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index]) + # Build the dataset accordingly. + kwargs = dict( + name=name, + data_prefix=data_prefix, + num_epochs=None, + max_num_samples=train_valid_test_num_samples[index], + max_seq_length=max_seq_length, + seed=seed, + binary_head=binary_head + ) + + if dataset_type == DSET_TYPE_ICT: + args = get_args() + dataset = ICTDataset( + block_dataset=indexed_dataset, + title_dataset=title_dataset, + query_in_block_prob=args.query_in_block_prob, + use_one_sent_docs=args.use_one_sent_docs, + **kwargs + ) + else: + dataset = BertDataset( + indexed_dataset=indexed_dataset, + masked_lm_prob=masked_lm_prob, + short_seq_prob=short_seq_prob, + **kwargs + ) + + # Set the original pointer so dataset remains the main dataset. + indexed_dataset.set_doc_idx(doc_idx_ptr) + # Checks. + assert indexed_dataset.doc_idx[0] == 0 + assert indexed_dataset.doc_idx.shape[0] == \ + (total_num_of_documents + 1) + return dataset + + train_dataset = build_dataset(0, 'train') + valid_dataset = build_dataset(1, 'valid') + test_dataset = build_dataset(2, 'test') + + return (train_dataset, valid_dataset, test_dataset) + + +def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): + logger = get_dist_logger() + start_time = time.time() + indexed_dataset = make_indexed_dataset(data_prefix, + data_impl, + skip_warmup) + assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1] + logger.info('\n > building dataset index ...', ranks=[0]) + logger.info('\n > finished creating indexed dataset in {:4f} ' + 'seconds'.format(time.time() - start_time), ranks=[0]) + logger.info('\n > indexed dataset stats:' + + '\n number of documents: {}'.format( + indexed_dataset.doc_idx.shape[0] - 1) + + '\n number of sentences: {}'.format( + indexed_dataset.sizes.shape[0]), + ranks=[0] + ) + + return indexed_dataset + + +def get_train_valid_test_split_(splits_string, size): + """ Get dataset splits from comma or '/' separated string list.""" + + splits = [] + if splits_string.find(',') != -1: + splits = [float(s) for s in splits_string.split(',')] + elif splits_string.find('/') != -1: + splits = [float(s) for s in splits_string.split('/')] + else: + splits = [float(splits_string)] + while len(splits) < 3: + splits.append(0.) + splits = splits[:3] + splits_sum = sum(splits) + assert splits_sum > 0.0 + splits = [split / splits_sum for split in splits] + splits_index = [0] + for index, split in enumerate(splits): + splits_index.append(splits_index[index] + + int(round(split * float(size)))) + diff = splits_index[-1] - size + for index in range(1, len(splits_index)): + splits_index[index] -= diff + assert len(splits_index) == 4 + assert splits_index[-1] == size + return splits_index diff --git a/examples/tutorial/sequence_parallel/data/datasets/helpers.cpp b/examples/tutorial/sequence_parallel/data/datasets/helpers.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e45926a976961eb5094658ba478cb697a88c8000 --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/datasets/helpers.cpp @@ -0,0 +1,717 @@ +/* + coding=utf-8 + Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + + +/* Helper methods for fast index mapping builds */ + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; +using namespace std; + +const int32_t LONG_SENTENCE_LEN = 512; + + +void build_blending_indices(py::array_t& dataset_index, + py::array_t& dataset_sample_index, + const py::array_t& weights, + const int32_t num_datasets, + const int64_t size, const bool verbose) { + /* Given multiple datasets and a weighting array, build samples + such that it follows those wieghts.*/ + + if (verbose) { + std::cout << "> building indices for blendable datasets ..." << std::endl; + } + + // Get the pointer access without the checks. + auto dataset_index_ptr = dataset_index.mutable_unchecked<1>(); + auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>(); + auto weights_ptr = weights.unchecked<1>(); + + // Initialize buffer for number of samples used for each dataset. + int64_t current_samples[num_datasets]; + for(int64_t i = 0; i < num_datasets; ++i) { + current_samples[i] = 0; + } + + // For each sample: + for(int64_t sample_idx = 0; sample_idx < size; ++sample_idx) { + + // Determine where the max error in sampling is happening. + auto sample_idx_double = std::max(static_cast(sample_idx), 1.0); + int64_t max_error_index = 0; + double max_error = weights_ptr[0] * sample_idx_double - + static_cast(current_samples[0]); + for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) { + double error = weights_ptr[dataset_idx] * sample_idx_double - + static_cast(current_samples[dataset_idx]); + if (error > max_error) { + max_error = error; + max_error_index = dataset_idx; + } + } + + // Populate the indices. + dataset_index_ptr[sample_idx] = static_cast(max_error_index); + dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index]; + + // Update the total samples. + current_samples[max_error_index] += 1; + + } + + // print info + if (verbose) { + std::cout << " > sample ratios:" << std::endl; + for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) { + auto ratio = static_cast(current_samples[dataset_idx]) / + static_cast(size); + std::cout << " dataset " << dataset_idx << ", input: " << + weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl; + } + } + +} + + +py::array build_sample_idx(const py::array_t& sizes_, + const py::array_t& doc_idx_, + const int32_t seq_length, + const int32_t num_epochs, + const int64_t tokens_per_epoch) { + /* Sample index (sample_idx) is used for gpt2 like dataset for which + the documents are flattened and the samples are built based on this + 1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2] + where [..., 0] contains the index into `doc_idx` and [..., 1] is the + starting offset in that document.*/ + + // Consistency checks. + assert(seq_length > 1); + assert(num_epochs > 0); + assert(tokens_per_epoch > 1); + + // Remove bound checks. + auto sizes = sizes_.unchecked<1>(); + auto doc_idx = doc_idx_.unchecked<1>(); + + // Mapping and it's length (1D). + int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length; + int32_t* sample_idx = new int32_t[2*(num_samples+1)]; + + cout << " using:" << endl << std::flush; + cout << " number of documents: " << + doc_idx_.shape(0) / num_epochs << endl << std::flush; + cout << " number of epochs: " << num_epochs << + endl << std::flush; + cout << " sequence length: " << seq_length << + endl << std::flush; + cout << " total number of samples: " << num_samples << + endl << std::flush; + + // Index into sample_idx. + int64_t sample_index = 0; + // Index into doc_idx. + int64_t doc_idx_index = 0; + // Begining offset for each document. + int32_t doc_offset = 0; + // Start with first document and no offset. + sample_idx[2 * sample_index] = doc_idx_index; + sample_idx[2 * sample_index + 1] = doc_offset; + ++sample_index; + + while (sample_index <= num_samples) { + // Start with a fresh sequence. + int32_t remaining_seq_length = seq_length + 1; + while (remaining_seq_length != 0) { + // Get the document length. + auto doc_id = doc_idx[doc_idx_index]; + auto doc_length = sizes[doc_id] - doc_offset; + // And add it to the current sequence. + remaining_seq_length -= doc_length; + // If we have more than a full sequence, adjust offset and set + // remaining length to zero so we return from the while loop. + // Note that -1 here is for the same reason we have -1 in + // `_num_epochs` calculations. + if (remaining_seq_length <= 0) { + doc_offset += (remaining_seq_length + doc_length - 1); + remaining_seq_length = 0; + } else { + // Otherwise, start from the begining of the next document. + ++doc_idx_index; + doc_offset = 0; + } + } + // Record the sequence. + sample_idx[2 * sample_index] = doc_idx_index; + sample_idx[2 * sample_index + 1] = doc_offset; + ++sample_index; + } + + // Method to deallocate memory. + py::capsule free_when_done(sample_idx, [](void *mem_) { + int32_t *mem = reinterpret_cast(mem_); + delete[] mem; + }); + + // Return the numpy array. + const auto byte_size = sizeof(int32_t); + return py::array(std::vector{num_samples+1, 2}, // shape + {2*byte_size, byte_size}, // C-style contiguous strides + sample_idx, // the data pointer + free_when_done); // numpy array references + +} + + +inline int32_t get_target_sample_len(const int32_t short_seq_ratio, + const int32_t max_length, + std::mt19937& rand32_gen) { + /* Training sample length. */ + if (short_seq_ratio == 0) { + return max_length; + } + const auto random_number = rand32_gen(); + if ((random_number % short_seq_ratio) == 0) { + return 2 + random_number % (max_length - 1); + } + return max_length; +} + + +template +py::array build_mapping_impl(const py::array_t& docs_, + const py::array_t& sizes_, + const int32_t num_epochs, + const uint64_t max_num_samples, + const int32_t max_seq_length, + const double short_seq_prob, + const int32_t seed, + const bool verbose, + const int32_t min_num_sent) { + /* Build a mapping of (start-index, end-index, sequence-length) where + start and end index are the indices of the sentences in the sample + and sequence-length is the target sequence length. + */ + + // Consistency checks. + assert(num_epochs > 0); + assert(max_seq_length > 1); + assert(short_seq_prob >= 0.0); + assert(short_seq_prob <= 1.0); + assert(seed > 0); + + // Remove bound checks. + auto docs = docs_.unchecked<1>(); + auto sizes = sizes_.unchecked<1>(); + + // For efficiency, convert probability to ratio. Note: rand() generates int. + int32_t short_seq_ratio = 0; + if (short_seq_prob > 0) { + short_seq_ratio = static_cast(round(1.0 / short_seq_prob)); + } + + if (verbose) { + const auto sent_start_index = docs[0]; + const auto sent_end_index = docs[docs_.shape(0) - 1]; + const auto num_sentences = sent_end_index - sent_start_index; + cout << " using:" << endl << std::flush; + cout << " number of documents: " << docs_.shape(0) - 1 << + endl << std::flush; + cout << " sentences range: [" << sent_start_index << + ", " << sent_end_index << ")" << endl << std::flush; + cout << " total number of sentences: " << num_sentences << + endl << std::flush; + cout << " number of epochs: " << num_epochs << + endl << std::flush; + cout << " maximum number of samples: " << max_num_samples << + endl << std::flush; + cout << " maximum sequence length: " << max_seq_length << + endl << std::flush; + cout << " short sequence probability: " << short_seq_prob << + endl << std::flush; + cout << " short sequence ration (1/prob): " << short_seq_ratio << + endl << std::flush; + cout << " seed: " << seed << endl << + std::flush; + } + + // Mapping and it's length (1D). + int64_t num_samples = -1; + DocIdx* maps = NULL; + + // Perform two iterations, in the first iteration get the size + // and allocate memory and in the second iteration populate the map. + bool second = false; + for (int32_t iteration=0; iteration<2; ++iteration) { + + // Set the seed so both iterations produce the same results. + std::mt19937 rand32_gen(seed); + + // Set the flag on second iteration. + second = (iteration == 1); + + // Counters: + uint64_t empty_docs = 0; + uint64_t one_sent_docs = 0; + uint64_t long_sent_docs = 0; + + // Current map index. + uint64_t map_index = 0; + + // For each epoch: + for (int32_t epoch=0; epoch= max_num_samples) { + if (verbose && (!second)) { + cout << " reached " << max_num_samples << " samples after " + << epoch << " epochs ..." << endl << std::flush; + } + break; + } + // For each document: + for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) { + + // Document sentences are in [sent_index_first, sent_index_last) + const auto sent_index_first = docs[doc]; + const auto sent_index_last = docs[doc + 1]; + + // At the begining of the document previous index is the + // start index. + auto prev_start_index = sent_index_first; + + // Remaining documents. + auto num_remain_sent = sent_index_last - sent_index_first; + + // Some bookkeeping + if ((epoch == 0) && (!second)) { + if (num_remain_sent == 0) { + ++empty_docs; + } + if (num_remain_sent == 1) { + ++one_sent_docs; + } + } + + // Detect documents with long sentences. + bool contains_long_sentence = false; + if (num_remain_sent > 1) { + for (auto sent_index=sent_index_first; + sent_index < sent_index_last; ++sent_index) { + if (sizes[sent_index] > LONG_SENTENCE_LEN){ + if ((epoch == 0) && (!second)) { + ++long_sent_docs; + } + contains_long_sentence = true; + break; + } + } + } + + // If we have more than two sentences. + if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) { + + // Set values. + auto seq_len = int32_t{0}; + auto num_sent = int32_t{0}; + auto target_seq_len = get_target_sample_len(short_seq_ratio, + max_seq_length, + rand32_gen); + + // Loop through sentences. + for (auto sent_index=sent_index_first; + sent_index < sent_index_last; ++sent_index) { + + // Add the size and number of sentences. + seq_len += sizes[sent_index]; + ++num_sent; + --num_remain_sent; + + // If we have reached the target length. + // and if not only one sentence is left in the document. + // and if we have at least two sentneces. + // and if we have reached end of the document. + if (((seq_len >= target_seq_len) && + (num_remain_sent > 1) && + (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) { + + // Check for overflow. + if ((3 * map_index + 2) > + std::numeric_limits::max()) { + cout << "number of samples exceeded maximum " + << "allowed by type int64: " + << std::numeric_limits::max() + << endl; + throw std::overflow_error("Number of samples"); + } + + // Populate the map. + if (second) { + const auto map_index_0 = 3 * map_index; + maps[map_index_0] = static_cast(prev_start_index); + maps[map_index_0 + 1] = static_cast(sent_index + 1); + maps[map_index_0 + 2] = static_cast(target_seq_len); + } + + // Update indices / counters. + ++map_index; + prev_start_index = sent_index + 1; + target_seq_len = get_target_sample_len(short_seq_ratio, + max_seq_length, + rand32_gen); + seq_len = 0; + num_sent = 0; + } + + } // for (auto sent_index=sent_index_first; ... + } // if (num_remain_sent > 1) { + } // for (int doc=0; doc < num_docs; ++doc) { + } // for (int epoch=0; epoch < num_epochs; ++epoch) { + + if (!second) { + if (verbose) { + cout << " number of empty documents: " << empty_docs << + endl << std::flush; + cout << " number of documents with one sentence: " << + one_sent_docs << endl << std::flush; + cout << " number of documents with long sentences: " << + long_sent_docs << endl << std::flush; + cout << " will create mapping for " << map_index << + " samples" << endl << std::flush; + } + assert(maps == NULL); + assert(num_samples < 0); + maps = new DocIdx[3*map_index]; + num_samples = static_cast(map_index); + } + + } // for (int iteration=0; iteration < 2; ++iteration) { + + // Shuffle. + // We need a 64 bit random number generator as we might have more + // than 2 billion samples. + std::mt19937_64 rand64_gen(seed + 1); + for (auto i=(num_samples - 1); i > 0; --i) { + const auto j = static_cast(rand64_gen() % (i + 1)); + const auto i0 = 3 * i; + const auto j0 = 3 * j; + // Swap values. + swap(maps[i0], maps[j0]); + swap(maps[i0 + 1], maps[j0 + 1]); + swap(maps[i0 + 2], maps[j0 + 2]); + } + + // Method to deallocate memory. + py::capsule free_when_done(maps, [](void *mem_) { + DocIdx *mem = reinterpret_cast(mem_); + delete[] mem; + }); + + // Return the numpy array. + const auto byte_size = sizeof(DocIdx); + return py::array(std::vector{num_samples, 3}, // shape + {3*byte_size, byte_size}, // C-style contiguous strides + maps, // the data pointer + free_when_done); // numpy array references + +} + + +py::array build_mapping(const py::array_t& docs_, + const py::array_t& sizes_, + const int num_epochs, + const uint64_t max_num_samples, + const int max_seq_length, + const double short_seq_prob, + const int seed, + const bool verbose, + const int32_t min_num_sent) { + + if (sizes_.size() > std::numeric_limits::max()) { + if (verbose) { + cout << " using uint64 for data mapping..." << endl << std::flush; + } + return build_mapping_impl(docs_, sizes_, num_epochs, + max_num_samples, max_seq_length, + short_seq_prob, seed, verbose, + min_num_sent); + } else { + if (verbose) { + cout << " using uint32 for data mapping..." << endl << std::flush; + } + return build_mapping_impl(docs_, sizes_, num_epochs, + max_num_samples, max_seq_length, + short_seq_prob, seed, verbose, + min_num_sent); + } +} + +template +py::array build_blocks_mapping_impl(const py::array_t& docs_, + const py::array_t& sizes_, + const py::array_t& titles_sizes_, + const int32_t num_epochs, + const uint64_t max_num_samples, + const int32_t max_seq_length, + const int32_t seed, + const bool verbose, + const bool use_one_sent_blocks) { + /* Build a mapping of (start-index, end-index, sequence-length) where + start and end index are the indices of the sentences in the sample + and sequence-length is the target sequence length. + */ + + // Consistency checks. + assert(num_epochs > 0); + assert(max_seq_length > 1); + assert(seed > 0); + + // Remove bound checks. + auto docs = docs_.unchecked<1>(); + auto sizes = sizes_.unchecked<1>(); + auto titles_sizes = titles_sizes_.unchecked<1>(); + + if (verbose) { + const auto sent_start_index = docs[0]; + const auto sent_end_index = docs[docs_.shape(0) - 1]; + const auto num_sentences = sent_end_index - sent_start_index; + cout << " using:" << endl << std::flush; + cout << " number of documents: " << docs_.shape(0) - 1 << + endl << std::flush; + cout << " sentences range: [" << sent_start_index << + ", " << sent_end_index << ")" << endl << std::flush; + cout << " total number of sentences: " << num_sentences << + endl << std::flush; + cout << " number of epochs: " << num_epochs << + endl << std::flush; + cout << " maximum number of samples: " << max_num_samples << + endl << std::flush; + cout << " maximum sequence length: " << max_seq_length << + endl << std::flush; + cout << " seed: " << seed << endl << + std::flush; + } + + // Mapping and its length (1D). + int64_t num_samples = -1; + DocIdx* maps = NULL; + + // Acceptable number of sentences per block. + int min_num_sent = 2; + if (use_one_sent_blocks) { + min_num_sent = 1; + } + + // Perform two iterations, in the first iteration get the size + // and allocate memory and in the second iteration populate the map. + bool second = false; + for (int32_t iteration=0; iteration<2; ++iteration) { + + // Set the flag on second iteration. + second = (iteration == 1); + + // Current map index. + uint64_t map_index = 0; + + uint64_t empty_docs = 0; + uint64_t one_sent_docs = 0; + uint64_t long_sent_docs = 0; + // For each epoch: + for (int32_t epoch=0; epoch= max_num_samples) { + if (verbose && (!second)) { + cout << " reached " << max_num_samples << " samples after " + << epoch << " epochs ..." << endl << std::flush; + } + break; + } + // For each document: + for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) { + + // Document sentences are in [sent_index_first, sent_index_last) + const auto sent_index_first = docs[doc]; + const auto sent_index_last = docs[doc + 1]; + const auto target_seq_len = max_seq_length - titles_sizes[doc]; + + // At the begining of the document previous index is the + // start index. + auto prev_start_index = sent_index_first; + + // Remaining documents. + auto num_remain_sent = sent_index_last - sent_index_first; + + // Some bookkeeping + if ((epoch == 0) && (!second)) { + if (num_remain_sent == 0) { + ++empty_docs; + } + if (num_remain_sent == 1) { + ++one_sent_docs; + } + } + // Detect documents with long sentences. + bool contains_long_sentence = false; + if (num_remain_sent >= min_num_sent) { + for (auto sent_index=sent_index_first; + sent_index < sent_index_last; ++sent_index) { + if (sizes[sent_index] > LONG_SENTENCE_LEN){ + if ((epoch == 0) && (!second)) { + ++long_sent_docs; + } + contains_long_sentence = true; + break; + } + } + } + // If we have enough sentences and no long sentences. + if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) { + + // Set values. + auto seq_len = int32_t{0}; + auto num_sent = int32_t{0}; + + // Loop through sentences. + for (auto sent_index=sent_index_first; + sent_index < sent_index_last; ++sent_index) { + + // Add the size and number of sentences. + seq_len += sizes[sent_index]; + ++num_sent; + --num_remain_sent; + + // If we have reached the target length. + // and there are an acceptable number of sentences left + // and if we have at least the minimum number of sentences. + // or if we have reached end of the document. + if (((seq_len >= target_seq_len) && + (num_remain_sent >= min_num_sent) && + (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) { + + // Populate the map. + if (second) { + const auto map_index_0 = 4 * map_index; + // Each sample has 4 items: the starting sentence index, ending sentence index, + // the index of the document from which the block comes (used for fetching titles) + // and the unique id of the block (used for creating block indexes) + + maps[map_index_0] = static_cast(prev_start_index); + maps[map_index_0 + 1] = static_cast(sent_index + 1); + maps[map_index_0 + 2] = static_cast(doc); + maps[map_index_0 + 3] = static_cast(block_id); + } + + // Update indices / counters. + ++map_index; + ++block_id; + prev_start_index = sent_index + 1; + seq_len = 0; + num_sent = 0; + } + } // for (auto sent_index=sent_index_first; ... + } // if (num_remain_sent > 1) { + } // for (int doc=0; doc < num_docs; ++doc) { + } // for (int epoch=0; epoch < num_epochs; ++epoch) { + + if (!second) { + if (verbose) { + cout << " number of empty documents: " << empty_docs << + endl << std::flush; + cout << " number of documents with one sentence: " << + one_sent_docs << endl << std::flush; + cout << " number of documents with long sentences: " << + long_sent_docs << endl << std::flush; + cout << " will create mapping for " << map_index << + " samples" << endl << std::flush; + } + assert(maps == NULL); + assert(num_samples < 0); + maps = new DocIdx[4*map_index]; + num_samples = static_cast(map_index); + } + + } // for (int iteration=0; iteration < 2; ++iteration) { + + // Shuffle. + // We need a 64 bit random number generator as we might have more + // than 2 billion samples. + std::mt19937_64 rand64_gen(seed + 1); + for (auto i=(num_samples - 1); i > 0; --i) { + const auto j = static_cast(rand64_gen() % (i + 1)); + const auto i0 = 4 * i; + const auto j0 = 4 * j; + // Swap values. + swap(maps[i0], maps[j0]); + swap(maps[i0 + 1], maps[j0 + 1]); + swap(maps[i0 + 2], maps[j0 + 2]); + swap(maps[i0 + 3], maps[j0 + 3]); + } + + // Method to deallocate memory. + py::capsule free_when_done(maps, [](void *mem_) { + DocIdx *mem = reinterpret_cast(mem_); + delete[] mem; + }); + + // Return the numpy array. + const auto byte_size = sizeof(DocIdx); + return py::array(std::vector{num_samples, 4}, // shape + {4*byte_size, byte_size}, // C-style contiguous strides + maps, // the data pointer + free_when_done); // numpy array references + +} + +py::array build_blocks_mapping(const py::array_t& docs_, + const py::array_t& sizes_, + const py::array_t& titles_sizes_, + const int num_epochs, + const uint64_t max_num_samples, + const int max_seq_length, + const int seed, + const bool verbose, + const bool use_one_sent_blocks) { + + if (sizes_.size() > std::numeric_limits::max()) { + if (verbose) { + cout << " using uint64 for data mapping..." << endl << std::flush; + } + return build_blocks_mapping_impl(docs_, sizes_, titles_sizes_, + num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks); + } else { + if (verbose) { + cout << " using uint32 for data mapping..." << endl << std::flush; + } + return build_blocks_mapping_impl(docs_, sizes_, titles_sizes_, + num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks); + } +} + +PYBIND11_MODULE(helpers, m) { + m.def("build_mapping", &build_mapping); + m.def("build_blocks_mapping", &build_blocks_mapping); + m.def("build_sample_idx", &build_sample_idx); + m.def("build_blending_indices", &build_blending_indices); +} diff --git a/examples/tutorial/sequence_parallel/data/datasets/ict_dataset.py b/examples/tutorial/sequence_parallel/data/datasets/ict_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..6dac35ff9d413898146df5c1cc8553719e142105 --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/datasets/ict_dataset.py @@ -0,0 +1,156 @@ +import itertools +import random + +import numpy as np +from torch.utils.data import Dataset + +from megatron import get_tokenizer +from megatron import get_args +from megatron.data.dataset_utils import get_indexed_dataset_ +from megatron.data.realm_dataset_utils import get_block_samples_mapping + +def make_attention_mask(source_block, target_block): + """ + Returns a 2-dimensional (2-D) attention mask + :param source_block: 1-D array + :param target_block: 1-D array + """ + mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1) + mask = mask.astype(np.int64) + # (source_length, target_length) + return mask + +def get_ict_dataset(use_titles=True, query_in_block_prob=1): + """Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block()) + rather than for training, since it is only built with a single epoch sample mapping. + """ + args = get_args() + block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True) + titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True) + + kwargs = dict( + name='full', + block_dataset=block_dataset, + title_dataset=titles_dataset, + data_prefix=args.data_path, + num_epochs=1, + max_num_samples=None, + max_seq_length=args.seq_length, + seed=1, + query_in_block_prob=query_in_block_prob, + use_titles=use_titles, + use_one_sent_docs=args.use_one_sent_docs + ) + dataset = ICTDataset(**kwargs) + return dataset + + +class ICTDataset(Dataset): + """Dataset containing sentences and their blocks for an inverse cloze task.""" + def __init__(self, name, block_dataset, title_dataset, data_prefix, + num_epochs, max_num_samples, max_seq_length, query_in_block_prob, + seed, use_titles=True, use_one_sent_docs=False, binary_head=False): + self.name = name + self.seed = seed + self.max_seq_length = max_seq_length + self.query_in_block_prob = query_in_block_prob + self.block_dataset = block_dataset + self.title_dataset = title_dataset + self.rng = random.Random(self.seed) + self.use_titles = use_titles + self.use_one_sent_docs = use_one_sent_docs + + self.samples_mapping = get_block_samples_mapping( + block_dataset, title_dataset, data_prefix, num_epochs, + max_num_samples, max_seq_length, seed, name, use_one_sent_docs) + self.tokenizer = get_tokenizer() + self.vocab_id_list = list(self.tokenizer.inv_vocab.keys()) + self.vocab_id_to_token_list = self.tokenizer.inv_vocab + self.cls_id = self.tokenizer.cls + self.sep_id = self.tokenizer.sep + self.mask_id = self.tokenizer.mask + self.pad_id = self.tokenizer.pad + + def __len__(self): + return len(self.samples_mapping) + + def __getitem__(self, idx): + """Get an ICT example of a pseudo-query and the block of text from which it was extracted""" + sample_data = self.samples_mapping[idx] + start_idx, end_idx, doc_idx, block_idx = sample_data.as_tuple() + + if self.use_titles: + title = self.title_dataset[int(doc_idx)] + title_pad_offset = 3 + len(title) + else: + title = None + title_pad_offset = 2 + block = [self.block_dataset[i] for i in range(start_idx, end_idx)] + assert len(block) > 1 or self.use_one_sent_docs or self.query_in_block_prob == 1 + + # randint() is inclusive for Python rng + rand_sent_idx = self.rng.randint(0, len(block) - 1) + + # keep the query in the context query_in_block_prob fraction of the time. + if self.rng.random() < self.query_in_block_prob: + query = block[rand_sent_idx].copy() + else: + query = block.pop(rand_sent_idx) + + # still need to truncate because blocks are concluded when + # the sentence lengths have exceeded max_seq_length. + query = query[:self.max_seq_length - 2] + block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset] + + query_tokens, query_pad_mask = self.concat_and_pad_tokens(query) + context_tokens, context_pad_mask = self.concat_and_pad_tokens(block, title) + + query_mask = make_attention_mask(query_tokens, query_tokens) + context_mask = make_attention_mask(context_tokens, context_tokens) + + block_data = sample_data.as_array() + + sample = { + 'query_tokens': query_tokens, + 'query_mask': query_mask, + 'query_pad_mask': query_pad_mask, + 'context_tokens': context_tokens, + 'context_mask': context_mask, + 'context_pad_mask': context_pad_mask, + 'block_data': block_data, + } + + return sample + + def get_block(self, start_idx, end_idx, doc_idx): + """Get the IDs for an evidence block plus the title of the corresponding document""" + block = [self.block_dataset[i] for i in range(start_idx, end_idx)] + title = self.title_dataset[int(doc_idx)] + + block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))] + block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title) + + return block_tokens, block_pad_mask + + def get_null_block(self): + """Get empty block and title - used in REALM pretraining""" + block, title = [], [] + block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title) + + return block_tokens, block_pad_mask + + def concat_and_pad_tokens(self, tokens, title=None): + """Concat with special tokens and pad sequence to self.max_seq_length""" + tokens = list(tokens) + if title is None: + tokens = [self.cls_id] + tokens + [self.sep_id] + else: + title = list(title) + tokens = [self.cls_id] + title + [self.sep_id] + tokens + [self.sep_id] + assert len(tokens) <= self.max_seq_length + + num_pad = self.max_seq_length - len(tokens) + pad_mask = [1] * len(tokens) + [0] * num_pad + tokens += [self.pad_id] * num_pad + + return np.array(tokens), np.array(pad_mask) diff --git a/examples/tutorial/sequence_parallel/data/datasets/indexed_dataset.py b/examples/tutorial/sequence_parallel/data/datasets/indexed_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..b4febcd822e1d8640a69e2f31287caad0bba39f7 --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/datasets/indexed_dataset.py @@ -0,0 +1,569 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +# copied from fairseq/fairseq/data/indexed_dataset.py +# Removed IndexedRawTextDataset since it relied on Fairseq dictionary +# other slight modifications to remove fairseq dependencies +# Added document index to index file and made it accessible. +# An empty sentence no longer separates documents. + +from functools import lru_cache +import os +import shutil +import struct +from itertools import accumulate + +import numpy as np +import torch + + +def __best_fitting_dtype(vocab_size=None): + if vocab_size is not None and vocab_size < 65500: + return np.uint16 + else: + return np.int32 + + +def get_available_dataset_impl(): + return ['lazy', 'cached', 'mmap'] + + +def infer_dataset_impl(path): + if IndexedDataset.exists(path): + with open(index_file_path(path), 'rb') as f: + magic = f.read(8) + if magic == IndexedDataset._HDR_MAGIC: + return 'cached' + elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]: + return 'mmap' + else: + return None + else: + print(f"Dataset does not exist: {path}") + print("Path should be a basename that both .idx and .bin can be appended to get full filenames.") + return None + + +def make_builder(out_file, impl, vocab_size=None): + if impl == 'mmap': + return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size)) + else: + return IndexedDatasetBuilder(out_file) + + +def make_dataset(path, impl, skip_warmup=False): + if not IndexedDataset.exists(path): + print(f"Dataset does not exist: {path}") + print("Path should be a basename that both .idx and .bin can be appended to get full filenames.") + return None + if impl == 'infer': + impl = infer_dataset_impl(path) + if impl == 'lazy' and IndexedDataset.exists(path): + return IndexedDataset(path) + elif impl == 'cached' and IndexedDataset.exists(path): + return IndexedCachedDataset(path) + elif impl == 'mmap' and MMapIndexedDataset.exists(path): + return MMapIndexedDataset(path, skip_warmup) + print(f"Unknown dataset implementation: {impl}") + return None + + +def dataset_exists(path, impl): + if impl == 'mmap': + return MMapIndexedDataset.exists(path) + else: + return IndexedDataset.exists(path) + + +def read_longs(f, n): + a = np.empty(n, dtype=np.int64) + f.readinto(a) + return a + + +def write_longs(f, a): + f.write(np.array(a, dtype=np.int64)) + + +dtypes = { + 1: np.uint8, + 2: np.int8, + 3: np.int16, + 4: np.int32, + 5: np.int64, + 6: np.float, + 7: np.double, + 8: np.uint16 +} + + +def code(dtype): + for k in dtypes.keys(): + if dtypes[k] == dtype: + return k + raise ValueError(dtype) + + +def index_file_path(prefix_path): + return prefix_path + '.idx' + + +def data_file_path(prefix_path): + return prefix_path + '.bin' + + +def create_doc_idx(sizes): + doc_idx = [0] + for i, s in enumerate(sizes): + if s == 0: + doc_idx.append(i + 1) + return doc_idx + + +class IndexedDataset(torch.utils.data.Dataset): + """Loader for IndexedDataset""" + _HDR_MAGIC = b'TNTIDX\x00\x00' + + def __init__(self, path): + super().__init__() + self.path = path + self.data_file = None + self.read_index(path) + + def read_index(self, path): + with open(index_file_path(path), 'rb') as f: + magic = f.read(8) + assert magic == self._HDR_MAGIC, ( + 'Index file doesn\'t match expected format. ' + 'Make sure that --dataset-impl is configured properly.' + ) + version = f.read(8) + assert struct.unpack('= self._len: + raise IndexError('index out of range') + + def __del__(self): + if self.data_file: + self.data_file.close() + + # @lru_cache(maxsize=8) + def __getitem__(self, idx): + if not self.data_file: + self.read_data(self.path) + if isinstance(idx, int): + i = idx + self.check_index(i) + tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] + a = np.empty(tensor_size, dtype=self.dtype) + self.data_file.seek(self.data_offsets[i] * self.element_size) + self.data_file.readinto(a) + return a + elif isinstance(idx, slice): + start, stop, step = idx.indices(len(self)) + if step != 1: + raise ValueError("Slices into indexed_dataset must be contiguous") + sizes = self.sizes[self.dim_offsets[start]:self.dim_offsets[stop]] + size = sum(sizes) + a = np.empty(size, dtype=self.dtype) + self.data_file.seek(self.data_offsets[start] * self.element_size) + self.data_file.readinto(a) + offsets = list(accumulate(sizes)) + sents = np.split(a, offsets[:-1]) + return sents + + def __len__(self): + return self._len + + def num_tokens(self, index): + return self.sizes[index] + + def size(self, index): + return self.sizes[index] + + @staticmethod + def exists(path): + return ( + os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)) + ) + + @property + def supports_prefetch(self): + return False # avoid prefetching to save memory + + +class IndexedCachedDataset(IndexedDataset): + + def __init__(self, path): + super().__init__(path) + self.cache = None + self.cache_index = {} + + @property + def supports_prefetch(self): + return True + + def prefetch(self, indices): + if all(i in self.cache_index for i in indices): + return + if not self.data_file: + self.read_data(self.path) + indices = sorted(set(indices)) + total_size = 0 + for i in indices: + total_size += self.data_offsets[i + 1] - self.data_offsets[i] + self.cache = np.empty(total_size, dtype=self.dtype) + ptx = 0 + self.cache_index.clear() + for i in indices: + self.cache_index[i] = ptx + size = self.data_offsets[i + 1] - self.data_offsets[i] + a = self.cache[ptx: ptx + size] + self.data_file.seek(self.data_offsets[i] * self.element_size) + self.data_file.readinto(a) + ptx += size + if self.data_file: + # close and delete data file after prefetch so we can pickle + self.data_file.close() + self.data_file = None + + # @lru_cache(maxsize=8) + def __getitem__(self, idx): + if isinstance(idx, int): + i = idx + self.check_index(i) + tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] + a = np.empty(tensor_size, dtype=self.dtype) + ptx = self.cache_index[i] + np.copyto(a, self.cache[ptx: ptx + a.size]) + return a + elif isinstance(idx, slice): + # Hack just to make this work, can optimizer later if necessary + sents = [] + for i in range(*idx.indices(len(self))): + sents.append(self[i]) + return sents + + +class IndexedDatasetBuilder(object): + element_sizes = { + np.uint8: 1, + np.int8: 1, + np.int16: 2, + np.int32: 4, + np.int64: 8, + np.float: 4, + np.double: 8 + } + + def __init__(self, out_file, dtype=np.int32): + self.out_file = open(out_file, 'wb') + self.dtype = dtype + self.data_offsets = [0] + self.dim_offsets = [0] + self.sizes = [] + self.element_size = self.element_sizes[self.dtype] + self.doc_idx = [0] + + def add_item(self, tensor): + bytes = self.out_file.write(np.array(tensor.numpy(), dtype=self.dtype)) + self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size) + for s in tensor.size(): + self.sizes.append(s) + self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size())) + + def end_document(self): + self.doc_idx.append(len(self.sizes)) + + def merge_file_(self, another_file): + index = IndexedDataset(another_file) + assert index.dtype == self.dtype + + begin = self.data_offsets[-1] + for offset in index.data_offsets[1:]: + self.data_offsets.append(begin + offset) + self.sizes.extend(index.sizes) + begin = self.dim_offsets[-1] + for dim_offset in index.dim_offsets[1:]: + self.dim_offsets.append(begin + dim_offset) + + with open(data_file_path(another_file), 'rb') as f: + while True: + data = f.read(1024) + if data: + self.out_file.write(data) + else: + break + + def finalize(self, index_file): + self.out_file.close() + index = open(index_file, 'wb') + index.write(b'TNTIDX\x00\x00') + index.write(struct.pack(' len(ds.doc_idx) - 1: + args.count = len(ds.doc_idx) - 1 + + for i in range(args.count): + start = ds.doc_idx[i] + end = ds.doc_idx[i + 1] + ids = ds[start:end] + print(f"Document {i}:") + print("--------------") + for s in ids: + assert len(s) > 0 + l = s.data.tolist() + text = tokenizer.detokenize(l) + print(text) + print("---") + + +def test_indexed_dataset_get(args): + ds = indexed_dataset.make_dataset(args.data, args.dataset_impl) + tokenizer = build_tokenizer(args) + size = ds.sizes[0] + print(f"size: {size}") + full = ds.get(0) + print(full) + # print(tokenizer.detokenize(full.data.tolist())) + print("---") + end = ds.get(0, offset=size - 10) + print(end) + # print(tokenizer.detokenize(end.data.tolist())) + + start = ds.get(0, length=10) + print(start) + # print(tokenizer.detokenize(start.data.tolist())) + + part = ds.get(0, offset=2, length=8) + print(part) + # print(tokenizer.detokenize(part.data.tolist())) + +# def test_albert_dataset(args): +# # tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True) +# # idataset = indexed_dataset.make_dataset(args.data, args.dataset_impl) +# # ds = AlbertDataset(idataset, tokenizer) +# ds = AlbertDataset.from_paths(args.vocab, args.data, args.dataset_impl, +# args.epochs, args.max_num_samples, +# args.masked_lm_prob, args.seq_length, +# args.short_seq_prob, args.seed) +# truncated = 0 +# total = 0 +# for i, s in enumerate(ds): +# ids = s['text'] +# tokens = ds.tokenizer.convert_ids_to_tokens(ids) +# print(tokens) +# if i >= args.count-1: +# exit() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--data', type=str, help='prefix to data files') + parser.add_argument('--dataset-impl', type=str, default='infer', + choices=['lazy', 'cached', 'mmap', 'infer']) + parser.add_argument('--count', type=int, default=10, + help='Number of samples/documents to print') + + group = parser.add_argument_group(title='tokenizer') + group.add_argument('--tokenizer-type', type=str, required=True, + choices=['BertWordPieceLowerCase', + 'GPT2BPETokenizer'], + help='What type of tokenizer to use.') + group.add_argument('--vocab-file', type=str, default=None, + help='Path to the vocab file') + group.add_argument('--merge-file', type=str, default=None, + help='Path to the BPE merge file (if necessary).') + + parser.add_argument('--epochs', type=int, default=5, + help='Number of epochs to plan for') + parser.add_argument('--max-num-samples', type=int, default=None, + help='Maximum number of samples to plan for') + parser.add_argument('--masked-lm-prob', type=float, default=0.15, + help='probability of masking tokens') + parser.add_argument('--seq-length', type=int, default=512, + help='maximum sequence length') + parser.add_argument('--short-seq-prob', type=float, default=0.1, + help='probability of creating a short sequence') + parser.add_argument('--seed', type=int, default=1234, + help='random seed') + args = parser.parse_args() + args.rank = 0 + args.make_vocab_size_divisible_by = 128 + args.tensor_model_parallel_size = 1 + + if args.dataset_impl == "infer": + args.dataset_impl = indexed_dataset.infer_dataset_impl(args.data) + +# test_albert_dataset(args) + test_indexed_dataset_get(args) + + +if __name__ == "__main__": + main() diff --git a/examples/tutorial/sequence_parallel/data/datasets/test/test_preprocess_data.sh b/examples/tutorial/sequence_parallel/data/datasets/test/test_preprocess_data.sh new file mode 100755 index 0000000000000000000000000000000000000000..d121c85958ff35e37431befdceabb831c8cd2705 --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/datasets/test/test_preprocess_data.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +IMPL=cached +python ../preprocess_data.py \ + --input test_samples.json \ + --vocab vocab.txt \ + --dataset-impl ${IMPL} \ + --output-prefix test_samples_${IMPL} \ + --workers 1 \ + --log-interval 2 diff --git a/examples/tutorial/sequence_parallel/data/dummy_dataloader.py b/examples/tutorial/sequence_parallel/data/dummy_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..faa90175cc60fa06283d1ae87ae992569226005c --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/dummy_dataloader.py @@ -0,0 +1,39 @@ +import torch + + +class DummyDataloader(): + + def __init__(self, batch_size, vocab_size, seq_length): + self.batch_size = batch_size + self.vocab_size = vocab_size + self.seq_length = seq_length + self.step = 0 + + def generate(self): + tokens = torch.randint(low=0, high=self.vocab_size, size=( + self.batch_size, + self.seq_length, + )) + types = torch.randint(low=0, high=3, size=( + self.batch_size, + self.seq_length, + )) + sentence_order = torch.randint(low=0, high=2, size=(self.batch_size,)) + loss_mask = torch.randint(low=0, high=2, size=( + self.batch_size, + self.seq_length, + )) + lm_labels = torch.randint(low=0, high=self.vocab_size, size=(self.batch_size, self.seq_length)) + padding_mask = torch.randint(low=0, high=2, size=(self.batch_size, self.seq_length)) + return dict(text=tokens, + types=types, + is_random=sentence_order, + loss_mask=loss_mask, + labels=lm_labels, + padding_mask=padding_mask) + + def __iter__(self): + return self + + def __next__(self): + return self.generate() \ No newline at end of file diff --git a/examples/tutorial/sequence_parallel/data/tokenizer/__init__.py b/examples/tutorial/sequence_parallel/data/tokenizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..df27f14247eba7b887f73ca1022f08b1597026dd --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/tokenizer/__init__.py @@ -0,0 +1,38 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .tokenizer import build_tokenizer + + +_TOKENIZER = None +_PADDED_VOCAB_SIZE = -1 + + +def initialize_tokenizer(vocab_file, tokenizer_type, vocab_extra_ids=0): + tokenizer, padded_vocab_size = build_tokenizer(vocab_file, tokenizer_type, vocab_extra_ids) + global _TOKENIZER, _PADDED_VOCAB_SIZE + _TOKENIZER = tokenizer + _PADDED_VOCAB_SIZE = padded_vocab_size + + +def get_tokenizer(): + global _TOKENIZER + return _TOKENIZER + + +def get_padded_vocab_size(): + global _PADDED_VOCAB_SIZE + return _PADDED_VOCAB_SIZE diff --git a/examples/tutorial/sequence_parallel/data/tokenizer/bert_tokenization.py b/examples/tutorial/sequence_parallel/data/tokenizer/bert_tokenization.py new file mode 100644 index 0000000000000000000000000000000000000000..1be494793909e878132a659104a8ecc597f4552e --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/tokenizer/bert_tokenization.py @@ -0,0 +1,431 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tokenization classes.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import re +import unicodedata +import six + + +def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): + """Checks whether the casing config is consistent with the checkpoint name.""" + + # The casing has to be passed in by the user and there is no explicit check + # as to whether it matches the checkpoint. The casing information probably + # should have been stored in the bert_config.json file, but it's not, so + # we have to heuristically detect it to validate. + + if not init_checkpoint: + return + + m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) + if m is None: + return + + model_name = m.group(1) + + lower_models = [ + "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", + "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" + ] + + cased_models = [ + "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", + "multi_cased_L-12_H-768_A-12" + ] + + is_bad_config = False + if model_name in lower_models and not do_lower_case: + is_bad_config = True + actual_flag = "False" + case_name = "lowercased" + opposite_flag = "True" + + if model_name in cased_models and do_lower_case: + is_bad_config = True + actual_flag = "True" + case_name = "cased" + opposite_flag = "False" + + if is_bad_config: + raise ValueError( + "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " + "However, `%s` seems to be a %s model, so you " + "should pass in `--do_lower_case=%s` so that the fine-tuning matches " + "how the model was pre-training. If this error is wrong, please " + "just comment out this check." % (actual_flag, init_checkpoint, + model_name, case_name, opposite_flag)) + + +def convert_to_unicode(text): + """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" + if six.PY3: + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode("utf-8", "ignore") + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + elif six.PY2: + if isinstance(text, str): + return text.decode("utf-8", "ignore") + elif isinstance(text, unicode): + return text + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + else: + raise ValueError("Not running on Python2 or Python 3?") + + +def printable_text(text): + """Returns text encoded in a way suitable for print or `tf.logging`.""" + + # These functions want `str` for both Python2 and Python3, but in one case + # it's a Unicode string and in the other it's a byte string. + if six.PY3: + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode("utf-8", "ignore") + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + elif six.PY2: + if isinstance(text, str): + return text + elif isinstance(text, unicode): + return text.encode("utf-8") + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + else: + raise ValueError("Not running on Python2 or Python 3?") + + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + index = 0 + with open(vocab_file, "r") as reader: + while True: + token = convert_to_unicode(reader.readline()) + if not token: + break + token = token.strip() + vocab[token] = index + index += 1 + return vocab + + +def convert_by_vocab(vocab, items): + """Converts a sequence of [tokens|ids] using the vocab.""" + output = [] + for item in items: + output.append(vocab[item]) + return output + + +def convert_tokens_to_ids(vocab, tokens): + return convert_by_vocab(vocab, tokens) + + +def convert_ids_to_tokens(inv_vocab, ids): + return convert_by_vocab(inv_vocab, ids) + + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +class FullTokenizer(object): + """Runs end-to-end tokenization.""" + + def __init__(self, vocab_file, do_lower_case=True): + self.vocab = load_vocab(vocab_file) + self.inv_vocab = {v: k for k, v in self.vocab.items()} + self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) + + def tokenize(self, text): + split_tokens = [] + for token in self.basic_tokenizer.tokenize(text): + for sub_token in self.wordpiece_tokenizer.tokenize(token): + split_tokens.append(sub_token) + + return split_tokens + + def convert_tokens_to_ids(self, tokens): + return convert_by_vocab(self.vocab, tokens) + + def convert_ids_to_tokens(self, ids): + return convert_by_vocab(self.inv_vocab, ids) + + @staticmethod + def convert_tokens_to_string(tokens, clean_up_tokenization_spaces=True): + """ Converts a sequence of tokens (string) in a single string. """ + + def clean_up_tokenization(out_string): + """ Clean up a list of simple English tokenization artifacts + like spaces before punctuations and abbreviated forms. + """ + out_string = ( + out_string.replace(" .", ".") + .replace(" ?", "?") + .replace(" !", "!") + .replace(" ,", ",") + .replace(" ' ", "'") + .replace(" n't", "n't") + .replace(" 'm", "'m") + .replace(" 's", "'s") + .replace(" 've", "'ve") + .replace(" 're", "'re") + ) + return out_string + + text = ' '.join(tokens).replace(' ##', '').strip() + if clean_up_tokenization_spaces: + clean_text = clean_up_tokenization(text) + return clean_text + else: + return text + + def vocab_size(self): + return len(self.vocab) + + +class BasicTokenizer(object): + """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" + + def __init__(self, do_lower_case=True): + """Constructs a BasicTokenizer. + + Args: + do_lower_case: Whether to lower case the input. + """ + self.do_lower_case = do_lower_case + + def tokenize(self, text): + """Tokenizes a piece of text.""" + text = convert_to_unicode(text) + text = self._clean_text(text) + + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + text = self._tokenize_chinese_chars(text) + + orig_tokens = whitespace_tokenize(text) + split_tokens = [] + for token in orig_tokens: + if self.do_lower_case: + token = token.lower() + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text): + """Splits punctuation on a piece of text.""" + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ((cp >= 0x4E00 and cp <= 0x9FFF) or # + (cp >= 0x3400 and cp <= 0x4DBF) or # + (cp >= 0x20000 and cp <= 0x2A6DF) or # + (cp >= 0x2A700 and cp <= 0x2B73F) or # + (cp >= 0x2B740 and cp <= 0x2B81F) or # + (cp >= 0x2B820 and cp <= 0x2CEAF) or + (cp >= 0xF900 and cp <= 0xFAFF) or # + (cp >= 0x2F800 and cp <= 0x2FA1F)): # + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xfffd or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + +class WordpieceTokenizer(object): + """Runs WordPiece tokenization.""" + + def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """Tokenizes a piece of text into its word pieces. + + This uses a greedy longest-match-first algorithm to perform tokenization + using the given vocabulary. + + For example: + input = "unaffable" + output = ["un", "##aff", "##able"] + + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through `BasicTokenizer. + + Returns: + A list of wordpiece tokens. + """ + + text = convert_to_unicode(text) + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + + +def _is_whitespace(char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically control characters but we treat them + # as whitespace since they are generally considered as such. + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + cat = unicodedata.category(char) + if cat == "Zs": + return True + return False + + +def _is_control(char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == "\t" or char == "\n" or char == "\r": + return False + cat = unicodedata.category(char) + if cat in ("Cc", "Cf"): + return True + return False + + +def _is_punctuation(char): + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or + (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): + return True + cat = unicodedata.category(char) + if cat.startswith("P"): + return True + return False diff --git a/examples/tutorial/sequence_parallel/data/tokenizer/tokenizer.py b/examples/tutorial/sequence_parallel/data/tokenizer/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..ee3c923e8e76517b8a5244608d6c18d080082952 --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/tokenizer/tokenizer.py @@ -0,0 +1,256 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Megatron tokenizers.""" + +from abc import ABC +from abc import abstractmethod +from colossalai.core import global_context as gpc +from colossalai.context import ParallelMode + +from .bert_tokenization import FullTokenizer as FullBertTokenizer + + +def build_tokenizer(vocab_file, tokenizer_type, vocab_extra_ids=0): + """Initialize tokenizer.""" + if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0: + print('> building {} tokenizer ...'.format(tokenizer_type), + flush=True) + + # Select and instantiate the tokenizer. + if tokenizer_type == 'BertWordPieceLowerCase': + tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file, + lower_case=True, + vocab_extra_ids=vocab_extra_ids) + elif tokenizer_type == 'BertWordPieceCase': + tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file, + lower_case=False, + vocab_extra_ids=vocab_extra_ids) + else: + raise NotImplementedError('{} tokenizer is not ' + 'implemented.'.format(tokenizer_type)) + + # Add vocab size. + padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size) + + return tokenizer, padded_vocab_size + + +def _vocab_size_with_padding(orig_vocab_size, make_vocab_size_divisible_by=128): + """Pad vocab size so it is divisible by model parallel size and + still having GPU friendly size.""" + + after = orig_vocab_size + + if gpc.is_initialized(ParallelMode.TENSOR): + multiple = make_vocab_size_divisible_by * gpc.get_world_size(ParallelMode.TENSOR) + else: + multiple = make_vocab_size_divisible_by + while (after % multiple) != 0: + after += 1 + if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0: + print(' > padded vocab (size: {}) with {} dummy tokens ' + '(new size: {})'.format( + orig_vocab_size, after - orig_vocab_size, after), flush=True) + return after + + +class AbstractTokenizer(ABC): + """Abstract class for tokenizer.""" + + def __init__(self, name): + self.name = name + super().__init__() + + @property + @abstractmethod + def vocab_size(self): + pass + + @property + @abstractmethod + def vocab(self): + """Dictionary from vocab text token to id token.""" + pass + + @property + @abstractmethod + def inv_vocab(self): + """Dictionary from vocab id token to text token.""" + pass + + @abstractmethod + def tokenize(self, text): + pass + + def detokenize(self, token_ids): + raise NotImplementedError('detokenizer is not implemented for {} ' + 'tokenizer'.format(self.name)) + + @property + def cls(self): + raise NotImplementedError('CLS is not provided for {} ' + 'tokenizer'.format(self.name)) + + @property + def sep(self): + raise NotImplementedError('SEP is not provided for {} ' + 'tokenizer'.format(self.name)) + + @property + def pad(self): + raise NotImplementedError('PAD is not provided for {} ' + 'tokenizer'.format(self.name)) + + @property + def eod(self): + raise NotImplementedError('EOD is not provided for {} ' + 'tokenizer'.format(self.name)) + + @property + def mask(self): + raise NotImplementedError('MASK is not provided for {} ' + 'tokenizer'.format(self.name)) + + +class _BertWordPieceTokenizer(AbstractTokenizer): + """Original BERT wordpiece tokenizer.""" + + def __init__(self, vocab_file, lower_case=True, vocab_extra_ids=0): + if lower_case: + name = 'BERT Lower Case' + else: + name = 'BERT Upper Case' + super().__init__(name) + self.tokenizer = FullBertTokenizer(vocab_file, do_lower_case=lower_case) + self.cls_id = self.tokenizer.vocab['[CLS]'] + self.sep_id = self.tokenizer.vocab['[SEP]'] + self.pad_id = self.tokenizer.vocab['[PAD]'] + self.mask_id = self.tokenizer.vocab['[MASK]'] + self._additional_special_tokens = [] + + # (dsachan) Add BOS and EOS tokens + SPECIAL_TOKENS = {'eos_token': '[EOS]', + 'bos_token': '[BOS]'} + self._bos_token = '[BOS]' + self.add_token(self._bos_token) + self._bos_token_id = self.vocab.get(self._bos_token) + + self._eos_token = '[EOS]' + self.add_token(self._eos_token) + self._eos_token_id = self.vocab.get(self._eos_token) + + # (dsachan) Add additional special tokens + # These can be used as sentinel tokens in T5 model inputs + additional_special_tokens = [] + additional_special_tokens.extend( + ["".format(i) for i in range(vocab_extra_ids)]) + self.add_additional_special_tokens(additional_special_tokens) + + def add_token(self, token): + if token not in self.vocab: + self.inv_vocab[self.vocab_size] = token + # self.vocab_size comes from len(vocab) + # and it will increase as we add elements + self.vocab[token] = self.vocab_size + + def add_additional_special_tokens(self, tokens_list): + setattr(self, "additional_special_tokens", tokens_list) + for value in tokens_list: + self.add_token(value) + + @property + def vocab_size(self): + return self.tokenizer.vocab_size() + + @property + def vocab(self): + return self.tokenizer.vocab + + @property + def inv_vocab(self): + return self.tokenizer.inv_vocab + + def tokenize(self, text): + text_tokens = self.tokenizer.tokenize(text) + return self.tokenizer.convert_tokens_to_ids(text_tokens) + + def decode(self, ids): + tokens = self.tokenizer.convert_ids_to_tokens(ids) + return self.tokenizer.convert_tokens_to_string(tokens) + + def decode_token_ids(self, token_ids): + tokens = self.tokenizer.convert_ids_to_tokens(token_ids) + exclude_list = ['[PAD]', '[CLS]'] + non_pads = [t for t in tokens if t not in exclude_list] + + result = "" + for s in non_pads: + if s.startswith("##"): + result += s[2:] + else: + result += " " + s + + return result + + @property + def cls(self): + return self.cls_id + + @property + def sep(self): + return self.sep_id + + @property + def pad(self): + return self.pad_id + + @property + def mask(self): + return self.mask_id + + @property + def bos_token(self): + """ Beginning of sentence token id """ + return self._bos_token + + @property + def eos_token(self): + """ End of sentence token id """ + return self._eos_token + + @property + def additional_special_tokens(self): + """ All the additional special tokens you may want to use (list of strings).""" + return self._additional_special_tokens + + @property + def bos_token_id(self): + """ Id of the beginning of sentence token in the vocabulary.""" + return self._bos_token_id + + @property + def eos_token_id(self): + """ Id of the end of sentence token in the vocabulary.""" + return self._eos_token_id + + @property + def additional_special_tokens_ids(self): + """ Ids of all the additional special tokens in the vocabulary (list of integers).""" + return [self.vocab.get(token) for token in self._additional_special_tokens] + + @additional_special_tokens.setter + def additional_special_tokens(self, value): + self._additional_special_tokens = value diff --git a/examples/tutorial/sequence_parallel/loss_func/__init__.py b/examples/tutorial/sequence_parallel/loss_func/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/tutorial/sequence_parallel/loss_func/bert_loss.py b/examples/tutorial/sequence_parallel/loss_func/bert_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..e87a778cf5d5959f68c134957f101787ca87cffe --- /dev/null +++ b/examples/tutorial/sequence_parallel/loss_func/bert_loss.py @@ -0,0 +1,41 @@ +import torch +import torch.nn as nn +from colossalai.core import global_context as gpc +from colossalai.context import ParallelMode +from colossalai.logging import get_dist_logger +import torch.nn.functional as F +import torch.distributed as dist +from .cross_entropy import vocab_cross_entropy + + +class BertLoss(nn.Module): + + def forward(self, + lm_loss, + sop_logits, + loss_mask, + sentence_order): + lm_loss_ = lm_loss.float() + loss_mask = loss_mask.float() + loss_mask_sum = loss_mask.sum() + lm_loss = torch.sum( + lm_loss_.view(-1) * loss_mask.reshape(-1)) + + lm_loss /= loss_mask_sum + + torch.distributed.all_reduce( + lm_loss, + group=gpc.get_group(ParallelMode.SEQUENCE) + ) + + if sop_logits is not None: + sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(), + sentence_order.view(-1), + ignore_index=-1) + sop_loss = sop_loss.float() + loss = lm_loss + sop_loss * gpc.get_world_size(ParallelMode.SEQUENCE) + else: + sop_loss = None + loss = lm_loss + + return loss diff --git a/examples/tutorial/sequence_parallel/loss_func/cross_entropy.py b/examples/tutorial/sequence_parallel/loss_func/cross_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..54553c29a61f91b34d9b192ab734ff201a92892e --- /dev/null +++ b/examples/tutorial/sequence_parallel/loss_func/cross_entropy.py @@ -0,0 +1,75 @@ +from colossalai.context.parallel_mode import ParallelMode +import torch +from torch.cuda.amp import custom_bwd, custom_fwd + + +class _VocabCrossEntropy(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, vocab_parallel_logits, target): + # Maximum value along vocab dimension across all GPUs. + logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] + + # Subtract the maximum value. + vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) + + # Create a mask of valid vocab ids (1 means it needs to be masked). + target_mask = target < 0 + masked_target = target.clone() + masked_target[target_mask] = 0 + + # Get predicted-logits = logits[target]. + # For Simplicity, we convert logits to a 2-D tensor with size + # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. + logits_2d = vocab_parallel_logits.view(-1, vocab_parallel_logits.size(-1)) + masked_target_1d = masked_target.view(-1) + arange_1d = torch.arange(start=0, end=logits_2d.size()[0], + device=logits_2d.device) + predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] + predicted_logits_1d = predicted_logits_1d.clone().contiguous() + predicted_logits = predicted_logits_1d.view_as(target) + predicted_logits[target_mask] = 0.0 + + # Sum of exponential of logits along vocab dimension across all GPUs. + exp_logits = vocab_parallel_logits + torch.exp(vocab_parallel_logits, out=exp_logits) + sum_exp_logits = exp_logits.sum(dim=-1) + + # Loss = log(sum(exp(logits))) - predicted-logit. + loss = torch.log(sum_exp_logits) - predicted_logits + + # Store softmax, target-mask and masked-target for backward pass. + exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) + + return loss + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + # Retreive tensors from the forward path. + softmax, target_mask, masked_target_1d = ctx.saved_tensors + + # All the inputs have softmax as their gradient. + grad_input = softmax + # For simplicity, work with the 2D gradient. + partition_vocab_size = softmax.size()[-1] + grad_2d = grad_input.view(-1, partition_vocab_size) + + # Add the gradient from matching classes. + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], + device=grad_2d.device) + grad_2d[arange_1d, masked_target_1d] -= ( + 1.0 - target_mask.view(-1).float()) + + # Finally elementwise multiplication with the output gradients. + grad_input.mul_(grad_output.unsqueeze(dim=-1)) + + return grad_input, None + + +def vocab_cross_entropy(vocab_logits, target): + """helper function for the cross entropy.""" + + return _VocabCrossEntropy.apply(vocab_logits, target) diff --git a/examples/tutorial/sequence_parallel/loss_func/utils.py b/examples/tutorial/sequence_parallel/loss_func/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a3d92f294326312a6f3cb9be04ed38124b104485 --- /dev/null +++ b/examples/tutorial/sequence_parallel/loss_func/utils.py @@ -0,0 +1,55 @@ + +import torch + + +def ensure_divisibility(numerator, denominator): + """Ensure that numerator is divisible by the denominator.""" + assert numerator % denominator == 0, '{} is not divisible by {}'.format( + numerator, denominator) + + +def divide(numerator, denominator): + """Ensure that numerator is divisible by the denominator and return + the division value.""" + ensure_divisibility(numerator, denominator) + return numerator // denominator + + +def split_tensor_along_last_dim(tensor, num_partitions, + contiguous_split_chunks=False): + """Split a tensor along its last dimension. + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = divide(tensor.size()[last_dim], num_partitions) + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + +class VocabUtility: + """Split the vocabulary into `world_size` chunks amd return the + first and last index of the vocabulary belonging to the `rank` + partition: Note that indices in [fist, last)""" + + @staticmethod + def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, + rank, world_size): + index_f = rank * per_partition_vocab_size + index_l = index_f + per_partition_vocab_size + return index_f, index_l + + @staticmethod + def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size): + per_partition_vocab_size = divide(global_vocab_size, world_size) + return VocabUtility.vocab_range_from_per_partition_vocab_size( + per_partition_vocab_size, rank, world_size) diff --git a/examples/tutorial/sequence_parallel/lr_scheduler/__init__.py b/examples/tutorial/sequence_parallel/lr_scheduler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2b8b615bce63cd1750db9cb063261d51a52e8b30 --- /dev/null +++ b/examples/tutorial/sequence_parallel/lr_scheduler/__init__.py @@ -0,0 +1 @@ +from .annealing_lr import AnnealingLR diff --git a/examples/tutorial/sequence_parallel/lr_scheduler/annealing_lr.py b/examples/tutorial/sequence_parallel/lr_scheduler/annealing_lr.py new file mode 100644 index 0000000000000000000000000000000000000000..8d95679ff76dcdab3020c02b30e9e8616d086c73 --- /dev/null +++ b/examples/tutorial/sequence_parallel/lr_scheduler/annealing_lr.py @@ -0,0 +1,158 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Learning rate decay functions.""" + +import math + + +class AnnealingLR(object): + """Anneals the learning rate.""" + + def __init__(self, + optimizer, + max_lr, + min_lr, + warmup_steps, + decay_steps, + decay_style, + use_checkpoint_lr_scheduler=True, + override_lr_scheduler=False): + + # Class values. + self.optimizer = optimizer + + self.max_lr = float(max_lr) + self.min_lr = min_lr + assert self.min_lr >= 0.0 + assert self.max_lr >= self.min_lr + + self.warmup_steps = warmup_steps + self.num_steps = 0 + self.decay_steps = decay_steps + assert self.decay_steps > 0 + assert self.warmup_steps < self.decay_steps + + self.decay_style = decay_style + + self.override_lr_scheduler = override_lr_scheduler + self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler + if self.override_lr_scheduler: + assert not self.use_checkpoint_lr_scheduler, 'both override and '\ + 'use-checkpoint are set.' + + # Set the learning rate + self.step(0) + + def get_lr(self): + """Learning rate decay functions from: + https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" + + # Use linear warmup for the initial part. + if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps: + return self.max_lr * float(self.num_steps) / \ + float(self.warmup_steps) + + # If the learning rate is constant, just return the initial value. + if self.decay_style == 'constant': + return self.max_lr + + # For any steps larger than `self.decay_steps`, use `self.min_lr`. + if self.num_steps > self.decay_steps: + return self.min_lr + + # If we are done with the warmup period, use the decay style. + num_steps_ = self.num_steps - self.warmup_steps + decay_steps_ = self.decay_steps - self.warmup_steps + decay_ratio = float(num_steps_) / float(decay_steps_) + assert decay_ratio >= 0.0 + assert decay_ratio <= 1.0 + delta_lr = self.max_lr - self.min_lr + + if self.decay_style == 'linear': + coeff = (1.0 - decay_ratio) + elif self.decay_style == 'cosine': + coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0) + else: + raise Exception('{} decay style is not supported.'.format( + self.decay_style)) + + return self.min_lr + coeff * delta_lr + + def step(self, increment=1): + """Set lr for all parameters groups.""" + self.num_steps += increment + new_lr = self.get_lr() + for group in self.optimizer.param_groups: + group['lr'] = new_lr + + def state_dict(self): + state_dict = { + 'max_lr': self.max_lr, + 'warmup_steps': self.warmup_steps, + 'num_steps': self.num_steps, + 'decay_style': self.decay_style, + 'decay_steps': self.decay_steps, + 'min_lr': self.min_lr + } + return state_dict + + def _check_and_set(self, cls_value, sd_value, name): + """Auxiliary function for checking the values in the checkpoint and + setting them.""" + if self.override_lr_scheduler: + return cls_value + + if not self.use_checkpoint_lr_scheduler: + assert cls_value == sd_value, \ + f'AnnealingLR: class input value {cls_value} and checkpoint' \ + f'value {sd_value} for {name} do not match' + return sd_value + + def load_state_dict(self, sd): + + if 'start_lr' in sd: + max_lr_ = sd['start_lr'] + else: + max_lr_ = sd['max_lr'] + self.max_lr = self._check_and_set(self.max_lr, max_lr_, + 'learning rate') + + self.min_lr = self._check_and_set(self.min_lr, sd['min_lr'], + 'minimum learning rate') + + if 'warmup_iter' in sd: + warmup_steps_ = sd['warmup_iter'] + else: + warmup_steps_ = sd['warmup_steps'] + self.warmup_steps = self._check_and_set(self.warmup_steps, + warmup_steps_, + 'warmup iterations') + + if 'end_iter' in sd: + decay_steps_ = sd['end_iter'] + else: + decay_steps_ = sd['decay_steps'] + self.decay_steps = self._check_and_set(self.decay_steps, decay_steps_, + 'total number of iterations') + self.decay_style = self._check_and_set(self.decay_style, + sd['decay_style'], + 'decay style') + + if 'num_iters' in sd: + num_steps = sd['num_iters'] + else: + num_steps = sd['num_steps'] + self.step(increment=num_steps) diff --git a/examples/tutorial/sequence_parallel/model/__init__.py b/examples/tutorial/sequence_parallel/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..139597f9cb07c5d48bed18984ec4747f4b4f3438 --- /dev/null +++ b/examples/tutorial/sequence_parallel/model/__init__.py @@ -0,0 +1,2 @@ + + diff --git a/examples/tutorial/sequence_parallel/model/bert.py b/examples/tutorial/sequence_parallel/model/bert.py new file mode 100644 index 0000000000000000000000000000000000000000..049579c5a639c2bab29348fe0290b75497dbdc31 --- /dev/null +++ b/examples/tutorial/sequence_parallel/model/bert.py @@ -0,0 +1,282 @@ +from colossalai.context.parallel_mode import ParallelMode +import torch +import torch.nn as nn +import inspect +from .layers import Embedding, BertLayer, BertDualHead, PreProcessor, VocabEmbedding +from .layers.init_method import init_normal, output_init_normal +from colossalai.core import global_context as gpc +from colossalai.context import ParallelMode +from colossalai.kernel import LayerNorm +from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper +from colossalai.logging import get_dist_logger +from colossalai.pipeline.utils import partition_uniform + + +class BertForPretrain(nn.Module): + + def __init__(self, + vocab_size, + hidden_size, + max_sequence_length, + num_attention_heads, + num_layers, + add_binary_head, + is_naive_fp16, + num_tokentypes=2, + dropout_prob=0.1, + mlp_ratio=4, + init_std=0.02, + convert_fp16_to_fp32_in_softmax=False, + ): + super().__init__() + self.seq_parallel_size = gpc.get_world_size(ParallelMode.SEQUENCE) + assert max_sequence_length % self.seq_parallel_size == 0, 'sequence length is not divisible by the sequence parallel size' + self.sub_seq_length = max_sequence_length // self.seq_parallel_size + self.init_std = init_std + self.num_layers = num_layers + + if not add_binary_head: + num_tokentypes = 0 + + self.preprocessor = PreProcessor(self.sub_seq_length) + self.embedding = Embedding(hidden_size=hidden_size, + vocab_size=vocab_size, + max_sequence_length=max_sequence_length, + embedding_dropout_prob=dropout_prob, + num_tokentypes=num_tokentypes) + self.bert_layers = nn.ModuleList() + + for i in range(num_layers): + bert_layer = BertLayer(layer_number=i+1, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + attention_dropout=dropout_prob, + mlp_ratio=mlp_ratio, + hidden_dropout=dropout_prob, + convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax, + is_naive_fp16=is_naive_fp16 + ) + self.bert_layers.append(bert_layer) + + self.layer_norm = LayerNorm(hidden_size) + self.head = BertDualHead(hidden_size, self.embedding.word_embedding_weight.size(0), + add_binary_head=add_binary_head) + self.reset_parameters() + + def _init_normal(self, tensor): + init_normal(tensor, sigma=self.init_std) + + def _output_init_normal(self, tensor): + output_init_normal(tensor, sigma=self.init_std, num_layers=self.num_layers) + + def reset_parameters(self): + # initialize embedding + self._init_normal(self.embedding.word_embedding_weight) + self._init_normal(self.embedding.position_embeddings.weight) + if self.embedding.tokentype_embeddings: + self._init_normal(self.embedding.tokentype_embeddings.weight) + + # initialize bert layer + for layer in self.bert_layers: + # initialize self attention + self._init_normal(layer.self_attention.query_key_value.weight) + self._output_init_normal(layer.self_attention.dense.weight) + self._init_normal(layer.mlp.dense_h_to_4h.weight) + self._output_init_normal(layer.mlp.dense_4h_to_h.weight) + + # initializer head + self._init_normal(self.head.lm_head.dense.weight) + if self.head.binary_head is not None: + self._init_normal(self.head.binary_head.pooler.dense.weight) + self._init_normal(self.head.binary_head.dense.weight) + + def forward(self, input_ids, attention_masks, tokentype_ids, lm_labels): + # inputs of the forward function + # input_ids: [batch_size, sub_seq_len] + # attention_mask: [batch_size, seq_len] + # tokentype_ids: [batch_size, sub_seq_len] + # outputs of preprocessor + # pos_ids: [batch_size, sub_seq_len] + # attention_masks: [batch_size, 1, sub_seq_len, seq_len] + pos_ids, attention_masks = self.preprocessor(input_ids, attention_masks) + + hidden_states = self.embedding(input_ids, pos_ids, tokentype_ids) + + # hidden_states shape change: + # [batch_size, sub_seq_len, hidden_size] -> [sub_seq_len, batch_size, hidden_size] + hidden_states = hidden_states.transpose(0, 1).contiguous() + + for idx, layer in enumerate(self.bert_layers): + hidden_states = layer(hidden_states, attention_masks) + + hidden_states = hidden_states.transpose(0, 1).contiguous() + output = self.layer_norm(hidden_states) + + # hidden_states: [sub_seq_len, batch_size, hidden_size] + # word_embedding: [vocab_size, hidden_size] + return self.head(output, self.embedding.word_embedding_weight, lm_labels) + + +class PipelineBertForPretrain(nn.Module): + + def __init__(self, + vocab_size, + hidden_size, + max_sequence_length, + num_attention_heads, + num_layers, + add_binary_head, + is_naive_fp16, + num_tokentypes=2, + dropout_prob=0.1, + mlp_ratio=4, + init_std=0.02, + convert_fp16_to_fp32_in_softmax=False, + first_stage=True, + last_stage=True, + start_idx=None, + end_idx=None): + super().__init__() + self.seq_parallel_size = gpc.get_world_size(ParallelMode.SEQUENCE) + assert max_sequence_length % self.seq_parallel_size == 0, 'sequence length is not divisible by the sequence parallel size' + self.sub_seq_length = max_sequence_length // self.seq_parallel_size + self.init_std = init_std + self.num_layers = num_layers + + if not add_binary_head: + num_tokentypes = 0 + + self.first_stage = first_stage + self.last_stage = last_stage + + self.preprocessor = PreProcessor(self.sub_seq_length) + + if self.first_stage: + self.embedding = Embedding(hidden_size=hidden_size, + vocab_size=vocab_size, + max_sequence_length=max_sequence_length, + embedding_dropout_prob=dropout_prob, + num_tokentypes=num_tokentypes) + + # transformer layers + self.bert_layers = nn.ModuleList() + + if start_idx is None and end_idx is None: + start_idx = 0 + end_idx = num_layers + + for i in range(start_idx, end_idx): + bert_layer = BertLayer(layer_number=i+1, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + attention_dropout=dropout_prob, + mlp_ratio=mlp_ratio, + hidden_dropout=dropout_prob, + convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax, + is_naive_fp16=is_naive_fp16 + ) + self.bert_layers.append(bert_layer) + + if self.last_stage: + self.word_embeddings = VocabEmbedding(vocab_size, hidden_size) + self.layer_norm = LayerNorm(hidden_size) + self.head = BertDualHead(hidden_size, vocab_size, + add_binary_head=add_binary_head) + self.reset_parameters() + + def _init_normal(self, tensor): + init_normal(tensor, sigma=self.init_std) + + def _output_init_normal(self, tensor): + output_init_normal(tensor, sigma=self.init_std, num_layers=self.num_layers) + + def reset_parameters(self): + # initialize embedding + if self.first_stage: + self._init_normal(self.embedding.word_embedding_weight) + self._init_normal(self.embedding.position_embeddings.weight) + if self.embedding.tokentype_embeddings: + self._init_normal(self.embedding.tokentype_embeddings.weight) + + # initialize bert layer + for layer in self.bert_layers: + # initialize self attention + self._init_normal(layer.self_attention.query_key_value.weight) + self._output_init_normal(layer.self_attention.dense.weight) + self._init_normal(layer.mlp.dense_h_to_4h.weight) + self._output_init_normal(layer.mlp.dense_4h_to_h.weight) + + # initializer head + if self.last_stage: + self._init_normal(self.head.lm_head.dense.weight) + if self.head.binary_head is not None: + self._init_normal(self.head.binary_head.pooler.dense.weight) + self._init_normal(self.head.binary_head.dense.weight) + + def forward(self, input_ids, attention_masks, tokentype_ids, lm_labels): + # inputs of the forward function + # input_ids: [batch_size, sub_seq_len] + # attention_mask: [batch_size, seq_len] + # tokentype_ids: [batch_size, sub_seq_len] + # outputs of preprocessor + # pos_ids: [batch_size, sub_seq_len] + # attention_masks: [batch_size, 1, sub_seq_len, seq_len] + if self.first_stage: + pos_ids, attention_masks = self.preprocessor(input_ids, attention_masks) + else: + _, attention_masks = self.preprocessor(None, attention_masks) + + if self.first_stage: + hidden_states = self.embedding(input_ids, pos_ids, tokentype_ids) + hidden_states = hidden_states.transpose(0, 1).contiguous() + else: + hidden_states = input_ids + + # hidden_states shape change: + # [batch_size, sub_seq_len, hidden_size] -> [sub_seq_len, batch_size, hidden_size] + for idx, layer in enumerate(self.bert_layers): + hidden_states = layer(hidden_states, attention_masks) + + if self.last_stage: + hidden_states = hidden_states.transpose(0, 1).contiguous() + output = self.layer_norm(hidden_states) + output = self.head(output, self.word_embeddings.weight, lm_labels) + else: + output = hidden_states + + # hidden_states: [sub_seq_len, batch_size, hidden_size] + # word_embedding: [vocab_size, hidden_size] + return output + + +def _filter_kwargs(func, kwargs): + sig = inspect.signature(func) + return {k: v for k, v in kwargs.items() if k in sig.parameters} + + +def build_pipeline_bert(num_layers, num_chunks, device=torch.device('cuda'), **kwargs): + logger = get_dist_logger() + pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + rank = gpc.get_global_rank() + wrapper = PipelineSharedModuleWrapper([0, pipeline_size - 1]) + parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank] + models = [] + for start, end in parts: + kwargs['num_layers'] = num_layers + kwargs['start_idx'] = start + kwargs['end_idx'] = end + kwargs['first_stage'] = start == 0 + kwargs['last_stage'] = end == num_layers + logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers') + chunk = PipelineBertForPretrain(**_filter_kwargs(PipelineBertForPretrain.__init__, kwargs)).to(device) + if start == 0: + wrapper.register_module(chunk.embedding.word_embeddings) + elif end == num_layers: + wrapper.register_module(chunk.word_embeddings) + models.append(chunk) + if len(models) == 1: + model = models[0] + else: + model = nn.ModuleList(models) + return model diff --git a/examples/tutorial/sequence_parallel/model/layers/__init__.py b/examples/tutorial/sequence_parallel/model/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3a8823caa81b4e1f3cdbf9ea4f79e012a1b99505 --- /dev/null +++ b/examples/tutorial/sequence_parallel/model/layers/__init__.py @@ -0,0 +1,4 @@ +from .embedding import VocabEmbedding, Embedding +from .bert_layer import BertLayer +from .head import BertDualHead +from .preprocess import PreProcessor diff --git a/examples/tutorial/sequence_parallel/model/layers/bert_layer.py b/examples/tutorial/sequence_parallel/model/layers/bert_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..4ede21516f65a8bc008a51dea505811fb0ceda48 --- /dev/null +++ b/examples/tutorial/sequence_parallel/model/layers/bert_layer.py @@ -0,0 +1,118 @@ +import torch +import torch.nn as nn +from colossalai.nn.layer.parallel_sequence import TransformerSelfAttentionRing +from colossalai.kernel.jit import bias_dropout_add_fused_train, bias_dropout_add_fused_inference +from colossalai.kernel.cuda_native import LayerNorm +from .mlp import TransformerMLP +from .dropout import get_bias_dropout_add + + +def attention_mask_func(attention_scores, attention_mask): + attention_scores.masked_fill_(attention_mask, -10000.0) + return attention_scores + + +class BertLayer(nn.Module): + """A single transformer layer. + Transformer layer takes input with size [b, s, h] and returns an + output of the same size. + """ + + def __init__(self, + layer_number, + hidden_size, + num_attention_heads, + attention_dropout, + mlp_ratio, + hidden_dropout, + is_naive_fp16, + apply_residual_connection_post_layernorm=False, + fp32_residual_connection=False, + bias_dropout_fusion: bool = True, + convert_fp16_to_fp32_in_softmax: bool = False): + super().__init__() + self.layer_number = layer_number + + self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm + self.fp32_residual_connection = fp32_residual_connection + + # Layernorm on the input data. + self.input_layernorm = LayerNorm(hidden_size) + + # Self attention. + self.self_attention = TransformerSelfAttentionRing( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + attention_dropout=attention_dropout, + attention_mask_func=attention_mask_func, + layer_number=layer_number, + apply_query_key_layer_scaling=True, + convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax, + fp16=is_naive_fp16 + ) + + self.hidden_dropout = hidden_dropout + self.bias_dropout_fusion = bias_dropout_fusion + + # Layernorm on the attention output + self.post_attention_layernorm = LayerNorm(hidden_size) + + self.mlp = TransformerMLP(hidden_size=hidden_size, mlp_ratio=mlp_ratio) + + def forward(self, hidden_states, attention_mask): + # hidden_states: [batch_size, sub_seq_len, hidden_size] + # attention_mask: [batch_size, 1, sub_seq_len, seq_len] + + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + + # Self attention. + attention_output, attention_bias = self.self_attention(layernorm_output, attention_mask) + + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + # jit scripting for a nn.module (with dropout) is not + # trigerring the fusion kernel. For now, we use two + # different nn.functional routines to account for varying + # dropout semantics during training and inference phases. + if self.bias_dropout_fusion: + if self.training: + bias_dropout_add_func = bias_dropout_add_fused_train + else: + bias_dropout_add_func = bias_dropout_add_fused_inference + else: + bias_dropout_add_func = get_bias_dropout_add(self.training) + + # re-enable torch grad to enable fused optimization. + with torch.enable_grad(): + layernorm_input = bias_dropout_add_func( + attention_output, + attention_bias.expand_as(residual), + residual, + self.hidden_dropout) + + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + + # MLP. + mlp_output, mlp_bias = self.mlp(layernorm_output) + + # Second residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + # re-enable torch grad to enable fused optimization. + with torch.enable_grad(): + output = bias_dropout_add_func( + mlp_output, + mlp_bias.expand_as(residual), + residual, + self.hidden_dropout) + + return output diff --git a/examples/tutorial/sequence_parallel/model/layers/dropout.py b/examples/tutorial/sequence_parallel/model/layers/dropout.py new file mode 100644 index 0000000000000000000000000000000000000000..0e99105b8f7e9d1ce2f884eba12678d5b18f4468 --- /dev/null +++ b/examples/tutorial/sequence_parallel/model/layers/dropout.py @@ -0,0 +1,13 @@ +import torch + +def bias_dropout_add(x, bias, residual, prob, training): + # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor + out = torch.nn.functional.dropout(x + bias, p=prob, training=training) + out = residual + out + return out + + +def get_bias_dropout_add(training): + def _bias_dropout_add(x, bias, residual, prob): + return bias_dropout_add(x, bias, residual, prob, training) + return _bias_dropout_add \ No newline at end of file diff --git a/examples/tutorial/sequence_parallel/model/layers/embedding.py b/examples/tutorial/sequence_parallel/model/layers/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..0700d960d845ff1ab1f0258d6e174b88ad3ac902 --- /dev/null +++ b/examples/tutorial/sequence_parallel/model/layers/embedding.py @@ -0,0 +1,96 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init + + +class VocabEmbedding(torch.nn.Module): + + def __init__(self, num_embeddings, embedding_dim): + super(VocabEmbedding, self).__init__() + # Keep the input dimensions. + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.padding_idx = None + self.max_norm = None + self.norm_type = 2. + self.scale_grad_by_freq = False + self.sparse = False + self._weight = None + + # Allocate weights and initialize. + self.weight = nn.Parameter(torch.empty( + self.num_embeddings, self.embedding_dim)) + init.xavier_uniform_(self.weight) + + def forward(self, hidden_state): + output = F.embedding(hidden_state, self.weight, + self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, + self.sparse) + return output + + def __repr__(self): + return f'VocabEmbedding(num_embeddings={self.num_embeddings}, ' \ + f'embedding_dim={self.embedding_dim})' + + +class Embedding(nn.Module): + """Language model embeddings. + Arguments: + hidden_size: hidden size + vocab_size: vocabulary size + max_sequence_length: maximum size of sequence. This + is used for positional embedding + embedding_dropout_prob: dropout probability for embeddings + init_method: weight initialization method + num_tokentypes: size of the token-type embeddings. 0 value + will ignore this embedding + """ + + def __init__(self, + hidden_size, + vocab_size, + max_sequence_length, + embedding_dropout_prob, + num_tokentypes): + super(Embedding, self).__init__() + + self.hidden_size = hidden_size + self.num_tokentypes = num_tokentypes + + self.word_embeddings = VocabEmbedding(vocab_size, self.hidden_size) + + # Position embedding (serial). + self.position_embeddings = torch.nn.Embedding( + max_sequence_length, self.hidden_size) + + # Token type embedding. + # Add this as an optional field that can be added through + # method call so we can load a pretrain model without + # token types and add them as needed. + if self.num_tokentypes > 0: + self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, + self.hidden_size) + else: + self.tokentype_embeddings = None + + # Embeddings dropout + self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) + + @property + def word_embedding_weight(self): + return self.word_embeddings.weight + + def forward(self, input_ids, position_ids, tokentype_ids=None): + # Embeddings. + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + embeddings = words_embeddings + position_embeddings + if tokentype_ids is not None and self.tokentype_embeddings is not None: + embeddings = embeddings + self.tokentype_embeddings(tokentype_ids) + + # Dropout. + embeddings = self.embedding_dropout(embeddings) + + return embeddings diff --git a/examples/tutorial/sequence_parallel/model/layers/head.py b/examples/tutorial/sequence_parallel/model/layers/head.py new file mode 100644 index 0000000000000000000000000000000000000000..ea336b9d131e3405dc22d03f3e7b51add4e54466 --- /dev/null +++ b/examples/tutorial/sequence_parallel/model/layers/head.py @@ -0,0 +1,78 @@ +import colossalai +import torch +import torch.nn as nn +import torch.nn.functional as F +from .pooler import Pooler +from .linear import Linear +from .embedding import VocabEmbedding +from colossalai.core import global_context as gpc +from colossalai.context import ParallelMode +from colossalai.kernel import LayerNorm +from loss_func.cross_entropy import vocab_cross_entropy + + +class BertLMHead(nn.Module): + """Masked LM head for Bert + Arguments: + hidden_size: hidden size + init_method: init method for weight initialization + layernorm_epsilon: tolerance for layer norm divisions + """ + + def __init__(self, + vocab_size, + hidden_size, + ): + + super(BertLMHead, self).__init__() + self.bias = torch.nn.Parameter(torch.zeros(vocab_size)) + + self.dense = Linear(hidden_size, hidden_size) + self.layernorm = LayerNorm(hidden_size) + self.gelu = torch.nn.functional.gelu + + def forward(self, hidden_states, word_embeddings_weight, lm_labels): + hidden_states = self.dense(hidden_states) + hidden_states = self.gelu(hidden_states) + hidden_states = self.layernorm(hidden_states) + + output = F.linear(hidden_states, word_embeddings_weight, self.bias) + lm_loss = vocab_cross_entropy(output, lm_labels) + + return lm_loss + + +class BertBinaryHead(nn.Module): + + def __init__(self, hidden_size): + super().__init__() + self.pooler = Pooler(hidden_size) + self.dense = Linear(hidden_size, 2) + + def forward(self, hidden_states): + if gpc.get_local_rank(ParallelMode.SEQUENCE) == 0: + output = self.pooler(hidden_states) + output = self.dense(output) + else: + output = None + return output + + +class BertDualHead(nn.Module): + + def __init__(self, hidden_size, vocab_size, add_binary_head): + super().__init__() + self.lm_head = BertLMHead(vocab_size, hidden_size) + self.add_binary_head = add_binary_head + if add_binary_head: + self.binary_head = BertBinaryHead(hidden_size) + else: + self.binary_head = None + + def forward(self, hidden_states, word_embeddings_weight, lm_labels): + if self.add_binary_head: + binary_output = self.binary_head(hidden_states) + else: + binary_output = None + lm_loss = self.lm_head(hidden_states, word_embeddings_weight, lm_labels) + return lm_loss, binary_output diff --git a/examples/tutorial/sequence_parallel/model/layers/init_method.py b/examples/tutorial/sequence_parallel/model/layers/init_method.py new file mode 100644 index 0000000000000000000000000000000000000000..1b409dfe40541524891f70fc7c7d8297afa86999 --- /dev/null +++ b/examples/tutorial/sequence_parallel/model/layers/init_method.py @@ -0,0 +1,12 @@ +import torch +import math + +def init_normal(tensor, sigma): + """Init method based on N(0, sigma).""" + torch.nn.init.normal_(tensor, mean=0.0, std=sigma) + + +def output_init_normal(tensor, sigma, num_layers): + """Init method based on N(0, sigma/sqrt(2*num_layers).""" + std = sigma / math.sqrt(2.0 * num_layers) + torch.nn.init.normal_(tensor, mean=0.0, std=std) diff --git a/examples/tutorial/sequence_parallel/model/layers/linear.py b/examples/tutorial/sequence_parallel/model/layers/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..5ae7d671e2bf2312da315d35629dcdc12ca075de --- /dev/null +++ b/examples/tutorial/sequence_parallel/model/layers/linear.py @@ -0,0 +1,63 @@ +import torch +import torch.nn as nn +from torch.nn import Parameter +import torch.nn.functional as F +import torch.nn.init as init + + +class Linear(nn.Module): + """Linear layer with column parallelism. + The linear layer is defined as Y = XA + b. A is parallelized along + its second dimension as A = [A_1, ..., A_p]. + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias + init_method: method to initialize weights. Note that bias is always set + to zero. + stride: For the strided linear layers. + keep_master_weight_for_test: This was added for testing and should be + set to False. It returns the master weights + used for initialization. + skip_bias_add: This was added to enable performance optimations where bias + can be fused with other elementwise operations. we skip + adding bias but instead return it. + """ + + def __init__(self, + input_size, + output_size, + bias=True, + skip_bias_add=False): + super(Linear, self).__init__() + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.skip_bias_add = skip_bias_add + + self.weight = Parameter(torch.empty(self.output_size, + self.input_size, + )) + init.normal_(self.weight) + if bias: + self.bias = Parameter(torch.empty(self.output_size)) + # Always initialize bias to zero. + with torch.no_grad(): + self.bias.zero_() + else: + self.register_parameter('bias', None) + + def forward(self, input_): + # Matrix multiply. + bias = self.bias if not self.skip_bias_add else None + output = F.linear(input_, self.weight, bias) + + if self.skip_bias_add: + return output, self.bias + else: + return output + + def __repr__(self): + return f'Linear(in_features={self.input_size}, out_features={self.output_size}, ' + \ + f'bias={self.bias is not None}, skip_bias_add={self.skip_bias_add})' diff --git a/examples/tutorial/sequence_parallel/model/layers/mlp.py b/examples/tutorial/sequence_parallel/model/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..a255de813d135e5c79a281c86462340337c5c036 --- /dev/null +++ b/examples/tutorial/sequence_parallel/model/layers/mlp.py @@ -0,0 +1,50 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .linear import Linear +from colossalai.kernel.jit import bias_gelu_impl + + +class TransformerMLP(nn.Module): + """MLP. + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. At the end, dropout is also + applied. + """ + + def __init__(self, hidden_size, mlp_ratio, fuse_gelu=True): + super(TransformerMLP, self).__init__() + + # Project to 4h. + self.dense_h_to_4h = Linear( + hidden_size, + int(hidden_size*mlp_ratio), + skip_bias_add=True) + + self.bias_gelu_fusion = fuse_gelu + self.activation_func = F.gelu + + # Project back to h. + self.dense_4h_to_h = Linear( + int(hidden_size*mlp_ratio), + hidden_size, + skip_bias_add=True) + + def forward(self, hidden_states): + # hidden states should be in the shape of [s, b, h] + # it will be projects into [s, b, 4h] + # and projected back to [s, b, h] + intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states) + + if self.bias_gelu_fusion: + intermediate_parallel = \ + bias_gelu_impl(intermediate_parallel, bias_parallel) + else: + intermediate_parallel = \ + self.activation_func(intermediate_parallel + bias_parallel) + + # [s, b, h] + output, output_bias = self.dense_4h_to_h(intermediate_parallel) + return output, output_bias diff --git a/examples/tutorial/sequence_parallel/model/layers/pooler.py b/examples/tutorial/sequence_parallel/model/layers/pooler.py new file mode 100644 index 0000000000000000000000000000000000000000..282ed114790b32618c1a92924bce403167d1b89e --- /dev/null +++ b/examples/tutorial/sequence_parallel/model/layers/pooler.py @@ -0,0 +1,28 @@ +import torch +import torch.nn as nn +from .linear import Linear + + +class Pooler(nn.Module): + """Pooler layer. + + Pool hidden states of a specific token (for example start of the + sequence) and add a linear transformation followed by a tanh. + + Arguments: + hidden_size: hidden size + init_method: weight initialization method for the linear layer. + bias is set to zero. + """ + + def __init__(self, hidden_size): + super(Pooler, self).__init__() + self.dense = Linear(hidden_size, hidden_size) + + def forward(self, hidden_states, sequence_index=0): + # hidden_states: [b, s, h] + # sequence_index: index of the token to pool. + pooled = hidden_states[:, sequence_index, :] + pooled = self.dense(pooled) + pooled = torch.tanh(pooled) + return pooled diff --git a/examples/tutorial/sequence_parallel/model/layers/preprocess.py b/examples/tutorial/sequence_parallel/model/layers/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..53a326ddacf14f7babaa18390a04b962bd84c809 --- /dev/null +++ b/examples/tutorial/sequence_parallel/model/layers/preprocess.py @@ -0,0 +1,58 @@ +from colossalai.context.parallel_mode import ParallelMode +import torch +import torch.nn as nn +from colossalai.core import global_context as gpc + + +class PreProcessor(nn.Module): + + def __init__(self, sub_seq_length): + super().__init__() + self.sub_seq_length = sub_seq_length + + def bert_position_ids(self, token_ids): + # Create position ids + seq_length = token_ids.size(1) + local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE) + position_ids = torch.arange(seq_length*local_rank, + seq_length * (local_rank+1), + dtype=torch.long, + device=token_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(token_ids) + + return position_ids + + def bert_extended_attention_mask(self, attention_mask): + local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE) + start_index = local_rank * self.sub_seq_length + end_index = (local_rank + 1) * self.sub_seq_length + + # We create a 3D attention mask from a 2D tensor mask. + # [b, 1, s] + attention_mask_b1s = attention_mask.unsqueeze(1) + # [b, s, 1] + attention_mask_bs1 = attention_mask.unsqueeze(2) + # [b, s/D, s] + attention_mask_bss = attention_mask_b1s * attention_mask_bs1 + + attention_mask_bss = attention_mask_bss[:, start_index:end_index, :] + + # [b, 1, s/D, s] + extended_attention_mask = attention_mask_bss.unsqueeze(1) + + # Convert attention mask to binary: + extended_attention_mask = (extended_attention_mask < 0.5) + + return extended_attention_mask + + def forward(self, input_ids=None, attention_mask=None): + if attention_mask is not None: + extended_attention_mask = self.bert_extended_attention_mask(attention_mask) + else: + extended_attention_mask = None + + if input_ids is not None: + position_ids = self.bert_position_ids(input_ids) + else: + position_ids = None + return position_ids, extended_attention_mask diff --git a/examples/tutorial/sequence_parallel/requirements.txt b/examples/tutorial/sequence_parallel/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..b49a94554afb699639f9da7d04808be760c051c0 --- /dev/null +++ b/examples/tutorial/sequence_parallel/requirements.txt @@ -0,0 +1,2 @@ +colossalai +torch diff --git a/examples/tutorial/sequence_parallel/test_ci.sh b/examples/tutorial/sequence_parallel/test_ci.sh new file mode 100644 index 0000000000000000000000000000000000000000..7bc20de3b6e414a8bc74bca5fcb2d17ec18b1106 --- /dev/null +++ b/examples/tutorial/sequence_parallel/test_ci.sh @@ -0,0 +1,7 @@ +#!/bin/bash +set -euxo pipefail + +pip install -r requirements.txt + +# run test +colossalai run --nproc_per_node 4 train.py diff --git a/examples/tutorial/sequence_parallel/train.py b/examples/tutorial/sequence_parallel/train.py new file mode 100644 index 0000000000000000000000000000000000000000..a89747b5845eb1a4f9e663373a015daf39d169a2 --- /dev/null +++ b/examples/tutorial/sequence_parallel/train.py @@ -0,0 +1,218 @@ +import argparse + +import torch +from data.bert_helper import SequenceParallelDataIterator, get_batch_for_sequence_parallel +from data.dummy_dataloader import DummyDataloader +from loss_func.bert_loss import BertLoss +from lr_scheduler import AnnealingLR +from model.bert import BertForPretrain, build_pipeline_bert + +import colossalai +from colossalai.amp import AMP_TYPE +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.engine.schedule import PipelineSchedule +from colossalai.kernel import LayerNorm +from colossalai.logging import get_dist_logger +from colossalai.nn.optimizer import FusedAdam +from colossalai.utils import MultiTimer, is_using_pp + + +def process_batch_data(batch_data): + tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = batch_data + if gpc.is_first_rank(ParallelMode.PIPELINE): + data = dict(input_ids=tokens, attention_masks=padding_mask, tokentype_ids=types, lm_labels=lm_labels) + else: + data = dict(attention_masks=padding_mask, tokentype_ids=types, lm_labels=lm_labels) + label = dict(loss_mask=loss_mask, sentence_order=sentence_order) + return data, label + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('-s', '--synthetic', action="store_true", help="whether use synthetic data") + return parser.parse_args() + + +def pipeline_data_process_func(stage_output, micro_batch_data): + tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = micro_batch_data + if gpc.is_first_rank(ParallelMode.PIPELINE): + data = (tokens, padding_mask, types, lm_labels) + label = (loss_mask, sentence_order) + else: + data = (stage_output, padding_mask, types, lm_labels) + label = (loss_mask, sentence_order) + return data, label + + +def main(): + # initialize + args = parse_args() + colossalai.launch_from_torch(config='./config.py', seed=1234, backend='nccl') + + logger = get_dist_logger() + + # build synthetic dataloader + BATCH_SIZE_PER_GPUS = gpc.config.GLOBAL_BATCH_SIZE // gpc.get_world_size(ParallelMode.DATA) + VOCAB_SIZE = 30528 + trainloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS, + vocab_size=VOCAB_SIZE, + seq_length=gpc.config.SEQ_LENGTH) + validloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS, + vocab_size=VOCAB_SIZE, + seq_length=gpc.config.SEQ_LENGTH) + + logger.info("Dataloaders are built", ranks=[0]) + + # build model + if hasattr(gpc.config, 'fp16') and gpc.config.fp16.get('mode') == AMP_TYPE.NAIVE: + is_naive_fp16 = True + else: + is_naive_fp16 = False + + use_pipeline = is_using_pp() + kwargs = dict(vocab_size=VOCAB_SIZE, + hidden_size=gpc.config.HIDDEN_SIZE, + max_sequence_length=gpc.config.SEQ_LENGTH, + num_attention_heads=gpc.config.NUM_ATTENTION_HEADS, + convert_fp16_to_fp32_in_softmax=True, + is_naive_fp16=is_naive_fp16, + add_binary_head=gpc.config.ADD_BINARY_HEAD) + + if use_pipeline: + model = build_pipeline_bert(num_layers=gpc.config.DEPTH, num_chunks=1, **kwargs) + else: + model = BertForPretrain(num_layers=gpc.config.DEPTH, **kwargs) + + model = model.half() + model.reset_parameters() + logger.info(f"Model is built with softmax in fp32 = {is_naive_fp16}", ranks=[0]) + + total_numel = 0 + for p in model.parameters(): + total_numel += p.numel() + logger.info(f"This model has {total_numel} parameters") + + # build criterion + criterion = BertLoss() + logger.info("Criterion is built", ranks=[0]) + + # layernorm and bias has no weight decay + weight_decay_params = {'params': []} + no_weight_decay_params = {'params': [], 'weight_decay': 0.0} + for module_ in model.modules(): + if isinstance(module_, LayerNorm): + no_weight_decay_params['params'].extend([p for p in list(module_._parameters.values()) if p is not None]) + else: + weight_decay_params['params'].extend( + [p for n, p in list(module_._parameters.items()) if p is not None and n != 'bias']) + no_weight_decay_params['params'].extend( + [p for n, p in list(module_._parameters.items()) if p is not None and n == 'bias']) + + logger.info( + f"without weight decay param: {len(no_weight_decay_params['params'])}, with weight decay param: {len(weight_decay_params['params'])}" + ) + # optimizer + optimizer = FusedAdam((weight_decay_params, no_weight_decay_params), + lr=gpc.config.LR, + weight_decay=gpc.config.WEIGHT_DECAY) + logger.info("Optimizer is built", ranks=[0]) + + # lr scheduler + # follow Megatron-LM setting + warmup_steps = int(gpc.config.DECAY_ITERS * gpc.config.WARMUP_FRACTION) + lr_scheduler = AnnealingLR(optimizer=optimizer, + max_lr=gpc.config.LR, + min_lr=gpc.config.MIN_LR, + warmup_steps=warmup_steps, + decay_steps=gpc.config.DECAY_ITERS, + decay_style='linear') + logger.info(f"LR Scheduler is built with {warmup_steps} warmup steps and {gpc.config.DECAY_ITERS} decay steps") + + # # init + engine, *dummy = colossalai.initialize(model, optimizer, criterion, verbose=True) + + # build timer + timer = MultiTimer() + skip_iters = 0 + + # build loss tracker + accumulated_train_loss = torch.zeros(1, dtype=torch.float32).cuda() + accumulated_eval_loss = torch.zeros(1, dtype=torch.float32).cuda() + + # build data iters for pipeline parallel + if use_pipeline: + train_data_iter = SequenceParallelDataIterator(trainloader) + valid_data_iter = SequenceParallelDataIterator(validloader) + engine.schedule.data_process_func = pipeline_data_process_func + + logger.info("start training") + + for step in range(1, gpc.config.TRAIN_ITERS + 1): + timer.start('train-iterations') + engine.train() + if use_pipeline: + engine.zero_grad() + _, _, train_loss = engine.execute_schedule(train_data_iter, return_output_label=False) + engine.step() + else: + tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch_for_sequence_parallel( + trainloader) + engine.zero_grad() + lm_loss, sop_output = engine(tokens, padding_mask, types, lm_labels) + train_loss = engine.criterion(lm_loss, sop_output, loss_mask, sentence_order) + engine.backward(train_loss) + engine.step() + timer.stop('train-iterations', keep_in_history=True) + + if not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE): + accumulated_train_loss += train_loss + + lr_scheduler.step() + + if step % gpc.config.EVAL_INTERVAL == 0: + engine.eval() + + for j in range(gpc.config.EVAL_ITERS): + with torch.no_grad(): + if use_pipeline: + _, _, eval_loss = engine.execute_schedule(valid_data_iter, + forward_only=True, + return_output_label=False) + else: + tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch_for_sequence_parallel( + validloader) + lm_loss, sop_output = engine(tokens, padding_mask, types, lm_labels) + eval_loss = engine.criterion(lm_loss, sop_output, loss_mask, sentence_order) + + if not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE): + accumulated_eval_loss += eval_loss + + if not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE): + accumulated_eval_loss /= gpc.config.EVAL_ITERS + accumulated_train_loss /= gpc.config.EVAL_INTERVAL + + timer_string = [] + for n, t in timer: + timer_string.append(f"{n}: {t.get_history_mean()*1000:.5f}") + timer_string = ' | '.join(timer_string) + lr = list(engine.optimizer.param_groups)[0]['lr'] + loss_scale = engine.optimizer.optim.loss_scale.item() + + if gpc.is_initialized(ParallelMode.PIPELINE): + ranks = [gpc.get_ranks_in_group(ParallelMode.PIPELINE)[-1]] + else: + ranks = [0] + logger.info(f'Step {step} / {gpc.config.TRAIN_ITERS} | Train Loss: {accumulated_train_loss.item():.5g} ' + + f'| Eval Loss: {accumulated_eval_loss.item():.5g} ' + f'| Loss Scale: {loss_scale}' + + f"| Learning rate: {lr} | " + timer_string, + ranks=ranks) + + for n, t in timer: + t.reset() + accumulated_eval_loss.zero_() + accumulated_train_loss.zero_() + + +if __name__ == '__main__': + main() diff --git a/inference b/inference new file mode 160000 index 0000000000000000000000000000000000000000..56b35f3c06eaac11b1bee633d1e836563f74bcea --- /dev/null +++ b/inference @@ -0,0 +1 @@ +Subproject commit 56b35f3c06eaac11b1bee633d1e836563f74bcea diff --git a/op_builder/README.md b/op_builder/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9c33a4a328d7c211fe30cc7ecd9bdfb96a152ff3 --- /dev/null +++ b/op_builder/README.md @@ -0,0 +1,32 @@ +# Build PyTorch Extensions + +## Overview + +Building PyTorch extensions can be a difficult task for users not from the system background. It is definitely frustrating if the users encounter many strange technical jargons when install Colossal-AI. Therefore, we will provide two methods of building the PyTorch extensions for the users. + +1. Build CUDA extensions when running `pip install` if `CUDA_EXT=1` +2. Build the extension during runtime + +The first method is more suitable for users who are familiar with CUDA environment configurations. The second method is for those who are not as they only need to build the kernel which is required by their program. + +These two methods have different advantages and disadvantages. +Method 1 is good because it allows the user to build all kernels during installation and directly import the kernel. They don't need to care about kernel building when running their program. However, installation may fail if they don't know how to configure their environments and this leads to much frustration. +Method 2 is good because it allows the user to only build the kernel they actually need, such that there is a lower probability that they encounter environment issue. However, it may slow down their program due to the first build and subsequence load. + +## PyTorch Extensions in Colossal-AI + +The project [DeepSpeed](https://github.com/microsoft/DeepSpeed) has proposed a [solution](https://github.com/microsoft/DeepSpeed/tree/master/op_builder) to support kernel-build during either installation or runtime. +We have adapted from DeepSpeed's solution to build extensions. The extension build requires two main functions from PyTorch: + +1. `torch.utils.cpp_extension.CUDAExtension`: used to build extensions in `setup.py` during `pip install`. +2. `torch.utils.cpp_extension.load`: used to build and load extension during runtime + +Please note that the extension build by `CUDAExtension` cannot be loaded by the `load` function and `load` will run its own build again (correct me if I am wrong). + +Based on the DeepSpeed's work, we have make several modifications and improvements: + +1. All pre-built kernels (those installed with `setup.py`) will be found in `colossalai._C` +2. All runtime-built kernels will be found in the default torch extension path, i.e. ~/.cache/colossalai/torch_extensions. (If we put the built kernels in the installed site-package directory, this will make pip uninstall incomplete) +3. Once a kernel is loaded, we will cache it in the builder to avoid repeated kernel loading. + +When loading the built kernel, we will first check if the pre-built one exists. If not, the runtime build will be triggered. diff --git a/op_builder/__init__.py b/op_builder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5ae7223b8c692d5570d74a1ef72c0fe1ad61c28a --- /dev/null +++ b/op_builder/__init__.py @@ -0,0 +1,23 @@ +from .cpu_adam import CPUAdamBuilder +from .fused_optim import FusedOptimBuilder +from .layernorm import LayerNormBuilder +from .moe import MOEBuilder +from .multi_head_attn import MultiHeadAttnBuilder +from .scaled_masked_softmax import ScaledMaskedSoftmaxBuilder +from .scaled_upper_triangle_masked_softmax import ScaledUpperTrainglemaskedSoftmaxBuilder + +ALL_OPS = { + 'cpu_adam': CPUAdamBuilder, + 'fused_optim': FusedOptimBuilder, + 'moe': MOEBuilder, + 'multi_head_attn': MultiHeadAttnBuilder, + 'scaled_masked_softmax': ScaledMaskedSoftmaxBuilder, + 'scaled_upper_triangle_masked_softmax': ScaledUpperTrainglemaskedSoftmaxBuilder, + 'layernorm': LayerNormBuilder, +} + +__all__ = [ + 'ALL_OPS', 'CPUAdamBuilder', 'FusedOptimBuilder', 'MultiHeadAttnBuilder', 'ScaledMaskedSoftmaxBuilder', + 'ScaledUpperTrainglemaskedSoftmaxBuilder', 'MOEBuilder', 'MultiTensorSGDBuilder', 'MultiTensorAdamBuilder', + 'MultiTensorLambBuilder', 'MultiTensorScaleBuilder', 'MultiTensorL2NormBuilder' +] diff --git a/op_builder/builder.py b/op_builder/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..8396235e5cfe89c6a4edccc4c7eb33ceefbbe11b --- /dev/null +++ b/op_builder/builder.py @@ -0,0 +1,219 @@ +# This code has been adapted from the DeepSpeed library. +# Copyright (c) Microsoft Corporation. + +# Licensed under the MIT License. +import importlib +import os +import time +from abc import ABC, abstractmethod +from pathlib import Path +from typing import List, Optional + +from .utils import check_cuda_availability, check_system_pytorch_cuda_match, print_rank_0 + + +class Builder(ABC): + """ + Builder is the base class to build extensions for PyTorch. + + Args: + name (str): the name of the kernel to be built + prebuilt_import_path (str): the path where the extension is installed during pip install + """ + + def __init__(self, name: str, prebuilt_import_path: str): + self.name = name + self.prebuilt_import_path = prebuilt_import_path + self.version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5'] + + # we store the op as an attribute to avoid repeated building and loading + self.cached_op_module = None + + assert prebuilt_import_path.startswith('colossalai._C'), \ + f'The prebuilt_import_path should start with colossalai._C, but got {self.prebuilt_import_path}' + + def relative_to_abs_path(self, code_path: str) -> str: + """ + This function takes in a path relative to the colossalai root directory and return the absolute path. + """ + op_builder_module_path = Path(__file__).parent + + # if we install from source + # the current file path will be op_builder/builder.py + # if we install via pip install colossalai + # the current file path will be colossalai/kernel/op_builder/builder.py + # this is because that the op_builder inside colossalai is a symlink + # this symlink will be replaced with actual files if we install via pypi + # thus we cannot tell the colossalai root directory by checking whether the op_builder + # is a symlink, we can only tell whether it is inside or outside colossalai + if str(op_builder_module_path).endswith('colossalai/kernel/op_builder'): + root_path = op_builder_module_path.parent.parent + else: + root_path = op_builder_module_path.parent.joinpath('colossalai') + + code_abs_path = root_path.joinpath(code_path) + return str(code_abs_path) + + def get_cuda_home_include(self): + """ + return include path inside the cuda home. + """ + from torch.utils.cpp_extension import CUDA_HOME + if CUDA_HOME is None: + raise RuntimeError("CUDA_HOME is None, please set CUDA_HOME to compile C++/CUDA kernels in ColossalAI.") + cuda_include = os.path.join(CUDA_HOME, "include") + return cuda_include + + def csrc_abs_path(self, path): + return os.path.join(self.relative_to_abs_path('kernel/cuda_native/csrc'), path) + + # functions must be overrided begin + @abstractmethod + def sources_files(self) -> List[str]: + """ + This function should return a list of source files for extensions. + """ + raise NotImplementedError + + @abstractmethod + def include_dirs(self) -> List[str]: + """ + This function should return a list of include files for extensions. + """ + pass + + @abstractmethod + def cxx_flags(self) -> List[str]: + """ + This function should return a list of cxx compilation flags for extensions. + """ + pass + + @abstractmethod + def nvcc_flags(self) -> List[str]: + """ + This function should return a list of nvcc compilation flags for extensions. + """ + pass + + # functions must be overrided over + def strip_empty_entries(self, args): + ''' + Drop any empty strings from the list of compile and link flags + ''' + return [x for x in args if len(x) > 0] + + def import_op(self): + """ + This function will import the op module by its string name. + """ + return importlib.import_module(self.prebuilt_import_path) + + def check_runtime_build_environment(self): + """ + Check whether the system environment is ready for extension compilation. + """ + try: + import torch + from torch.utils.cpp_extension import CUDA_HOME + TORCH_AVAILABLE = True + except ImportError: + TORCH_AVAILABLE = False + CUDA_HOME = None + + if not TORCH_AVAILABLE: + raise ModuleNotFoundError( + "PyTorch is not found. You need to install PyTorch first in order to build CUDA extensions") + + if CUDA_HOME is None: + raise RuntimeError( + "CUDA_HOME is not found. You need to export CUDA_HOME environment variable or install CUDA Toolkit first in order to build CUDA extensions" + ) + + # make sure CUDA is available for compilation during + cuda_available = check_cuda_availability() + if not cuda_available: + raise RuntimeError("CUDA is not available on your system as torch.cuda.is_available() returns False.") + + # make sure system CUDA and pytorch CUDA match, an error will raised inside the function if not + check_system_pytorch_cuda_match(CUDA_HOME) + + def load(self, verbose: Optional[bool] = None): + """ + load the kernel during runtime. If the kernel is not built during pip install, it will build the kernel. + If the kernel is built during runtime, it will be stored in `~/.cache/colossalai/torch_extensions/`. If the + kernel is built during pip install, it can be accessed through `colossalai._C`. + + Warning: do not load this kernel repeatedly during model execution as it could slow down the training process. + + Args: + verbose (bool, optional): show detailed info. Defaults to True. + """ + if verbose is None: + verbose = os.environ.get('CAI_KERNEL_VERBOSE', '0') == '1' + # if the kernel has be compiled and cached, we directly use it + if self.cached_op_module is not None: + return self.cached_op_module + + try: + # if the kernel has been pre-built during installation + # we just directly import it + op_module = self.import_op() + if verbose: + print_rank_0( + f"[extension] OP {self.prebuilt_import_path} has been compiled ahead of time, skip building.") + except ImportError: + # check environment + self.check_runtime_build_environment() + + # time the kernel compilation + start_build = time.time() + + # construct the build directory + import torch + from torch.utils.cpp_extension import load + torch_version_major = torch.__version__.split('.')[0] + torch_version_minor = torch.__version__.split('.')[1] + torch_cuda_version = torch.version.cuda + home_directory = os.path.expanduser('~') + extension_directory = f".cache/colossalai/torch_extensions/torch{torch_version_major}.{torch_version_minor}_cu{torch_cuda_version}" + build_directory = os.path.join(home_directory, extension_directory) + Path(build_directory).mkdir(parents=True, exist_ok=True) + + if verbose: + print_rank_0(f"[extension] Compiling or loading the JIT-built {self.name} kernel during runtime now") + + # load the kernel + op_module = load(name=self.name, + sources=self.strip_empty_entries(self.sources_files()), + extra_include_paths=self.strip_empty_entries(self.include_dirs()), + extra_cflags=self.cxx_flags(), + extra_cuda_cflags=self.nvcc_flags(), + extra_ldflags=[], + build_directory=build_directory, + verbose=verbose) + + build_duration = time.time() - start_build + + # log jit compilation time + if verbose: + print_rank_0(f"[extension] Time to compile or load {self.name} op: {build_duration} seconds") + + # cache the built/loaded kernel + self.cached_op_module = op_module + + return op_module + + def builder(self) -> 'CUDAExtension': + """ + get a CUDAExtension instance used for setup.py + """ + from torch.utils.cpp_extension import CUDAExtension + + return CUDAExtension(name=self.prebuilt_import_path, + sources=self.strip_empty_entries(self.sources_files()), + include_dirs=self.strip_empty_entries(self.include_dirs()), + extra_compile_args={ + 'cxx': self.strip_empty_entries(self.cxx_flags()), + 'nvcc': self.strip_empty_entries(self.nvcc_flags()) + }) diff --git a/op_builder/cpu_adam.py b/op_builder/cpu_adam.py new file mode 100644 index 0000000000000000000000000000000000000000..500e2cc0eddc55c0ecb5b40be5e979d912adb0c2 --- /dev/null +++ b/op_builder/cpu_adam.py @@ -0,0 +1,38 @@ +import os + +from .builder import Builder +from .utils import append_nvcc_threads + + +class CPUAdamBuilder(Builder): + NAME = "cpu_adam" + PREBUILT_IMPORT_PATH = "colossalai._C.cpu_adam" + + def __init__(self): + super().__init__(name=CPUAdamBuilder.NAME, prebuilt_import_path=CPUAdamBuilder.PREBUILT_IMPORT_PATH) + self.version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5'] + + # necessary 4 functions + def sources_files(self): + ret = [ + self.csrc_abs_path('cpu_adam.cpp'), + ] + return ret + + def include_dirs(self): + return [ + self.csrc_abs_path("includes"), + self.get_cuda_home_include() + ] + + def cxx_flags(self): + extra_cxx_flags = ['-std=c++14', '-lcudart', '-lcublas', '-g', '-Wno-reorder', '-fopenmp', '-march=native'] + return ['-O3'] + self.version_dependent_macros + extra_cxx_flags + + def nvcc_flags(self): + extra_cuda_flags = [ + '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', + '-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK' + ] + ret = ['-O3', '--use_fast_math'] + self.version_dependent_macros + extra_cuda_flags + return append_nvcc_threads(ret) diff --git a/op_builder/fused_optim.py b/op_builder/fused_optim.py new file mode 100644 index 0000000000000000000000000000000000000000..31ddfced1db24f18fb37cb549dc27060ae2fcff6 --- /dev/null +++ b/op_builder/fused_optim.py @@ -0,0 +1,34 @@ +import os + +from .builder import Builder +from .utils import get_cuda_cc_flag + + +class FusedOptimBuilder(Builder): + NAME = "fused_optim" + PREBUILT_IMPORT_PATH = "colossalai._C.fused_optim" + + def __init__(self): + super().__init__(name=FusedOptimBuilder.NAME, prebuilt_import_path=FusedOptimBuilder.PREBUILT_IMPORT_PATH) + + def sources_files(self): + ret = [ + self.csrc_abs_path(fname) for fname in [ + 'colossal_C_frontend.cpp', 'multi_tensor_sgd_kernel.cu', 'multi_tensor_scale_kernel.cu', + 'multi_tensor_adam.cu', 'multi_tensor_l2norm_kernel.cu', 'multi_tensor_lamb.cu' + ] + ] + return ret + + def include_dirs(self): + ret = [self.csrc_abs_path('kernels/include'), self.get_cuda_home_include()] + return ret + + def cxx_flags(self): + version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5'] + return ['-O3'] + version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = ['-lineinfo'] + extra_cuda_flags.extend(get_cuda_cc_flag()) + return ['-O3', '--use_fast_math'] + extra_cuda_flags diff --git a/op_builder/layernorm.py b/op_builder/layernorm.py new file mode 100644 index 0000000000000000000000000000000000000000..61d9417419293c4cc3d4835406f7b7756ffda845 --- /dev/null +++ b/op_builder/layernorm.py @@ -0,0 +1,29 @@ +import os + +from .builder import Builder +from .utils import append_nvcc_threads, get_cuda_cc_flag + + +class LayerNormBuilder(Builder): + NAME = "layernorm" + PREBUILT_IMPORT_PATH = "colossalai._C.layernorm" + + def __init__(self): + super().__init__(name=LayerNormBuilder.NAME, prebuilt_import_path=LayerNormBuilder.PREBUILT_IMPORT_PATH) + + def sources_files(self): + ret = [self.csrc_abs_path(fname) for fname in ['layer_norm_cuda.cpp', 'layer_norm_cuda_kernel.cu']] + return ret + + def include_dirs(self): + ret = [self.csrc_abs_path('kernels/include'), self.get_cuda_home_include()] + return ret + + def cxx_flags(self): + return ['-O3'] + self.version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = ['-maxrregcount=50'] + extra_cuda_flags.extend(get_cuda_cc_flag()) + ret = ['-O3', '--use_fast_math'] + extra_cuda_flags + self.version_dependent_macros + return append_nvcc_threads(ret) diff --git a/op_builder/moe.py b/op_builder/moe.py new file mode 100644 index 0000000000000000000000000000000000000000..eeb7d8e3980c095d297fb985fbd961a3f852602a --- /dev/null +++ b/op_builder/moe.py @@ -0,0 +1,36 @@ +import os + +from .builder import Builder +from .utils import append_nvcc_threads, get_cuda_cc_flag + + +class MOEBuilder(Builder): + + NAME = "moe" + PREBUILT_IMPORT_PATH = "colossalai._C.moe" + + def __init__(self): + super().__init__(name=MOEBuilder.NAME, prebuilt_import_path=MOEBuilder.PREBUILT_IMPORT_PATH) + + def include_dirs(self): + ret = [ + self.csrc_abs_path("kernels/include"), + self.get_cuda_home_include() + ] + return ret + + def sources_files(self): + ret = [self.csrc_abs_path(fname) for fname in ['moe_cuda.cpp', 'moe_cuda_kernel.cu']] + return ret + + def cxx_flags(self): + return ['-O3'] + self.version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = [ + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '--expt-relaxed-constexpr', + '--expt-extended-lambda' + ] + extra_cuda_flags.extend(get_cuda_cc_flag()) + ret = ['-O3', '--use_fast_math'] + extra_cuda_flags + return append_nvcc_threads(ret) diff --git a/op_builder/multi_head_attn.py b/op_builder/multi_head_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..f9103fe947297441d708dbbb42a5ce748ac5c3d7 --- /dev/null +++ b/op_builder/multi_head_attn.py @@ -0,0 +1,41 @@ +import os + +from .builder import Builder +from .utils import append_nvcc_threads, get_cuda_cc_flag + + +class MultiHeadAttnBuilder(Builder): + + NAME = "multihead_attention" + PREBUILT_IMPORT_PATH = "colossalai._C.multihead_attention" + + def __init__(self): + super().__init__(name=MultiHeadAttnBuilder.NAME, + prebuilt_import_path=MultiHeadAttnBuilder.PREBUILT_IMPORT_PATH) + + + def include_dirs(self): + ret = [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] + return ret + + def sources_files(self): + ret = [ + self.csrc_abs_path(fname) for fname in [ + 'multihead_attention_1d.cpp', 'kernels/cublas_wrappers.cu', 'kernels/transform_kernels.cu', + 'kernels/dropout_kernels.cu', 'kernels/normalize_kernels.cu', 'kernels/softmax_kernels.cu', + 'kernels/general_kernels.cu', 'kernels/cuda_util.cu' + ] + ] + return ret + + def cxx_flags(self): + return ['-O3'] + self.version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = [ + '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', + '-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK' + ] + extra_cuda_flags.extend(get_cuda_cc_flag()) + ret = ['-O3', '--use_fast_math'] + self.version_dependent_macros + extra_cuda_flags + return append_nvcc_threads(ret) diff --git a/op_builder/scaled_masked_softmax.py b/op_builder/scaled_masked_softmax.py new file mode 100644 index 0000000000000000000000000000000000000000..11cfda39a85c799578653c909cf79539fc24cb48 --- /dev/null +++ b/op_builder/scaled_masked_softmax.py @@ -0,0 +1,37 @@ +import os + +from .builder import Builder +from .utils import append_nvcc_threads + + +class ScaledMaskedSoftmaxBuilder(Builder): + NAME = "scaled_masked_softmax" + PREBUILT_IMPORT_PATH = "colossalai._C.scaled_masked_softmax" + + def __init__(self): + super().__init__(name=ScaledMaskedSoftmaxBuilder.NAME, prebuilt_import_path=ScaledMaskedSoftmaxBuilder.PREBUILT_IMPORT_PATH) + + # necessary 4 functions + def sources_files(self): + ret = [ + self.csrc_abs_path(fname) for fname in + ['scaled_masked_softmax.cpp', 'scaled_masked_softmax_cuda.cu'] + ] + return ret + + def include_dirs(self): + return [ + self.csrc_abs_path("kernels/include"), + self.get_cuda_home_include() + ] + + def cxx_flags(self): + return ['-O3'] + self.version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = [ + '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', + '-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK' + ] + ret = ['-O3', '--use_fast_math'] + self.version_dependent_macros + extra_cuda_flags + return append_nvcc_threads(ret) diff --git a/op_builder/scaled_upper_triangle_masked_softmax.py b/op_builder/scaled_upper_triangle_masked_softmax.py new file mode 100644 index 0000000000000000000000000000000000000000..d0d2433aa64527abb662381b08b42aabb9d61807 --- /dev/null +++ b/op_builder/scaled_upper_triangle_masked_softmax.py @@ -0,0 +1,37 @@ +import os + +from .builder import Builder +from .utils import append_nvcc_threads, get_cuda_cc_flag + + +class ScaledUpperTrainglemaskedSoftmaxBuilder(Builder): + NAME = "scaled_upper_triangle_masked_softmax" + PREBUILT_IMPORT_PATH = "colossalai._C.scaled_upper_triangle_masked_softmax" + + def __init__(self): + super().__init__(name=ScaledUpperTrainglemaskedSoftmaxBuilder.NAME, prebuilt_import_path=ScaledUpperTrainglemaskedSoftmaxBuilder.PREBUILT_IMPORT_PATH) + + def include_dirs(self): + return [ + self.csrc_abs_path("kernels/include"), + self.get_cuda_home_include() + ] + + def sources_files(self): + ret = [ + self.csrc_abs_path(fname) + for fname in ['scaled_upper_triang_masked_softmax.cpp', 'scaled_upper_triang_masked_softmax_cuda.cu'] + ] + return ret + + def cxx_flags(self): + return ['-O3'] + self.version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = [ + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '--expt-relaxed-constexpr', + '--expt-extended-lambda' + ] + extra_cuda_flags.extend(get_cuda_cc_flag()) + ret = ['-O3', '--use_fast_math'] + extra_cuda_flags + return append_nvcc_threads(ret) diff --git a/op_builder/utils.py b/op_builder/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1b1bd5f499707b685903771d14d466023f0d66b7 --- /dev/null +++ b/op_builder/utils.py @@ -0,0 +1,221 @@ +import os +import re +import subprocess +import warnings +from typing import List + + +def print_rank_0(message: str) -> None: + """ + Print on only one process to avoid spamming. + """ + try: + import torch.distributed as dist + if not dist.is_initialized(): + is_main_rank = True + else: + is_main_rank = dist.get_rank() == 0 + except ImportError: + is_main_rank = True + + if is_main_rank: + print(message) + + +def get_cuda_version_in_pytorch() -> List[int]: + """ + This function returns the CUDA version in the PyTorch build. + + Returns: + The CUDA version required by PyTorch, in the form of tuple (major, minor). + """ + import torch + + try: + torch_cuda_major = torch.version.cuda.split(".")[0] + torch_cuda_minor = torch.version.cuda.split(".")[1] + except: + raise ValueError( + "[extension] Cannot retrive the CUDA version in the PyTorch binary given by torch.version.cuda") + return torch_cuda_major, torch_cuda_minor + + +def get_cuda_bare_metal_version(cuda_dir) -> List[int]: + """ + Get the System CUDA version from nvcc. + + Args: + cuda_dir (str): the directory for CUDA Toolkit. + + Returns: + The CUDA version required by PyTorch, in the form of tuple (major, minor). + """ + nvcc_path = os.path.join(cuda_dir, 'bin/nvcc') + + if cuda_dir is None: + raise ValueError( + f"[extension] The argument cuda_dir is None, but expected to be a string. Please make sure your have exported the environment variable CUDA_HOME correctly." + ) + + # check for nvcc path + if not os.path.exists(nvcc_path): + raise FileNotFoundError( + f"[extension] The nvcc compiler is not found in {nvcc_path}, please make sure you have set the correct value for CUDA_HOME." + ) + + # parse the nvcc -v output to obtain the system cuda version + try: + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + release = output[release_idx].split(".") + bare_metal_major = release[0] + bare_metal_minor = release[1][0] + except: + raise ValueError( + f"[extension] Failed to parse the nvcc output to obtain the system CUDA bare metal version. The output for 'nvcc -v' is \n{raw_output}" + ) + + return bare_metal_major, bare_metal_minor + + +def check_system_pytorch_cuda_match(cuda_dir): + bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir) + torch_cuda_major, torch_cuda_minor = get_cuda_version_in_pytorch() + + if bare_metal_major != torch_cuda_major: + raise Exception( + f'[extension] Failed to build PyTorch extension because the detected CUDA version ({bare_metal_major}.{bare_metal_minor}) ' + f'mismatches the version that was used to compile PyTorch ({torch_cuda_major}.{torch_cuda_minor}).' + 'Please make sure you have set the CUDA_HOME correctly and installed the correct PyTorch in https://pytorch.org/get-started/locally/ .' + ) + + if bare_metal_minor != torch_cuda_minor: + warnings.warn( + f"[extension] The CUDA version on the system ({bare_metal_major}.{bare_metal_minor}) does not match with the version ({torch_cuda_major}.{torch_cuda_minor}) torch was compiled with. " + "The mismatch is found in the minor version. As the APIs are compatible, we will allow compilation to proceed. " + "If you encounter any issue when using the built kernel, please try to build it again with fully matched CUDA versions" + ) + return True + + +def get_pytorch_version() -> List[int]: + """ + This functions finds the PyTorch version. + + Returns: + A tuple of integers in the form of (major, minor, patch). + """ + import torch + torch_version = torch.__version__.split('+')[0] + TORCH_MAJOR = int(torch_version.split('.')[0]) + TORCH_MINOR = int(torch_version.split('.')[1]) + TORCH_PATCH = int(torch_version.split('.')[2]) + return TORCH_MAJOR, TORCH_MINOR, TORCH_PATCH + + +def check_pytorch_version(min_major_version, min_minor_version) -> bool: + """ + Compare the current PyTorch version with the minium required version. + + Args: + min_major_version (int): the minimum major version of PyTorch required + min_minor_version (int): the minimum minor version of PyTorch required + + Returns: + A boolean value. The value is True if the current pytorch version is acceptable and False otherwise. + """ + # get pytorch version + torch_major, torch_minor, _ = get_pytorch_version() + + # if the + if torch_major < min_major_version or (torch_major == min_major_version and torch_minor < min_minor_version): + raise RuntimeError( + f"[extension] Colossal-AI requires Pytorch {min_major_version}.{min_minor_version} or newer.\n" + "The latest stable release can be obtained from https://pytorch.org/get-started/locally/") + + +def check_cuda_availability(): + """ + Check if CUDA is available on the system. + + Returns: + A boolean value. True if CUDA is available and False otherwise. + """ + import torch + return torch.cuda.is_available() + + +def set_cuda_arch_list(cuda_dir): + """ + This function sets the PyTorch TORCH_CUDA_ARCH_LIST variable for ahead-of-time extension compilation. + Ahead-of-time compilation occurs when CUDA_EXT=1 is set when running 'pip install'. + """ + cuda_available = check_cuda_availability() + + # we only need to set this when CUDA is not available for cross-compilation + if not cuda_available: + warnings.warn('\n[extension] PyTorch did not find available GPUs on this system.\n' + 'If your intention is to cross-compile, this is not an error.\n' + 'By default, Colossal-AI will cross-compile for \n' + '1. Pascal (compute capabilities 6.0, 6.1, 6.2),\n' + '2. Volta (compute capability 7.0)\n' + '3. Turing (compute capability 7.5),\n' + '4. Ampere (compute capability 8.0, 8.6)if the CUDA version is >= 11.0\n' + '\nIf you wish to cross-compile for a single specific architecture,\n' + 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n') + + if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None: + bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir) + + arch_list = ['6.0', '6.1', '6.2', '7.0', '7.5'] + + if int(bare_metal_major) == 11: + if int(bare_metal_minor) == 0: + arch_list.append('8.0') + else: + arch_list.append('8.0') + arch_list.append('8.6') + + arch_list_str = ';'.join(arch_list) + os.environ["TORCH_CUDA_ARCH_LIST"] = arch_list_str + return False + return True + + +def get_cuda_cc_flag() -> List[str]: + """ + This function produces the cc flags for your GPU arch + + Returns: + The CUDA cc flags for compilation. + """ + + # only import torch when needed + # this is to avoid importing torch when building on a machine without torch pre-installed + # one case is to build wheel for pypi release + import torch + + cc_flag = [] + for arch in torch.cuda.get_arch_list(): + res = re.search(r'sm_(\d+)', arch) + if res: + arch_cap = res[1] + if int(arch_cap) >= 60: + cc_flag.extend(['-gencode', f'arch=compute_{arch_cap},code={arch}']) + return cc_flag + + +def append_nvcc_threads(nvcc_extra_args: List[str]) -> List[str]: + """ + This function appends the threads flag to your nvcc args. + + Returns: + The nvcc compilation flags including the threads flag. + """ + from torch.utils.cpp_extension import CUDA_HOME + + bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) + if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2: + return nvcc_extra_args + ["--threads", "4"] + return nvcc_extra_args diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000000000000000000000000000000000000..ac31ace4bfae025025b1098719aba873db615d1c --- /dev/null +++ b/pytest.ini @@ -0,0 +1,6 @@ +[pytest] +markers = + cpu: tests which can run on CPU + gpu: tests which requires a single GPU + dist: tests which are run in a multi-GPU or multi-machine environment + experiment: tests for experimental features \ No newline at end of file diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt new file mode 100644 index 0000000000000000000000000000000000000000..82b6173b351791654e71034c4b445b04babf480b --- /dev/null +++ b/requirements/requirements-test.txt @@ -0,0 +1,15 @@ +diffusers +fbgemm-gpu==0.2.0 +pytest +pytest-cov +torchvision +transformers +timm +titans +torchaudio +torchrec==0.2.0 +contexttimer +einops +triton==2.0.0.dev20221202 +git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn +requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611 diff --git a/requirements/requirements.txt b/requirements/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..b34dc2e223ae7c8144c8e0b95eadf6646aaca4e9 --- /dev/null +++ b/requirements/requirements.txt @@ -0,0 +1,12 @@ +numpy +tqdm +psutil +packaging +pre-commit +rich +click +fabric +contexttimer +ninja +torch>=1.11 +safetensors diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..5d8f831218d95a3d1cad775083f2652da8a19c15 --- /dev/null +++ b/setup.py @@ -0,0 +1,198 @@ +import os +import sys +from datetime import datetime +from typing import List + +from setuptools import find_packages, setup + +from op_builder.utils import ( + check_cuda_availability, + check_pytorch_version, + check_system_pytorch_cuda_match, + get_cuda_bare_metal_version, + get_pytorch_version, + set_cuda_arch_list, +) + +try: + import torch + from torch.utils.cpp_extension import CUDA_HOME, BuildExtension + TORCH_AVAILABLE = True +except ImportError: + TORCH_AVAILABLE = False + CUDA_HOME = None + +# Some constants for installation checks +MIN_PYTORCH_VERSION_MAJOR = 1 +MIN_PYTORCH_VERSION_MINOR = 10 +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) +BUILD_CUDA_EXT = int(os.environ.get('CUDA_EXT', '0')) == 1 +IS_NIGHTLY = int(os.environ.get('NIGHTLY', '0')) == 1 + +# a variable to store the op builder +ext_modules = [] + +# we do not support windows currently +if sys.platform == 'win32': + raise RuntimeError("Windows is not supported yet. Please try again within the Windows Subsystem for Linux (WSL).") + + +# check for CUDA extension dependencies +def environment_check_for_cuda_extension_build(): + if not TORCH_AVAILABLE: + raise ModuleNotFoundError( + "[extension] PyTorch is not found while CUDA_EXT=1. You need to install PyTorch first in order to build CUDA extensions" + ) + + if not CUDA_HOME: + raise RuntimeError( + "[extension] CUDA_HOME is not found while CUDA_EXT=1. You need to export CUDA_HOME environment variable or install CUDA Toolkit first in order to build CUDA extensions" + ) + + check_system_pytorch_cuda_match(CUDA_HOME) + check_pytorch_version(MIN_PYTORCH_VERSION_MAJOR, MIN_PYTORCH_VERSION_MINOR) + check_cuda_availability() + + +def fetch_requirements(path) -> List[str]: + """ + This function reads the requirements file. + + Args: + path (str): the path to the requirements file. + + Returns: + The lines in the requirements file. + """ + with open(path, 'r') as fd: + return [r.strip() for r in fd.readlines()] + + +def fetch_readme() -> str: + """ + This function reads the README.md file in the current directory. + + Returns: + The lines in the README file. + """ + with open('README.md', encoding='utf-8') as f: + return f.read() + + +def get_version() -> str: + """ + This function reads the version.txt and generates the colossalai/version.py file. + + Returns: + The library version stored in version.txt. + """ + + setup_file_path = os.path.abspath(__file__) + project_path = os.path.dirname(setup_file_path) + version_txt_path = os.path.join(project_path, 'version.txt') + version_py_path = os.path.join(project_path, 'colossalai/version.py') + + with open(version_txt_path) as f: + version = f.read().strip() + + # write version into version.py + with open(version_py_path, 'w') as f: + f.write(f"__version__ = '{version}'\n") + + # look for pytorch and cuda version + if BUILD_CUDA_EXT: + torch_major, torch_minor, _ = get_pytorch_version() + torch_version = f'{torch_major}.{torch_minor}' + cuda_version = '.'.join(get_cuda_bare_metal_version(CUDA_HOME)) + else: + torch_version = None + cuda_version = None + + # write the version into the python file + if torch_version: + f.write(f'torch = "{torch_version}"\n') + else: + f.write('torch = None\n') + + if cuda_version: + f.write(f'cuda = "{cuda_version}"\n') + else: + f.write('cuda = None\n') + + return version + + +if BUILD_CUDA_EXT: + environment_check_for_cuda_extension_build() + set_cuda_arch_list(CUDA_HOME) + + from op_builder import ALL_OPS + op_names = [] + + # load all builders + for name, builder_cls in ALL_OPS.items(): + op_names.append(name) + ext_modules.append(builder_cls().builder()) + + # show log + op_name_list = ', '.join(op_names) + print(f"[extension] loaded builders for {op_name_list}") + +# always put not nightly branch as the if branch +# otherwise github will treat colossalai-nightly as the project name +# and it will mess up with the dependency graph insights +if not IS_NIGHTLY: + version = get_version() + package_name = 'colossalai' +else: + # use date as the nightly version + version = datetime.today().strftime('%Y.%m.%d') + package_name = 'colossalai-nightly' + +setup(name=package_name, + version=version, + packages=find_packages(exclude=( + 'op_builder', + 'benchmark', + 'docker', + 'tests', + 'docs', + 'examples', + 'tests', + 'scripts', + 'requirements', + '*.egg-info', + )), + description='An integrated large-scale model training system with efficient parallelization techniques', + long_description=fetch_readme(), + long_description_content_type='text/markdown', + license='Apache Software License 2.0', + url='https://www.colossalai.org', + project_urls={ + 'Forum': 'https://github.com/hpcaitech/ColossalAI/discussions', + 'Bug Tracker': 'https://github.com/hpcaitech/ColossalAI/issues', + 'Examples': 'https://github.com/hpcaitech/ColossalAI-Examples', + 'Documentation': 'http://colossalai.readthedocs.io', + 'Github': 'https://github.com/hpcaitech/ColossalAI', + }, + ext_modules=ext_modules, + cmdclass={'build_ext': BuildExtension} if ext_modules else {}, + install_requires=fetch_requirements('requirements/requirements.txt'), + entry_points=''' + [console_scripts] + colossalai=colossalai.cli:cli + ''', + python_requires='>=3.6', + classifiers=[ + 'Programming Language :: Python :: 3', + 'License :: OSI Approved :: Apache Software License', + 'Environment :: GPU :: NVIDIA CUDA', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + 'Topic :: System :: Distributed Computing', + ], + package_data={ + 'colossalai': [ + '_C/*.pyi', 'kernel/cuda_native/csrc/*', 'kernel/cuda_native/csrc/kernel/*', + 'kernel/cuda_native/csrc/kernels/include/*' + ] + }) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/components_to_test/__init__.py b/tests/components_to_test/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f29efefce4a461bb71ea2f65d10a57a8b1439320 --- /dev/null +++ b/tests/components_to_test/__init__.py @@ -0,0 +1,19 @@ +from . import ( + beit, + bert, + gpt2, + hanging_param_model, + inline_op_model, + nested_model, + repeated_computed_layers, + resnet, + simple_net, +) +from .utils import run_fwd, run_fwd_bwd + +from . import albert # isort:skip + +__all__ = [ + 'bert', 'gpt2', 'hanging_param_model', 'inline_op_model', 'nested_model', 'repeated_computed_layers', 'resnet', + 'simple_net', 'run_fwd_bwd', 'albert', 'beit', 'run_fwd' +] diff --git a/tests/components_to_test/albert.py b/tests/components_to_test/albert.py new file mode 100644 index 0000000000000000000000000000000000000000..d5b6bc89a83e0ae49a4045800b6e9d57ff848604 --- /dev/null +++ b/tests/components_to_test/albert.py @@ -0,0 +1,59 @@ +import torch +import transformers +from packaging import version +from transformers import AlbertConfig, AlbertForSequenceClassification + +from .bert import get_bert_data_loader +from .registry import non_distributed_component_funcs + + +@non_distributed_component_funcs.register(name='albert') +def get_training_components(): + hidden_dim = 8 + num_head = 4 + sequence_length = 12 + num_layer = 2 + vocab_size = 32 + + def bert_model_builder(checkpoint: bool = False): + config = AlbertConfig(vocab_size=vocab_size, + gradient_checkpointing=checkpoint, + hidden_size=hidden_dim, + intermediate_size=hidden_dim * 4, + num_attention_heads=num_head, + max_position_embeddings=sequence_length, + num_hidden_layers=num_layer, + hidden_dropout_prob=0., + attention_probs_dropout_prob=0.) + print('building AlbertForSequenceClassification model') + + # adapting huggingface BertForSequenceClassification for single unitest calling interface + class ModelAaptor(AlbertForSequenceClassification): + + def forward(self, input_ids, labels): + """ + inputs: data, label + outputs: loss + """ + return super().forward(input_ids=input_ids, labels=labels)[0] + + model = ModelAaptor(config) + # if checkpoint and version.parse(transformers.__version__) >= version.parse("4.11.0"): + # model.gradient_checkpointing_enable() + + return model + + is_distrbuted = torch.distributed.is_initialized() + trainloader = get_bert_data_loader(n_class=vocab_size, + batch_size=2, + total_samples=10000, + sequence_length=sequence_length, + is_distrbuted=is_distrbuted) + testloader = get_bert_data_loader(n_class=vocab_size, + batch_size=2, + total_samples=10000, + sequence_length=sequence_length, + is_distrbuted=is_distrbuted) + + criterion = None + return bert_model_builder, trainloader, testloader, torch.optim.Adam, criterion diff --git a/tests/components_to_test/beit.py b/tests/components_to_test/beit.py new file mode 100644 index 0000000000000000000000000000000000000000..1252071f40759149606267b3d18b495d1a3f490d --- /dev/null +++ b/tests/components_to_test/beit.py @@ -0,0 +1,42 @@ +import torch +from timm.models.beit import Beit + +from colossalai.utils.cuda import get_current_device + +from .registry import non_distributed_component_funcs +from .utils.dummy_data_generator import DummyDataGenerator + + +class DummyDataLoader(DummyDataGenerator): + img_size = 64 + num_channel = 3 + num_class = 10 + batch_size = 4 + + def generate(self): + data = torch.randn((DummyDataLoader.batch_size, DummyDataLoader.num_channel, DummyDataLoader.img_size, + DummyDataLoader.img_size), + device=get_current_device()) + label = torch.randint(low=0, + high=DummyDataLoader.num_class, + size=(DummyDataLoader.batch_size,), + device=get_current_device()) + return data, label + + +@non_distributed_component_funcs.register(name='beit') +def get_training_components(): + + def model_buider(checkpoint=False): + model = Beit(img_size=DummyDataLoader.img_size, + num_classes=DummyDataLoader.num_class, + embed_dim=32, + depth=2, + num_heads=4) + return model + + trainloader = DummyDataLoader() + testloader = DummyDataLoader() + + criterion = torch.nn.CrossEntropyLoss() + return model_buider, trainloader, testloader, torch.optim.Adam, criterion diff --git a/tests/components_to_test/bert.py b/tests/components_to_test/bert.py new file mode 100644 index 0000000000000000000000000000000000000000..c1faa6f9d892650c4717d7baa359165822a8165b --- /dev/null +++ b/tests/components_to_test/bert.py @@ -0,0 +1,84 @@ +import torch +import transformers +from packaging import version +from torch.utils.data import SequentialSampler +from transformers import BertConfig, BertForSequenceClassification + +from .registry import non_distributed_component_funcs + + +def get_bert_data_loader( + n_class, + batch_size, + total_samples, + sequence_length, + device=torch.device('cpu:0'), + is_distrbuted=False, +): + train_data = torch.randint( + low=0, + high=n_class, + size=(total_samples, sequence_length), + device=device, + dtype=torch.long, + ) + train_label = torch.randint(low=0, high=2, size=(total_samples,), device=device, dtype=torch.long) + train_dataset = torch.utils.data.TensorDataset(train_data, train_label) + if is_distrbuted: + sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) + else: + sampler = SequentialSampler(train_dataset) + train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=sampler) + return train_loader + + +@non_distributed_component_funcs.register(name='bert') +def get_training_components(): + hidden_dim = 8 + num_head = 4 + sequence_length = 12 + num_layer = 2 + vocab_size = 32 + + def bert_model_builder(checkpoint: bool = False): + config = BertConfig(vocab_size=vocab_size, + gradient_checkpointing=checkpoint, + hidden_size=hidden_dim, + intermediate_size=hidden_dim * 4, + num_attention_heads=num_head, + max_position_embeddings=sequence_length, + num_hidden_layers=num_layer, + hidden_dropout_prob=0., + attention_probs_dropout_prob=0.) + print('building BertForSequenceClassification model') + + # adapting huggingface BertForSequenceClassification for single unitest calling interface + class ModelAaptor(BertForSequenceClassification): + + def forward(self, input_ids, labels): + """ + inputs: data, label + outputs: loss + """ + return super().forward(input_ids=input_ids, labels=labels)[0] + + model = ModelAaptor(config) + if checkpoint and version.parse(transformers.__version__) >= version.parse("4.11.0"): + model.gradient_checkpointing_enable() + + return model + + is_distrbuted = torch.distributed.is_initialized() + trainloader = get_bert_data_loader(n_class=vocab_size, + batch_size=2, + total_samples=10000, + sequence_length=sequence_length, + is_distrbuted=is_distrbuted) + testloader = get_bert_data_loader(n_class=vocab_size, + batch_size=2, + total_samples=10000, + sequence_length=sequence_length, + is_distrbuted=is_distrbuted) + + criterion = None + return bert_model_builder, trainloader, testloader, torch.optim.Adam, criterion diff --git a/tests/components_to_test/gpt2.py b/tests/components_to_test/gpt2.py new file mode 100644 index 0000000000000000000000000000000000000000..fe25b4923fa27493be7f006956ce642be6047503 --- /dev/null +++ b/tests/components_to_test/gpt2.py @@ -0,0 +1,90 @@ +import torch +import torch.nn as nn +from transformers import GPT2Config, GPT2LMHeadModel + +from colossalai.utils.cuda import get_current_device + +from .registry import non_distributed_component_funcs +from .utils.dummy_data_generator import DummyDataGenerator + + +class DummyDataLoader(DummyDataGenerator): + vocab_size = 128 + batch_size = 4 + seq_len = 64 + + def generate(self): + input_ids = torch.randint(0, + DummyDataLoader.vocab_size, (DummyDataLoader.batch_size, DummyDataLoader.seq_len), + device=get_current_device()) + return input_ids, input_ids + + +class GPTLMModel(nn.Module): + + def __init__(self, + hidden_size=768, + num_layers=12, + num_attention_heads=12, + max_seq_len=1024, + vocab_size=50304, + checkpoint=False): + super().__init__() + self.checkpoint = checkpoint + self.model = GPT2LMHeadModel( + GPT2Config(n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size, + resid_pdrop=0.0, + embd_pdrop=0.0, + attn_pdrop=0.0)) + if checkpoint: + self.model.gradient_checkpointing_enable() + + def forward(self, input_ids): + # Only return lm_logits + attention_mask = torch.ones_like(input_ids) + return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0] + + +def gpt2_micro(checkpoint=True): + return GPTLMModel(checkpoint=checkpoint, + hidden_size=32, + num_layers=2, + num_attention_heads=4, + max_seq_len=64, + vocab_size=128) + + +def gpt2_s(checkpoint=True): + return GPTLMModel(checkpoint=checkpoint) + + +def gpt2_m(checkpoint=True): + return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint) + + +class GPTLMLoss(nn.Module): + + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, logits, labels): + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + +@non_distributed_component_funcs.register(name='gpt2') +def get_training_components(): + + trainloader = DummyDataLoader() + testloader = DummyDataLoader() + + criterion = GPTLMLoss() + return gpt2_micro, trainloader, testloader, torch.optim.Adam, criterion diff --git a/tests/components_to_test/hanging_param_model.py b/tests/components_to_test/hanging_param_model.py new file mode 100644 index 0000000000000000000000000000000000000000..329a08ea28f0224226c6061c2894982f3cd1d397 --- /dev/null +++ b/tests/components_to_test/hanging_param_model.py @@ -0,0 +1,49 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from colossalai.nn import CheckpointModule + +from .registry import non_distributed_component_funcs +from .utils.dummy_data_generator import DummyDataGenerator + + +class HangingParamModule(CheckpointModule): + """ + Hanging Parameter: a parameter dose not belong to a leaf Module. + It has subordinate nn.modules and a nn.Parameter. + """ + + def __init__(self, checkpoint=False) -> None: + super().__init__(checkpoint=checkpoint) + self.proj1 = nn.Linear(4, 8) + self.weight = nn.Parameter(torch.randn(8, 8)) + self.proj2 = nn.Linear(8, 4) + + def forward(self, x): + x = self.proj1(x) + x = F.linear(x, self.weight) + x = self.proj2(x) + return x + + +class DummyDataLoader(DummyDataGenerator): + + def generate(self): + data = torch.rand(16, 4) + label = torch.randint(low=0, high=2, size=(16,)) + return data, label + + +@non_distributed_component_funcs.register(name='hanging_param_model') +def get_training_components(): + + def model_builder(checkpoint=False): + return HangingParamModule(checkpoint) + + trainloader = DummyDataLoader() + testloader = DummyDataLoader() + + criterion = torch.nn.CrossEntropyLoss() + from colossalai.nn.optimizer import HybridAdam + return model_builder, trainloader, testloader, HybridAdam, criterion diff --git a/tests/components_to_test/inline_op_model.py b/tests/components_to_test/inline_op_model.py new file mode 100644 index 0000000000000000000000000000000000000000..f061d48f92c6eb7a2095fed039f2d2e49c1e07d3 --- /dev/null +++ b/tests/components_to_test/inline_op_model.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from colossalai.nn import CheckpointModule + +from .registry import non_distributed_component_funcs +from .utils.dummy_data_generator import DummyDataGenerator + + +class InlineOpModule(CheckpointModule): + """ + a module with inline Ops + """ + + def __init__(self, checkpoint=False) -> None: + super().__init__(checkpoint=checkpoint) + self.proj1 = nn.Linear(4, 8) + self.proj2 = nn.Linear(8, 8) + + def forward(self, x): + + x = self.proj1(x) + # inline add_ + x.add_(10) + x = self.proj2(x) + # inline relu_ + x = torch.relu_(x) + x = self.proj2(x) + return x + + +class DummyDataLoader(DummyDataGenerator): + + def generate(self): + data = torch.rand(16, 4) + label = torch.randint(low=0, high=2, size=(16,)) + return data, label + + +@non_distributed_component_funcs.register(name='inline_op_model') +def get_training_components(): + + def model_builder(checkpoint=False): + return InlineOpModule(checkpoint) + + trainloader = DummyDataLoader() + testloader = DummyDataLoader() + + criterion = torch.nn.CrossEntropyLoss() + from colossalai.nn.optimizer import HybridAdam + return model_builder, trainloader, testloader, HybridAdam, criterion diff --git a/tests/components_to_test/nested_model.py b/tests/components_to_test/nested_model.py new file mode 100644 index 0000000000000000000000000000000000000000..339084639244ef4f434218ff77673b46b9deb5c1 --- /dev/null +++ b/tests/components_to_test/nested_model.py @@ -0,0 +1,55 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from colossalai.nn import CheckpointModule + +from .registry import non_distributed_component_funcs +from .utils import DummyDataGenerator + + +class SubNet(nn.Module): + + def __init__(self, out_features) -> None: + super().__init__() + self.bias = nn.Parameter(torch.zeros(out_features)) + + def forward(self, x, weight): + return F.linear(x, weight, self.bias) + + +class NestedNet(CheckpointModule): + + def __init__(self, checkpoint=False) -> None: + super().__init__(checkpoint) + self.fc1 = nn.Linear(5, 5) + self.sub_fc = SubNet(5) + self.fc2 = nn.Linear(5, 2) + + def forward(self, x): + x = self.fc1(x) + x = self.sub_fc(x, self.fc1.weight) + x = self.fc1(x) + x = self.fc2(x) + return x + + +class DummyDataLoader(DummyDataGenerator): + + def generate(self): + data = torch.rand(16, 5) + label = torch.randint(low=0, high=2, size=(16,)) + return data, label + + +@non_distributed_component_funcs.register(name='nested_model') +def get_training_components(): + + def model_builder(checkpoint=False): + return NestedNet(checkpoint) + + trainloader = DummyDataLoader() + testloader = DummyDataLoader() + + criterion = torch.nn.CrossEntropyLoss() + return model_builder, trainloader, testloader, torch.optim.Adam, criterion diff --git a/tests/components_to_test/registry.py b/tests/components_to_test/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..728ed9eba6ea176b70f152751f877781a5beb214 --- /dev/null +++ b/tests/components_to_test/registry.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python + + +class Registry: + + def __init__(self): + self._registry = dict() + + def register(self, name): + assert name not in self._registry + + def _regsiter(callable_): + self._registry[name] = callable_ + + return _regsiter + + def get_callable(self, name: str): + return self._registry[name] + + def __iter__(self): + self._idx = 0 + self._len = len(self._registry) + self._names = list(self._registry.keys()) + return self + + def __next__(self): + if self._idx < self._len: + key = self._names[self._idx] + callable_ = self._registry[key] + self._idx += 1 + return callable_ + else: + raise StopIteration + + +non_distributed_component_funcs = Registry() +model_paralle_component_funcs = Registry() + +__all__ = ['non_distributed_component_funcs', 'model_paralle_component_funcs'] diff --git a/tests/components_to_test/repeated_computed_layers.py b/tests/components_to_test/repeated_computed_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..b3f84bd0e203eb3fd2d23ae627048f422cfc8a51 --- /dev/null +++ b/tests/components_to_test/repeated_computed_layers.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python + +import torch +import torch.nn as nn + +from colossalai.nn import CheckpointModule + +from .registry import non_distributed_component_funcs +from .utils.dummy_data_generator import DummyDataGenerator + + +class NetWithRepeatedlyComputedLayers(CheckpointModule): + """ + This model is to test with layers which go through forward pass multiple times. + In this model, the fc1 and fc2 call forward twice + """ + + def __init__(self, checkpoint=False) -> None: + super().__init__(checkpoint=checkpoint) + self.fc1 = nn.Linear(5, 5) + self.fc2 = nn.Linear(5, 5) + self.fc3 = nn.Linear(5, 2) + self.layers = [self.fc1, self.fc2, self.fc1, self.fc2, self.fc3] + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + + +class DummyDataLoader(DummyDataGenerator): + + def generate(self): + data = torch.rand(16, 5) + label = torch.randint(low=0, high=2, size=(16,)) + return data, label + + +@non_distributed_component_funcs.register(name='repeated_computed_layers') +def get_training_components(): + + def model_builder(checkpoint=False): + return NetWithRepeatedlyComputedLayers(checkpoint) + + trainloader = DummyDataLoader() + testloader = DummyDataLoader() + + criterion = torch.nn.CrossEntropyLoss() + return model_builder, trainloader, testloader, torch.optim.Adam, criterion diff --git a/tests/components_to_test/resnet.py b/tests/components_to_test/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..193832ebc12da0c3f6d18a8bb21fa9a67a2ac02c --- /dev/null +++ b/tests/components_to_test/resnet.py @@ -0,0 +1,33 @@ +from torchvision.models import resnet18 +from .registry import non_distributed_component_funcs +from pathlib import Path +import os +import torch +from torchvision.transforms import transforms +from torchvision.datasets import CIFAR10 +from colossalai.utils import get_dataloader + + +def get_cifar10_dataloader(train): + # build dataloaders + dataset = CIFAR10(root=Path(os.environ['DATA']), + download=True, + train=train, + transform=transforms.Compose( + [transforms.ToTensor(), + transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])) + dataloader = get_dataloader(dataset=dataset, shuffle=True, batch_size=16, drop_last=True) + return dataloader + + +@non_distributed_component_funcs.register(name='resnet18') +def get_resnet_training_components(): + + def model_builder(checkpoint=False): + return resnet18(num_classes=10) + + trainloader = get_cifar10_dataloader(train=True) + testloader = get_cifar10_dataloader(train=False) + + criterion = torch.nn.CrossEntropyLoss() + return model_builder, trainloader, testloader, torch.optim.Adam, criterion diff --git a/tests/components_to_test/simple_net.py b/tests/components_to_test/simple_net.py new file mode 100644 index 0000000000000000000000000000000000000000..cd9d7ebc0b1a1a699e427905269e8e8a9b0936bf --- /dev/null +++ b/tests/components_to_test/simple_net.py @@ -0,0 +1,54 @@ +import torch +import torch.nn as nn + +from colossalai.nn import CheckpointModule +from colossalai.utils.cuda import get_current_device + +from .registry import non_distributed_component_funcs +from .utils.dummy_data_generator import DummyDataGenerator + + +class SimpleNet(CheckpointModule): + """ + In this no-leaf module, it has subordinate nn.modules and a nn.Parameter. + """ + + def __init__(self, checkpoint=False) -> None: + super().__init__(checkpoint=checkpoint) + self.embed = nn.Embedding(20, 4) + self.proj1 = nn.Linear(4, 8) + self.ln1 = nn.LayerNorm(8) + self.proj2 = nn.Linear(8, 4) + self.ln2 = nn.LayerNorm(4) + self.classifier = nn.Linear(4, 4) + + def forward(self, x): + x = self.embed(x) + x = self.proj1(x) + x = self.ln1(x) + x = self.proj2(x) + x = self.ln2(x) + x = self.classifier(x) + return x + + +class DummyDataLoader(DummyDataGenerator): + + def generate(self): + data = torch.randint(low=0, high=20, size=(16,), device=get_current_device()) + label = torch.randint(low=0, high=2, size=(16,), device=get_current_device()) + return data, label + + +@non_distributed_component_funcs.register(name='simple_net') +def get_training_components(): + + def model_builder(checkpoint=False): + return SimpleNet(checkpoint) + + trainloader = DummyDataLoader() + testloader = DummyDataLoader() + + criterion = torch.nn.CrossEntropyLoss() + from colossalai.nn.optimizer import HybridAdam + return model_builder, trainloader, testloader, HybridAdam, criterion diff --git a/tests/components_to_test/utils/__init__.py b/tests/components_to_test/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..150124b58800f4b8ae6694c6edc7464665fd61f6 --- /dev/null +++ b/tests/components_to_test/utils/__init__.py @@ -0,0 +1,2 @@ +from .dummy_data_generator import DummyDataGenerator +from .executor import run_fwd, run_fwd_bwd diff --git a/tests/components_to_test/utils/dummy_data_generator.py b/tests/components_to_test/utils/dummy_data_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..5ab33e86de230da778d0e2bc1b8d1e8e581c1f79 --- /dev/null +++ b/tests/components_to_test/utils/dummy_data_generator.py @@ -0,0 +1,25 @@ +from abc import ABC, abstractmethod + + +class DummyDataGenerator(ABC): + + def __init__(self, length=10): + self.length = length + + @abstractmethod + def generate(self): + pass + + def __iter__(self): + self.step = 0 + return self + + def __next__(self): + if self.step < self.length: + self.step += 1 + return self.generate() + else: + raise StopIteration + + def __len__(self): + return self.length diff --git a/tests/components_to_test/utils/executor.py b/tests/components_to_test/utils/executor.py new file mode 100644 index 0000000000000000000000000000000000000000..631401e022e661678d977d0a97842d62c11dc2ee --- /dev/null +++ b/tests/components_to_test/utils/executor.py @@ -0,0 +1,46 @@ +import torch + + +def run_fwd(model, data, label, criterion) -> torch.Tensor: + """run_fwd + run fwd for the model + + Args: + model (torch.nn.Module): a PyTorch model + data (torch.Tensor): input data + label (torch.Tensor): label + criterion (Optional[Callable]): a function of criterion + + Returns: + torch.Tensor: loss of fwd + """ + if criterion: + y = model(data) + y = y.float() + loss = criterion(y, label) + else: + loss = model(data, label) + + loss = loss.float() + return loss + + +def run_fwd_bwd(model, data, label, criterion, optimizer=None) -> torch.Tensor: + """run_fwd_bwd + run fwd and bwd for the model + + Args: + model (torch.nn.Module): a PyTorch model + data (torch.Tensor): input data + label (torch.Tensor): label + criterion (Optional[Callable]): a function of criterion + + Returns: + torch.Tensor: loss of fwd + """ + loss = run_fwd(model, data, label, criterion) + if optimizer: + optimizer.backward(loss) + else: + loss.backward() + return loss diff --git a/tests/kit/__init__.py b/tests/kit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/kit/model_zoo/__init__.py b/tests/kit/model_zoo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..466a2a55882955cf855ab8c9a07991f9a6e833e4 --- /dev/null +++ b/tests/kit/model_zoo/__init__.py @@ -0,0 +1,4 @@ +from . import diffusers, timm, torchaudio, torchrec, torchvision, transformers +from .registry import model_zoo + +__all__ = ['model_zoo'] diff --git a/tests/kit/model_zoo/diffusers/__init__.py b/tests/kit/model_zoo/diffusers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..288f626a4539aa4387a14ccd99fa88935654ee8d --- /dev/null +++ b/tests/kit/model_zoo/diffusers/__init__.py @@ -0,0 +1 @@ +from .diffusers import * diff --git a/tests/kit/model_zoo/diffusers/diffusers.py b/tests/kit/model_zoo/diffusers/diffusers.py new file mode 100644 index 0000000000000000000000000000000000000000..204c1d7773ca9e87bdbbb49d5cf0bb72b572315f --- /dev/null +++ b/tests/kit/model_zoo/diffusers/diffusers.py @@ -0,0 +1,74 @@ +from functools import partial + +import diffusers +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo + +BATCH_SIZE = 2 +SEQ_LENGTH = 5 +HEIGHT = 224 +WIDTH = 224 +IN_CHANNELS = 3 +LATENTS_SHAPE = (BATCH_SIZE, IN_CHANNELS, HEIGHT // 7, WIDTH // 7) +TIME_STEP = 3 + +data_vae_fn = lambda: dict(sample=torch.randn(2, 3, 32, 32)) +data_unet_fn = lambda: dict(sample=torch.randn(2, 3, 32, 32), timestep=3) + +identity_output = lambda x: x +clip_vision_model_output = lambda x: dict(pooler_output=x[1]) + + +def data_clip_model(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + position_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + pixel_values = torch.zeros((BATCH_SIZE, IN_CHANNELS, HEIGHT, WIDTH), dtype=torch.float32) + return dict(input_ids=input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + position_ids=position_ids) + + +def data_clip_text(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + return dict(input_ids=input_ids, attention_mask=attention_mask) + + +def data_clip_vision(): + pixel_values = torch.zeros((BATCH_SIZE, IN_CHANNELS, HEIGHT, WIDTH), dtype=torch.float32) + return dict(pixel_values=pixel_values) + + +model_zoo.register(name='diffusers_auto_encoder_kl', + model_fn=diffusers.AutoencoderKL, + data_gen_fn=data_vae_fn, + output_transform_fn=identity_output) + +model_zoo.register(name='diffusers_vq_model', + model_fn=diffusers.VQModel, + data_gen_fn=data_vae_fn, + output_transform_fn=identity_output) + +model_zoo.register(name='diffusers_clip_model', + model_fn=partial(transformers.CLIPModel, config=transformers.CLIPConfig()), + data_gen_fn=data_clip_model, + output_transform_fn=identity_output) + +model_zoo.register(name='diffusers_clip_text_model', + model_fn=partial(transformers.CLIPTextModel, config=transformers.CLIPTextConfig()), + data_gen_fn=data_clip_text, + output_transform_fn=identity_output) + +model_zoo.register(name='diffusers_clip_vision_model', + model_fn=partial(transformers.CLIPVisionModel, config=transformers.CLIPVisionConfig()), + data_gen_fn=data_clip_vision, + output_transform_fn=clip_vision_model_output) + +model_zoo.register(name='diffusers_unet2d_model', + model_fn=diffusers.UNet2DModel, + data_gen_fn=data_unet_fn, + output_transform_fn=identity_output) diff --git a/tests/kit/model_zoo/registry.py b/tests/kit/model_zoo/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..7470327a65b62e4d126e5724a43ad2748a02962d --- /dev/null +++ b/tests/kit/model_zoo/registry.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python +from dataclasses import dataclass +from typing import Callable + +__all__ = ['ModelZooRegistry', 'ModelAttributem', 'model_zoo'] + + +@dataclass +class ModelAttribute: + """ + Attributes of a model. + + Args: + has_control_flow (bool): Whether the model contains branching in its forward method. + has_stochastic_depth_prob (bool): Whether the model contains stochastic depth probability. Often seen in the torchvision models. + """ + has_control_flow: bool = False + has_stochastic_depth_prob: bool = False + + +class ModelZooRegistry(dict): + """ + A registry to map model names to model and data generation functions. + """ + + def register(self, + name: str, + model_fn: Callable, + data_gen_fn: Callable, + output_transform_fn: Callable, + model_attribute: ModelAttribute = None): + """ + Register a model and data generation function. + + Examples: + >>> # Register + >>> model_zoo = ModelZooRegistry() + >>> model_zoo.register('resnet18', resnet18, resnet18_data_gen) + >>> # Run the model + >>> data = resnresnet18_data_gen() # do not input any argument + >>> model = resnet18() # do not input any argument + >>> out = model(**data) + + Args: + name (str): Name of the model. + model_fn (callable): A function that returns a model. **It must not contain any arguments.** + output_transform_fn (callable): A function that transforms the output of the model into Dict. + data_gen_fn (callable): A function that returns a data sample in the form of Dict. **It must not contain any arguments.** + model_attribute (ModelAttribute): Attributes of the model. Defaults to None. + """ + self[name] = (model_fn, data_gen_fn, output_transform_fn, model_attribute) + + def get_sub_registry(self, keyword: str): + """ + Get a sub registry with models that contain the keyword. + + Args: + keyword (str): Keyword to filter models. + """ + new_dict = dict() + + for k, v in self.items(): + if keyword in k: + new_dict[k] = v + return new_dict + + +model_zoo = ModelZooRegistry() diff --git a/tests/kit/model_zoo/timm/__init__.py b/tests/kit/model_zoo/timm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c9c85319448dcce9102df9c069db8755b12cf7e1 --- /dev/null +++ b/tests/kit/model_zoo/timm/__init__.py @@ -0,0 +1 @@ +from .timm import * diff --git a/tests/kit/model_zoo/timm/timm.py b/tests/kit/model_zoo/timm/timm.py new file mode 100644 index 0000000000000000000000000000000000000000..b29ac12a6b534f554f6917a78dc38a0b4f533abc --- /dev/null +++ b/tests/kit/model_zoo/timm/timm.py @@ -0,0 +1,159 @@ +import timm.models as tm +import torch + +from ..registry import ModelAttribute, model_zoo + +## ============== +# Register models without control flow +## ============== +data_gen_fn = lambda: dict(x=torch.rand(2, 3, 224, 224)) +output_transform_fn = lambda x: dict(output=x) + +model_zoo.register(name='timm_resnet', + model_fn=tm.resnest.resnest50d, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_beit', + model_fn=tm.beit.beit_base_patch16_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_cait', + model_fn=tm.cait.cait_s24_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_convmixer', + model_fn=tm.convmixer.convmixer_768_32, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_efficientnetv2', + model_fn=tm.efficientnet.efficientnetv2_m, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_resmlp', + model_fn=tm.resmlp_12_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_vision_transformer', + model_fn=tm.vision_transformer.vit_base_patch16_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_deit', + model_fn=tm.deit_base_distilled_patch16_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_beitv2', + model_fn=tm.beitv2_base_patch16_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_coat', + model_fn=tm.coat.coat_lite_mini, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) + +model_zoo.register(name='timm_deit3', + model_fn=tm.deit3_base_patch16_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) + +model_zoo.register(name='timm_eca_nfnet', + model_fn=tm.eca_nfnet_l0, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_efficientformer', + model_fn=tm.efficientformer_l1, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_ese_vovnet19b_dw', + model_fn=tm.ese_vovnet19b_dw, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_gmixer_12_224', + model_fn=tm.gmixer_12_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_gmlp_b16_224', + model_fn=tm.gmlp_b16_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_hardcorenas_a', + model_fn=tm.hardcorenas_a, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_hrnet_w18_small', + model_fn=tm.hrnet_w18_small, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_inception_v3', + model_fn=tm.inception_v3, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_mixer_b16_224', + model_fn=tm.mixer_b16_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_nf_ecaresnet101', + model_fn=tm.nf_ecaresnet101, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_nf_regnet_b0', + model_fn=tm.nf_regnet_b0, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_regnetv_040', + model_fn=tm.regnetv_040, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_skresnet18', + model_fn=tm.skresnet18, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_tnt_b_patch16_224', + model_fn=tm.tnt_b_patch16_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_wide_resnet50_2', + model_fn=tm.wide_resnet50_2, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_convit', + model_fn=tm.convit_base, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_dm_nfnet', + model_fn=tm.dm_nfnet_f0, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) + +# ============== +# Register models with control flow +# ============== +model_zoo.register(name='timm_convnext', + model_fn=tm.convnext.convnext_base, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='timm_vgg', + model_fn=tm.vgg.vgg11, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='timm_dpn', + model_fn=tm.dpn.dpn68, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='timm_densenet', + model_fn=tm.densenet.densenet121, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='timm_rexnet', + model_fn=tm.rexnet.rexnet_100, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='timm_swin_transformer', + model_fn=tm.swin_transformer.swin_base_patch4_window7_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/kit/model_zoo/torchaudio/__init__.py b/tests/kit/model_zoo/torchaudio/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..082eb9ebb89c552d54351fb67abe0514ff01ac00 --- /dev/null +++ b/tests/kit/model_zoo/torchaudio/__init__.py @@ -0,0 +1 @@ +from .torchaudio import * diff --git a/tests/kit/model_zoo/torchaudio/torchaudio.py b/tests/kit/model_zoo/torchaudio/torchaudio.py new file mode 100644 index 0000000000000000000000000000000000000000..9a244ac312c0bfa0479844c887dfdee6bf4f942a --- /dev/null +++ b/tests/kit/model_zoo/torchaudio/torchaudio.py @@ -0,0 +1,130 @@ +from functools import partial + +import torch +import torchaudio.models as tm + +from ..registry import ModelAttribute, model_zoo + +INPUT_DIM = 80 +IN_FEATURES = 16 +N_TIME = 20 +KERNEL_SIZE = 5 +HOP_LENGTH = 20 +N_CLASSES = 10 +N_FREQ = 16 +N_MELS = 80 + + +def conformer_data_gen_fn(): + lengths = torch.randint(1, 400, (4,)) + input = torch.rand(4, int(lengths.max()), INPUT_DIM) + return dict(input=input, lengths=lengths) + + +transformer_output_transform_fn = lambda outputs: dict(frames=outputs[0], lengths=outputs[1]) + +model_zoo.register(name='torchaudio_conformer', + model_fn=lambda: tm.Conformer( + input_dim=INPUT_DIM, num_heads=4, ffn_dim=128, num_layers=4, depthwise_conv_kernel_size=31), + data_gen_fn=conformer_data_gen_fn, + output_transform_fn=transformer_output_transform_fn) + +single_output_transform_fn = lambda output: dict(output=output) + +model_zoo.register(name='torchaudio_convtasnet', + model_fn=tm.ConvTasNet, + data_gen_fn=lambda: dict(input=torch.rand(4, 1, 8)), + output_transform_fn=single_output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) + +model_zoo.register(name='torchaudio_deepspeech', + model_fn=lambda: tm.DeepSpeech(IN_FEATURES, n_hidden=128, n_class=4), + data_gen_fn=lambda: dict(x=torch.rand(4, 1, 10, IN_FEATURES)), + output_transform_fn=single_output_transform_fn) + + +def emformer_data_gen_fn(): + input = torch.rand(4, 400, IN_FEATURES) + lengths = torch.randint(1, 200, (4,)) + return dict(input=input, lengths=lengths) + + +model_zoo.register( + name='torchaudio_emformer', + model_fn=lambda: tm.Emformer(input_dim=IN_FEATURES, num_heads=4, ffn_dim=128, num_layers=4, segment_length=4), + data_gen_fn=emformer_data_gen_fn, + output_transform_fn=transformer_output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) + +model_zoo.register(name='torchaudio_wav2letter_waveform', + model_fn=lambda: tm.Wav2Letter(input_type='waveform', num_features=40), + data_gen_fn=lambda: dict(x=torch.rand(4, 40, 400)), + output_transform_fn=single_output_transform_fn) + +model_zoo.register(name='torchaudio_wav2letter_mfcc', + model_fn=lambda: tm.Wav2Letter(input_type='mfcc', num_features=40), + data_gen_fn=lambda: dict(x=torch.rand(4, 40, 400)), + output_transform_fn=single_output_transform_fn) + + +def wavernn_data_gen_fn(): + waveform = torch.rand(4, 1, (N_TIME - KERNEL_SIZE + 1) * HOP_LENGTH) + specgram = torch.rand(4, 1, N_FREQ, N_TIME) + return dict(waveform=waveform, specgram=specgram) + + +model_zoo.register(name='torchaudio_wavernn', + model_fn=lambda: tm.WaveRNN(upsample_scales=[2, 2, 5], + n_classes=N_CLASSES, + hop_length=HOP_LENGTH, + kernel_size=KERNEL_SIZE, + n_freq=N_FREQ, + n_res_block=2, + n_rnn=64, + n_fc=64, + n_hidden=16, + n_output=16), + data_gen_fn=wavernn_data_gen_fn, + output_transform_fn=single_output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) + + +def tacotron_data_gen_fn(): + n_batch = 4 + max_text_length = 100 + max_mel_specgram_length = 300 + tokens = torch.randint(0, 148, (n_batch, max_text_length)) + token_lengths = max_text_length * torch.ones((n_batch,)) + mel_specgram = torch.rand(n_batch, N_MELS, max_mel_specgram_length) + mel_specgram_lengths = max_mel_specgram_length * torch.ones((n_batch,)) + return dict(tokens=tokens, + token_lengths=token_lengths, + mel_specgram=mel_specgram, + mel_specgram_lengths=mel_specgram_lengths) + + +model_zoo.register(name='torchaudio_tacotron', + model_fn=lambda: tm.Tacotron2(n_mels=N_MELS), + data_gen_fn=tacotron_data_gen_fn, + output_transform_fn=lambda outputs: dict(summed_output=sum(x.sum() for x in outputs)), + model_attribute=ModelAttribute(has_control_flow=True)) + + +def wav2vec_data_gen_fn(): + batch_size, num_frames = 4, 400 + waveforms = torch.randn(batch_size, num_frames) + lengths = torch.randint(0, num_frames, (batch_size,)) + return dict(waveforms=waveforms, lengths=lengths) + + +model_zoo.register(name='torchaudio_wav2vec2_base', + model_fn=partial(tm.wav2vec2_base, encoder_layer_drop=0.0), + data_gen_fn=wav2vec_data_gen_fn, + output_transform_fn=transformer_output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) + +model_zoo.register(name='torchaudio_hubert_base', + model_fn=tm.hubert_base, + data_gen_fn=wav2vec_data_gen_fn, + output_transform_fn=transformer_output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/kit/model_zoo/torchrec/__init__.py b/tests/kit/model_zoo/torchrec/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..43952e6998cf2ac4c19de8077e22a713054a15e3 --- /dev/null +++ b/tests/kit/model_zoo/torchrec/__init__.py @@ -0,0 +1 @@ +from .torchrec import * diff --git a/tests/kit/model_zoo/torchrec/torchrec.py b/tests/kit/model_zoo/torchrec/torchrec.py new file mode 100644 index 0000000000000000000000000000000000000000..dda563155fcac3c996867b244fd5215a2cd65e22 --- /dev/null +++ b/tests/kit/model_zoo/torchrec/torchrec.py @@ -0,0 +1,142 @@ +from collections import namedtuple +from functools import partial + +import torch +from torchrec.models import deepfm, dlrm +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor + +from ..registry import ModelAttribute, model_zoo + +BATCH = 2 +SHAPE = 10 + + +def gen_kt(): + KT = KeyedTensor(keys=["f1", "f2"], length_per_key=[SHAPE, SHAPE], values=torch.rand((BATCH, 2 * SHAPE))) + return KT + + +# KeyedJaggedTensor +def gen_kjt(): + KJT = KeyedJaggedTensor.from_offsets_sync(keys=["f1", "f2"], + values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]), + offsets=torch.tensor([0, 2, 4, 6, 8])) + return KJT + + +data_gen_fn = lambda: dict(features=torch.rand((BATCH, SHAPE))) + + +def interaction_arch_data_gen_fn(): + KT = gen_kt() + return dict(dense_features=torch.rand((BATCH, SHAPE)), sparse_features=KT) + + +def simple_dfm_data_gen_fn(): + KJT = gen_kjt() + return dict(dense_features=torch.rand((BATCH, SHAPE)), sparse_features=KJT) + + +def sparse_arch_data_gen_fn(): + KJT = gen_kjt() + return dict(features=KJT) + + +def output_transform_fn(x): + if isinstance(x, KeyedTensor): + output = dict() + for key in x.keys(): + output[key] = x[key] + return output + else: + return dict(output=x) + + +def output_transform_fn(x): + if isinstance(x, KeyedTensor): + output = dict() + for key in x.keys(): + output[key] = x[key] + return output + else: + return dict(output=x) + + +def get_ebc(): + # EmbeddingBagCollection + eb1_config = EmbeddingBagConfig(name="t1", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f1"]) + eb2_config = EmbeddingBagConfig(name="t2", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f2"]) + return EmbeddingBagCollection(tables=[eb1_config, eb2_config], device=torch.device('cpu')) + + +def sparse_arch_model_fn(): + ebc = get_ebc() + return deepfm.SparseArch(ebc) + + +def simple_deep_fmnn_model_fn(): + ebc = get_ebc() + return deepfm.SimpleDeepFMNN(SHAPE, ebc, SHAPE, SHAPE) + + +def dlrm_model_fn(): + ebc = get_ebc() + return dlrm.DLRM(ebc, SHAPE, [SHAPE, SHAPE], [5, 1]) + + +def dlrm_sparsearch_model_fn(): + ebc = get_ebc() + return dlrm.SparseArch(ebc) + + +model_zoo.register(name='deepfm_densearch', + model_fn=partial(deepfm.DenseArch, SHAPE, SHAPE, SHAPE), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) + +model_zoo.register(name='deepfm_interactionarch', + model_fn=partial(deepfm.FMInteractionArch, SHAPE * 3, ["f1", "f2"], SHAPE), + data_gen_fn=interaction_arch_data_gen_fn, + output_transform_fn=output_transform_fn) + +model_zoo.register(name='deepfm_overarch', + model_fn=partial(deepfm.OverArch, SHAPE), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) + +model_zoo.register(name='deepfm_simpledeepfmnn', + model_fn=simple_deep_fmnn_model_fn, + data_gen_fn=simple_dfm_data_gen_fn, + output_transform_fn=output_transform_fn) + +model_zoo.register(name='deepfm_sparsearch', + model_fn=sparse_arch_model_fn, + data_gen_fn=sparse_arch_data_gen_fn, + output_transform_fn=output_transform_fn) + +model_zoo.register(name='dlrm', + model_fn=dlrm_model_fn, + data_gen_fn=simple_dfm_data_gen_fn, + output_transform_fn=output_transform_fn) + +model_zoo.register(name='dlrm_densearch', + model_fn=partial(dlrm.DenseArch, SHAPE, [SHAPE, SHAPE]), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) + +model_zoo.register(name='dlrm_interactionarch', + model_fn=partial(dlrm.InteractionArch, 2), + data_gen_fn=interaction_arch_data_gen_fn, + output_transform_fn=output_transform_fn) + +model_zoo.register(name='dlrm_overarch', + model_fn=partial(dlrm.OverArch, SHAPE, [5, 1]), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) + +model_zoo.register(name='dlrm_sparsearch', + model_fn=dlrm_sparsearch_model_fn, + data_gen_fn=sparse_arch_data_gen_fn, + output_transform_fn=output_transform_fn) diff --git a/tests/kit/model_zoo/torchvision/__init__.py b/tests/kit/model_zoo/torchvision/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..55d58f97b5d4febfa7c46d15281a9f088bfcf210 --- /dev/null +++ b/tests/kit/model_zoo/torchvision/__init__.py @@ -0,0 +1 @@ +from .torchvision import * diff --git a/tests/kit/model_zoo/torchvision/torchvision.py b/tests/kit/model_zoo/torchvision/torchvision.py new file mode 100644 index 0000000000000000000000000000000000000000..ddc3ec24b2ff5d486df1054ca58fcc9c6095097a --- /dev/null +++ b/tests/kit/model_zoo/torchvision/torchvision.py @@ -0,0 +1,131 @@ +from collections import namedtuple + +import torch +import torchvision +import torchvision.models as tm +from packaging import version + +from ..registry import ModelAttribute, model_zoo + +data_gen_fn = lambda: dict(x=torch.rand(4, 3, 224, 224)) +output_transform_fn = lambda x: dict(output=x) + +# special data gen fn +inception_v3_data_gen_fn = lambda: dict(x=torch.rand(4, 3, 299, 299)) + + +# special model fn +def swin_s(): + from torchvision.models.swin_transformer import Swin_T_Weights, _swin_transformer + + # adapted from torchvision.models.swin_transformer.swin_small + weights = None + weights = Swin_T_Weights.verify(weights) + progress = True + + return _swin_transformer( + patch_size=[4, 4], + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=[7, 7], + stochastic_depth_prob=0, # it is originally 0.2, but we set it to 0 to make it deterministic + weights=weights, + progress=progress, + ) + + +# special output transform fn +google_net_output_transform_fn = lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.GoogLeNetOutputs + ) else dict(output=x) +swin_s_output_output_transform_fn = lambda x: {f'output{idx}': val + for idx, val in enumerate(x)} if isinstance(x, tuple) else dict(output=x) +inception_v3_output_transform_fn = lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.InceptionOutputs + ) else dict(output=x) + +model_zoo.register(name='torchvision_alexnet', + model_fn=tm.alexnet, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='torchvision_densenet121', + model_fn=tm.densenet121, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='torchvision_efficientnet_b0', + model_fn=tm.efficientnet_b0, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_stochastic_depth_prob=True)) +model_zoo.register(name='torchvision_googlenet', + model_fn=tm.googlenet, + data_gen_fn=data_gen_fn, + output_transform_fn=google_net_output_transform_fn) +model_zoo.register(name='torchvision_inception_v3', + model_fn=tm.inception_v3, + data_gen_fn=inception_v3_data_gen_fn, + output_transform_fn=inception_v3_output_transform_fn) +model_zoo.register(name='torchvision_mobilenet_v2', + model_fn=tm.mobilenet_v2, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='torchvision_mobilenet_v3_small', + model_fn=tm.mobilenet_v3_small, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='torchvision_mnasnet0_5', + model_fn=tm.mnasnet0_5, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='torchvision_resnet18', + model_fn=tm.resnet18, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='torchvision_regnet_x_16gf', + model_fn=tm.regnet_x_16gf, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='torchvision_resnext50_32x4d', + model_fn=tm.resnext50_32x4d, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='torchvision_shufflenet_v2_x0_5', + model_fn=tm.shufflenet_v2_x0_5, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='torchvision_squeezenet1_0', + model_fn=tm.squeezenet1_0, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) + +model_zoo.register(name='torchvision_vgg11', + model_fn=tm.vgg11, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='torchvision_wide_resnet50_2', + model_fn=tm.wide_resnet50_2, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) + +if version.parse(torchvision.__version__) >= version.parse('0.12.0'): + model_zoo.register(name='torchvision_vit_b_16', + model_fn=tm.vit_b_16, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) + model_zoo.register(name='torchvision_convnext_base', + model_fn=tm.convnext_base, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_stochastic_depth_prob=True)) + +if version.parse(torchvision.__version__) >= version.parse('0.13.0'): + model_zoo.register( + name='torchvision_swin_s', + model_fn=swin_s, + data_gen_fn=data_gen_fn, + output_transform_fn=swin_s_output_output_transform_fn, + ) + model_zoo.register(name='torchvision_efficientnet_v2_s', + model_fn=tm.efficientnet_v2_s, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_stochastic_depth_prob=True)) diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f56ff7ad84eb1b1360afc0fd47d817756c6df35b --- /dev/null +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -0,0 +1,5 @@ +from .albert import * +from .bert import * +from .gpt import * +from .opt import * +from .t5 import * diff --git a/tests/kit/model_zoo/transformers/albert.py b/tests/kit/model_zoo/transformers/albert.py new file mode 100644 index 0000000000000000000000000000000000000000..e85f564e376a53669374dc21eeceb88a69ee691c --- /dev/null +++ b/tests/kit/model_zoo/transformers/albert.py @@ -0,0 +1,85 @@ +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo + +# =============================== +# Register single-sentence ALBERT +# =============================== +BATCH_SIZE = 2 +SEQ_LENGTH = 16 + + +def data_gen_fn(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + + +output_transform_fn = lambda x: x + +config = transformers.AlbertConfig(embedding_size=128, + hidden_size=128, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=256) + +model_zoo.register(name='transformers_albert', + model_fn=lambda: transformers.AlbertModel(config), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_albert_for_pretraining', + model_fn=lambda: transformers.AlbertForPreTraining(config), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_albert_for_masked_lm', + model_fn=lambda: transformers.AlbertForMaskedLM(config), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_albert_for_sequence_classification', + model_fn=lambda: transformers.AlbertForSequenceClassification(config), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_albert_for_token_classification', + model_fn=lambda: transformers.AlbertForTokenClassification(config), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) + +# =============================== +# Register multi-sentence ALBERT +# =============================== + + +def data_gen_for_qa(): + question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" + tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased") + inputs = tokenizer(question, text, return_tensors="pt") + return inputs + + +def data_gen_for_mcq(): + prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + choice0 = "It is eaten with a fork and a knife." + choice1 = "It is eaten while held in the hand." + tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased") + encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True) + encoding = {k: v.unsqueeze(0) for k, v in encoding.items()} + return encoding + + +model_zoo.register(name='transformers_albert_for_question_answering', + model_fn=lambda: transformers.AlbertForQuestionAnswering(config), + data_gen_fn=data_gen_for_qa, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_albert_for_multiple_choice', + model_fn=lambda: transformers.AlbertForMultipleChoice(config), + data_gen_fn=data_gen_for_mcq, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/kit/model_zoo/transformers/bert.py b/tests/kit/model_zoo/transformers/bert.py new file mode 100644 index 0000000000000000000000000000000000000000..99135704da70f96237bdcb6553ebd847ea8fcba1 --- /dev/null +++ b/tests/kit/model_zoo/transformers/bert.py @@ -0,0 +1,88 @@ +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo + +# =============================== +# Register single-sentence BERT +# =============================== +BATCH_SIZE = 2 +SEQ_LENGTH = 16 + + +def data_gen_fn(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + + +output_transform_fn = lambda x: x + +config = transformers.BertConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256) + +# register the BERT variants +model_zoo.register(name='transformers_bert', + model_fn=lambda: transformers.BertModel(config), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_bert_for_pretraining', + model_fn=lambda: transformers.BertForPreTraining(config), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_bert_lm_head_model', + model_fn=lambda: transformers.BertLMHeadModel(config), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_bert_for_masked_lm', + model_fn=lambda: transformers.BertForMaskedLM(config), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_bert_for_sequence_classification', + model_fn=lambda: transformers.BertForSequenceClassification(config), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_bert_for_token_classification', + model_fn=lambda: transformers.BertForTokenClassification(config), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) + + +# =============================== +# Register multi-sentence BERT +# =============================== +def data_gen_for_next_sentence(): + tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased") + prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + next_sentence = "The sky is blue due to the shorter wavelength of blue light." + encoding = tokenizer(prompt, next_sentence, return_tensors="pt") + return encoding + + +def data_gen_for_mcq(): + tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased") + prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + choice0 = "It is eaten with a fork and a knife." + choice1 = "It is eaten while held in the hand." + encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True) + encoding = {k: v.unsqueeze(0) for k, v in encoding.items()} + return encoding + + +# register the following models +model_zoo.register(name='transformers_bert_for_next_sentence', + model_fn=lambda: transformers.BertForNextSentencePrediction(config), + data_gen_fn=data_gen_for_next_sentence, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_bert_for_mcq', + model_fn=lambda: transformers.BertForMultipleChoice(config), + data_gen_fn=data_gen_for_mcq, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py new file mode 100644 index 0000000000000000000000000000000000000000..5ed4fbe70dc9df559ad8285ddae49616c43c2f9e --- /dev/null +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -0,0 +1,57 @@ +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo + +# =============================== +# Register single-sentence GPT +# =============================== +BATCH_SIZE = 1 # it can only be 1 as GPT cannot handle batch sizes > 1 if no padding token is defined. +SEQ_LENGTH = 16 + + +def data_gen(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + + +def seq_classification_data_gen(): + # batch sizes should be 1 if no padding token is defined. + input_ids = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64) + token_type_ids = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64) + attention_mask = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64) + return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + + +output_transform_fn = lambda x: x + +config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=4) + +# register the following models +model_zoo.register(name='transformers_gpt', + model_fn=lambda: transformers.GPT2Model(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_gpt_lm', + model_fn=lambda: transformers.GPT2LMHeadModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_gpt_double_heads', + model_fn=lambda: transformers.GPT2DoubleHeadsModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_gpt_for_token_classification', + model_fn=lambda: transformers.GPT2ForTokenClassification(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_gpt_for_sequence_classification', + model_fn=lambda: transformers.GPT2ForSequenceClassification(config), + data_gen_fn=seq_classification_data_gen, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/kit/model_zoo/transformers/opt.py b/tests/kit/model_zoo/transformers/opt.py new file mode 100644 index 0000000000000000000000000000000000000000..d9c4a0b3c23c52ac54fcdff786cae7eacdb2137c --- /dev/null +++ b/tests/kit/model_zoo/transformers/opt.py @@ -0,0 +1,35 @@ +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo + +# =============================== +# Register single-sentence OPT +# =============================== +BATCH_SIZE = 2 +SEQ_LENGTH = 16 + + +def data_gen(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + return dict(input_ids=input_ids, attention_mask=attention_mask) + + +output_transform_fn = lambda x: x + +config = transformers.OPTConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4) + +# register the following models +# transformers.OPTModel, +# transformers.OPTForCausalLM, +model_zoo.register(name='transformers_opt', + model_fn=lambda: transformers.OPTModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_opt_for_causal_lm', + model_fn=lambda: transformers.OPTForCausalLM(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/kit/model_zoo/transformers/t5.py b/tests/kit/model_zoo/transformers/t5.py new file mode 100644 index 0000000000000000000000000000000000000000..b81bcad90db87f1b3edebe48316d9d87bbd3b795 --- /dev/null +++ b/tests/kit/model_zoo/transformers/t5.py @@ -0,0 +1,46 @@ +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo + +# =============================== +# Register single-sentence T5 +# =============================== +BATCH_SIZE = 2 +SEQ_LENGTH = 16 + + +def data_gen(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + decoder_input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + return dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + + +def data_gen_for_encoder_only(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + return dict(input_ids=input_ids) + + +output_transform_fn = lambda x: x + +config = transformers.T5Config(d_model=128, num_layers=2) + +# register the following models +# transformers.T5Model, +# transformers.T5ForConditionalGeneration, +# transformers.T5EncoderModel, +model_zoo.register(name='transformers_t5', + model_fn=lambda: transformers.T5Model(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_t5_for_conditional_generation', + model_fn=lambda: transformers.T5ForConditionalGeneration(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_t5_encoder_model', + model_fn=lambda: transformers.T5EncoderModel(config), + data_gen_fn=data_gen_for_encoder_only, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/test_amp/test_naive_fp16.py b/tests/test_amp/test_naive_fp16.py new file mode 100644 index 0000000000000000000000000000000000000000..6ce4c7f497254cc45a62c640fd2463159b97018f --- /dev/null +++ b/tests/test_amp/test_naive_fp16.py @@ -0,0 +1,93 @@ +import copy + +import pytest +import torch + +import colossalai +from colossalai.amp import convert_to_apex_amp, convert_to_naive_amp +from colossalai.testing import assert_close_loose, clear_cache_before_run, rerun_if_address_is_in_use, spawn +from tests.components_to_test.registry import non_distributed_component_funcs + + +def check_equal(a, b): + """ + This function checks if two tensors are equal within tolerance + """ + assert torch.allclose(a.float(), b.float(), rtol=1e-4, atol=1e-3), f'a = {a}, b = {b}' + + +def run_naive_amp(): + """ + In this test, we compare the naive fp16 optimizer implemented in colossalai + and fp32 torch optimizer + """ + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + + # create layer + test_models = ['repeated_computed_layers', 'nested_model', 'resnet18'] + for test_name in test_models: + get_component_func = non_distributed_component_funcs.get_callable(test_name) + model_builder, train_dataloader, _, optim_class, _ = get_component_func() + + # create model + naive_amp_model = model_builder(checkpoint=True).cuda() + apex_amp_model = copy.deepcopy(naive_amp_model) + + # create optimizer + # we use SGD here, since the correctness of gradient clipping can't be tested with Adam + naive_amp_optimizer = torch.optim.SGD(naive_amp_model.parameters(), lr=1e-3) + apex_amp_optimizer = torch.optim.SGD(apex_amp_model.parameters(), lr=1e-3) + + # inject naive and apex amp + naive_amp_config = dict(initial_scale=128, clip_grad_norm=1.0) + naive_amp_model, naive_amp_optimizer = convert_to_naive_amp(naive_amp_model, naive_amp_optimizer, + naive_amp_config) + apex_amp_config = dict(opt_level='O2', loss_scale=128, keep_batchnorm_fp32=False) + apex_amp_model, apex_amp_optimizer = convert_to_apex_amp(apex_amp_model, apex_amp_optimizer, apex_amp_config) + + # create data + data_iter = iter(train_dataloader) + data, label = next(data_iter) + data = data.cuda() + + # forward pass + naive_amp_output = naive_amp_model(data) + apex_amp_output = apex_amp_model(data) + assert_close_loose(naive_amp_output, apex_amp_output) + + # backward + # use sum() to get big gradient + naive_amp_optimizer.backward(naive_amp_output.sum()) + apex_amp_optimizer.backward(apex_amp_output.sum()) + + # check grad + for naive_amp_param, apex_amp_param in zip(naive_amp_model.parameters(), apex_amp_model.parameters()): + assert_close_loose(naive_amp_param.grad, apex_amp_param.grad) + + # clip gradient + apex_amp_optimizer.clip_grad_norm(model=apex_amp_model, max_norm=1.0) + + # step + naive_amp_optimizer.step() + apex_amp_optimizer.step() + + # check updated param + for naive_amp_param, apex_amp_param in zip(naive_amp_model.parameters(), apex_amp_model.parameters()): + assert_close_loose(naive_amp_param, apex_amp_param) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + run_naive_amp() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_naive_amp(): + spawn(run_dist, 1) + + +if __name__ == '__main__': + test_naive_amp() diff --git a/tests/test_amp/test_torch_fp16.py b/tests/test_amp/test_torch_fp16.py new file mode 100644 index 0000000000000000000000000000000000000000..6451aa6264a37c363e5458d1a0b859d67bddc122 --- /dev/null +++ b/tests/test_amp/test_torch_fp16.py @@ -0,0 +1,93 @@ +import copy + +import pytest +import torch + +import colossalai +from colossalai.amp import convert_to_apex_amp, convert_to_torch_amp +from colossalai.testing import assert_close_loose, clear_cache_before_run, rerun_if_address_is_in_use, spawn +from tests.components_to_test.registry import non_distributed_component_funcs + + +def run_torch_amp(): + """ + In this test, we compare the torch amp and apex amp implemented in colossalai + """ + + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + + # create layer + test_models = ['resnet18', 'simple_net'] + for test_name in test_models: + get_component_func = non_distributed_component_funcs.get_callable(test_name) + model_builder, train_dataloader, _, optim_class, _ = get_component_func() + + # create model + torch_amp_model = model_builder(checkpoint=True).cuda() + apex_amp_model = copy.deepcopy(torch_amp_model) + + # create optimizer + # we use SGD here, since the correctness of gradient clipping can't be tested with Adam + torch_amp_optimizer = torch.optim.SGD(torch_amp_model.parameters(), lr=1e-3) + apex_amp_optimizer = torch.optim.SGD(apex_amp_model.parameters(), lr=1e-3) + + # inject torch and apex amp + torch_amp_config = dict(init_scale=128, enabled=True) + torch_amp_model, torch_amp_optimizer, _ = convert_to_torch_amp(torch_amp_model, + torch_amp_optimizer, + amp_config=torch_amp_config) + apex_amp_config = dict(opt_level='O1', loss_scale=128) + apex_amp_model, apex_amp_optimizer = convert_to_apex_amp(apex_amp_model, apex_amp_optimizer, apex_amp_config) + + # create data + data_iter = iter(train_dataloader) + data, label = next(data_iter) + data = data.cuda() + + # forward pass + torch_amp_output = torch_amp_model(data) + apex_amp_output = apex_amp_model(data) + assert_close_loose(torch_amp_output, apex_amp_output) + + for torch_amp_param, apex_amp_param in zip(torch_amp_model.parameters(), apex_amp_model.parameters()): + assert_close_loose(torch_amp_param, apex_amp_param) + + # backward + # use sum() to get big gradient + torch_amp_optimizer.backward(torch_amp_output.sum()) + apex_amp_optimizer.backward(apex_amp_output.sum()) + + # check grad + # In apex amp, grad is not scaled before backward, but torch amp does + for torch_amp_param, apex_amp_param in zip(torch_amp_model.parameters(), apex_amp_model.parameters()): + assert_close_loose(torch_amp_param.grad, apex_amp_param.grad * apex_amp_config['loss_scale']) + + # clip gradient + apex_amp_optimizer.clip_grad_norm(model=apex_amp_model, max_norm=1.0) + torch_amp_optimizer.clip_grad_norm(model=torch_amp_model, max_norm=1.0) + + # step + torch_amp_optimizer.step() + apex_amp_optimizer.step() + + # check updated param and grad + for torch_amp_param, apex_amp_param in zip(torch_amp_model.parameters(), apex_amp_model.parameters()): + assert_close_loose(torch_amp_param.grad, apex_amp_param.grad) + assert_close_loose(torch_amp_param, apex_amp_param) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + run_torch_amp() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_torch_amp(): + spawn(run_dist, 1) + + +if __name__ == '__main__': + test_torch_amp() diff --git a/tests/test_analyzer/__init__.py b/tests/test_analyzer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/test_analyzer/test_fx/__init__.py b/tests/test_analyzer/test_fx/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/test_analyzer/test_fx/test_bias_addition.py b/tests/test_analyzer/test_fx/test_bias_addition.py new file mode 100644 index 0000000000000000000000000000000000000000..f7b5eb140f2437c03bbf3f6f399960f5c5d6ed98 --- /dev/null +++ b/tests/test_analyzer/test_fx/test_bias_addition.py @@ -0,0 +1,122 @@ +import pytest +import torch +from packaging import version +from torch.utils.checkpoint import checkpoint + +from colossalai.testing.utils import clear_cache_before_run, parameterize + +try: + from colossalai._analyzer.fx import symbolic_trace +except: + pass + + +class LinearModel(torch.nn.Module): + + def __init__(self, in_features, out_features, bias): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features, bias=bias) + + def forward(self, x): + x = self.linear(x) + return x + + +class ConvModel(torch.nn.Module): + + def __init__(self, in_channel, out_channels, kernel_size, bias) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(in_channel, + out_channels, + kernel_size, + bias=bias, + padding=1, + stride=2, + dilation=2, + groups=3) + self.conv_transpose = torch.nn.ConvTranspose2d(in_channel, + out_channels, + kernel_size, + bias=bias, + padding=1, + stride=2, + dilation=2, + groups=3) + + def forward(self, x, select=0): + if select == 0: + x = self.conv(x) + else: + x = self.conv_transpose(x) + return x + + +class SiuModel(torch.nn.Module): + + def __init__(self, bias) -> None: + super().__init__() + self.linear = LinearModel(3, 3, bias) + self.conv = ConvModel(3, 6, 3, bias) + + def forward(self, x, select=torch.Tensor([0])): + x = self.linear(x) + if select: + x = checkpoint(self.conv, x, 0) + else: + x = checkpoint(self.conv, x, 1) + + return x + + +class AddmmModel(torch.nn.Module): + + def __init__(self, alpha, beta) -> None: + super().__init__() + self.alpha = alpha + self.beta = beta + + def forward(self, x): + x = torch.addmm(x, x, x, alpha=self.alpha, beta=self.beta) + return x + + +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@clear_cache_before_run() +@parameterize("bias", [True, False]) +@parameterize("bias_addition_split", [True, False]) +@parameterize("shape", [(3, 3, 3), (3, 3, 3, 3)]) +@parameterize("select", [torch.Tensor([0]), torch.Tensor([1])]) +def test_siu_model(bias, bias_addition_split, shape, select): + model = SiuModel(bias=bias) + x = torch.rand(shape) + gm = symbolic_trace(model, + meta_args={'x': x}, + concrete_args={'select': select}, + trace_act_ckpt=True, + bias_addition_split=bias_addition_split) + assert torch.allclose(model(x, select), gm(x)), 'original model and traced model should be the same!' + if bias and bias_addition_split: + assert '+' in gm.code, 'bias addition should be split!' + else: + assert '+' not in gm.code, 'bias addition should not be split!' + + +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@parameterize("alpha", [1, 2]) +@parameterize("beta", [1, 2]) +@parameterize("bias_addition_split", [True, False]) +@parameterize("shape", [(3, 3), (5, 5)]) +def test_addmm_model(alpha, beta, bias_addition_split, shape): + model = AddmmModel(alpha=alpha, beta=beta) + x = torch.rand(shape) + gm = symbolic_trace(model, meta_args={'x': x}, trace_act_ckpt=True, bias_addition_split=bias_addition_split) + assert torch.allclose(model(x), gm(x)), 'original model and traced model should be the same!' + if (alpha == 1 and beta == 1) or not bias_addition_split: + assert '*' not in gm.code, 'bias addition should not be split!' + elif bias_addition_split: + assert '+' in gm.code, 'bias addition should be split!' + + +if __name__ == '__main__': + test_siu_model() + test_addmm_model() diff --git a/tests/test_analyzer/test_fx/test_mod_dir.py b/tests/test_analyzer/test_fx/test_mod_dir.py new file mode 100644 index 0000000000000000000000000000000000000000..f62147b297a2ccd6ede7cbbf56804125172fa847 --- /dev/null +++ b/tests/test_analyzer/test_fx/test_mod_dir.py @@ -0,0 +1,81 @@ +import pytest +import torch + +from colossalai.testing import clear_cache_before_run, parameterize + +try: + from colossalai._analyzer.fx import symbolic_trace +except: + pass + + +class LinearModel(torch.nn.Module): + + def __init__(self, in_features, out_features, bias): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features, bias=bias) + + def forward(self, x): + x = self.linear(x) + return x + + +class ConvModel(torch.nn.Module): + + def __init__(self, in_channel, out_channels, kernel_size, bias) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(in_channel, + out_channels, + kernel_size, + bias=bias, + padding=1, + stride=2, + dilation=2, + groups=3) + self.conv_transpose = torch.nn.ConvTranspose2d(out_channels, + out_channels, + kernel_size, + bias=bias, + padding=1, + stride=2, + dilation=2, + groups=3) + + def forward(self, x): + x = self.conv(x) + x = self.conv_transpose(x) + return x + + +class AModel(torch.nn.Module): + + def __init__(self, bias) -> None: + super().__init__() + self.linear_1 = LinearModel(3, 3, bias) + self.linear_2 = LinearModel(3, 3, bias) + self.conv = ConvModel(3, 6, 3, bias) + + def forward(self, x): + for i in range(x.shape[0]): + x = self.linear_1(x) + x = self.linear_2(x) + x = self.conv(x) + return x + + +@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12') +@clear_cache_before_run() +@parameterize("bias", [True, False]) +@parameterize("bias_addition_split", [True, False]) +@parameterize("shape", [(3, 3, 3), (3, 3, 3, 3)]) +def test_mod_dir(bias, bias_addition_split, shape): + model = AModel(bias=bias) + x = torch.rand(shape) + gm = symbolic_trace(model, meta_args={'x': x}, bias_addition_split=bias_addition_split) + for node in gm.graph.nodes: + assert len(node.meta['info'].mod_dir), f"{node} should have non-trivial ``mod_dir``." + print(node, node.meta['info'].mod_dir) + + +if __name__ == '__main__': + test_mod_dir(bias=True, bias_addition_split=True, shape=(3, 3, 3)) diff --git a/tests/test_analyzer/test_fx/test_nested_ckpt.py b/tests/test_analyzer/test_fx/test_nested_ckpt.py new file mode 100644 index 0000000000000000000000000000000000000000..bd16f5a4f95dd2ad4b60221a3b0281ce6b5a766d --- /dev/null +++ b/tests/test_analyzer/test_fx/test_nested_ckpt.py @@ -0,0 +1,58 @@ +import pytest +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint + +from colossalai.testing import clear_cache_before_run + +try: + from colossalai._analyzer.fx import symbolic_trace +except: + pass + + +class MyModule(nn.Module): + + def __init__(self): + super().__init__() + self.a = nn.Linear(10, 10) + self.b = nn.Linear(10, 10) + self.c = nn.Linear(10, 10) + self.d = nn.Linear(10, 10) + self.e = nn.Linear(10, 10) + + def checkpoint_0(self, x): + return checkpoint(self.checkpoint_0_0, x) + checkpoint(self.checkpoint_0_1, x) + self.e(x) + + def checkpoint_0_0(self, x): + return checkpoint(self.checkpoint_0_0_0, x) + checkpoint(self.checkpoint_0_0_1, x) + + def checkpoint_0_0_0(self, x): + return self.a(x) + checkpoint(self.checkpoint_0_0_0_0, x, use_reentrant=False) + + def checkpoint_0_0_0_0(self, x): + return self.b(x) + + def checkpoint_0_0_1(self, x): + return self.b(x) + self.c(x) + + def checkpoint_0_1(self, x): + return self.d(x) + + def forward(self, x): + return checkpoint(self.checkpoint_0, x) + + +@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12') +@clear_cache_before_run() +def test_nested_ckpt(): + model = MyModule() + x = torch.rand(10, 10) + gm = symbolic_trace(model, meta_args={'x': x}, trace_act_ckpt=True) + assert torch.allclose(gm(x), model(x)), "The traced model should generate the same output as the original model." + for ckpt_def in filter(lambda s: s.startswith('checkpoint'), dir(model)): + assert ckpt_def in gm.code, f"Checkpoint {ckpt_def} should be in the traced code.\n Traced code = {gm.code}" + + +if __name__ == "__main__": + test_nested_ckpt() diff --git a/tests/test_analyzer/test_fx/test_shape_prop.py b/tests/test_analyzer/test_fx/test_shape_prop.py new file mode 100644 index 0000000000000000000000000000000000000000..a849feb795e5d4a935892d88ddfc099ec2a6c1b6 --- /dev/null +++ b/tests/test_analyzer/test_fx/test_shape_prop.py @@ -0,0 +1,67 @@ +import pytest +import torch +import torchvision.models as tm +from packaging import version + +from colossalai.testing.utils import clear_cache_before_run, parameterize +from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models + +try: + from colossalai._analyzer._subclasses import MetaTensorMode + from colossalai._analyzer.fx import symbolic_trace + from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass + from colossalai._analyzer.fx.symbolic_profile import register_shape_impl + + @register_shape_impl(torch.nn.functional.linear) + def linear_impl(*args, **kwargs): + assert True + return torch.nn.functional.linear(*args, **kwargs) +except: + pass + + +def _check_gm_validity(gm: torch.fx.GraphModule): + for node in gm.graph.nodes: + assert node.meta['info'].outputs, f'In {gm.__class__.__name__}, {node} has no output shape.' + if node.op in [ + 'call_module', # can apply to params + 'call_function', # can apply to params + 'call_method', # can apply to params + ]: + assert hasattr(node.meta['info'], 'inputs'), f'In {gm.__class__.__name__}, {node} has no input shape.' + + +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@clear_cache_before_run() +@parameterize('m', tm_models) +def test_torchvision_shape_prop(m): + with MetaTensorMode(): + model = m() + data = torch.rand(100, 3, 224, 224) + meta_args = { + "x": data, + } + gm = symbolic_trace(model, meta_args=meta_args) + shape_prop_pass(gm, data) + _check_gm_validity(gm) + + +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@clear_cache_before_run() +@parameterize('m', tmm_models) +def test_timm_shape_prop(m): + with MetaTensorMode(): + model = m() + data = torch.rand(100, 3, 224, 224) + meta_args = { + "x": data, + } + + gm = symbolic_trace(model, meta_args=meta_args) + shape_prop_pass(gm, data) + _check_gm_validity(gm) + + +if __name__ == "__main__": + test_torchvision_shape_prop() + test_timm_shape_prop() diff --git a/tests/test_analyzer/test_fx/test_symbolic_profile.py b/tests/test_analyzer/test_fx/test_symbolic_profile.py new file mode 100644 index 0000000000000000000000000000000000000000..17deee7a71188ba60def6e79e9995727c3589001 --- /dev/null +++ b/tests/test_analyzer/test_fx/test_symbolic_profile.py @@ -0,0 +1,53 @@ +import pytest +import torch +import torchvision.models as tm +from packaging import version + +from colossalai.testing.utils import clear_cache_before_run, parameterize +from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models + +try: + from colossalai._analyzer._subclasses import MetaTensorMode + from colossalai._analyzer.fx import symbolic_profile, symbolic_trace +except: + pass + + +def _check_gm_validity(gm: torch.fx.GraphModule): + for node in gm.graph.nodes: + assert len(node.meta['info'].global_ctx), f'In {gm.__class__.__name__}, {node} has empty global context.' + + +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@clear_cache_before_run() +@parameterize('m', tm_models) +def test_torchvision_profile(m, verbose=False, bias_addition_split=False): + with MetaTensorMode(): + model = m() + data = torch.rand(8, 3, 224, 224) + meta_args = { + "x": data, + } + gm = symbolic_trace(model, meta_args=meta_args, bias_addition_split=bias_addition_split) + symbolic_profile(gm, data, verbose=verbose) + _check_gm_validity(gm) + + +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@clear_cache_before_run() +@parameterize('m', tmm_models) +def test_timm_profile(m, verbose=False, bias_addition_split=False): + with MetaTensorMode(): + model = m() + data = torch.rand(8, 3, 224, 224) + meta_args = { + "x": data, + } + gm = symbolic_trace(model, meta_args=meta_args, bias_addition_split=bias_addition_split) + symbolic_profile(gm, data, verbose=verbose) + _check_gm_validity(gm) + + +if __name__ == "__main__": + test_torchvision_profile() + test_timm_profile() diff --git a/tests/test_analyzer/test_fx/zoo.py b/tests/test_analyzer/test_fx/zoo.py new file mode 100644 index 0000000000000000000000000000000000000000..a96aa3949134e6ca3555a7d69390b1140bfdb5e4 --- /dev/null +++ b/tests/test_analyzer/test_fx/zoo.py @@ -0,0 +1,53 @@ +import timm.models as tmm +import torchvision.models as tm + +# input shape: (batch_size, 3, 224, 224) +tm_models = [ + tm.alexnet, + tm.convnext_base, + tm.densenet121, + # tm.efficientnet_v2_s, + # tm.googlenet, # output bad case + # tm.inception_v3, # bad case + tm.mobilenet_v2, + tm.mobilenet_v3_small, + tm.mnasnet0_5, + tm.resnet18, + tm.regnet_x_16gf, + tm.resnext50_32x4d, + tm.shufflenet_v2_x0_5, + tm.squeezenet1_0, + # tm.swin_s, # fx bad case + tm.vgg11, + tm.vit_b_16, + tm.wide_resnet50_2, +] + +tmm_models = [ + tmm.beit_base_patch16_224, + tmm.beitv2_base_patch16_224, + tmm.cait_s24_224, + tmm.coat_lite_mini, + tmm.convit_base, + tmm.deit3_base_patch16_224, + tmm.dm_nfnet_f0, + tmm.eca_nfnet_l0, + tmm.efficientformer_l1, + # tmm.ese_vovnet19b_dw, + tmm.gmixer_12_224, + tmm.gmlp_b16_224, + # tmm.hardcorenas_a, + tmm.hrnet_w18_small, + tmm.inception_v3, + tmm.mixer_b16_224, + tmm.nf_ecaresnet101, + tmm.nf_regnet_b0, + # tmm.pit_b_224, # pretrained only + # tmm.regnetv_040, + # tmm.skresnet18, + # tmm.swin_base_patch4_window7_224, # fx bad case + # tmm.tnt_b_patch16_224, # bad case + tmm.vgg11, + tmm.vit_base_patch16_18x2_224, + tmm.wide_resnet50_2, +] diff --git a/tests/test_analyzer/test_subclasses/__init__.py b/tests/test_analyzer/test_subclasses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/test_analyzer/test_subclasses/test_aten.py b/tests/test_analyzer/test_subclasses/test_aten.py new file mode 100644 index 0000000000000000000000000000000000000000..b7858110ac0939c4d6d09825143167e18b32e58f --- /dev/null +++ b/tests/test_analyzer/test_subclasses/test_aten.py @@ -0,0 +1,85 @@ +from typing import Any, Callable, Union + +import pytest +import torch +import torch.nn as nn + +from colossalai.testing import clear_cache_before_run + +try: + from colossalai._analyzer._subclasses import MetaTensor +except: + pass + +aten = torch.ops.aten + +registered_meta = { + ('aten.convolution.default', True): [ # (aten ops, requires_backward) + (nn.Conv1d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4)), + (nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4, 4)), + (nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4, 4, 4)), + (nn.ConvTranspose1d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4)), + (nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, + dilation=2), torch.rand(2, 3, 4, 4)), + (nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, + dilation=2), torch.rand(2, 3, 4, 4, 4)), + ], + ('aten.native_batch_norm.default', True): [ + (nn.BatchNorm1d(4), torch.rand(2, 4)), + (nn.BatchNorm2d(4), torch.rand(1, 4, 4, 4)), + (nn.BatchNorm3d(4), torch.rand(1, 4, 4, 4, 4)), + ], + ('aten.native_layer_norm.default', True): [(nn.LayerNorm(4), torch.rand(1, 2, 3, 4)),], + ('aten.avg_pool1d.default', True): [ + (nn.MaxPool1d(3, stride=2), torch.rand(4, 5, 5)), + (nn.AvgPool1d(3, stride=2), torch.rand(4, 5, 5)), + (nn.AdaptiveMaxPool1d(3), torch.rand(4, 5, 5)), + (nn.AdaptiveAvgPool1d(3), torch.rand(4, 5, 5)), + ], + ('aten.avg_pool2d.default', True): [ + (nn.MaxPool2d((3, 2), stride=(2, 1)), torch.rand(2, 4, 5, 5)), + (nn.AvgPool2d((3, 2), stride=(2, 1)), torch.rand(2, 4, 5, 5)), + (nn.AdaptiveMaxPool2d((3, 2)), torch.rand(2, 4, 5, 5)), + (nn.AdaptiveAvgPool2d((3, 2)), torch.rand(2, 4, 5, 5)), + ], + ('aten.relu.default', True): [ + (nn.ReLU(), torch.rand(4, 3, 1, 2)), + (nn.LeakyReLU(), torch.rand(4, 3, 1, 2)), + (nn.SiLU(), torch.rand(4, 3, 1, 2)), + (nn.GELU(), torch.rand(4, 3, 1, 2)), + (nn.ELU(), torch.rand(4, 3, 1, 2)), + (nn.Sigmoid(), torch.rand(4, 3, 1, 2)), + (nn.Tanh(), torch.rand(4, 3, 1, 2)), + (nn.Hardswish(), torch.rand(4, 3, 1, 2)), + ] +} + + +def compare_all(tensor: torch.Tensor, meta_tensor: torch.Tensor) -> Any: + assert tensor.shape == meta_tensor.shape, f'the shape of tensor ({tensor.shape}) and meta tensor ({meta_tensor.shape}) does not match.' + assert tensor.dtype == meta_tensor.dtype, f'the dtype of tensor ({tensor.dtype}) and meta tensor ({meta_tensor.dtype}) does not match.' + assert tensor.stride() == meta_tensor.stride( + ), f'the stride of tensor ({tensor.stride()}) and meta tensor ({meta_tensor.stride()}) does not match.' + + +def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_backward=False) -> Any: + x.requires_grad = requires_backward + meta_x = MetaTensor(x) + x_out, meta_out = f(x), f(meta_x) + compare_all(x_out, meta_out) + if requires_backward: + x_out.sum().backward() + meta_out.sum().backward() + compare_all(x.grad, meta_x.grad) + + +@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12') +@clear_cache_before_run() +def test_meta_aten(): + for (aten_op, requires_backward), v in registered_meta.items(): + for f, x in v: + run_and_compare(f, x, requires_backward) + + +if __name__ == '__main__': + test_meta_aten() diff --git a/tests/test_analyzer/test_subclasses/test_flop_tensor.py b/tests/test_analyzer/test_subclasses/test_flop_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..da3829e401468be885bc226cb50c67d9ebc8288e --- /dev/null +++ b/tests/test_analyzer/test_subclasses/test_flop_tensor.py @@ -0,0 +1,53 @@ +import pytest +import torch +import torch.nn.functional as F +import torchvision.models as tm +from packaging import version + +from colossalai.testing import clear_cache_before_run, parameterize +from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models + +try: + from colossalai._analyzer._subclasses import MetaTensorMode, flop_count +except: + pass + + +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@pytest.mark.parametrize('m', tm_models + tmm_models) +def test_flop_count_module(m): + x = torch.rand(2, 3, 224, 224) + with MetaTensorMode(): # save time for testing + module = m() + rs_fwd, rs_bwd = flop_count(module, x, verbose=True) + assert rs_fwd > 0, f'fwd flop count of {m.__name__} is {rs_fwd}' + assert rs_bwd > 0, f'bwd flop count of {m.__name__} is {rs_bwd}' + + +odd_cases = [ + (F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), { + 'inplace': True + }), + (F.max_pool2d, (torch.rand(2, 3, 224, 224, requires_grad=True),), { + 'kernel_size': 3, + 'stride': 2, + 'padding': 1, + 'dilation': 2 + }), + (torch.where, (torch.rand(2, 3, 224, 224) > 0.5, torch.rand(2, 3, 224, 224, requires_grad=True), + torch.rand(2, 3, 224, 224, requires_grad=True)), {}), +] + + +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@clear_cache_before_run() +@parameterize('func, args, kwargs', odd_cases) +def test_flop_count_function(func, args, kwargs): + rs_fwd, rs_bwd = flop_count(func, *args, **kwargs, verbose=True) + assert rs_fwd > 0, f'fwd flop count of {func.__name__} is {rs_fwd}' + assert rs_bwd > 0, f'bwd flop count of {func.__name__} is {rs_bwd}' + + +if __name__ == '__main__': + test_flop_count_module(tm.resnet18) + test_flop_count_function(F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), {'inplace': True}) diff --git a/tests/test_analyzer/test_subclasses/test_meta_mode.py b/tests/test_analyzer/test_subclasses/test_meta_mode.py new file mode 100644 index 0000000000000000000000000000000000000000..d2a0a1b9cfb590b4a32b2fafe17ede5a9c4de0ef --- /dev/null +++ b/tests/test_analyzer/test_subclasses/test_meta_mode.py @@ -0,0 +1,42 @@ +import pytest +import torch +import torchvision.models as tm +from packaging import version + +from colossalai.testing import clear_cache_before_run, parameterize + +try: + from colossalai._analyzer._subclasses import MetaTensor, MetaTensorMode +except: + pass +from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models + + +def compare_all(tensor: torch.Tensor, meta_tensor: torch.Tensor): + assert tensor.shape == meta_tensor.shape, f'the shape of tensor ({tensor.shape}) and meta tensor ({meta_tensor.shape}) does not match.' + assert tensor.dtype == meta_tensor.dtype, f'the dtype of tensor ({tensor.dtype}) and meta tensor ({meta_tensor.dtype}) does not match.' + assert tensor.stride() == meta_tensor.stride( + ), f'the stride of tensor ({tensor.stride()}) and meta tensor ({meta_tensor.stride()}) does not match.' + + +def run_and_compare(model): + x = torch.rand(2, 3, 224, 224, requires_grad=True) + x_out = model(x) + with MetaTensorMode(): + meta_x = torch.rand(2, 3, 224, 224, requires_grad=True) + meta_out = model(meta_x) + compare_all(x_out, meta_out) + x_out.sum().backward() + meta_out.sum().backward() + compare_all(x.grad, meta_x.grad) + + +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@clear_cache_before_run() +@parameterize('m', tm_models + tmm_models) +def test_meta_mode_shape(m): + run_and_compare(m()) + + +if __name__ == '__main__': + test_meta_mode_shape(tm.resnet18) diff --git a/tests/test_auto_parallel/__init__.py b/tests/test_auto_parallel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py b/tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py new file mode 100644 index 0000000000000000000000000000000000000000..f184f64b35d020196ded866936a142c6858533c4 --- /dev/null +++ b/tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py @@ -0,0 +1,78 @@ +import copy + +import pytest +import torch +import torch.fx +import torchvision.models as tm + +import colossalai +from colossalai.core import global_context as gpc +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.fx._compatibility import is_compatible_with_meta +# from colossalai.fx.passes.algorithms import solver_rotor +# from colossalai.fx.passes.algorithms.operation import Sequence +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.testing import rerun_if_address_is_in_use, spawn + +if is_compatible_with_meta(): + from colossalai.fx.profiler.tensor import MetaTensor + +try: + from colossalai.fx.codegen import ActivationCheckpointCodeGen + withcodegen = True +except: + from colossalai.fx.codegen import python_code_with_activation_checkpoint + withcodegen = False + + +def _run_C_solver_consistency_test(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + for M, mem_budget in [(tm.resnet50, 4000), (tm.densenet121, 8080)]: + model = M() + data = torch.rand(128, 3, 224, 224, device='meta') + + tracer = ColoTracer() + graph = tracer.trace(model, meta_args={"x": data}) + graph.set_codegen(ActivationCheckpointCodeGen()) + gm = ColoGraphModule(model, graph, model.__class__.__name__) + if is_compatible_with_meta(): + data_meta = MetaTensor(data, fake_device=next(gm.parameters()).device) + MetaInfoProp(gm).run(data_meta) + + # python solver + gm = solver_rotor(gm, data_meta, mem_budget * 1024 * 1024, force_python=True) + sequence_python: Sequence = copy.deepcopy(gm.__sequence__) + opt_python = copy.deepcopy(gm.__opttable__) + + # C solver + gm = solver_rotor(gm, data_meta, mem_budget * 1024 * 1024) + sequence_C: Sequence = copy.deepcopy(gm.__sequence__) + opt_C = copy.deepcopy(gm.__opttable__) + + # make sure the opt_tables are the same + for m in range(len(opt_python)): + for d in range(1, len(opt_python[0])): + for i in range(len(opt_python[0]) - d): + assert opt_python[m][i][i + d] == opt_C[m][i][i + d], \ + f"item ({m}, {i}, {i + d}) is not consistent with python version!\npython version: {opt_python[m][i][i + d]}\nC version: {opt_C[m][i][i + d]}" + + sequence_python = sequence_python.list_operations() + sequence_C = sequence_C.list_operations() + + # make sure the sequences are the same + assert len(sequence_python) == len(sequence_C) and \ + all(python_op.__repr__() == C_op.__repr__() for (python_op, C_op) in zip(sequence_python, sequence_C)) + + gpc.destroy() + + +@pytest.mark.skip("TODO(lyl): refactor all tests.") +@pytest.mark.skipif(not withcodegen, reason="torch version is less than 1.12.0") +@rerun_if_address_is_in_use() +def test_C_solver_consistency(): + spawn(_run_C_solver_consistency_test, 1) + + +if __name__ == '__main__': + _run_C_solver_consistency_test(rank=0) diff --git a/tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py new file mode 100644 index 0000000000000000000000000000000000000000..db268b91d0a0ed1394b139ff503747f90ee81156 --- /dev/null +++ b/tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py @@ -0,0 +1,142 @@ +import copy +import re +from typing import Callable + +import pytest +import torch +import torchvision.models as tm +from torch.fx import GraphModule + +import colossalai +from colossalai.core import global_context as gpc +from colossalai.fx import ColoTracer +from colossalai.fx._compatibility import is_compatible_with_meta +from colossalai.fx.graph_module import ColoGraphModule +# from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.testing import rerun_if_address_is_in_use, spawn + +if is_compatible_with_meta(): + from colossalai.fx.profiler.tensor import MetaTensor + +try: + from colossalai.fx.codegen import ActivationCheckpointCodeGen + with_codegen = True +except: + # fall back to older pytorch version + from colossalai.fx.codegen import python_code_with_activation_checkpoint + with_codegen = False + +# SOLVERS = [chen_greedy, solver_rotor] +SOLVERS = [] + + +def _is_activation_checkpoint_available(gm: GraphModule): + for n in gm.graph.nodes: + if hasattr(n, 'activation_checkpoint') and getattr(n, 'activation_checkpoint') is not None: + return True + + +def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule): + for m_p, gm_p in zip(m.parameters(), gm.parameters()): + if not torch.allclose(m_p.grad, gm_p.grad): + return False + return True + + +def _is_graph_linearized(gm: GraphModule): + code = gm.code + # find patterns like r' return output_1, output_2', which is not expected on a linearized graph + pattern = re.compile(r' return [a-zA-Z0-9_]+(, [a-zA-Z0-9_]+)+') + if pattern.findall(code): + return False + else: + return True + + +def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Callable[[GraphModule], GraphModule], + model_cls: Callable[[], torch.nn.Module]): + criterion = torch.nn.MSELoss() + m.cuda() + data = torch.rand(2, 3, 32, 32).cuda() + label = torch.rand(2, 5).cuda() + loss = criterion(m(data), label) + loss.backward() + loss = criterion(gm(data), label) + loss.backward() + assert _is_all_gradient_close(m, gm), f'Solver {solver} did not work correctly in backward pass on {model_cls}' + + +def _run_ckpt_solver(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + MODEL_LIST = [tm.densenet121] + + torch.backends.cudnn.deterministic = True + + tracer = ColoTracer(trace_act_ckpt=False) + + data = torch.rand(8, 3, 224, 224, device='meta') + for solver in SOLVERS: + for model_cls in MODEL_LIST: + m = model_cls(num_classes=5) + graph = tracer.trace(root=m) + gm = ColoGraphModule(copy.deepcopy(m), graph, m.__class__.__name__) + MetaInfoProp(gm.cuda()).run(MetaTensor(data).cuda()) + codegen = ActivationCheckpointCodeGen() + gm.graph.set_codegen(codegen) + if solver == solver_rotor: + gm = solver(gm, data, mem_limit=500 * 1024 * 1024, mem_slots=500) + else: + gm = solver(gm) + assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner." + assert _is_activation_checkpoint_available( + gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" + check_backward_consistency(m, gm, solver, model_cls) + gpc.destroy() + + +@pytest.mark.skip("TODO(super-dainiu): refactor all tests.") +@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') +@rerun_if_address_is_in_use() +def test_ckpt_solver(): + spawn(_run_ckpt_solver, 1) + + +def _run_ckpt_solver_torch11(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + MODEL_LIST = [tm.densenet121] + + torch.backends.cudnn.deterministic = True + + tracer = ColoTracer(trace_act_ckpt=False) + + data = torch.rand(8, 3, 32, 32, device='meta') + for solver in SOLVERS: + for model_cls in MODEL_LIST: + m = model_cls(num_classes=5) + graph = tracer.trace(root=m) + gm = ColoGraphModule(copy.deepcopy(m), graph, m.__class__.__name__) + MetaInfoProp(gm).run(data) + gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph) + if solver == solver_rotor: + gm = solver(gm, data, mem_limit=500 * 1024 * 1024, mem_slots=500, force_python=True) + else: + gm = solver(gm) + assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner." + assert _is_activation_checkpoint_available( + gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" + check_backward_consistency(m, gm, solver, model_cls) + gpc.destroy() + + +@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') +@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done") +@rerun_if_address_is_in_use() +def test_ckpt_solver_torch11(): + spawn(_run_ckpt_solver_torch11, 1) + + +if __name__ == '__main__': + _run_ckpt_solver(rank=0) + test_ckpt_solver() + test_ckpt_solver_torch11() diff --git a/tests/test_auto_parallel/test_ckpt_solvers/test_linearize.py b/tests/test_auto_parallel/test_ckpt_solvers/test_linearize.py new file mode 100644 index 0000000000000000000000000000000000000000..59880815dc5ebbd97dc7ee7fca52ca502be4207c --- /dev/null +++ b/tests/test_auto_parallel/test_ckpt_solvers/test_linearize.py @@ -0,0 +1,144 @@ +import pytest +import torch +import torchvision.models as tm + +from colossalai.fx import ColoTracer +from colossalai.fx._compatibility import is_compatible_with_meta +from colossalai.fx.graph_module import ColoGraphModule +# from colossalai.fx.passes.algorithms import linearize, solver_rotor +# from colossalai.fx.passes.algorithms.operation import (ForwardCheck, ForwardEnable, ForwardNograd, Loss) +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.testing import clear_cache_before_run + +if is_compatible_with_meta(): + from colossalai.fx.profiler.tensor import MetaTensor + +try: + from colossalai.fx.codegen import ActivationCheckpointCodeGen + with_codegen = True +except: + # fall back to older pytorch version + from colossalai.fx.codegen import python_code_with_activation_checkpoint + with_codegen = False + + +@pytest.mark.skip(reason='TODO: modify the logger') +@pytest.mark.skip("TODO(lyl): refactor all tests.") +@pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0") +@clear_cache_before_run() +def test_linearize(): + MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]} + tracer = ColoTracer() + for M, budgets in MODEL_DICT.items(): + for budget in budgets: + model = M() + graph = tracer.trace(model) + graph.set_codegen(ActivationCheckpointCodeGen()) + gm = ColoGraphModule(model, graph, model.__class__.__name__) + MetaInfoProp(gm).run(MetaTensor(torch.rand(128, 3, 224, 224, device="meta"), fake_device='cpu')) + node_list = linearize(gm) + gm = solver_rotor(gm, data=torch.rand(128, 3, 224, 224, device="meta"), mem_limit=budget * 1024**2) + op_list = gm.__sequence__.list_operations() + loss_op = next(op for op in op_list if isinstance(op, Loss)) + op_list = op_list[:op_list.index(loss_op)] + in_ckpt = False + ckpt_idx = 0 + for idx, op in enumerate(op_list): + if in_ckpt: + if isinstance(op, ForwardNograd): + for n in node_list[idx]: + assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!" + assert n.activation_checkpoint[ + 0] == ckpt_idx, f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!" + + continue + + if isinstance(op, ForwardEnable): + for n in node_list[idx]: + assert getattr(n, "activation_checkpoint", None) == None, f"{n} should not be annotated!" + in_ckpt = False + + ckpt_idx += 1 + continue + + if isinstance(op, ForwardCheck): + ckpt_idx += 1 + for n in node_list[idx]: + assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!" + assert n.activation_checkpoint[ + 0] == ckpt_idx, f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!" + + continue + + else: + if isinstance(op, ForwardCheck): + in_ckpt = True + for n in node_list[idx]: + assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!" + assert n.activation_checkpoint[ + 0] == ckpt_idx, f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!" + + del model + del gm + del node_list + + +@pytest.mark.skip("TODO(lyl): refactor all tests.") +@pytest.mark.skip(reason="torch11 meta tensor not implemented") +@pytest.mark.skipif(with_codegen, reason="torch version is equal to or higher than 1.12.0") +@clear_cache_before_run() +def test_linearize_torch11(): + MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]} + tracer = ColoTracer() + for M, budgets in MODEL_DICT.items(): + for budget in budgets: + model = M() + graph = tracer.trace(model) + gm = ColoGraphModule(model, graph, model.__class__.__name__) + gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph) + node_list = linearize(gm) + gm = solver_rotor(gm, data=torch.rand(128, 3, 224, 224, device="meta"), mem_limit=budget * 1024**2) + op_list = gm.__sequence__.list_operations() + loss_op = next(op for op in op_list if isinstance(op, Loss)) + op_list = op_list[:op_list.index(loss_op)] + in_ckpt = False + ckpt_idx = 0 + for idx, op in enumerate(op_list): + if in_ckpt: + if isinstance(op, ForwardNograd): + for n in node_list[idx]: + assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!" + assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!" + + continue + + if isinstance(op, ForwardEnable): + for n in node_list[idx]: + assert getattr(n, "activation_checkpoint", None) == None, f"{n} should not be annotated!" + in_ckpt = False + + ckpt_idx += 1 + continue + + if isinstance(op, ForwardCheck): + ckpt_idx += 1 + for n in node_list[idx]: + assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!" + assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!" + + continue + + else: + if isinstance(op, ForwardCheck): + in_ckpt = True + for n in node_list[idx]: + assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!" + assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!" + + del model + del gm + del node_list + + +if __name__ == "__main__": + test_linearize() diff --git a/tests/test_auto_parallel/test_offload/model_utils.py b/tests/test_auto_parallel/test_offload/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c22b17ae42ba6b5f201698048dbfd7c0bb628341 --- /dev/null +++ b/tests/test_auto_parallel/test_offload/model_utils.py @@ -0,0 +1,86 @@ +import torch +import torch.nn as nn +from transformers import GPT2Config, GPT2LMHeadModel +from transformers import BertConfig, BertLMHeadModel +from tests.components_to_test.registry import non_distributed_component_funcs + +class GPTLMModel(nn.Module): + + def __init__(self, + hidden_size=768, + num_layers=12, + num_attention_heads=12, + max_seq_len=1024, + vocab_size=50257): + super().__init__() + self.model = GPT2LMHeadModel( + GPT2Config(n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size)) + + def forward(self, input_ids, attention_mask): + # Only return lm_logits + return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=True)[0] + + +class LMLoss(nn.Module): + + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, logits, labels): + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + +class BertLMModel(nn.Module): + def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=32, vocab_size=30522): + super().__init__() + self.model = BertLMHeadModel(BertConfig(n_embd=hidden_size, num_hidden_layers=num_layers, hidden_size=hidden_size, + num_attention_heads=num_attention_heads, max_position_embeddings=hidden_size, + vocab_size=vocab_size)) + + def forward(self, input_ids, attention_mask): + # Only return lm_logits + return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=True)[0] + +@non_distributed_component_funcs.register(name='bert_') +def get_bert_components(): + vocab_size = 1024 + seq_len = 64 + batchSize = 64 + + def bert_model_builder(): + model = BertLMModel(hidden_size=8192, num_layers=4, num_attention_heads=32, vocab_size=vocab_size) + return model + + def bert_data_gen(device="meta"): + input_ids = torch.randint(0, vocab_size, (batchSize, seq_len), device=device) + attention_mask = torch.ones_like(input_ids, device=device) + kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) + return kwargs + + return bert_model_builder, bert_data_gen + +@non_distributed_component_funcs.register(name='gpt2_') +def get_gpt2_components(): + vocab_size = 1024 + seq_len = 8 + batchSize = 64 + + def gpt2_model_builder(): + model = GPTLMModel(hidden_size=8192, num_layers=2, num_attention_heads=32, vocab_size=vocab_size) + return model + + def gpt2_data_gen(device="meta"): + input_ids = torch.randint(0, vocab_size, (batchSize, seq_len), device=device) + attention_mask = torch.ones_like(input_ids, device=device) + kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) + return kwargs + + return gpt2_model_builder, gpt2_data_gen \ No newline at end of file diff --git a/tests/test_auto_parallel/test_offload/test_perf.py b/tests/test_auto_parallel/test_offload/test_perf.py new file mode 100644 index 0000000000000000000000000000000000000000..80f134fd85d0007226ce0883b3ee302dbf87ee52 --- /dev/null +++ b/tests/test_auto_parallel/test_offload/test_perf.py @@ -0,0 +1,147 @@ +import time + +import pytest +import torch +from torch.utils._pytree import tree_map + +import colossalai +from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer +from colossalai.auto_parallel.offload.mem_optimize import memory_optimize +from colossalai.auto_parallel.offload.solver import NOT_NVML +from colossalai.fx.profiler import parameter_size +from colossalai.nn.optimizer import HybridAdam +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device +from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper +from tests.test_auto_parallel.test_offload.model_utils import * +from tests.test_tensor.common_utils import set_seed + + +@parameterize('model_name', ['gpt2_']) +@parameterize('memory_budget', [5000]) +@parameterize('solver_name', ['asyn']) +def exam_fwd_bwd(model_name: str, memory_budget: float, solver_name: str): + + # build model + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, data_gen = get_components_func() + label = torch.randint(low=0, high=128, size=( + 64, + 8, + ), device=get_current_device()) + criterion = LMLoss() + + set_seed(42) + start_time = time.time() + model = model_builder() + model.train() + param_size = parameter_size(model) / 1024**2 / 2 + init_time = time.time() - start_time + print(f"init_param_size={param_size:.3f} MB | init_model_time={init_time:.3f} s") + + data_args = data_gen(device="cpu") + wrap_fn = lambda x: x.to(dtype=torch.half) if isinstance(x, torch.Tensor) and torch.is_floating_point(x) else x + data_args = tree_map(wrap_fn, data_args) + start_time = time.time() + model = memory_optimize(model, data_args, memory_budget * 1024 * 1024, solver_name) + solver_time = time.time() - start_time + print(f"solver_time={solver_time:.3f} s") + + hybrid_optimizer = HybridAdam(model.model.parameters(), lr=1e-3) + optim = AMPOptimizer(hybrid_optimizer, model) + + with ColoInitContext(device=torch.device('cpu')): + gemini_model = model_builder() + gemini_model.train() + + hybrid_optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3) + gemini_config = dict(strict_ddp_mode=False, + device=torch.device('cpu'), + placement_policy='cpu', + pin_memory=True, + hidden_dim=8192, + search_range_mb=128) + gemini_model = zero_model_wrapper(gemini_model, 3, gemini_config) + optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True) + gemini_optim = zero_optim_wrapper(gemini_model, hybrid_optimizer, optim_config=optim_config) + + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + + # test gemini + time_list = [] + set_seed(42) + data_args = data_gen(device="cuda") + for step in range(10): + gemini_optim.zero_grad() + torch.cuda.synchronize() + start_time = time.time() + gemini_out = gemini_model(**data_args) + gemini_loss = criterion(gemini_out, label) + gemini_optim.backward(gemini_loss) + torch.cuda.synchronize() + time_list.append(time.time() - start_time) + gemini_optim.step() + + torch.cuda.synchronize() + + exec_time = sum(sorted(time_list)[:5]) / 5 + runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024**2 + runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024**2 + print(f'gemini | model_name: {model_name}') + print(f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB ' + f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|') + print(time_list) + + del data_args + del gemini_model + del gemini_optim + del gemini_out + del gemini_loss + + # test asyn offload + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + + time_list = [] + set_seed(42) + data_args = data_gen(device="cuda") + data_args = tree_map(wrap_fn, data_args) + for step in range(10): + optim.zero_grad() + torch.cuda.synchronize() + start_time = time.time() + loss = criterion(model(**data_args), label) + optim.backward(loss) + torch.cuda.synchronize() + time_list.append(time.time() - start_time) + optim.step() + + torch.cuda.synchronize() + + exec_time = sum(sorted(time_list)[:5]) / 5 + runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024**2 + runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024**2 + print(f'solver_name: {solver_name} | model_name: {model_name}') + print(f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB ' + f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|') + print(time_list) + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_fwd_bwd() + + +@pytest.mark.skip("this test failed") +@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed') +@rerun_if_address_is_in_use() +def test_perf(): + spawn(run_dist, 1) + + +if __name__ == '__main__': + test_perf() diff --git a/tests/test_auto_parallel/test_offload/test_solver.py b/tests/test_auto_parallel/test_offload/test_solver.py new file mode 100644 index 0000000000000000000000000000000000000000..aa2c9a36849fa9027caa8c66f7d8cf9d95590d0c --- /dev/null +++ b/tests/test_auto_parallel/test_offload/test_solver.py @@ -0,0 +1,67 @@ +import pytest +import torch.fx +from torch.fx import GraphModule +from torch.utils._pytree import tree_map + +from colossalai.auto_parallel.offload.region_manager import RegionManager +from colossalai.auto_parallel.offload.solver import NOT_NVML, SolverFactory +from colossalai.fx import ColoTracer, is_compatible_with_meta +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.testing import clear_cache_before_run, parameterize +from tests.test_auto_parallel.test_offload.model_utils import * + + +@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed') +@clear_cache_before_run() +@parameterize('model_name', ['gpt2_', 'bert_']) +@parameterize('memory_budget', [4000]) +@parameterize('solver_name', ['syn', 'asyn']) +def solver_test(model_name: str, memory_budget: float, solver_name: str): + + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, data_gen = get_components_func() + data_args = data_gen(device="cpu") + wrap_fn = lambda x: x.to(dtype=torch.half) if isinstance(x, torch.Tensor) and torch.is_floating_point(x) else x + data_args = tree_map(wrap_fn, data_args) + model = model_builder() + model.train() + model = model.cpu().half() + + tracer = ColoTracer() + assert is_compatible_with_meta() + wrap_fn = lambda x: x.to("meta") if isinstance(x, torch.Tensor) else x + meta_args = tree_map(wrap_fn, data_args) + graph = tracer.trace(model, meta_args=meta_args) + gm = GraphModule(model, graph, model.__class__.__name__) + + interp = MetaInfoProp(gm) + interp.propagate(*meta_args.values()) + + region_manager = RegionManager(graph, solver_name=solver_name) + region_manager._pre_process() + region_list = region_manager.region_list + + solver_cls = SolverFactory.create(solver_name) + memory_budget = memory_budget * 1024 * 1024 + solver = solver_cls(region_list, memory_budget) + solver._call_solver() + + assert solver.best_ts.peak_mem < memory_budget + + print("****************** execution plan *******************") + for region in region_list: + need_offload = region.need_offload + to_prefetch = region.fwd_prefetch_region.r_id if region.fwd_prefetch_region is not None else None + print( + f'| {model_name} forward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}' + ) + for region in region_list.__reversed__(): + need_offload = region.need_offload + to_prefetch = region.bwd_prefetch_region.r_id if region.bwd_prefetch_region is not None else None + print( + f'| {model_name} backward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}' + ) + + +if __name__ == '__main__': + solver_test() diff --git a/tests/test_auto_parallel/test_pass/__init__.py b/tests/test_auto_parallel/test_pass/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/test_auto_parallel/test_pass/test_node_converting_pass.py b/tests/test_auto_parallel/test_pass/test_node_converting_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..429e89aae5d38c5ac76fddb66a67116c0398f6a3 --- /dev/null +++ b/tests/test_auto_parallel/test_pass/test_node_converting_pass.py @@ -0,0 +1,56 @@ +import torch +import torch.nn.functional as F + +from colossalai.auto_parallel.passes.runtime_preparation_pass import node_args_converting_pass +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.tracer import ColoTracer +from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.testing import clear_cache_before_run + + +class TestModule(torch.nn.Module): + + def forward(self, x): + x = x.view(4, 4, 2) + return x + + +def insert_narrow(gm, x_node): + graph = gm.graph + with graph.inserting_after(x_node): + shard_node = graph.create_node('call_method', 'narrow', args=(x_node, 0, 0, 2), kwargs={}) + view_node = list(x_node.users.keys())[0] + new_args = list(view_node.args) + new_args[0] = shard_node + view_node.args = tuple(new_args) + return gm + + +@clear_cache_before_run() +def test_node_args_converting_pass(): + model = TestModule() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + meta_args = {'x': torch.rand(4, 8).to('meta')} + input = torch.rand(4, 8) + tracer = ColoTracer() + graph = tracer.trace(root=model, meta_args=meta_args) + + x_node = list(graph.nodes)[0] + view_node = list(graph.nodes)[1] + sharding_spec = ShardingSpec(device_mesh, entire_shape=(4, 8), dim_partition_dict={0: [0]}) + setattr(x_node, 'sharding_spec', sharding_spec) + setattr(view_node, 'sharding_spec', sharding_spec) + + gm = ColoGraphModule(model, graph) + gm = node_args_converting_pass(gm, device_mesh) + gm = insert_narrow(gm, x_node) + gm.recompile() + output = gm(input) + assert output.shape == torch.Size([2, 4, 2]) + + +if __name__ == '__main__': + test_node_args_converting_pass() diff --git a/tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py b/tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..bca81201c6ef03f3263369f2af7a13de67d2262d --- /dev/null +++ b/tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py @@ -0,0 +1,70 @@ +import pytest +import torch +import torch.nn.functional as F + +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes import shape_prop_pass +from colossalai._analyzer.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.passes.runtime_preparation_pass import size_value_converting_pass +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.testing import clear_cache_before_run + + +class TestModule(torch.nn.Module): + + def forward(self, x): + size = x.size() + return size + + +def insert_narrow(gm, x_node): + graph = gm.graph + with graph.inserting_after(x_node): + shard_node = graph.create_node('call_method', 'narrow', args=(x_node, 0, 0, 2), kwargs={}) + size_node = list(x_node.users.keys())[0] + size_node.args = (shard_node,) + return gm + + +def recover_narrow(gm, narrow_node): + graph = gm.graph + size_node = list(graph.nodes)[2] + x_node = narrow_node.args[0] + size_node.args = (x_node,) + graph.erase_node(narrow_node) + return gm + + +@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') +@clear_cache_before_run() +def test_size_value_converting_pass(): + model = TestModule() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + meta_args = {'x': torch.rand(4, 8).to('meta')} + input = torch.rand(4, 8) + tracer = ColoTracer(bias_addition_split=True) + graph = tracer.trace(root=model, meta_args=meta_args) + x_node = list(graph.nodes)[0] + x_sharding_spec = ShardingSpec(device_mesh, entire_shape=(4, 8), dim_partition_dict={0: [0]}) + setattr(x_node, 'sharding_spec', x_sharding_spec) + gm = ColoGraphModule(model, graph) + gm = insert_narrow(gm, x_node) + shape_prop_pass(gm, *meta_args.values()) + gm.recompile() + size = gm(input) + assert size == torch.Size([2, 8]) + + narrow_node = list(gm.graph.nodes)[1] + gm = recover_narrow(gm, narrow_node) + gm = size_value_converting_pass(gm, device_mesh) + gm = insert_narrow(gm, x_node) + gm.recompile() + size = gm(input) + assert size == torch.Size([4, 8]) + + +if __name__ == '__main__': + test_size_value_converting_pass() diff --git a/tests/test_auto_parallel/test_tensor_shard/__init__.py b/tests/test_auto_parallel/test_tensor_shard/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py b/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py new file mode 100644 index 0000000000000000000000000000000000000000..9fbe674ef4f4a9609e6ff1a8a5b507b64d8be7f1 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py @@ -0,0 +1,91 @@ +from functools import partial + +import pytest +import torch + +try: + from colossalai.auto_parallel.tensor_shard.initialize import initialize_model + NO_CODEGEN = False +except: + NO_CODEGEN = True + +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, rerun_if_address_is_in_use, run_on_environment_flag, spawn + + +class LinearModel(torch.nn.Module): + + def __init__(self, in_features, out_features): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features) + + def forward(self, x): + x = self.linear(x) + x = x * 2 + + return x + + +class ConvModel(torch.nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, bias=True): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + bias=bias) + + def forward(self, x): + x = self.conv(x) + x = x * 2 + + return x + + +def check_linear_module(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = LinearModel(4, 8).cuda() + input = torch.rand(4, 4).cuda() + output_compare = model(input) + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + meta_args = {'x': torch.rand(4, 4).to('meta')} + gm = initialize_model(model, meta_args=meta_args, device_mesh=device_mesh) + output = gm(input) + assert_close(output, output_compare) + + +def check_conv_module(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = ConvModel(3, 6, 2).cuda() + input = torch.rand(4, 3, 64, 64).cuda() + output_compare = model(input) + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + meta_args = {'x': torch.rand(4, 3, 64, 64).to('meta')} + gm = initialize_model(model, meta_args=meta_args, device_mesh=device_mesh) + output = gm(input) + assert_close(output, output_compare) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.skipif(NO_CODEGEN, reason='No codegen found') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bias_addition_module(): + spawn(check_linear_module, 4) + spawn(check_conv_module, 4) + + +if __name__ == '__main__': + test_bias_addition_module() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_broadcast.py b/tests/test_auto_parallel/test_tensor_shard/test_broadcast.py new file mode 100644 index 0000000000000000000000000000000000000000..5607587496f354a9b5a9623a3ea54c539dfd7ade --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_broadcast.py @@ -0,0 +1,64 @@ +import torch + +from colossalai.auto_parallel.tensor_shard.utils import ( + get_broadcast_shape, + is_broadcastable, + recover_sharding_spec_for_broadcast_shape, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.sharding_spec import ShardingSpec + + +def test_is_broadcastable(): + x1 = torch.rand(4, 4, 8) + x2 = torch.rand(1, 8) + assert is_broadcastable(x1.shape, x2.shape) + + x1 = torch.rand(4, 2, 8) + x2 = torch.rand(2, 8) + assert is_broadcastable(x1.shape, x2.shape) + + x1 = torch.rand(4, 2, 8) + x2 = torch.rand(4, 8) + assert not is_broadcastable(x1.shape, x2.shape) + + +def test_get_broadcast_shape(): + x1 = torch.rand(4, 4, 8) + x2 = torch.rand(1, 8) + assert get_broadcast_shape(x1.shape, x2.shape) == [4, 4, 8] + + x1 = torch.rand(4, 2, 8) + x2 = torch.rand(2, 8) + assert get_broadcast_shape(x1.shape, x2.shape) == [4, 2, 8] + + x1 = torch.rand(4, 2, 8) + x2 = torch.rand(8) + assert get_broadcast_shape(x1.shape, x2.shape) == [4, 2, 8] + + +def test_recover_sharding_spec_for_broadcast_shape(): + x1 = torch.rand(4, 1, 8) + x2 = torch.rand(2, 8) + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + + broadcast_shape = get_broadcast_shape(x1.shape, x2.shape) + logical_sharding_spec_for_x1 = ShardingSpec(device_mesh=device_mesh, + dim_partition_dict={ + 0: [0], + 1: [1] + }, + entire_shape=broadcast_shape) + physical_sharding_spec_for_x1, removed_dims = recover_sharding_spec_for_broadcast_shape( + logical_sharding_spec_for_x1, broadcast_shape, x1.shape) + print(physical_sharding_spec_for_x1) + + assert physical_sharding_spec_for_x1.entire_shape == x1.shape + # dim 1 for the physical tensor is of broadcast type MULTIPLE, so should ignore + assert physical_sharding_spec_for_x1.dim_partition_dict == {0: [0]} + assert physical_sharding_spec_for_x1.sharding_sequence == ['S0', 'R', 'R'] diff --git a/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py b/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..398458306e3d34c9dceba3dfc3030d98a2a3e5ad --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py @@ -0,0 +1,68 @@ +from typing import Optional, Tuple + +import pytest +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint +from transformers.pytorch_utils import Conv1D + +try: + from colossalai.auto_parallel.tensor_shard.initialize import initialize_model + NO_CODEGEN = False +except: + NO_CODEGEN = True + +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn + +HIDDEN_SIZE = 16 + + +class GPT2MLPWithCkpt(nn.Module): + + def __init__(self, intermediate_size, hidden_size): + super().__init__() + embed_dim = hidden_size + self.c_fc = Conv1D(intermediate_size, embed_dim) + self.c_proj = Conv1D(embed_dim, intermediate_size) + self.act = torch.nn.ReLU() + + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + hidden_states = self.c_fc(hidden_states) + hidden_states = checkpoint(self.c_proj, hidden_states) + hidden_states = self.act(hidden_states) + + return hidden_states + + +def check_act_ckpt(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = GPT2MLPWithCkpt(intermediate_size=4 * HIDDEN_SIZE, hidden_size=HIDDEN_SIZE) + input = torch.rand(1, 64, HIDDEN_SIZE) + input_sample = { + 'hidden_states': torch.rand(1, 64, HIDDEN_SIZE).to('meta'), + } + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + gm = initialize_model(model, input_sample, device_mesh) + code = gm.module.graph.python_code('self').src + assert "runtime_comm_spec_apply_1 = colossalai_auto_parallel_passes_runtime_apply_pass_runtime_comm_spec_apply(linear_1, comm_actions_dict, 12, 'linear_1')" in code + assert "view_3 = torch.utils.checkpoint.checkpoint(self.checkpoint_0, view_1, comm_actions_dict, use_reentrant=False)" in code + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.skipif(NO_CODEGEN, reason='No codegen found') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_mlp_layer(): + spawn(check_act_ckpt, 4) + + +if __name__ == '__main__': + test_mlp_layer() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py new file mode 100644 index 0000000000000000000000000000000000000000..6908a17818691ecae401f5ada98f94d6c9c84621 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py @@ -0,0 +1,105 @@ +import copy + +import pytest +import torch +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + from colossalai.auto_parallel.tensor_shard.initialize import initialize_model + NO_CODEGEN = False +except: + NO_CODEGEN = True + +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, rerun_if_address_is_in_use, run_on_environment_flag, spawn + + +class MLP(torch.nn.Module): + + def __init__(self, in_features): + super().__init__() + self.linear_1 = torch.nn.Linear(in_features, 4 * in_features, bias=False) + self.linear_2 = torch.nn.Linear(4 * in_features, in_features, bias=False) + + def forward(self, x): + x = self.linear_1(x) + x = self.linear_2(x) + + return x + + +def check_compatibility_with_ddp(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = MLP(4).cuda() + if rank in [0, 1]: + input = torch.arange(0, 16, dtype=torch.float).reshape(4, 4).cuda() + elif rank in [2, 3]: + input = torch.arange(16, 32, dtype=torch.float).reshape(4, 4).cuda() + input_compare = torch.arange(0, 32, dtype=torch.float).reshape(8, 4).cuda() + output_compare = model(input_compare) + loss_compare = output_compare.sum() + loss_compare.backward() + grad_compare = copy.deepcopy(model.linear_1.weight.grad / 2) + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + meta_args = {'x': torch.rand(4, 4).to('meta')} + gm, solution = initialize_model(model, + meta_args=meta_args, + device_mesh=device_mesh, + return_solution=True, + solver_preference='tp', + shard_option='shard_last_axis') + + msg = '| TP strategy combination chosen by auto-parallel solver |' + msg_length = len(msg) + if rank == 0: + print('=' * msg_length) + print(msg) + print('=' * msg_length) + for strategy in solution: + print(strategy) + print('=' * msg_length) + + dp_process_group = None + for (ranks, process_group_handle) in device_mesh.process_groups_dict[0]: + if rank in ranks: + dp_process_group = process_group_handle + assert dp_process_group is not None + gm = DDP(gm, process_group=dp_process_group) + output = gm(input) + + if rank in (0, 1): + assert_close(output, output_compare.narrow(0, 0, 4)) + else: + assert_close(output, output_compare.narrow(0, 4, 4)) + print(f'output on rank{rank} is correct') + loss = output.sum() + + loss.backward() + + if rank in (0, 2): + assert_close(gm.module.module.linear_1.weight.grad, grad_compare.narrow(0, 0, 8)) + + if rank in (1, 3): + assert_close(gm.module.module.linear_1.weight.grad, grad_compare.narrow(0, 8, 8)) + + print(f'gradient on rank{rank} is correct') + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.skipif(NO_CODEGEN, reason='No codegen found') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_compatibility_with_ddp(): + spawn(check_compatibility_with_ddp, 4) + + +if __name__ == '__main__': + test_compatibility_with_ddp() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py new file mode 100644 index 0000000000000000000000000000000000000000..05704acbf7fdb9e801ab47d138026f9885e8c92f --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py @@ -0,0 +1,113 @@ +import copy + +import pytest +import torch + +try: + from colossalai.auto_parallel.tensor_shard.initialize import initialize_model + NO_CODEGEN = False +except: + NO_CODEGEN = True + +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.nn.optimizer import HybridAdam +from colossalai.tensor.process_group import ProcessGroup +from colossalai.testing import assert_close, rerun_if_address_is_in_use, run_on_environment_flag, spawn +from colossalai.utils import get_current_device +from colossalai.zero import post_process_colo_init_ctx, zero_model_wrapper, zero_optim_wrapper + + +class MLP(torch.nn.Module): + + def __init__(self, in_features): + super().__init__() + self.linear_1 = torch.nn.Linear(in_features, 4 * in_features, bias=False) + self.linear_2 = torch.nn.Linear(4 * in_features, in_features, bias=False) + + def forward(self, x): + x = self.linear_1(x) + x = self.linear_2(x) + + return x + + +def check_auto_parallel_with_gemini(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = MLP(4).half().cuda() + if rank in [0, 1]: + input = torch.arange(0, 16).reshape(4, 4).half().cuda() + elif rank in [2, 3]: + input = torch.arange(16, 32).reshape(4, 4).half().cuda() + input_compare = torch.arange(0, 32).reshape(8, 4).half().cuda() + output_compare = model(input_compare) + loss_compare = output_compare.sum() + loss_compare.backward() + grad_compare = copy.deepcopy(model.linear_1.weight.grad / 2) + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + meta_args = {'x': torch.rand(4, 4).half().to('meta')} + gm, solution = initialize_model(model, + meta_args=meta_args, + device_mesh=device_mesh, + return_solution=True, + solver_preference='tp', + shard_option='shard_last_axis') + + if rank == 0: + msg = '| TP strategy combination chosen by auto-parallel solver |' + msg_length = len(msg) + print('=' * msg_length) + print(msg) + print('=' * msg_length) + for strategy in solution: + print(strategy) + print('=' * msg_length) + + dp_process_group = ProcessGroup(rank=rank, ranks=[0, 1, 2, 3], tp_degree=2, dp_degree=2) + gemini_config = dict(strict_ddp_mode=False, + device=get_current_device(), + placement_policy='cpu', + pin_memory=True, + search_range_mb=128) + + post_process_colo_init_ctx(gm, device=get_current_device(), default_pg=dp_process_group) + gm = zero_model_wrapper(gm, zero_stage=3, gemini_config=gemini_config) + optimizer = HybridAdam(gm.parameters(), betas=(0, 0)) + optimizer = zero_optim_wrapper(gm, optimizer, initial_scale=1) + output = gm(input) + if rank in (0, 1): + assert_close(output, output_compare.narrow(0, 0, 4)) + else: + assert_close(output, output_compare.narrow(0, 4, 4)) + print(f'output on rank{rank} is correct') + loss = output.sum() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + + if rank in (0, 2): + assert_close(list(optimizer.optim.state.values())[0]['exp_avg'].half(), grad_compare.narrow(0, 0, 8).flatten()) + + if rank in (1, 3): + assert_close(list(optimizer.optim.state.values())[0]['exp_avg'].half(), grad_compare.narrow(0, 8, 8).flatten()) + + print(f'gradient on rank{rank} is correct') + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.skipif(NO_CODEGEN, reason='No codegen found') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_auto_parallel_with_gemini(): + spawn(check_auto_parallel_with_gemini, 4) + + +if __name__ == '__main__': + test_auto_parallel_with_gemini() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_find_repeat_block.py b/tests/test_auto_parallel/test_tensor_shard/test_find_repeat_block.py new file mode 100644 index 0000000000000000000000000000000000000000..a0b407b240e1fc4432002ed50a1e0132348e3ded --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_find_repeat_block.py @@ -0,0 +1,114 @@ +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from torch.fx import GraphModule +from transformers.pytorch_utils import Conv1D + +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes import shape_prop_pass +# from colossalai.fx.tracer.tracer import ColoTracer +from colossalai._analyzer.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks +from colossalai.testing import clear_cache_before_run, parameterize, run_on_environment_flag + +NUM_REPEAT_BLOCKS = 4 +BATCH_SIZE = 1 +SEQ_LENGTH = 32 +HIDDEN_DIM = 384 + + +class RepeatBlock(nn.Module): + + def __init__(self, intermediate_size, hidden_size): + super().__init__() + self.c_fc = Conv1D(intermediate_size, hidden_size) + self.c_proj = Conv1D(hidden_size, intermediate_size) + self.act = torch.nn.ReLU() + + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + + return hidden_states + + +class RepeatModel(nn.Module): + + def __init__(self, intermediate_size, hidden_size, num_layers): + super().__init__() + self.blocks = nn.ModuleList([RepeatBlock(intermediate_size, hidden_size) for i in range(num_layers)]) + + def forward(self, x): + + for block in self.blocks: + x = block(x) + + return x + + +class NonRepeatBlock(nn.Module): + + def __init__(self, intermediate_size, hidden_size, layer_index): + super().__init__() + intermediate_size //= (layer_index + 1) + self.c_fc = Conv1D(intermediate_size, hidden_size) + self.c_proj = Conv1D(hidden_size, intermediate_size) + self.act = torch.nn.ReLU() + + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + + return hidden_states + + +class NonRepeatModel(nn.Module): + + def __init__(self, intermediate_size, hidden_size, num_layers): + super().__init__() + self.blocks = nn.ModuleList([NonRepeatBlock(intermediate_size, hidden_size, i) for i in range(num_layers)]) + + def forward(self, x): + + for block in self.blocks: + x = block(x) + + return x + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@clear_cache_before_run() +@parameterize('model_cls', [RepeatModel, NonRepeatModel]) +def test_repeat_blocks(model_cls): + + model = model_cls(4 * HIDDEN_DIM, HIDDEN_DIM, NUM_REPEAT_BLOCKS) + + tracer = ColoTracer(bias_addition_split=True) + input_sample = {'x': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta')} + graph = tracer.trace(root=model, meta_args=input_sample) + + gm = GraphModule(model, graph, model.__class__.__name__) + shape_prop_pass(gm, *input_sample.values()) + gm.recompile() + + node_list = list(graph.nodes) + root_module = graph.owning_module + common_blocks = find_repeat_blocks(node_list, root_module, common_length_threshold=10) + + total_num_nodes = len(list(graph.nodes)) + # remove the input placeholder node and the output node + num_repeat_nodes_per_block = (total_num_nodes - 2) // NUM_REPEAT_BLOCKS + for common_block in common_blocks: + print(common_block) + if model_cls == RepeatModel: + assert len(common_blocks) == NUM_REPEAT_BLOCKS + assert len(common_blocks[0]) == num_repeat_nodes_per_block + elif model_cls == NonRepeatModel: + assert len(common_blocks) == 0 + + +if __name__ == '__main__': + test_repeat_blocks() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/__init__.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/gpt_modules.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/gpt_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..22a2371311f9125cbc6854f2dbb1141d906c1ebb --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/gpt_modules.py @@ -0,0 +1,279 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +from transformers.activations import ACT2FN +from transformers.models.gpt2.modeling_gpt2 import BaseModelOutputWithPastAndCrossAttentions, GPT2PreTrainedModel +from transformers.pytorch_utils import Conv1D + + +class GPT2MLP(nn.Module): + + def __init__(self, intermediate_size, config): + super().__init__() + embed_dim = config.hidden_size + self.c_fc = Conv1D(intermediate_size, embed_dim) + self.c_proj = Conv1D(embed_dim, intermediate_size) + self.act = ACT2FN[config.activation_function] + # We temporarily banned the Dropout layer because the rng state need + # to process to get the correct result. + # self.dropout = nn.Dropout(config.resid_pdrop) + + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + # TODO: the rng state need to be fixed for distributed runtime + # hidden_states = self.dropout(hidden_states) + return hidden_states + + +# The reason Why we don't import GPT2Attention from transformers directly is that: +# 1. The tracer will not work correctly when we feed meta_args and concrete_args at same time, +# so we have to build the customized GPT2Attention class and remove the conditional branch manually. +# 2. The order of split and view op has been changed in the customized GPT2Attention class, the new +# order is same as megatron-lm gpt model. +class GPT2Attention(nn.Module): + + def __init__(self, config, layer_idx=None): + super().__init__() + + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), + dtype=torch.uint8)).view(1, 1, max_positions, max_positions), + ) + self.register_buffer("masked_bias", torch.tensor(-1e4)) + + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.split_size = self.embed_dim + self.scale_attn_weights = config.scale_attn_weights + + # Layer-wise attention scaling, reordering, and upcasting + self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx + self.layer_idx = layer_idx + + self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) + self.c_proj = Conv1D(self.embed_dim, self.embed_dim) + + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + + self.pruned_heads = set() + + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + + if self.scale_attn_weights: + attn_weights = attn_weights / (value.size(-1)**0.5) + + # Layer-wise attention scaling + if self.scale_attn_by_inverse_layer_idx: + attn_weights = attn_weights / float(self.layer_idx + 1) + + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length].to(torch.bool) + attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise + attn_weights = attn_weights.type(value.dtype) + # attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def _split_heads(self, tensor, num_heads, attn_head_size): + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(new_shape) + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + + def _merge_heads(self, tensor, num_heads, attn_head_size): + tensor = tensor.permute(0, 2, 1, 3).contiguous() + new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) + return tensor.view(new_shape) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + + # query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + qkv = self.c_attn(hidden_states) + + # query = self._split_heads(query, self.num_heads, self.head_dim) + # key = self._split_heads(key, self.num_heads, self.head_dim) + # value = self._split_heads(value, self.num_heads, self.head_dim) + query, key, value = self._split_heads(qkv, self.num_heads, 3 * self.head_dim).split(self.head_dim, dim=3) + present = (key, value) + + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = self.c_proj(attn_output) + # attn_output = self.resid_dropout(attn_output) + return attn_output + + +class GPT2Block(nn.Module): + + def __init__(self, config, layer_idx=None): + super().__init__() + hidden_size = config.hidden_size + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = GPT2Attention(config, layer_idx=layer_idx) + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.mlp = GPT2MLP(inner_dim, config) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: + residual = hidden_states + # %transformer_h_0_ln_1 + hidden_states = self.ln_1(hidden_states) + attn_outputs = self.attn( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + ) + # residual connection + hidden_states = attn_outputs + residual + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + + return hidden_states + + +class GPT2Model(GPT2PreTrainedModel): + _keys_to_ignore_on_load_missing = ["attn.masked_bias"] + + def __init__(self, config): + super().__init__(config) + + self.embed_dim = config.hidden_size + + self.wte = nn.Embedding(config.vocab_size, self.embed_dim) + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + + self.drop = nn.Dropout(config.embd_pdrop) + self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + + device = input_ids.device + + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + # GPT2Attention mask. + attention_mask = attention_mask.view(batch_size, -1) + attention_mask = attention_mask[:, None, None, :] + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * -10000.0 + + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + + # add_2 + hidden_states = inputs_embeds + position_embeds + + # comment to run pipeline + # add_3 + output_shape = input_shape + (hidden_states.size(-1),) + + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + outputs = block(hidden_states, attention_mask=attention_mask, head_mask=head_mask[i]) + hidden_states = outputs + + hidden_states = self.ln_f(hidden_states) + # comment to run pipeline + hidden_states = hidden_states.view(output_shape) + + return hidden_states + + +class GPT2LMHeadModel(GPT2PreTrainedModel): + _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.transformer = GPT2Model(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + ): + transformer_outputs = self.transformer( + input_ids=input_ids, + attention_mask=attention_mask, + ) + + lm_logits = self.lm_head(transformer_outputs) + + return lm_logits + + +class GPTLMLoss(nn.Module): + + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, logits, labels): + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..48d2672c65714797c53e5e94442e6161bf5550e9 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py @@ -0,0 +1,194 @@ +import copy +import random +from typing import Dict + +import numpy as np +import pytest +import torch +import transformers +from torch.fx import GraphModule + +from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass +# from colossalai.fx.tracer.tracer import ColoTracer +from colossalai._analyzer.fx.tracer.tracer import ColoTracer + +try: + from colossalai.auto_parallel.tensor_shard.initialize import ( + ModuleWrapper, + build_strategy_constructor, + solve_solution, + transform_to_sharded_model, + ) + NO_CODEGEN = False +except: + NO_CODEGEN = True + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ShardingSpec +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.tensor.shape_consistency import to_global +from colossalai.testing import assert_close, assert_close_loose, parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model + +BATCH_SIZE = 1 +SEQ_LENGTH = 32 +HIDDEN_DIM = 768 + +seed = 128 +torch.manual_seed(seed) +torch.cuda.manual_seed_all(seed) +np.random.seed(seed) +random.seed(seed) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False + + +def _check_module_grad(module: torch.nn.Module, origin_param_dict: Dict[str, torch.Tensor], + best_sharding_spec_dict: Dict[str, ShardingSpec]): + for name, param in module.named_parameters(): + param_grad = param.grad + name = name.replace('module.', '') + origin_param_grad = origin_param_dict[name].grad + atoms = name.split('.') + new_name = '_'.join(atoms) + if new_name in best_sharding_spec_dict: + param_sharding_spec = best_sharding_spec_dict[new_name] + grad_to_compare = copy.deepcopy(param_grad) + param_grad_global = to_global(grad_to_compare, param_sharding_spec) + try: + assert_close_loose(param_grad_global, origin_param_grad, rtol=1e-03, atol=1e-05) + except: + difference = param_grad_global - origin_param_grad + avg_diff = difference.abs().sum() / difference.numel() + assert avg_diff < 0.001 + print(f'{name} param has {avg_diff} average difference') + + +def check_attention_layer(rank, model_cls, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=16, n_embd=HIDDEN_DIM) + + if model_cls == GPT2MLP: + model = model_cls(intermediate_size=4 * config.hidden_size, config=config).to('cuda') + else: + model = model_cls(config=config).to('cuda') + test_model = copy.deepcopy(model) + + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + hidden_states = torch.rand((BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM), dtype=torch.float32) + + if model_cls == GPT2MLP: + input_sample = (hidden_states.to('cuda'),) + test_input_sample = copy.deepcopy(input_sample) + meta_input_sample = { + 'hidden_states': hidden_states.to('meta'), + } + elif model_cls in (GPT2Attention, GPT2Block): + input_sample = ( + hidden_states.to('cuda'), + attention_mask.to('cuda'), + ) + test_input_sample = copy.deepcopy(input_sample) + meta_input_sample = { + 'hidden_states': hidden_states.to('meta'), + 'attention_mask': attention_mask.to('meta'), + } + else: + input_sample = ( + input_ids.to('cuda'), + attention_mask.to('cuda'), + ) + test_input_sample = copy.deepcopy(input_sample) + meta_input_sample = { + 'input_ids': input_ids.to('meta'), + 'attention_mask': attention_mask.to('meta'), + } + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + tracer = ColoTracer(bias_addition_split=True) + + graph = tracer.trace(root=model, meta_args=meta_input_sample) + gm = GraphModule(model, graph, model.__class__.__name__) + shape_prop_pass(gm, *meta_input_sample.values()) + gm.recompile() + + strategies_constructor = build_strategy_constructor(graph, device_mesh, 'standard', 'replicated', 'standard') + solution = solve_solution(gm, strategies_constructor, memory_budget=-1) + gm, sharding_spec_dicts = transform_to_sharded_model(gm, meta_input_sample, solution, device_mesh, + strategies_constructor) + gm = ModuleWrapper(gm, *sharding_spec_dicts) + + nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] + best_sharding_spec_dict = {} + for index, node in enumerate(nodes): + best_sharding_spec_dict[node.name] = node.sharding_spec + + cuda_rng_state = torch.cuda.get_rng_state() + cpu_rng_state = torch.get_rng_state() + origin_output = test_model(*test_input_sample) + torch.cuda.set_rng_state(cuda_rng_state) + torch.set_rng_state(cpu_rng_state) + output = gm(*input_sample) + assert_close(output, origin_output, rtol=1e-03, atol=1e-03) + + #*******************backward starting******************* + cuda_rng_state = torch.cuda.get_rng_state() + cpu_rng_state = torch.get_rng_state() + output.sum().backward() + torch.set_rng_state(cpu_rng_state) + torch.cuda.set_rng_state(cuda_rng_state) + origin_output.sum().backward() + origin_param_dict = dict(test_model.named_parameters()) + + if rank == 0: + print("*******************backward starting*******************") + + _check_module_grad(gm, origin_param_dict, best_sharding_spec_dict) + + if rank == 0: + print("*******************backward finished*******************") + + #*******************backward finished******************* + + #*******************strategy selected******************* + if rank == 0: + print("*******************strategy selected*******************") + nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] + computation_cost = 0 + communication_cost = 0 + memory_cost = 0 + for index, node in enumerate(nodes): + print(node.name, node.strategies_vector[solution[index]].name) + computation_cost += node.strategies_vector[solution[index]].compute_cost.total + communication_cost += node.strategies_vector[solution[index]].communication_cost.total + node_memory_cost = node.strategies_vector[solution[index]].memory_cost.total + if isinstance(node_memory_cost, tuple): + node_memory_cost = node_memory_cost[0] + memory_cost += node_memory_cost.activation + node_memory_cost.parameter + + print(f'computation cost is {computation_cost}') + print(f'communication cost is {communication_cost}') + print(f'memory cost is {memory_cost}') + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.skipif(NO_CODEGEN, reason="no codegen module") +@pytest.mark.dist +@parameterize('model_cls', [GPT2MLP, GPT2Block, GPT2Attention, GPT2Model]) +@rerun_if_address_is_in_use() +def test_mlp_layer(model_cls): + spawn(check_attention_layer, 4, model_cls=model_cls) + + +if __name__ == '__main__': + test_mlp_layer() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py new file mode 100644 index 0000000000000000000000000000000000000000..5a8c3c4bf5a08c0d2c07acce6ff70f17b76212e9 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py @@ -0,0 +1,88 @@ +import torch +import transformers +from torch.fx import GraphModule + +from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass +from colossalai._analyzer.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP +from colossalai.auto_parallel.tensor_shard.options import SolverOptions +from colossalai.auto_parallel.tensor_shard.solver import CostGraph, Solver, StrategiesConstructor +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.testing import clear_cache_before_run, parameterize +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model + +BATCH_SIZE = 1 +SEQ_LENGTH = 32 +HIDDEN_DIM = 384 + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@clear_cache_before_run() +@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model]) +def test_self_attention_block(model_cls): + config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=16, n_embd=HIDDEN_DIM) + if model_cls == GPT2MLP: + model = model_cls(intermediate_size=4 * config.hidden_size, config=config) + else: + model = model_cls(config=config) + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + shape_consistency_manager = ShapeConsistencyManager() + + tracer = ColoTracer(bias_addition_split=True) + if model_cls == GPT2MLP: + input_sample = { + 'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'), + } + elif model_cls in (GPT2Attention, GPT2Block): + input_sample = { + 'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'), + 'attention_mask': torch.rand(1, SEQ_LENGTH).to('meta'), + } + else: + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) + input_sample = {k: v.to('meta') for k, v in kwargs.items()} + + graph = tracer.trace(root=model, meta_args=input_sample) + + gm = GraphModule(model, graph, model.__class__.__name__) + shape_prop_pass(gm, *input_sample.values()) + print(gm.graph) + gm.recompile() + solver_options = SolverOptions() + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + strategies_constructor.build_strategies_and_cost() + + cost_graph = CostGraph(strategies_constructor.leaf_strategies) + cost_graph.simplify_graph() + solver = Solver(gm.graph, strategies_constructor, cost_graph, memory_budget=-1) + ret = solver.call_solver_serialized_args() + strategies_list = solver.last_s_val + nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] + + computation_cost = 0 + communication_cost = 0 + memory_cost = 0 + for index, node in enumerate(nodes): + print(node.name, node.strategies_vector[strategies_list[index]].name) + computation_cost += node.strategies_vector[strategies_list[index]].compute_cost.total + communication_cost += node.strategies_vector[strategies_list[index]].communication_cost.total + node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost.total + if isinstance(node_memory_cost, tuple): + node_memory_cost = node_memory_cost[0] + memory_cost += node_memory_cost.activation + node_memory_cost.parameter + + print(f'computation cost is {computation_cost}') + print(f'communication cost is {communication_cost}') + print(f'memory cost is {memory_cost}') + + +if __name__ == '__main__': + test_self_attention_block() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py b/tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..d10b222c060d7d96e98c05d4b3107db8405c4785 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py @@ -0,0 +1,60 @@ +import pytest +import torch +import torch.nn as nn + +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes import shape_prop_pass +from colossalai._analyzer.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.tensor_shard.solver import GraphAnalyser +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.testing import clear_cache_before_run + + +class LinearModel(nn.Module): + + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(4, 4) + self.relu = nn.ReLU(inplace=True) + self.linear2 = nn.Linear(4, 4) + + def forward(self, x1, x2): + x1 = x1 * 2 + x1 = self.linear1(x1) + x1 = self.relu(x1) + x1 = self.linear2(x1) + out = x1 + x2 + return out + + +@pytest.mark.skip('meta tensor has some bugs in 1.11') +@clear_cache_before_run() +def test_liveness_analysis(): + model = LinearModel() + tracer = ColoTracer(bias_addition_split=True) + meta_args = {'x1': torch.rand(4, 4, device='meta'), 'x2': torch.rand(4, 4, device='meta')} + graph = tracer.trace(model, meta_args=meta_args) + gm = ColoGraphModule(root=model, graph=graph, class_name=model.__class__.__name__) + shape_prop_pass(gm, *meta_args.values()) + + graph_analyser = GraphAnalyser(gm) + liveness_list = graph_analyser.liveness_analysis() + stage_count = len(liveness_list) + + # if a LiveStage is covered by another LiveStage, we just keep the larger one. + assert stage_count == 1 + + # a variable named `relu` must exist + # and this live var must have inplace = True + assert liveness_list[0].all_live_vars.exists('relu') + relu_var = liveness_list[0].all_live_vars.get('relu') + assert relu_var.is_inplace + + # the unique vars must be fewer than the all vars since in-place ops exist + all_live_vars = liveness_list[0].all_live_vars + unique_live_vars = liveness_list[0].unique_live_vars + assert len(unique_live_vars) + 1 == len(all_live_vars) + + +if __name__ == '__main__': + test_liveness_analysis() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py new file mode 100644 index 0000000000000000000000000000000000000000..e0a2133e654eafd3bb3719155a101e0a937f38b8 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py @@ -0,0 +1,62 @@ +import pytest +import torch + +from colossalai.auto_parallel.meta_profiler import meta_register +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType +from colossalai.testing.utils import clear_cache_before_run, parameterize +from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results + + +@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") +@clear_cache_before_run() +@parameterize('func', [ + torch.nn.functional.softmax, + torch.nn.functional.relu, + torch.tanh, + torch.nn.functional.dropout, +]) +def test_activation_meta_info(func): + meta_func = meta_register.get(func) + # construct meta tensors + input_tensor = torch.rand(256, 1024, device="meta") + output_tensor = torch.rand(256, 1024, device="meta") + softmax_dim = 0 + + # construct operation data + input_data = OperationData(name='input', type=OperationDataType.ARG, data=input_tensor) + output_data = OperationData(name='output', type=OperationDataType.OUTPUT, data=output_tensor) + softmax_dim_data = OperationData(name='softmax_dim', type=OperationDataType.ARG, data=softmax_dim) + + # construct args and kwargs + args = [input_data, softmax_dim_data, output_data] + kwargs = {'inplace': False} + + # estimated results + compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out = meta_func(*args, **kwargs) + + # actual results + input_real_tensor = torch.rand(256, 1024, device="cuda") + + input_real_tensor.requires_grad = True + + # fwd + torch.cuda.reset_peak_memory_stats() + mem_stamp0 = torch.cuda.memory_allocated() + output_real_tensor = func(input_real_tensor) + fwd_allocated = torch.cuda.memory_allocated() - mem_stamp0 + fwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0 + + # bwd + upstream_grad = torch.rand_like(output_real_tensor) + torch.cuda.reset_peak_memory_stats() + mem_stamp0 = torch.cuda.memory_allocated() + torch.autograd.backward(output_real_tensor, upstream_grad) + bwd_allocated = torch.cuda.memory_allocated() - mem_stamp0 + bwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0 + + print_results([input_real_tensor], [output_real_tensor], compute_cost, memory_cost, fwd_allocated, fwd_peak, + bwd_allocated, bwd_peak) + + +if __name__ == '__main__': + test_activation_meta_info() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py new file mode 100644 index 0000000000000000000000000000000000000000..68ccc7835bc354ca9979964b144e6a42f0edf560 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py @@ -0,0 +1,65 @@ +import pytest +import torch +import torch.nn as nn + +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing.utils import rerun_if_address_is_in_use, spawn +from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy + + +class BinaryElementwiseOpModule(nn.Module): + + def __init__(self, token=torch.add, shape=64) -> None: + super().__init__() + self.token = token + self.param = nn.Parameter(torch.rand(shape)) + + def forward(self, input): + return input + self.param + + +def _binary_elementwise_mem_test(rank, world_size, port): + """This function is for binary elementwise ops memory test + Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL + + Args: + rank: device rank + bias: indicate whether conv module need bias + world_size: number of devices + port: port for initializing process group + """ + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = BinaryElementwiseOpModule(token=torch.add, shape=1024).cuda() + input = torch.rand(32, 1024).cuda() + input.requires_grad = True + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # index of target node in computation graph + node_index = 2 + # total number of target node strategies + strategy_number = 9 + mem_test_for_node_strategy(rank=rank, + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=['input']) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_binary_elementwise_meta_concrete_info_match(): + spawn(_binary_elementwise_mem_test, 4) + + +if __name__ == '__main__': + test_binary_elementwise_meta_concrete_info_match() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py new file mode 100644 index 0000000000000000000000000000000000000000..c6f7b88f44a50042cd96ce0548ca47906bf167fd --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py @@ -0,0 +1,104 @@ +import pytest +import torch +import torch.nn as nn + +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing.utils import rerun_if_address_is_in_use, spawn +from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy + + +class ConvFunctionModule(nn.Module): + + def __init__(self, in_channels=4, out_channels=64, kernel_size=3): + super().__init__() + self.conv_weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size)) + + def forward(self, input): + return nn.functional.conv2d(input, self.conv_weight) + + +def _conv_module_mem_test(rank, world_size, port, bias): + """This function is for conv memory test + Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL + + Args: + Args: + rank: device rank + bias: indicate whether conv module need bias + world_size: number of devices + port: port for initializing process group + """ + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = nn.Sequential(nn.Conv2d(4, 64, 3, padding=1, bias=bias)).cuda() + input = torch.rand(4, 4, 64, 64).cuda() + input.requires_grad = True + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # index of target node in computation graph + node_index = 1 + # total number of target node strategies + strategy_number = 16 + mem_test_for_node_strategy(rank=rank, + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=['input']) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_conv_meta_concrete_info_match(bias=False): + spawn(_conv_module_mem_test, 4, bias=bias) + + +def _conv_function_mem_test(rank, world_size, port): + """This function is for conv function memory test + Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL + + Args: + rank: device rank + bias: indicate whether conv module need bias + world_size: number of devices + port: port for initializing process group + """ + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = ConvFunctionModule().cuda() + input = torch.rand(4, 4, 64, 64).cuda() + input.requires_grad = True + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # index of target node in computation graph + node_index = 2 + # total number of target node strategies + strategy_number = 16 + mem_test_for_node_strategy(rank=rank, + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=['input']) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_conv_function_concrete_info_match(): + spawn(_conv_function_mem_test, 4) + + +if __name__ == '__main__': + # test_conv_meta_concrete_info_match() + test_conv_function_concrete_info_match() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py new file mode 100644 index 0000000000000000000000000000000000000000..e3f76a95c4a5f4d7832b5916e4172276d2c43a41 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py @@ -0,0 +1,60 @@ +import pytest +import torch + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType +from colossalai.testing.utils import clear_cache_before_run +from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results + +if torch.__version__ >= '1.12.0': + from colossalai.auto_parallel.meta_profiler import meta_register + + +@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") +@clear_cache_before_run() +def test_embedding_meta_info(): + meta_func = meta_register.get(torch.nn.Embedding) + + # construct meta tensors + input_tensor = torch.randint(0, 50256, (8, 1024), device="meta") + weight_tensor = torch.rand(50257, 1024, device="meta") + output_tensor = torch.rand(8, 1024, 1024, device="meta") + + # construct operation data + input_data = OperationData(name="input", type=OperationDataType.ARG, data=input_tensor) + + weight_data = OperationData(name="weight", type=OperationDataType.PARAM, data=weight_tensor) + + output_data = OperationData(name="output", type=OperationDataType.OUTPUT, data=output_tensor) + + # construct args and kwargs + args = [input_data, weight_data, output_data] + kwargs = {'inplace': False} + + # estimated results + compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out = meta_func(*args, **kwargs) + + # actual results + input_real_tensor = torch.randint(0, 50256, (8, 1024), device="cuda") + embedding_module = torch.nn.Embedding(50257, 1024).cuda() + + # fwd + torch.cuda.reset_peak_memory_stats() + mem_stamp0 = torch.cuda.memory_allocated() + output_real_tensor = embedding_module(input_real_tensor) + fwd_allocated = torch.cuda.memory_allocated() - mem_stamp0 + fwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0 + + # bwd + upstream_grad = torch.rand_like(output_real_tensor) + torch.cuda.reset_peak_memory_stats() + mem_stamp0 = torch.cuda.memory_allocated() + torch.autograd.backward(output_real_tensor, upstream_grad) + bwd_allocated = torch.cuda.memory_allocated() - mem_stamp0 + bwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0 + + print_results([input_real_tensor], [output_real_tensor], compute_cost, memory_cost, fwd_allocated, fwd_peak, + bwd_allocated, bwd_peak) + + +if __name__ == '__main__': + test_embedding_meta_info() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py new file mode 100644 index 0000000000000000000000000000000000000000..fb3ded339ddf17d6fa2c9f0927850bf16631035d --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py @@ -0,0 +1,97 @@ +import pytest +import torch +import torch.nn as nn + +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing.utils import rerun_if_address_is_in_use, spawn +from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy + + +class MyModule(nn.Module): + + def __init__(self, in_features=64, out_features=128): + super().__init__() + self.fc_weight = nn.Parameter(torch.randn(out_features, in_features)) + + def forward(self, input): + return nn.functional.linear(input, self.fc_weight) + + +def _linear_module_mem_test(rank, world_size, port): + """This function is for linear memory test + Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL + + Args: + rank: device rank + bias: indicate whether linear module need bias + world_size: number of devices + port: port for initializing process group + """ + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = nn.Sequential(nn.Linear(64, 128, bias=False)).cuda() + input = torch.rand(8, 8, 16, 64).cuda() + input.requires_grad = True + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # memory test + mem_test_for_node_strategy(rank=rank, + model=model, + device_mesh=device_mesh, + node_index=1, + strategy_number=13, + input_args=[input], + meta_arg_names=["input"]) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_linear_module_meta_concrete_info_match(): + spawn(_linear_module_mem_test, 4) + + +def _linear_function_mem_test(rank, world_size, port): + """This function is for linear memory test + Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL + + Args: + rank: device rank + bias: indicate whether linear module need bias + world_size: number of devices + port: port for initializing process group + """ + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = MyModule().cuda() + input = torch.rand(8, 8, 16, 64).cuda() + input.requires_grad = True + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # memory test + mem_test_for_node_strategy(rank=rank, + model=model, + device_mesh=device_mesh, + node_index=2, + strategy_number=24, + input_args=[input], + meta_arg_names=["input"]) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_linear_function_meta_concrete_info_match(): + spawn(_linear_function_mem_test, 4) + + +if __name__ == '__main__': + # test_linear_module_meta_concrete_info_match() + test_linear_function_meta_concrete_info_match() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_matmul_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_matmul_metainfo.py new file mode 100644 index 0000000000000000000000000000000000000000..2d2d77f0c637092664174ab8600f55bdfc4c1a46 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_matmul_metainfo.py @@ -0,0 +1,93 @@ +import pytest +import torch + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, TrainCycleItem +from colossalai.testing.utils import clear_cache_before_run, parameterize +from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results + +if torch.__version__ >= '1.12.0': + from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register + + +@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") +@clear_cache_before_run() +@parameterize( + 'tensor_shapes', + [ + [[128], [128]], # dot product + [[64, 128], [128]], # mat-vec + [[128], [128, 64]], # vec-mat + [[64, 64, 128], [128]], # batched mat-vec + [[128], [64, 128, 64]], # vec-batched mat + [[64, 128], [128, 192]], # mat-mat + [[64, 64, 128], [128, 192]], # batched mat-mat + [[64, 128], [64, 128, 192]], # mat-batched mat + [[64, 64, 128], [64, 128, 192]], # batched mat-batched mat (matched batch dims) + [[64, 1, 64, 128], [64, 128, 192]], # batched mat-batched mat (unmatched batch dims) + ]) +def test_matmul_function_meta_info(tensor_shapes): + meta_func = meta_register.get(torch.matmul) + + # construct meta tensors + input_tensor = torch.rand(*tensor_shapes[0], device="meta") + other_tensor = torch.rand(*tensor_shapes[1], device="meta") + output_tensor = torch.matmul(input_tensor, other_tensor) + + # construct operation data + input_data = OperationData( + name="input", + data=input_tensor, + type=OperationDataType.ARG, + logical_shape=input_tensor.shape, + ) + other_data = OperationData( + name="other", + data=other_tensor, + type=OperationDataType.ARG, + logical_shape=other_tensor.shape, + ) + output_data = OperationData( + name="output", + data=output_tensor, + type=OperationDataType.OUTPUT, + logical_shape=output_tensor.shape, + ) + + # construct args and kwargs + args = [input_data, other_data, output_data] + kwargs = {'inplace': False} + + # estimated results + compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out = meta_func(*args, **kwargs) + + # actual results + input_real_tensor = torch.rand(*tensor_shapes[0], device="cuda:0") + other_real_tensor = torch.rand(*tensor_shapes[1], device="cuda:0") + + input_real_tensor.requires_grad = True + other_real_tensor.requires_grad = True + + # fwd + torch.cuda.reset_peak_memory_stats() + mem_stamp0 = torch.cuda.memory_allocated() + output_real_tensor = torch.matmul(input_real_tensor, other_real_tensor) + fwd_allocated = torch.cuda.memory_allocated() - mem_stamp0 + fwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0 + + # bwd + upstream_grad = torch.rand_like(output_real_tensor) + torch.cuda.reset_peak_memory_stats() + mem_stamp0 = torch.cuda.memory_allocated() + torch.autograd.backward(output_real_tensor, upstream_grad) + bwd_allocated = torch.cuda.memory_allocated() - mem_stamp0 + bwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0 + + compute_cost: TrainCycleItem + memory_cost: TrainCycleItem + + print_results([input_real_tensor, other_real_tensor], [output_real_tensor], compute_cost, memory_cost, + fwd_allocated, fwd_peak, bwd_allocated, bwd_peak) + + +if __name__ == '__main__': + test_matmul_function_meta_info() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py new file mode 100644 index 0000000000000000000000000000000000000000..808172977b6046dd55a6e1ee8c8a251774cb5a1f --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py @@ -0,0 +1,117 @@ +import pytest +import torch +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, TrainCycleItem +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use, spawn +from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy, print_results + +if torch.__version__ >= '1.12.0': + from colossalai.auto_parallel.meta_profiler import meta_register + + +def _batchnorm_module_mem_test(rank, world_size, port): + """This function is for batchnorm memory test + Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL + + Args: + rank: device rank + bias: indicate whether conv module need bias + world_size: number of devices + port: port for initializing process group + """ + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = nn.Sequential(nn.BatchNorm2d(128)).cuda() + input = torch.rand(4, 128, 64, 64).cuda() + input.requires_grad = True + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # index of target node in computation graph + node_index = 1 + # total number of target node strategies + strategy_number = 9 + mem_test_for_node_strategy(rank=rank, + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=['input']) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_batchnorm_meta_concrete_info_match(): + spawn(_batchnorm_module_mem_test, 4) + + +@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='need pytorch 1.12.0 or higher for aten level operations') +@parameterize('tensor_shape', [ + [256, 1024], + [1024, 256], +]) +def test_layernorm_meta_info(tensor_shape): + meta_func = meta_register.get(torch.nn.LayerNorm) + + # construct input + input_tensor = torch.rand(*tensor_shape, device="meta") + output_tensor = torch.rand(*tensor_shape, device="meta") + weight_tensor = torch.rand(tensor_shape[1], device="meta") + bias_tensor = torch.rand(tensor_shape[1], device="meta") + + # construct operation data + input_data = OperationData(name="input", type=OperationDataType.ARG, data=input_tensor) + + output_data = OperationData(name="output", type=OperationDataType.OUTPUT, data=output_tensor) + + weight_data = OperationData(name="weight", type=OperationDataType.PARAM, data=weight_tensor) + + bias_data = OperationData(name="bias", type=OperationDataType.PARAM, data=bias_tensor) + + # construct args and kwargs + args = [input_data, output_data, weight_data, bias_data] + kwargs = {'inplace': False} + + # estimated results + compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out = meta_func(*args, **kwargs) + + # actual results + input_real_tensor = torch.rand(*tensor_shape, device="cuda:0") + + input_real_tensor.requires_grad = True + + ln_module = torch.nn.LayerNorm(tensor_shape[1]).cuda() + + # fwd + torch.cuda.reset_peak_memory_stats() + mem_stamp0 = torch.cuda.memory_allocated() + output_real_tensor = ln_module(input_real_tensor) + fwd_allocated = torch.cuda.memory_allocated() - mem_stamp0 + fwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0 + + # bwd + upstream_grad = torch.rand_like(output_real_tensor) + torch.cuda.reset_peak_memory_stats() + mem_stamp0 = torch.cuda.memory_allocated() + torch.autograd.backward(output_real_tensor, upstream_grad) + bwd_allocated = torch.cuda.memory_allocated() - mem_stamp0 + bwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0 + + compute_cost: TrainCycleItem + memory_cost: TrainCycleItem + + print_results([input_real_tensor], [output_real_tensor], compute_cost, memory_cost, fwd_allocated, fwd_peak, + bwd_allocated, bwd_peak) + + +if __name__ == '__main__': + test_batchnorm_meta_concrete_info_match() + test_layernorm_meta_info() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py new file mode 100644 index 0000000000000000000000000000000000000000..4cddf4e19fcabd1174854beebd89be6bafe64b74 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py @@ -0,0 +1,93 @@ +import pytest +import torch +import torch.nn as nn + +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing.utils import rerun_if_address_is_in_use, spawn +from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy + + +def _adaptiveavgpool_module_mem_test(rank, world_size, port): + """This function is for AdaptiveAvgPool memory test + Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL + + Args: + rank: device rank + bias: indicate whether conv module need bias + world_size: number of devices + port: port for initializing process group + """ + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = nn.Sequential(nn.AdaptiveAvgPool2d((16, 16))).cuda() + input = torch.rand(4, 128, 64, 64).cuda() + input.requires_grad = True + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # index of target node in computation graph + node_index = 1 + # total number of target strategies + strategy_number = 1 + mem_test_for_node_strategy(rank=rank, + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=['input']) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_adaptiveavgpool_meta_concrete_info_match(): + spawn(_adaptiveavgpool_module_mem_test, 4) + + +def _maxpool_module_mem_test(rank, world_size, port): + """This function is for MaxPool memory test + Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL + + Args: + rank: device rank + bias: indicate whether conv module need bias + world_size: number of devices + port: port for initializing process group + """ + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = nn.Sequential(nn.MaxPool2d((16, 16))).cuda() + input = torch.rand(4, 128, 64, 64).cuda() + input.requires_grad = True + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # index of target node in computation graph + node_index = 1 + # total number of target node strategies + strategy_number = 9 + mem_test_for_node_strategy(rank=rank, + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=['input']) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_maxpool_meta_concrete_info_match(): + spawn(_maxpool_module_mem_test, 4) + + +if __name__ == '__main__': + test_adaptiveavgpool_meta_concrete_info_match() + test_maxpool_meta_concrete_info_match() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_tensor_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_tensor_metainfo.py new file mode 100644 index 0000000000000000000000000000000000000000..6e8145885d67dc9596c71015ae8c0e134fe29ad6 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_tensor_metainfo.py @@ -0,0 +1,87 @@ +import pytest +import torch +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType +from colossalai.testing.utils import clear_cache_before_run +from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results + +if torch.__version__ >= '1.12.0': + from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register + + +class SplitModule(nn.Module): + + def __init__(self) -> None: + super().__init__() + + def forward(self, x): + return x.split(512, dim=0) + + +@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") +@clear_cache_before_run() +def test_tensor_meta_info(): + """test tensor related meta information + We will just use torch.Tensor.split for the test + """ + meta_func = meta_register.get(torch.Tensor.split) + + # construct meta tensors + input_tensor = torch.rand(1024, 1024, device="meta") + output_tensor = input_tensor.split(512, dim=0) + + # construct operation data + input_data = OperationData( + name="input", + data=input_tensor, + type=OperationDataType.ARG, + logical_shape=input_tensor.shape, + ) + output_data = OperationData( + name="output", + data=output_tensor, + type=OperationDataType.OUTPUT, + logical_shape=input_tensor.shape, + ) + split_info_data = OperationData( + name='split_info', + type=OperationDataType.ARG, + data=0, + logical_shape=None, + ) + + # construct args + args = [input_data, output_data, split_info_data] + kwargs = {'inplace': False} + + # estimated results + compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out = meta_func(*args, **kwargs) + + # actual results + model = SplitModule() + input_real_tensor = torch.rand(1024, 1024).cuda() + + input_real_tensor.requires_grad = True + + # fwd + torch.cuda.reset_peak_memory_stats() + mem_stamp0 = torch.cuda.memory_allocated() + output_real_tensor = model(input_real_tensor) + fwd_allocated = torch.cuda.memory_allocated() - mem_stamp0 + fwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0 + + # bwd + upstream_grad = [torch.rand_like(tensor) for tensor in output_real_tensor] + torch.cuda.reset_peak_memory_stats() + mem_stamp0 = torch.cuda.memory_allocated() + torch.autograd.backward(output_real_tensor, upstream_grad) + bwd_allocated = torch.cuda.memory_allocated() - mem_stamp0 + bwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0 + + print_results([input_real_tensor], output_real_tensor, compute_cost, memory_cost, fwd_allocated, fwd_peak, + bwd_allocated, bwd_peak) + + +if __name__ == "__main__": + test_tensor_meta_info() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_where_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_where_metainfo.py new file mode 100644 index 0000000000000000000000000000000000000000..b4564312eeb4698e045eab5263d7f68147484b56 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_where_metainfo.py @@ -0,0 +1,89 @@ +import pytest +import torch + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, TrainCycleItem +from colossalai.testing.utils import clear_cache_before_run +from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results + +if torch.__version__ >= '1.12.0': + from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register + + +@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") +@clear_cache_before_run() +def test_where_meta_info(): + meta_func = meta_register.get(torch.where) + + # construct meta tensors + condition_tensor = torch.rand(1, 1, 1024, 1024) > 0.5 + condition_tensor = condition_tensor.to(device="meta") + x_tensor = torch.rand(8, 16, 1024, 1024, device="meta") + y_tensor = torch.tensor(0, device="meta") + output_tensor = torch.rand(8, 16, 1024, 1024) + + # construct operation data + condition_data = OperationData( + name="condition", + data=condition_tensor, + type=OperationDataType.ARG, + logical_shape=condition_tensor.shape, + ) + x_data = OperationData( + name="x", + data=x_tensor, + type=OperationDataType.ARG, + logical_shape=x_tensor.shape, + ) + y_data = OperationData( + name="y", + data=y_tensor, + type=OperationDataType.ARG, + logical_shape=y_tensor.shape, + ) + output_data = OperationData( + name="output", + data=output_tensor, + type=OperationDataType.OUTPUT, + logical_shape=output_tensor.shape, + ) + + # construct args and kwargs + args = [condition_data, x_data, y_data, output_data] + kwargs = {'inplace': False} + + # estimated results + compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out = meta_func(*args, **kwargs) + + # actual results + condition_real_tensor = torch.rand(1, 1, 1024, 1024) > 0.5 + condition_real_tensor = condition_real_tensor.to(device="cuda") + x_real_tensor = torch.rand(8, 16, 1024, 1024, device="cuda") + y_real_tensor = torch.tensor(0.0, device="cuda") + + x_real_tensor.requires_grad = True + y_real_tensor.requires_grad = True + + # fwd + torch.cuda.reset_peak_memory_stats() + mem_stamp0 = torch.cuda.memory_allocated() + output_real_tensor = torch.where(condition_real_tensor, x_real_tensor, y_real_tensor) + fwd_allocated = torch.cuda.memory_allocated() - mem_stamp0 + fwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0 + + # bwd + upstream_grad = torch.rand_like(output_real_tensor) + torch.cuda.reset_peak_memory_stats() + mem_stamp0 = torch.cuda.memory_allocated() + torch.autograd.backward(output_real_tensor, upstream_grad) + bwd_allocated = torch.cuda.memory_allocated() - mem_stamp0 + bwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0 + + compute_cost: TrainCycleItem + memory_cost: TrainCycleItem + + print_results([condition_real_tensor, x_real_tensor, y_real_tensor], [output_real_tensor], compute_cost, + memory_cost, fwd_allocated, fwd_peak, bwd_allocated, bwd_peak) + + +if __name__ == '__main__': + test_where_meta_info() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4ca85d34da3097a54d1c568c09224e7087bc45cc --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py @@ -0,0 +1,187 @@ +import copy +from pprint import pprint +from typing import Dict, List + +import torch +from torch.fx import GraphModule + +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes import shape_prop_pass +# from colossalai.fx.tracer.tracer import ColoTracer +from colossalai._analyzer.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass +from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass +from colossalai.auto_parallel.tensor_shard.options import SolverOptions +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDataType, TrainCycleItem +from colossalai.auto_parallel.tensor_shard.solver import StrategiesConstructor +from colossalai.device.device_mesh import DeviceMesh + +if torch.__version__ >= '1.12.0': + from colossalai.auto_parallel.meta_profiler import ShardMetaInfo + + +def mem_test_for_node_strategy(rank: int, + model: torch.nn.Module, + device_mesh: DeviceMesh, + node_index: int, + strategy_number: int, + input_args: List[torch.Tensor], + meta_arg_names: List[str], + input_kwargs: Dict[str, torch.Tensor] = {}): + for strategy_index in range(strategy_number): + # We need to copy the model to avoid do backward more than once in same graph + model_to_shard, args_to_shard, kwargs_to_shard = copy.deepcopy(model), copy.deepcopy(input_args), copy.deepcopy( + input_kwargs) + + tracer = ColoTracer(bias_addition_split=True) + input_sample = {} + for input_arg, meta_arg_name in zip(input_args, meta_arg_names): + input_sample[meta_arg_name] = torch.rand(input_arg.shape).to('meta') + for meta_kwarg_name, input_kwarg in input_kwargs.items(): + input_sample[meta_kwarg_name] = torch.rand(input_kwarg.shape).to('meta') + graph = tracer.trace(root=model_to_shard, meta_args=input_sample) + gm = ColoGraphModule(model_to_shard, graph, model_to_shard.__class__.__name__) + shape_prop_pass(gm, *input_sample.values()) + gm.recompile() + solver_options = SolverOptions() + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + strategies_constructor.build_strategies_and_cost() + target_node = list(graph.nodes)[node_index] + + # solution construction + # construct the strategy for the target node + solution_len = len(strategies_constructor.leaf_strategies) + solution = [0] * solution_len + solution[node_index] = strategy_index + + # construct the strategy for the output node + placeholder_strategy = list(graph.nodes)[-1].strategies_vector[0] + + output_key = next(key for key in target_node.strategies_vector[strategy_index].sharding_specs.keys() + if key.type == OperationDataType.OUTPUT) + placeholder_strategy.sharding_specs[output_key] = target_node.strategies_vector[strategy_index].sharding_specs[ + output_key] + + gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass( + gm, solution, device_mesh, strategies_constructor) + gm = runtime_apply_pass(gm) + gm.recompile() + gm: GraphModule + + num_of_strategies = len(target_node.strategies_vector) + if rank == 0: + print("=======================") + print(f"#strategy_index: {strategy_index + 1}/{num_of_strategies}") + pprint(target_node.strategies_vector[strategy_index]) + + # warmup + with torch.no_grad(): + output = gm(*args_to_shard, + sharding_spec_convert_dict=sharding_spec_dict, + origin_node_sharding_spec_dict=origin_spec_dict, + comm_actions_dict=comm_actions_dict, + **kwargs_to_shard) + + del output + # forward memory compare + if rank == 0: + torch.cuda.reset_peak_memory_stats() + mem_stamp0 = torch.cuda.memory_allocated() + output = gm(*args_to_shard, + sharding_spec_convert_dict=sharding_spec_dict, + origin_node_sharding_spec_dict=origin_spec_dict, + comm_actions_dict=comm_actions_dict, + **kwargs_to_shard) + + if rank == 0: + # print forward memory allocated and peak memory stats in kb + print( + f"forward memory allocated: {(torch.cuda.memory_allocated() - mem_stamp0) / 1024} kb, peak memory stats: {(torch.cuda.max_memory_allocated() - mem_stamp0) / 1024} kb" + ) + + # backward memory compare + grad_tensors = torch.ones_like(output) + torch.cuda.reset_peak_memory_stats() + mem_stamp0 = torch.cuda.memory_allocated() + torch.autograd.backward(output, grad_tensors) + + if rank == 0: + # print backward memory allocated and peak memory stats in kb + print( + f"backward memory allocated: {(torch.cuda.memory_allocated() - mem_stamp0) / 1024} kb, peak memory stats: {(torch.cuda.max_memory_allocated() - mem_stamp0) / 1024} kb" + ) + + # estimated memory + if target_node.op == "call_module": + metainfo = ShardMetaInfo(target_node.strategies_vector[strategy_index], + target_node.graph.owning_module.get_submodule(target_node.target)) + else: + metainfo = ShardMetaInfo(target_node.strategies_vector[strategy_index], target_node.target) + + print("estimated memory:") + print( + f"forward activation: {metainfo.memory_cost.fwd.activation / 1024} kb, forward param: {metainfo.memory_cost.fwd.parameter / 1024} kb" + ) + print( + f"forward temp: {metainfo.memory_cost.fwd.temp / 1024} kb, forward buffer: {metainfo.memory_cost.fwd.buffer / 1024} kb" + ) + print( + f"backward activation: {metainfo.memory_cost.bwd.activation / 1024} kb, backward param: {metainfo.memory_cost.bwd.parameter / 1024} kb" + ) + print( + f"backward temp: {metainfo.memory_cost.bwd.temp / 1024} kb, backward buffer: {metainfo.memory_cost.bwd.buffer / 1024} kb" + ) + print("=======================") + + +def print_results(input: List[torch.Tensor], output: List[torch.Tensor], compute_cost: TrainCycleItem, + memory_cost: TrainCycleItem, fwd_allocated, fwd_peak, bwd_allocated, bwd_peak): + """Print the results of the meta information test. + + Args: + input (List[torch.Tensor]): input tensors + output (List[torch.Tensor]): output tensors + compute_cost (TrainCycleItem): compute cost estimated by meta_func + memory_cost (TrainCycleItem): memory cost estimated by meta_func + fwd_allocated: real forward memory allocated + fwd_peak: real forward peak memory stats + bwd_allocated: real backward memory allocated + bwd_peak: real backward peak memory stats + """ + print("=====================") + print(f"input shapes: {[tensor.shape for tensor in input]}") + print(f"output shapes: {[tensor.shape for tensor in output]}") + + # estimated results + print("Estimated Results") + + # compute cost + print("compute_cost:") + print(f" fwd: {compute_cost.fwd}") + print(f" bwd: {compute_cost.bwd}") + + # memory cost + print("memory_cost:") + # fwd + print(f" fwd activation: {memory_cost.fwd.activation / 1024} KB") + print(f" fwd buffer: {memory_cost.fwd.buffer / 1024} KB") + print(f" fwd temp: {memory_cost.fwd.temp / 1024} KB") + print(f" fwd parameter: {memory_cost.fwd.parameter / 1024} KB") + + # bwd + print(f" bwd activation: {memory_cost.bwd.activation / 1024} KB") + print(f" bwd buffer: {memory_cost.bwd.buffer / 1024} KB") + print(f" bwd temp: {memory_cost.bwd.temp / 1024} KB") + print(f" bwd parameter: {memory_cost.bwd.parameter / 1024} KB") + + # actual results + print("Actual Results") + + print("memory_cost:") + # fwd + print(f" fwd allocated: {fwd_allocated / 1024} KB") + print(f" fwd peak: {fwd_peak / 1024} KB") + + # bwd + print(f" bwd allocated: {bwd_allocated / 1024} KB") + print(f" bwd peak: {bwd_peak / 1024} KB") diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/__init__.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..80e6a6c1460c756c7212ad5dac7fcf24a78ab489 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py @@ -0,0 +1,275 @@ +import pytest +import torch +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler import BMMFunctionHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + + +class AddBMMTensorMethodModule(nn.Module): + + def __init__(self, using_kwargs): + super().__init__() + self.using_kwargs = using_kwargs + + def forward(self, bias, x1, x2): + if self.using_kwargs: + output = bias.addbmm(x1, x2, alpha=2, beta=3) + else: + output = bias.addbmm(x1, x2) + return output + + +class AddBMMTorchFunctionModule(nn.Module): + + def __init__(self, using_kwargs): + super().__init__() + self.using_kwargs = using_kwargs + + def forward(self, bias, x1, x2): + if self.using_kwargs: + output = torch.addbmm(bias, x1, x2, alpha=2, beta=3) + else: + output = torch.addbmm(bias, x1, x2) + return output + + +def check_2d_device_mesh(rank, world_size, port, module, bias_shape, using_kwargs): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = module(using_kwargs).cuda() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + x1 = torch.rand(4, 8, 16).cuda() + x2 = torch.rand(4, 16, 8).cuda() + bias = torch.rand(bias_shape).cuda() + # the index of addbmm node in computation graph + node_index = 3 + # strategy number of addbmm node on 2d device mesh + strategy_number = 7 + # construct input args + input_args = [bias, x1, x2] + # construct meta arg names + meta_arg_names = ['bias', 'x1', 'x2'] + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names) + tracer = ColoTracer() + # graph(): + # %bias : torch.Tensor [#users=1] = placeholder[target=bias] + # %x1 : torch.Tensor [#users=1] = placeholder[target=x1] + # %x2 : torch.Tensor [#users=1] = placeholder[target=x2] + # %bmm : [#users=1] = call_function[target=torch.bmm](args = (%x1, %x2), kwargs = {}) + # %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%bmm, 0), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%sum_1, %bias), kwargs = {}) + # return add + graph = tracer.trace(model, + meta_args={ + 'bias': torch.rand(*bias_shape).to('meta'), + "x1": torch.rand(4, 8, 16).to('meta'), + 'x2': torch.rand(4, 16, 8).to('meta') + }) + gm = ColoGraphModule(model, graph) + + bmm_mod_node = list(graph.nodes)[3] + strategies_vector = StrategiesVector(bmm_mod_node) + + # build handler + handler = BMMFunctionHandler(node=bmm_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['input'].name == "x1" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 8, 16]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 8, 16]) + + assert mapping['other'].name == "x2" + assert mapping['other'].data.is_meta + assert mapping['other'].data.shape == torch.Size([4, 16, 8]) + assert mapping['other'].type == OperationDataType.ARG + assert mapping['other'].logical_shape == torch.Size([4, 16, 8]) + + assert mapping['output'].name == "bmm" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([4, 8, 8]) + assert mapping['output'].type == OperationDataType.OUTPUT + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + for name in strategy_name_list: + print(name) + # one batch dim + assert 'Sb0 = Sb0 x Sb0' not in strategy_name_list + + # two batch dim + assert 'Sb01 = Sb01 x Sb01' in strategy_name_list + + # SbSi = SbSi x Sb + assert 'Sb0Si1 = Sb0Si1 x Sb0' in strategy_name_list + assert 'Sb1Si0 = Sb1Si0 x Sb1' in strategy_name_list + + # SbSj = SbR x SbSj + assert 'Sb0Sj1 = Sb0R x Sb0Sj1' in strategy_name_list + assert 'Sb1Sj0 = Sb1R x Sb1Sj0' in strategy_name_list + + # SbR = SbSk x SbSk + assert 'Sb0R = Sb0Sk1 x Sb0Sk1' in strategy_name_list + assert 'Sb1R = Sb1Sk0 x Sb1Sk0' in strategy_name_list + + for strategy in strategies_vector: + input_sharding_spec = strategy.get_sharding_spec_by_name('x1') + other_sharding_spec = strategy.get_sharding_spec_by_name('x2') + output_sharding_spec = strategy.get_sharding_spec_by_name('bmm') + + # make sure the sharding matches across different operation data + assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0] + assert other_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1] + assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] + + +def check_1d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (1, 4) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + model = module(using_kwargs).cuda() + x1 = torch.rand(4, 8, 16).cuda() + x2 = torch.rand(4, 16, 8).cuda() + bias = torch.rand(bias_shape).cuda() + # the index of addbmm node in computation graph + node_index = 3 + # strategy number of addbmm node on 2d device mesh + strategy_number = 1 + # construct input args + input_args = [bias, x1, x2] + # construct meta arg names + meta_arg_names = ['bias', 'x1', 'x2'] + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names) + + tracer = ColoTracer() + # graph(): + # %bias : torch.Tensor [#users=1] = placeholder[target=bias] + # %x1 : torch.Tensor [#users=1] = placeholder[target=x1] + # %x2 : torch.Tensor [#users=1] = placeholder[target=x2] + # %bmm : [#users=1] = call_function[target=torch.bmm](args = (%x1, %x2), kwargs = {}) + # %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%bmm, 0), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%sum_1, %bias), kwargs = {}) + # return add + graph = tracer.trace(model, + meta_args={ + 'bias': torch.rand(*bias_shape).to('meta'), + "x1": torch.rand(4, 8, 16).to('meta'), + 'x2': torch.rand(4, 16, 8).to('meta') + }) + gm = ColoGraphModule(model, graph) + bmm_mod_node = list(graph.nodes)[3] + strategies_vector = StrategiesVector(bmm_mod_node) + + # build handler + handler = BMMFunctionHandler(node=bmm_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['input'].name == "x1" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 8, 16]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 8, 16]) + + assert mapping['other'].name == "x2" + assert mapping['other'].data.is_meta + assert mapping['other'].data.shape == torch.Size([4, 16, 8]) + assert mapping['other'].type == OperationDataType.ARG + assert mapping['other'].logical_shape == torch.Size([4, 16, 8]) + + assert mapping['output'].name == "bmm" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([4, 8, 8]) + assert mapping['output'].type == OperationDataType.OUTPUT + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + assert len(strategy_name_list) == 1 + # one batch dim + assert 'Sb0 = Sb0 x Sb0' in strategy_name_list + + for strategy in strategies_vector: + input_sharding_spec = strategy.get_sharding_spec_by_name('x1') + other_sharding_spec = strategy.get_sharding_spec_by_name('x2') + output_sharding_spec = strategy.get_sharding_spec_by_name('bmm') + + # make sure the sharding matches across different operation data + assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0] + assert other_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1] + assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] + + +@pytest.mark.skip("skip due to bias cases not ready") +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@parameterize('module', [AddBMMTorchFunctionModule, AddBMMTensorMethodModule]) +@parameterize('bias_shape', [[8], [1, 8], [8, 8]]) +@parameterize('using_kwargs', [True, False]) +@rerun_if_address_is_in_use() +def test_2d_device_mesh(module, bias_shape, using_kwargs): + spawn( + check_2d_device_mesh, + 4, + module=module, + bias_shape=bias_shape, + using_kwargs=using_kwargs, + ) + + +@pytest.mark.skip("skip due to bias cases not ready") +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@parameterize('module', [AddBMMTorchFunctionModule, AddBMMTensorMethodModule]) +@parameterize('bias_shape', [[8], [1, 8], [8, 8]]) +@parameterize('using_kwargs', [True, False]) +@rerun_if_address_is_in_use() +def test_1d_device_mesh(module, bias_shape, using_kwargs): + spawn( + check_1d_device_mesh, + 4, + module=module, + bias_shape=bias_shape, + using_kwargs=using_kwargs, + ) + + +if __name__ == '__main__': + test_1d_device_mesh() + test_2d_device_mesh() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..fe6554cd81eed38976fab9a73b68b8b8bd16b026 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py @@ -0,0 +1,191 @@ +import pytest +import torch +import torch.nn as nn + +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass +from colossalai._analyzer.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + OperationDataType, + ShardingStrategy, + StrategiesVector, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + + +class AddmmModel(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, input, m1, m2): + x = torch.addmm(input, m1, m2, beta=3, alpha=2) + return x + + +class AddmmModel_with_param(nn.Module): + + def __init__(self, weight_shape, bias_shape): + super().__init__() + self.weight = torch.nn.Parameter(torch.rand(weight_shape)) + self.bias = torch.nn.Parameter(torch.rand(bias_shape)) + + def forward(self, m1): + x = torch.addmm(self.bias, m1, self.weight, beta=3, alpha=2) + return x + + +def check_addmm_function_handler(rank, world_size, port, input_shape, model_cls): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + if model_cls == AddmmModel: + model = AddmmModel().cuda() + else: + model = AddmmModel_with_param(weight_shape=(8, 16), bias_shape=input_shape).cuda() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + if model_cls == AddmmModel: + input = torch.rand(input_shape).cuda() + m1 = torch.rand(4, 8).cuda() + m2 = torch.rand(8, 16).cuda() + # construct input args + input_args = [input, m1, m2] + # construct meta arg names + meta_arg_names = ['input', 'm1', 'm2'] + meta_args_for_tracer = {} + for meta_arg, input_arg in zip(meta_arg_names, input_args): + meta_args_for_tracer[meta_arg] = input_arg.to('meta') + + # the index of addmm node in computation graph + node_index = 4 + # strategy number of linear node + strategy_number = 14 + else: + m1 = torch.rand(4, 8).cuda() + # construct input args + input_args = [m1] + # construct meta arg names + meta_arg_names = ['m1'] + # the index of addmm node in computation graph + meta_args_for_tracer = {} + for meta_arg, input_arg in zip(meta_arg_names, input_args): + meta_args_for_tracer[meta_arg] = input_arg.to('meta') + node_index = 4 + # strategy number of linear node + strategy_number = 14 + + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + node_type='bias_module') + + tracer = ColoTracer(bias_addition_split=True) + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %m1 : torch.Tensor [#users=1] = placeholder[target=m1] + # %m2 : torch.Tensor [#users=1] = placeholder[target=m2] + # %transpose : [#users=1] = call_function[target=torch.transpose](args = (%m2, 0, 1), kwargs = {}) + # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%m1, %transpose), kwargs = {}) + # %mul : [#users=1] = call_function[target=operator.mul](args = (%input_1, 3), kwargs = {}) + # %mul_1 : [#users=1] = call_function[target=operator.mul](args = (2, %linear), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%mul_1, %mul), kwargs = {}) + # return add + graph = tracer.trace(model, meta_args=meta_args_for_tracer) + gm = ColoGraphModule(model, graph) + shape_prop_pass(gm, *meta_args_for_tracer.values()) + # [input_1, m1, m2, addmm, output] + node_list = list(graph.nodes) + linear_node = node_list[4] + strategies_vector = StrategiesVector(linear_node) + + # build handler + handler = LinearFunctionHandler(node=linear_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + assert mapping['input'].name == "m1" + assert mapping['input'].data.shape == torch.Size([4, 8]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 8]) + + assert mapping['other'].name == "transpose" + assert mapping['other'].data.shape == torch.Size([16, 8]) + if model_cls == AddmmModel: + assert mapping['other'].type == OperationDataType.ARG + else: + assert mapping['other'].type == OperationDataType.PARAM + assert mapping['other'].logical_shape == torch.Size([8, 16]) + + assert mapping['output'].name == "linear" + assert mapping['output'].data.shape == torch.Size([4, 16]) + assert mapping['output'].type == OperationDataType.OUTPUT + + # SS = SR x RS + assert 'S0S1 = S0R x RS1_0' in strategy_name_list + assert 'S1S0 = S1R x RS0_0' in strategy_name_list + + # SR = SS x SR + assert 'S0R = S0S1 x S1R_0' in strategy_name_list + assert 'S1R = S1S0 x S0R_0' in strategy_name_list + + # RS = RS x SS + assert 'RS0 = RS1 x S1S0' in strategy_name_list + assert 'RS1 = RS0 x S0S1' in strategy_name_list + + # RR = RS x SR + assert 'RR = RS0 x S0R' in strategy_name_list + assert 'RR = RS1 x S1R' in strategy_name_list + + # RS= RR x RS + assert 'RS0 = RR x RS0' in strategy_name_list + assert 'RS1 = RR x RS1' in strategy_name_list + + # S01R = S01R x RR + assert 'S01R = S01R x RR_0' in strategy_name_list + + # RR = RS01 x S01R + assert 'RR = RS01 x S01R' in strategy_name_list + + # RS01 = RR x RS01 + assert 'RS01 = RR x RS01' in strategy_name_list + + # RR = RR x RR + assert 'RR = RR x RR' in strategy_name_list + + for strategy in strategies_vector: + strategy: ShardingStrategy + input_sharding_spec = strategy.get_sharding_spec_by_name('m1') + weight_sharding_spec = strategy.get_sharding_spec_by_name('transpose') + output_sharding_spec = strategy.get_sharding_spec_by_name('linear') + + # make sure the sharding matches across different operation data + assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] + assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[1] + assert weight_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[1] + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@parameterize('input_shape', [(16,), (4, 16)]) +@parameterize('model_cls', [AddmmModel, AddmmModel_with_param]) +@rerun_if_address_is_in_use() +def test_addmm_handler(input_shape, model_cls): + spawn(check_addmm_function_handler, 4, input_shape=input_shape, model_cls=model_cls) + + +if __name__ == '__main__': + test_addmm_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..b47b3508ad1b9047e492ba8b10ba87c1d96b311a --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py @@ -0,0 +1,116 @@ +import pytest +import torch +import torch.nn as nn + +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass +from colossalai._analyzer.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.tensor_shard.node_handler.batch_norm_handler import BatchNormModuleHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + + +def check_bn_module_handler(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = nn.Sequential(nn.BatchNorm2d(16)).cuda() + + physical_mesh_id = torch.arange(0, 4) + + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + input = torch.rand(4, 16, 64, 64).cuda() + # the index of bn node in computation graph + node_index = 1 + # the total number of bn strategies without sync bn mode + # TODO: add sync bn stategies after related passes ready + strategy_number = 4 + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=['input']) + tracer = ColoTracer(bias_addition_split=True) + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {}) + # return _0 + meta_args = {"input": torch.rand(4, 16, 64, 64).to('meta')} + graph = tracer.trace(model, meta_args=meta_args) + gm = ColoGraphModule(model, graph) + shape_prop_pass(gm, *meta_args.values()) + bn_mod_node = list(graph.nodes)[1] + strategies_vector = StrategiesVector(bn_mod_node) + + # build handler + handler = BatchNormModuleHandler(node=bn_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['input'].name == "input_1" + assert mapping['input'].data.shape == torch.Size([4, 16, 64, 64]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 16, 64, 64]) + + assert mapping['other'].name == "weight" + assert mapping['other'].data.shape == torch.Size([16]) + assert mapping['other'].type == OperationDataType.PARAM + assert mapping['other'].logical_shape == torch.Size([16]) + + assert mapping['bias'].name == "bias" + assert mapping['bias'].data.shape == torch.Size([16]) + assert mapping['bias'].type == OperationDataType.PARAM + assert mapping['bias'].logical_shape == torch.Size([16]) + + assert mapping['output'].name == "_0" + assert mapping['output'].data.shape == torch.Size([4, 16, 64, 64]) + assert mapping['output'].type == OperationDataType.OUTPUT + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # RS = RS x S + assert 'RS0 = RS0 x S0' in strategy_name_list + assert 'RS1 = RS1 x S1' in strategy_name_list + + # RR = RR x R + assert 'RR = RR x R' in strategy_name_list + + # RS01 = RS01 x S01 + assert 'RS01 = RS01 x S01' in strategy_name_list + + # temporarily skip the sync bn test + # TODO: test sync bn after the implicit runtime pass completed + # SR = SR x R WITH SYNC_BN + # assert 'S0R = S0R x R WITH SYNC_BN' in strategy_name_list + # assert 'S1R = S1R x R WITH SYNC_BN' in strategy_name_list + + # SS = SS x S WITH SYNC_BN + # assert 'S0S1 = S0S1 x S1 WITH SYNC_BN' in strategy_name_list + # assert 'S1S0 = S1S0 x S0 WITH SYNC_BN' in strategy_name_list + + # S01R = S01R x R WITH SYNC_BN + # assert 'S01R = S01R x R WITH SYNC_BN' in strategy_name_list + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bn_module_handler(): + spawn(check_bn_module_handler, 4) + + +if __name__ == '__main__': + test_bn_module_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py new file mode 100644 index 0000000000000000000000000000000000000000..800bc11a50e443eeb59598ce1ef0317b96589899 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py @@ -0,0 +1,169 @@ +import pytest +import torch +import torch.nn.functional as F + +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass +from colossalai._analyzer.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + OperationData, + OperationDataType, + ShardingStrategy, + StrategiesVector, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + +WEIGHT_SHAPE = (32, 16) + + +class LinearModule(torch.nn.Module): + + def __init__(self, weight_shape): + super().__init__() + self.weight = torch.nn.Parameter(torch.rand(*weight_shape)) + self.bias = torch.nn.Parameter(torch.rand(weight_shape[0])) + + def forward(self, x): + x = F.linear(x, self.weight, bias=self.bias) + return x + + +def check_linear_module_handler(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = LinearModule(weight_shape=WEIGHT_SHAPE).cuda() + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + input = torch.rand(4, 4, 4, 16).cuda() + # the index of linear node in computation graph + node_index = 3 + # strategy number of linear node + strategy_number = 24 + # construct input args + input_args = [input] + # construct meta arg names + meta_arg_names = ['x'] + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + node_type='bias_module') + + tracer = ColoTracer(bias_addition_split=True) + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %weight : [#users=1] = get_attr[target=weight] + # %bias : [#users=1] = get_attr[target=bias] + # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%x, %weight), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%linear, %bias), kwargs = {}) + # return add + meta_args = {"x": torch.rand(4, 4, 4, 16).to('meta')} + graph = tracer.trace(model, meta_args=meta_args) + gm = ColoGraphModule(model, graph) + shape_prop_pass(gm, *meta_args.values()) + + linear_mod_node = list(graph.nodes)[3] + strategies_vector = StrategiesVector(linear_mod_node) + + # build handler + handler = LinearFunctionHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['input'].name == "x" + assert mapping['input'].data.shape == torch.Size([4, 4, 4, 16]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([64, 16]) + + assert mapping['other'].name == "weight" + assert mapping['other'].data.shape == torch.Size([32, 16]) + assert mapping['other'].type == OperationDataType.PARAM + assert mapping['other'].logical_shape == torch.Size([16, 32]) + + assert 'bias' not in mapping + + assert mapping['output'].name == "linear" + assert mapping['output'].data.shape == torch.Size([4, 4, 4, 32]) + assert mapping['output'].type == OperationDataType.OUTPUT + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # SS = SR x RS + assert 'S0S1 = S0R x RS1_0' in strategy_name_list + assert 'S0S1 = S0R x RS1_1' in strategy_name_list + assert 'S0S1 = S0R x RS1_2' in strategy_name_list + assert 'S1S0 = S1R x RS0_0' in strategy_name_list + assert 'S1S0 = S1R x RS0_1' in strategy_name_list + assert 'S1S0 = S1R x RS0_2' in strategy_name_list + + # SR = SS x SR + assert 'S0R = S0S1 x S1R_0' in strategy_name_list + assert 'S0R = S0S1 x S1R_1' in strategy_name_list + assert 'S0R = S0S1 x S1R_2' in strategy_name_list + assert 'S1R = S1S0 x S0R_0' in strategy_name_list + assert 'S1R = S1S0 x S0R_1' in strategy_name_list + assert 'S1R = S1S0 x S0R_2' in strategy_name_list + + # RS = RS x SS + assert 'RS0 = RS1 x S1S0' in strategy_name_list + assert 'RS1 = RS0 x S0S1' in strategy_name_list + + # RR = RS x SR + assert 'RR = RS0 x S0R' in strategy_name_list + assert 'RR = RS1 x S1R' in strategy_name_list + + # RS= RR x RS + assert 'RS0 = RR x RS0' in strategy_name_list + assert 'RS1 = RR x RS1' in strategy_name_list + + # S01R = S01R x RR + assert 'S01R = S01R x RR_0' in strategy_name_list + assert 'S01R = S01R x RR_1' in strategy_name_list + assert 'S01R = S01R x RR_2' in strategy_name_list + + # RR = RS01 x S01R + assert 'RR = RS01 x S01R' in strategy_name_list + + # RS01 = RR x RS01 + assert 'RS01 = RR x RS01' in strategy_name_list + + # RR = RR x RR + assert 'RR = RR x RR' in strategy_name_list + + for strategy in strategies_vector: + strategy: ShardingStrategy + input_sharding_spec = strategy.get_sharding_spec_by_name('x') + weight_sharding_spec = strategy.get_sharding_spec_by_name('weight') + output_sharding_spec = strategy.get_sharding_spec_by_name('linear') + + # make sure the sharding matches across different operation data + assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] + assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1] + assert weight_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[-1] + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_linear_handler(): + spawn(check_linear_module_handler) + + +if __name__ == '__main__': + test_linear_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py new file mode 100644 index 0000000000000000000000000000000000000000..c29a065d10baf4faee039bd1965650abdbfc6062 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py @@ -0,0 +1,158 @@ +import pytest +import torch + +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass +from colossalai._analyzer.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + OperationData, + OperationDataType, + ShardingStrategy, + StrategiesVector, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + + +class LinearModule(torch.nn.Module): + + def __init__(self, in_features, out_features, bias): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features, bias=bias) + + def forward(self, x): + x = self.linear(x) + return x + + +def check_linear_module_handler(rank, world_size, port, bias): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = LinearModule(16, 32, bias=bias).cuda() + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + input = torch.rand(4, 4, 4, 16).cuda() + # the index of linear node in computation graph + node_index = 3 + # strategy number of linear node + strategy_number = 24 + # construct input args + input_args = [input] + # construct meta arg names + meta_arg_names = ['x'] + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + node_type='bias_module') + + tracer = ColoTracer(bias_addition_split=True) + meta_args = {"x": torch.rand(4, 4, 4, 16).to('meta')} + graph = tracer.trace(model, meta_args=meta_args) + gm = ColoGraphModule(model, graph) + shape_prop_pass(gm, *meta_args.values()) + + linear_mod_node = list(graph.nodes)[3] + strategies_vector = StrategiesVector(linear_mod_node) + + # build handler + handler = LinearFunctionHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['input'].name == "x" + assert mapping['input'].data.shape == torch.Size([4, 4, 4, 16]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([64, 16]) + + assert mapping['other'].name == "linear_weight" + assert mapping['other'].data.shape == torch.Size([32, 16]) + assert mapping['other'].type == OperationDataType.PARAM + assert mapping['other'].logical_shape == torch.Size([16, 32]) + + assert 'bias' not in mapping + + assert mapping['output'].name == "linear" + assert mapping['output'].data.shape == torch.Size([4, 4, 4, 32]) + assert mapping['output'].type == OperationDataType.OUTPUT + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # SS = SR x RS + assert 'S0S1 = S0R x RS1_0' in strategy_name_list + assert 'S0S1 = S0R x RS1_1' in strategy_name_list + assert 'S0S1 = S0R x RS1_2' in strategy_name_list + assert 'S1S0 = S1R x RS0_0' in strategy_name_list + assert 'S1S0 = S1R x RS0_1' in strategy_name_list + assert 'S1S0 = S1R x RS0_2' in strategy_name_list + + # SR = SS x SR + assert 'S0R = S0S1 x S1R_0' in strategy_name_list + assert 'S0R = S0S1 x S1R_1' in strategy_name_list + assert 'S0R = S0S1 x S1R_2' in strategy_name_list + assert 'S1R = S1S0 x S0R_0' in strategy_name_list + assert 'S1R = S1S0 x S0R_1' in strategy_name_list + assert 'S1R = S1S0 x S0R_2' in strategy_name_list + + # RS = RS x SS + assert 'RS0 = RS1 x S1S0' in strategy_name_list + assert 'RS1 = RS0 x S0S1' in strategy_name_list + + # RR = RS x SR + assert 'RR = RS0 x S0R' in strategy_name_list + assert 'RR = RS1 x S1R' in strategy_name_list + + # RS= RR x RS + assert 'RS0 = RR x RS0' in strategy_name_list + assert 'RS1 = RR x RS1' in strategy_name_list + + # S01R = S01R x RR + assert 'S01R = S01R x RR_0' in strategy_name_list + assert 'S01R = S01R x RR_1' in strategy_name_list + assert 'S01R = S01R x RR_2' in strategy_name_list + + # RR = RS01 x S01R + assert 'RR = RS01 x S01R' in strategy_name_list + + # RS01 = RR x RS01 + assert 'RS01 = RR x RS01' in strategy_name_list + + # RR = RR x RR + assert 'RR = RR x RR' in strategy_name_list + + for strategy in strategies_vector: + strategy: ShardingStrategy + input_sharding_spec = strategy.get_sharding_spec_by_name('x') + weight_sharding_spec = strategy.get_sharding_spec_by_name('linear_weight') + output_sharding_spec = strategy.get_sharding_spec_by_name('linear') + + # make sure the sharding matches across different operation data + assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] + assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1] + assert weight_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[-1] + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_linear_handler(bias=True): + spawn(check_linear_module_handler, bias=bias) + + +if __name__ == '__main__': + test_linear_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..83f3aafe220eaadb31fbada624cf6bb65c899df2 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py @@ -0,0 +1,260 @@ +import pytest +import torch +import torch.nn as nn + +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass +from colossalai._analyzer.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.tensor_shard.node_handler import BinaryElementwiseHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + + +def check_binary_elementwise_handler_with_tensor(rank, world_size, port, op, other_dim): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + class BinaryElementwiseOpModel(nn.Module): + + def __init__(self, op): + super().__init__() + self.op = op + + def forward(self, x1, x2): + out = self.op(x1, x2) + return out + + model = BinaryElementwiseOpModel(op).cuda() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + x1 = torch.rand(4, 4).cuda() + x2 = torch.rand([4] * other_dim).cuda() + # the index of binary-elementwise node in computation graph + node_index = 2 + # strategy number of binary-elementwise node + strategy_number = 9 + # construct input args + input_args = [x1, x2] + # construct meta arg names + meta_arg_names = ['x1', 'x2'] + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names) + + tracer = ColoTracer(bias_addition_split=True) + meta_args = {'x1': torch.rand(4, 4).to('meta'), 'x2': torch.rand([4] * other_dim).to('meta')} + graph = tracer.trace(model, meta_args=meta_args) + gm = ColoGraphModule(model, graph) + shape_prop_pass(gm, *meta_args.values()) + + op_node = list(graph.nodes)[2] + strategies_vector = StrategiesVector(op_node) + + # build handler + handler = BinaryElementwiseHandler(node=op_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['input'].name == "x1" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 4]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 4]) + + assert mapping['other'].name == "x2" + assert mapping['other'].data.is_meta + assert mapping['other'].data.shape == torch.Size([4] * other_dim) + assert mapping['other'].type == OperationDataType.ARG + assert mapping['other'].logical_shape == torch.Size([4, 4]) + + assert mapping['output'].name == str(op_node) + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([4, 4]) + assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping['output'].logical_shape == torch.Size([4, 4]) + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # one strategy will be converted to different physical sharding spec + assert len(strategy_name_list) == 9 + + # check if the sharding strategy is correct + assert '[S0, S1] = [S0, S1] [S0, S1]' in strategy_name_list + assert '[S1, S0] = [S1, S0] [S1, S0]' in strategy_name_list + assert '[S01, R] = [S01, R] [S01, R]' in strategy_name_list + assert '[R, S01] = [R, S01] [R, S01]' in strategy_name_list + assert '[S0, R] = [S0, R] [S0, R]' in strategy_name_list + assert '[R, S0] = [R, S0] [R, S0]' in strategy_name_list + assert '[S1, R] = [S1, R] [S1, R]' in strategy_name_list + assert '[R, S1] = [R, S1] [R, S1]' in strategy_name_list + assert '[R, R] = [R, R] [R, R]' in strategy_name_list + + for strategy in strategies_vector: + input_sharding_spec = strategy.get_sharding_spec_by_name('x1') + other_sharding_spec = strategy.get_sharding_spec_by_name('x2') + output_sharding_spec = strategy.get_sharding_spec_by_name(str(op_node)) + + # make sure the sharding spec is the same for input and output + assert input_sharding_spec.sharding_sequence == output_sharding_spec.sharding_sequence + + # since the dim of the other can change, we make sure at least its last dim sharding is the same + if len(other_sharding_spec.sharding_sequence) == 2: + assert input_sharding_spec.sharding_sequence == other_sharding_spec.sharding_sequence + elif len(other_sharding_spec.sharding_sequence) == 1: + assert input_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[-1] + + +class BEOpModelWithNodeConst(nn.Module): + + def __init__(self, op): + super().__init__() + self.op = op + + def forward(self, x1): + const = x1.dim() + out = self.op(x1, const) + return out + + +class BEOpModelWithIntConst(nn.Module): + + def __init__(self, op, const): + super().__init__() + self.op = op + self.const = const + + def forward(self, x1): + out = self.op(x1, self.const) + return out + + +def check_binary_elementwise_handler_with_int(rank, world_size, port, op, other_dim, model_cls): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + if model_cls == BEOpModelWithNodeConst: + model = model_cls(op).cuda() + else: + model = model_cls(op, other_dim).cuda() + x1 = torch.rand(4, 4).cuda() + # the index of binary-elementwise node in computation graph + node_index = 1 + # strategy number of binary-elementwise node + strategy_number = 9 + # construct input args + input_args = [x1] + # construct meta arg names + meta_arg_names = ['x1'] + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names) + tracer = ColoTracer(bias_addition_split=True) + meta_args = {'x1': torch.rand(4, 4).to('meta')} + graph = tracer.trace(model, meta_args=meta_args) + gm = ColoGraphModule(model, graph) + shape_prop_pass(gm, *meta_args.values()) + + if model_cls == BEOpModelWithNodeConst: + op_node = list(graph.nodes)[2] + else: + op_node = list(graph.nodes)[1] + strategies_vector = StrategiesVector(op_node) + + # build handler + handler = BinaryElementwiseHandler(node=op_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + assert mapping['input'].name == "x1" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 4]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 4]) + + assert mapping['output'].name == str(op_node) + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([4, 4]) + assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping['output'].logical_shape == torch.Size([4, 4]) + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # one strategy will be converted to different physical sharding spec + assert len(strategy_name_list) == 9 + + # check if the sharding strategy is correct + assert '[S0, S1] = [S0, S1] [S0, S1]' in strategy_name_list + assert '[S1, S0] = [S1, S0] [S1, S0]' in strategy_name_list + assert '[S01, R] = [S01, R] [S01, R]' in strategy_name_list + assert '[R, S01] = [R, S01] [R, S01]' in strategy_name_list + assert '[S0, R] = [S0, R] [S0, R]' in strategy_name_list + assert '[R, S0] = [R, S0] [R, S0]' in strategy_name_list + assert '[S1, R] = [S1, R] [S1, R]' in strategy_name_list + assert '[R, S1] = [R, S1] [R, S1]' in strategy_name_list + assert '[R, R] = [R, R] [R, R]' in strategy_name_list + + for strategy in strategies_vector: + input_sharding_spec = strategy.get_sharding_spec_by_name('x1') + output_sharding_spec = strategy.get_sharding_spec_by_name(str(op_node)) + + # make sure the sharding spec is the same for input and output + assert input_sharding_spec.sharding_sequence == output_sharding_spec.sharding_sequence + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@parameterize('op', [torch.add]) +@parameterize('other_dim', [1, 2]) +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_binary_elementwise_handler_with_tensor(op, other_dim): + spawn( + check_binary_elementwise_handler_with_tensor, + 4, + op=op, + other_dim=other_dim, + ) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@parameterize('op', [torch.add]) +@parameterize('other_dim', [1, 2]) +@parameterize('model_cls', [BEOpModelWithNodeConst, BEOpModelWithIntConst]) +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_binary_elementwise_handler_with_int(op, model_cls, other_dim): + spawn( + check_binary_elementwise_handler_with_int, + 4, + op=op, + model_cls=model_cls, + other_dim=other_dim, + ) + + +if __name__ == '__main__': + test_binary_elementwise_handler_with_tensor() + test_binary_elementwise_handler_with_int() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..f4fdc458f80ea874bc6a3916e2974558e0663c72 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py @@ -0,0 +1,210 @@ +import pytest +import torch +import torch.nn as nn + +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass +from colossalai._analyzer.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.tensor_shard.node_handler import BMMFunctionHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + + +class BMMTensorMethodModule(nn.Module): + + def forward(self, x1, x2): + return x1.bmm(x2) + + +class BMMTorchFunctionModule(nn.Module): + + def forward(self, x1, x2): + return torch.bmm(x1, x2) + + +def check_2d_device_mesh(rank, module, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = module().cuda() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + x1 = torch.rand(4, 8, 16).cuda() + x2 = torch.rand(4, 16, 8).cuda() + # the index of bmm node in computation graph + node_index = 2 + # strategy number of bmm node on 2d device mesh + strategy_number = 7 + # construct input args + input_args = [x1, x2] + # construct meta arg names + meta_arg_names = ['x1', 'x2'] + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names) + tracer = ColoTracer(bias_addition_split=True) + meta_args = {'x1': torch.rand(4, 8, 16).to('meta'), 'x2': torch.rand(4, 16, 8).to('meta')} + graph = tracer.trace(model, meta_args=meta_args) + gm = ColoGraphModule(model, graph) + shape_prop_pass(gm, *meta_args.values()) + + linear_mod_node = list(graph.nodes)[2] + strategies_vector = StrategiesVector(linear_mod_node) + + # build handler + handler = BMMFunctionHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['input'].name == "x1" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 8, 16]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 8, 16]) + + assert mapping['other'].name == "x2" + assert mapping['other'].data.is_meta + assert mapping['other'].data.shape == torch.Size([4, 16, 8]) + assert mapping['other'].type == OperationDataType.ARG + assert mapping['other'].logical_shape == torch.Size([4, 16, 8]) + + assert mapping['output'].name == "bmm" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([4, 8, 8]) + assert mapping['output'].type == OperationDataType.OUTPUT + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # one batch dim + assert 'Sb0 = Sb0 x Sb0' not in strategy_name_list + + # two batch dim + assert 'Sb01 = Sb01 x Sb01' in strategy_name_list + + # SbSi = SbSi x Sb + assert 'Sb0Si1 = Sb0Si1 x Sb0' in strategy_name_list + assert 'Sb1Si0 = Sb1Si0 x Sb1' in strategy_name_list + + # SbSj = SbR x SbSj + assert 'Sb0Sj1 = Sb0R x Sb0Sj1' in strategy_name_list + assert 'Sb1Sj0 = Sb1R x Sb1Sj0' in strategy_name_list + + # SbR = SbSk x SbSk + assert 'Sb0R = Sb0Sk1 x Sb0Sk1' in strategy_name_list + assert 'Sb1R = Sb1Sk0 x Sb1Sk0' in strategy_name_list + + for strategy in strategies_vector: + input_sharding_spec = strategy.get_sharding_spec_by_name('x1') + other_sharding_spec = strategy.get_sharding_spec_by_name('x2') + output_sharding_spec = strategy.get_sharding_spec_by_name('bmm') + + # make sure the sharding matches across different operation data + assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] + assert other_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1] + assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] + + +def check_1d_device_mesh(rank, module, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = module().cuda() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (1, 4) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + x1 = torch.rand(4, 8, 16).cuda() + x2 = torch.rand(4, 16, 8).cuda() + # the index of bmm node in computation graph + node_index = 2 + # strategy number of bmm node on 1d device mesh + strategy_number = 1 + # construct input args + input_args = [x1, x2] + # construct meta arg names + meta_arg_names = ['x1', 'x2'] + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names) + tracer = ColoTracer(bias_addition_split=True) + meta_args = {'x1': torch.rand(4, 8, 16).to('meta'), 'x2': torch.rand(4, 16, 8).to('meta')} + graph = tracer.trace(model, meta_args=meta_args) + gm = ColoGraphModule(model, graph) + shape_prop_pass(gm, *meta_args.values()) + linear_mod_node = list(graph.nodes)[2] + strategies_vector = StrategiesVector(linear_mod_node) + + # build handler + handler = BMMFunctionHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['input'].name == "x1" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 8, 16]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 8, 16]) + + assert mapping['other'].name == "x2" + assert mapping['other'].data.is_meta + assert mapping['other'].data.shape == torch.Size([4, 16, 8]) + assert mapping['other'].type == OperationDataType.ARG + assert mapping['other'].logical_shape == torch.Size([4, 16, 8]) + + assert mapping['output'].name == "bmm" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([4, 8, 8]) + assert mapping['output'].type == OperationDataType.OUTPUT + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + assert len(strategy_name_list) == 1 + # one batch dim + assert 'Sb0 = Sb0 x Sb0' in strategy_name_list + + for strategy in strategies_vector: + input_sharding_spec = strategy.get_sharding_spec_by_name('x1') + other_sharding_spec = strategy.get_sharding_spec_by_name('x2') + output_sharding_spec = strategy.get_sharding_spec_by_name('bmm') + + # make sure the sharding matches across different operation data + assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] + assert other_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1] + assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@parameterize('module', [BMMTensorMethodModule, BMMTorchFunctionModule]) +@parameterize('module', [BMMTensorMethodModule, BMMTorchFunctionModule]) +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bmm_handler(module): + spawn(check_2d_device_mesh, 4, module=module) + spawn(check_1d_device_mesh, 4, module=module) + + +if __name__ == '__main__': + test_bmm_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..f9632b1cd8f9bb02299718d0917379cbcdbde78b --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py @@ -0,0 +1,315 @@ +import pytest +import torch +import torch.nn as nn + +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass +from colossalai._analyzer.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler, ConvModuleHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + + +def check_conv_module_handler(rank, world_size, port, bias): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1, bias=bias)).cuda() + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {}) + # return _0 + input = torch.rand(4, 4, 64, 64).cuda() + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # index of conv node in computation graph + node_index = 1 + # total number of conv strategies + strategy_number = 16 + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=['input']) + tracer = ColoTracer(bias_addition_split=True) + meta_args = {'input': torch.rand(4, 4, 64, 64).to('meta')} + graph = tracer.trace(model, meta_args=meta_args) + gm = ColoGraphModule(model, graph) + shape_prop_pass(gm, *meta_args.values()) + conv_mod_node = list(graph.nodes)[1] + strategies_vector = StrategiesVector(conv_mod_node) + + # build handler + handler = ConvModuleHandler(node=conv_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['input'].name == "input_1" + # assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 4, 64, 64]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 4, 64, 64]) + + assert mapping['other'].name == "weight" + # assert mapping['other'].data.is_meta + assert mapping['other'].data.shape == torch.Size([16, 4, 3, 3]) + assert mapping['other'].type == OperationDataType.PARAM + assert mapping['other'].logical_shape == torch.Size([4, 16, 3, 3]) + + if bias: + assert mapping['bias'].name == "bias" + # assert mapping['bias'].data.is_meta + assert mapping['bias'].data.shape == torch.Size([16]) + assert mapping['bias'].type == OperationDataType.PARAM + assert mapping['bias'].logical_shape == torch.Size([16]) + + assert mapping['output'].name == "_0" + # assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([4, 16, 64, 64]) + assert mapping['output'].type == OperationDataType.OUTPUT + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # SS = SR x RS + assert 'S0S1 = S0R x RS1' in strategy_name_list + assert 'S1S0 = S1R x RS0' in strategy_name_list + + # SR = SR x RR + assert 'S0R = S0R x RR' in strategy_name_list + assert 'S1R = S1R x RR' in strategy_name_list + + # SR = SS x SR + assert 'S0R = S0S1 x S1R' in strategy_name_list + assert 'S1R = S1S0 x S0R' in strategy_name_list + + # RS = RS x SS + assert 'RS0 = RS1 x S1S0' in strategy_name_list + assert 'RS1 = RS0 x S0S1' in strategy_name_list + + # RR = RS x SR + assert 'RR = RS0 x S0R' in strategy_name_list + assert 'RR = RS1 x S1R' in strategy_name_list + + # RS= RR x RS + assert 'RS0 = RR x RS0' in strategy_name_list + assert 'RS1 = RR x RS1' in strategy_name_list + + # RR = RR x RR + assert 'RR = RR x RR' in strategy_name_list + + # S01R = S01R x RR + assert 'S01R = S01R x RR' in strategy_name_list + + # RR = RS01 x S01R + assert 'RR = RS01 x S01R' in strategy_name_list + + # RS01 = RR x RS01 + assert 'RS01 = RR x RS01' in strategy_name_list + + for strategy in strategies_vector: + input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') + weight_sharding_spec = strategy.get_sharding_spec_by_name('weight') + output_sharding_spec = strategy.get_sharding_spec_by_name('_0') + + if bias: + bias_sharding_spec = strategy.get_sharding_spec_by_name('bias') + + # make sure the sharding matches across different operation data + assert output_sharding_spec.sharding_sequence[1] == weight_sharding_spec.sharding_sequence[0] + assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0] + assert input_sharding_spec.sharding_sequence[2:] == output_sharding_spec.sharding_sequence[2:] + assert input_sharding_spec.sharding_sequence[1] == weight_sharding_spec.sharding_sequence[1] + + if bias: + assert bias_sharding_spec.sharding_sequence[-1] == weight_sharding_spec.sharding_sequence[0] + assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[1] + + +class ConvModel(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, input, others, bias=None): + x = nn.functional.conv2d(input, others, bias=bias, padding=1) + return x + + +def check_conv_function_handler(rank, world_size, port, bias): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = ConvModel().cuda() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + input = torch.rand(4, 4, 64, 64).cuda() + others = torch.rand(16, 4, 3, 3).cuda() + input_args = [input, others] + meta_arg_names = ['input', 'others'] + input_kwargs = {} + # total number of conv strategies + strategy_number = 16 + node_index = 2 + if bias: + bias_tensor = torch.rand(16).cuda() + input_kwargs['bias'] = bias_tensor + node_index += 1 + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + input_kwargs=input_kwargs) + + tracer = ColoTracer(bias_addition_split=True) + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %others : torch.Tensor [#users=1] = placeholder[target=others] + # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %others), kwargs = {}) + # return conv2d + meta_args = {"input": torch.rand(4, 4, 64, 64).to('meta'), "others": torch.rand(16, 4, 3, 3).to('meta')} + if bias: + meta_args['bias'] = torch.rand(16).to('meta') + graph = tracer.trace(model, meta_args=meta_args) + gm = ColoGraphModule(model, graph) + shape_prop_pass(gm, *meta_args.values()) + + if bias: + conv_mod_node = list(graph.nodes)[3] + else: + conv_mod_node = list(graph.nodes)[2] + strategies_vector = StrategiesVector(conv_mod_node) + + # build handler + handler = ConvFunctionHandler(node=conv_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['input'].name == "input_1" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 4, 64, 64]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 4, 64, 64]) + + assert mapping['other'].name == "others" + assert mapping['other'].data.is_meta + assert mapping['other'].data.shape == torch.Size([16, 4, 3, 3]) + assert mapping['other'].type == OperationDataType.ARG + assert mapping['other'].logical_shape == torch.Size([4, 16, 3, 3]) + + if bias: + assert mapping['bias'].name == "bias" + assert mapping['bias'].data.is_meta + assert mapping['bias'].data.shape == torch.Size([16]) + assert mapping['bias'].type == OperationDataType.ARG + assert mapping['bias'].logical_shape == torch.Size([16]) + + assert mapping['output'].name == "conv2d" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([4, 16, 64, 64]) + assert mapping['output'].type == OperationDataType.OUTPUT + + handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # SS = SR x RS + assert 'S0S1 = S0R x RS1' in strategy_name_list + assert 'S1S0 = S1R x RS0' in strategy_name_list + + # SR = SR x RR + assert 'S0R = S0R x RR' in strategy_name_list + assert 'S1R = S1R x RR' in strategy_name_list + + # SR = SS x SR + assert 'S0R = S0S1 x S1R' in strategy_name_list + assert 'S1R = S1S0 x S0R' in strategy_name_list + + # RS = RS x SS + assert 'RS0 = RS1 x S1S0' in strategy_name_list + assert 'RS1 = RS0 x S0S1' in strategy_name_list + + # RR = RS x SR + assert 'RR = RS0 x S0R' in strategy_name_list + assert 'RR = RS1 x S1R' in strategy_name_list + + # RS= RR x RS + assert 'RS0 = RR x RS0' in strategy_name_list + assert 'RS1 = RR x RS1' in strategy_name_list + + # RR = RR x RR + assert 'RR = RR x RR' in strategy_name_list + + # S01R = S01R x RR + assert 'S01R = S01R x RR' in strategy_name_list + + # RR = RS01 x S01R + assert 'RR = RS01 x S01R' in strategy_name_list + + # RS01 = RR x RS01 + assert 'RS01 = RR x RS01' in strategy_name_list + + for strategy in strategies_vector: + input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') + weight_sharding_spec = strategy.get_sharding_spec_by_name('others') + output_sharding_spec = strategy.get_sharding_spec_by_name('conv2d') + + if bias: + bias_sharding_spec = strategy.get_sharding_spec_by_name('bias') + + # make sure the sharding matches across different operation data + assert output_sharding_spec.sharding_sequence[1] == weight_sharding_spec.sharding_sequence[0] + assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0] + assert input_sharding_spec.sharding_sequence[2:] == output_sharding_spec.sharding_sequence[2:] + assert input_sharding_spec.sharding_sequence[1] == weight_sharding_spec.sharding_sequence[1] + + if bias: + assert bias_sharding_spec.sharding_sequence[-1] == weight_sharding_spec.sharding_sequence[0] + assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[1] + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +# We temporarily ban the bias option before doing bias add +# before all reduce communication may encounter correctness issue. +# @parameterize('bias', [True, False]) +@rerun_if_address_is_in_use() +def test_conv_module_handler(bias=False): + spawn(check_conv_module_handler, 4, bias=bias) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +# We temporarily ban the bias option before doing bias add +# before all reduce communication may encounter correctness issue. +# @parameterize('bias', [True, False]) +@rerun_if_address_is_in_use() +def test_conv_function_handler(bias=False): + spawn(check_conv_function_handler, 4, bias=bias) + + +if __name__ == '__main__': + test_conv_module_handler() + test_conv_function_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_default_reshape_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_default_reshape_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..64f56ba98e2b9a2725f2a1bd1a8b3a9a18c4d0c2 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_default_reshape_handler.py @@ -0,0 +1,88 @@ +import torch +import torch.nn as nn + +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass +from colossalai._analyzer.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.tensor_shard.node_handler import DefaultReshapeHandler +from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.testing import clear_cache_before_run, run_on_environment_flag + + +class ReshapeModel(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, input, other): + conv_node = nn.functional.conv2d(input, other) + reshape_node = conv_node.view(2, -1) + return reshape_node + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@clear_cache_before_run() +def test_reshape_handler(): + model = ReshapeModel() + tracer = ColoTracer(bias_addition_split=True) + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %other : torch.Tensor [#users=1] = placeholder[target=other] + # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {}) + # %view : [#users=1] = call_method[target=view](args = (%conv2d, 2, -1), kwargs = {}) + # return view + meta_args = { + "input": torch.rand(4, 4, 64, 64).to('meta'), + "other": torch.rand(16, 4, 3, 3).to('meta'), + } + graph = tracer.trace(model, meta_args=meta_args) + gm = ColoGraphModule(model, graph) + shape_prop_pass(gm, *meta_args.values()) + physical_mesh_id = torch.arange(0, 4) + + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + conv_mod_node = list(graph.nodes)[2] + reshape_node = list(graph.nodes)[3] + reshape_strategies_vector = StrategiesVector(reshape_node) + conv_strategies_vector = StrategiesVector(conv_mod_node) + + # build handler + conv_handler = ConvFunctionHandler(node=conv_mod_node, + device_mesh=device_mesh, + strategies_vector=conv_strategies_vector) + conv_handler.register_strategy(compute_resharding_cost=False) + setattr(conv_mod_node, 'strategies_vector', conv_strategies_vector) + reshape_handler = DefaultReshapeHandler(node=reshape_node, + device_mesh=device_mesh, + strategies_vector=reshape_strategies_vector) + + reshape_handler.register_strategy(compute_resharding_cost=False) + + # check operation data mapping + mapping = reshape_handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.data is not None + + assert mapping['input'].name == "conv2d" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 16, 62, 62]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 16, 62, 62]) + + assert mapping['output'].name == "view" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([2, 123008]) + assert mapping['output'].type == OperationDataType.OUTPUT + + # reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node. + assert len(reshape_strategies_vector) == len(conv_strategies_vector) + + +if __name__ == '__main__': + test_reshape_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_embedding_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_embedding_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..4fa0313b1cb5da80c53ace4be0cae5012ea68f0e --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_embedding_handler.py @@ -0,0 +1,283 @@ +import pytest +import torch +import torch.nn as nn + +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass +from colossalai._analyzer.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.tensor_shard.node_handler.embedding_handler import ( + EmbeddingFunctionHandler, + EmbeddingModuleHandler, +) +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + +NUM_EMBEDDINGS = 16 +EMBEDDING_DIMS = 32 + + +class EmbeddingModule(nn.Module): + + def __init__(self, num_embeddings, embedding_dims): + super().__init__() + self.embedding = nn.Embedding(num_embeddings, embedding_dims) + + def forward(self, input): + x = self.embedding(input) + return x + + +def check_embedding_module_handler(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = EmbeddingModule(num_embeddings=NUM_EMBEDDINGS, embedding_dims=EMBEDDING_DIMS).cuda() + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %embedding : [#users=1] = call_module[target=embedding](args = (%input_1,), kwargs = {}) + # return embedding + input = torch.rand(4, 16, 16) * NUM_EMBEDDINGS + input = input.to(torch.int64).cuda() + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # index of embedding node in computation graph + node_index = 1 + # total number of embedding strategies + strategy_number = 19 + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=['input']) + + tracer = ColoTracer(bias_addition_split=True) + meta_args = {"input": torch.randint(NUM_EMBEDDINGS, (4, 16, 16)).to('meta')} + graph = tracer.trace(model, meta_args=meta_args) + gm = ColoGraphModule(model, graph) + shape_prop_pass(gm, *meta_args.values()) + embedding_node = list(graph.nodes)[1] + strategies_vector = StrategiesVector(embedding_node) + + # build handler + handler = EmbeddingModuleHandler(node=embedding_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['input'].name == "input_1" + # assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 16, 16]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([1024]) + + assert mapping['other'].name == "weight" + assert mapping['other'].data.shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS]) + assert mapping['other'].type == OperationDataType.PARAM + assert mapping['other'].logical_shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS]) + + assert mapping['output'].name == "embedding" + assert mapping['output'].data.shape == torch.Size([4, 16, 16, EMBEDDING_DIMS]) + assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping['output'].logical_shape == torch.Size([1024, EMBEDDING_DIMS]) + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # RR = RR x RR + assert 'RR = R x RR' in strategy_name_list + + # SR = SR x RR + assert 'S0R = S0 x RR_0' in strategy_name_list + assert 'S0R = S0 x RR_1' in strategy_name_list + assert 'S0R = S0 x RR_2' in strategy_name_list + assert 'S1R = S1 x RR_0' in strategy_name_list + assert 'S1R = S1 x RR_1' in strategy_name_list + assert 'S1R = S1 x RR_2' in strategy_name_list + + # SS = SR x RS + assert 'S0S1 = S0 x RS1_0' in strategy_name_list + assert 'S0S1 = S0 x RS1_1' in strategy_name_list + assert 'S0S1 = S0 x RS1_2' in strategy_name_list + assert 'S1S0 = S1 x RS0_0' in strategy_name_list + assert 'S1S0 = S1 x RS0_1' in strategy_name_list + assert 'S1S0 = S1 x RS0_2' in strategy_name_list + + # RS= RR x RS + assert 'RS0 = R x RS0' in strategy_name_list + assert 'RS1 = R x RS1' in strategy_name_list + + # S01R = S01R x RR + assert 'S01R = S01 x RR_0' in strategy_name_list + assert 'S01R = S01 x RR_1' in strategy_name_list + assert 'S01R = S01 x RR_2' in strategy_name_list + + # RS01 = RR x RS01 + assert 'RS01 = R x RS01' in strategy_name_list + + for strategy in strategies_vector: + input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') + weight_sharding_spec = strategy.get_sharding_spec_by_name('weight') + output_sharding_spec = strategy.get_sharding_spec_by_name('embedding') + + # make sure the sharding matches across different operation data + assert output_sharding_spec.sharding_sequence[-1] == weight_sharding_spec.sharding_sequence[-1] + assert input_sharding_spec.sharding_sequence == output_sharding_spec.sharding_sequence[:-1] + + +class EmbeddingFunction(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, input, others): + x = nn.functional.embedding(input, others) + return x + + +def check_embedding_function_handler(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = EmbeddingFunction().cuda() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + input = torch.rand(4, 16, 16) * NUM_EMBEDDINGS + input = input.to(torch.int64).cuda() + others = torch.rand(NUM_EMBEDDINGS, EMBEDDING_DIMS).cuda() + input_args = [input, others] + meta_arg_names = ['input', 'others'] + input_kwargs = {} + # total number of embedding strategies + strategy_number = 19 + node_index = 2 + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + input_kwargs=input_kwargs) + tracer = ColoTracer(bias_addition_split=True) + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %others : torch.Tensor [#users=1] = placeholder[target=others] + # %embedding : [#users=1] = call_function[target=torch.nn.functional.embedding](args = (%input_1, %others), kwargs = {padding_idx: None, max_norm: None, norm_type: 2.0, scale_grad_by_freq: False, sparse: False}) + # return embedding + meta_args = { + "input": torch.randint(NUM_EMBEDDINGS, (4, 16, 16)).to('meta'), + "others": torch.rand(NUM_EMBEDDINGS, EMBEDDING_DIMS).to('meta') + } + graph = tracer.trace(model, meta_args=meta_args) + gm = ColoGraphModule(model, graph) + shape_prop_pass(gm, *meta_args.values()) + + embedding_node = list(graph.nodes)[2] + strategies_vector = StrategiesVector(embedding_node) + + # build handler + handler = EmbeddingFunctionHandler(node=embedding_node, + device_mesh=device_mesh, + strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['input'].name == "input_1" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 16, 16]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([1024]) + + assert mapping['other'].name == "others" + assert mapping['other'].data.is_meta + assert mapping['other'].data.shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS]) + assert mapping['other'].type == OperationDataType.ARG + assert mapping['other'].logical_shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS]) + + assert mapping['output'].name == "embedding" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([4, 16, 16, EMBEDDING_DIMS]) + assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping['output'].logical_shape == torch.Size([1024, EMBEDDING_DIMS]) + + handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # RR = RR x RR + assert 'RR = R x RR' in strategy_name_list + + # SR = SR x RR + assert 'S0R = S0 x RR_0' in strategy_name_list + assert 'S0R = S0 x RR_1' in strategy_name_list + assert 'S0R = S0 x RR_2' in strategy_name_list + assert 'S1R = S1 x RR_0' in strategy_name_list + assert 'S1R = S1 x RR_1' in strategy_name_list + assert 'S1R = S1 x RR_2' in strategy_name_list + + # SS = SR x RS + assert 'S0S1 = S0 x RS1_0' in strategy_name_list + assert 'S0S1 = S0 x RS1_1' in strategy_name_list + assert 'S0S1 = S0 x RS1_2' in strategy_name_list + assert 'S1S0 = S1 x RS0_0' in strategy_name_list + assert 'S1S0 = S1 x RS0_1' in strategy_name_list + assert 'S1S0 = S1 x RS0_2' in strategy_name_list + + # RS= RR x RS + assert 'RS0 = R x RS0' in strategy_name_list + assert 'RS1 = R x RS1' in strategy_name_list + + # S01R = S01R x RR + assert 'S01R = S01 x RR_0' in strategy_name_list + assert 'S01R = S01 x RR_1' in strategy_name_list + assert 'S01R = S01 x RR_2' in strategy_name_list + + # RS01 = RR x RS01 + assert 'RS01 = R x RS01' in strategy_name_list + + for strategy in strategies_vector: + input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') + weight_sharding_spec = strategy.get_sharding_spec_by_name('others') + output_sharding_spec = strategy.get_sharding_spec_by_name('embedding') + + # make sure the sharding matches across different operation data + assert output_sharding_spec.sharding_sequence[-1] == weight_sharding_spec.sharding_sequence[-1] + assert input_sharding_spec.sharding_sequence == output_sharding_spec.sharding_sequence[:-1] + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_embedding_module_handler(): + spawn(check_embedding_module_handler, 4) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_embedding_function_handler(): + spawn(check_embedding_function_handler, 4) + + +if __name__ == '__main__': + test_embedding_module_handler() + test_embedding_function_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..a089df743ec0156f5905988f36f37b5afdb387a5 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py @@ -0,0 +1,75 @@ +import pytest +import torch +import torch.nn as nn + +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass +from colossalai._analyzer.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.tensor_shard.node_handler.getattr_handler import GetattrHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.testing import clear_cache_before_run + + +class GetattrModel(nn.Module): + + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(4, 16, 3, padding=1, bias=False) + + def forward(self, input): + weight = self.conv.weight + return weight + + +@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') +@clear_cache_before_run() +def test_getattr_handler(): + model = GetattrModel() + tracer = ColoTracer(bias_addition_split=True) + # graph(): + # %input_1 : torch.Tensor [#users=0] = placeholder[target=input] + # %conv_weight : [#users=1] = get_attr[target=conv.weight] + # return conv_weight + meta_args = {'input': torch.rand(4, 4, 64, 64).to('meta')} + graph = tracer.trace(model, meta_args=meta_args) + gm = ColoGraphModule(model, graph) + shape_prop_pass(gm, *meta_args.values()) + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + getattr_node = list(graph.nodes)[1] + getattr_strategies_vector = StrategiesVector(getattr_node) + + # build handler + getattr_handler = GetattrHandler(node=getattr_node, + device_mesh=device_mesh, + strategies_vector=getattr_strategies_vector) + + getattr_handler.register_strategy(compute_resharding_cost=False) + + # check operation data mapping + mapping = getattr_handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.data is not None + + assert mapping['output'].name == "conv_weight" + assert mapping['output'].data.shape == torch.Size((16, 4, 3, 3)) + assert mapping['output'].type == OperationDataType.OUTPUT + strategy_name_list = [val.name for val in getattr_handler.strategies_vector] + assert 'get_attr [S0, S1, R, R]' in strategy_name_list + assert 'get_attr [S1, S0, R, R]' in strategy_name_list + assert 'get_attr [S01, R, R, R]' in strategy_name_list + assert 'get_attr [R, S01, R, R]' in strategy_name_list + assert 'get_attr [S0, R, R, R]' in strategy_name_list + assert 'get_attr [R, S0, R, R]' in strategy_name_list + assert 'get_attr [S1, R, R, R]' in strategy_name_list + assert 'get_attr [R, S1, R, R]' in strategy_name_list + assert 'get_attr [R, R, R, R]' in strategy_name_list + + +if __name__ == '__main__': + test_getattr_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..a2e0968b18bb2dc4f2d3cbe27bc5e88bae20f69e --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py @@ -0,0 +1,191 @@ +from functools import partial + +import pytest +import torch +import torch.nn as nn + +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass +from colossalai._analyzer.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.tensor_shard.node_handler.default_reshape_handler import DefaultReshapeHandler +from colossalai.auto_parallel.tensor_shard.node_handler.getitem_handler import GetItemHandler +from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler +from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + + +class GetItemFromTensorModel(nn.Module): + + def __init__(self, getitem_index): + super().__init__() + self.getitem_index = getitem_index + + def forward(self, input, other): + linear_node = nn.functional.linear(input, other, bias=None) + x = linear_node[self.getitem_index] + return x + + +def check_getitem_from_tensor_handler(rank, getitem_index, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + model = GetItemFromTensorModel(getitem_index=getitem_index) + + input = torch.rand(8, 16, 64, 32).to('cuda') + other = torch.rand(64, 32).to('cuda') + # index of linear node in computation graph + node_index = 2 + # total number of linear strategies + strategy_number = 23 + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input, other], + meta_arg_names=['input', 'other'], + node_type='following') + + tracer = ColoTracer(bias_addition_split=True) + meta_args = { + "input": torch.rand(8, 16, 64, 32).to('meta'), + "other": torch.rand(64, 32).to('meta'), + } + graph = tracer.trace(model, meta_args=meta_args) + + gm = ColoGraphModule(model, graph) + shape_prop_pass(gm, *list(meta_args.values())) + linear_mod_node = list(graph.nodes)[2] + getitem_mod_node = list(graph.nodes)[3] + getitem_strategies_vector = StrategiesVector(getitem_mod_node) + linear_strategies_vector = StrategiesVector(linear_mod_node) + + # build handler + linear_handler = LinearFunctionHandler(node=linear_mod_node, + device_mesh=device_mesh, + strategies_vector=linear_strategies_vector) + linear_handler.register_strategy(compute_resharding_cost=False) + setattr(linear_mod_node, 'strategies_vector', linear_strategies_vector) + getitem_handler = GetItemHandler(node=getitem_mod_node, + device_mesh=device_mesh, + strategies_vector=getitem_strategies_vector) + + getitem_handler.register_strategy(compute_resharding_cost=False) + # check operation data mapping + mapping = getitem_handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.data is not None + + # getitem is a following strategy handler, so the number of strategies is equal to the predecessor node. + assert len(getitem_strategies_vector) == len(linear_strategies_vector) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +# @parameterize('getitem_index', [slice(0, 2), (slice(None), slice(None))]) +@parameterize('getitem_index', [1, (1, 4), slice(0, 2), (slice(None), slice(None))]) +def test_getitem_from_tensor_handler(getitem_index): + spawn(check_getitem_from_tensor_handler, 4) + + +class GetItemFromTupleModel(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, input): + split_node = torch.split(input, 2, 0) + x = split_node[1] + return x + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@clear_cache_before_run() +def test_getitem_from_tuple_handler(): + model = GetItemFromTupleModel() + tracer = ColoTracer() + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %split : [#users=1] = call_function[target=torch.functional.split](args = (%conv2d, 2), kwargs = {dim: 0}) + # %getitem : [#users=1] = call_function[target=operator.getitem](args = (%split, 1), kwargs = {}) + # return getitem + meta_args = { + "input": torch.rand(4, 4, 64, 64).to('meta'), + } + graph = tracer.trace(model, meta_args=meta_args) + gm = ColoGraphModule(model, graph) + shape_prop_pass(gm, *meta_args.values()) + physical_mesh_id = torch.arange(0, 4) + + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + input_node = list(graph.nodes)[0] + split_node = list(graph.nodes)[1] + getitem_node = list(graph.nodes)[2] + input_strategies_vector = StrategiesVector(input_node) + getitem_strategies_vector = StrategiesVector(getitem_node) + split_strategies_vector = StrategiesVector(split_node) + + # build handler + input_handler = PlaceholderHandler( + node=input_node, + device_mesh=device_mesh, + strategies_vector=input_strategies_vector, + placeholder_option='replicated', + ) + input_handler.register_strategy(compute_resharding_cost=False) + setattr(input_node, 'strategies_vector', input_strategies_vector) + split_handler = DefaultReshapeHandler(node=split_node, + device_mesh=device_mesh, + strategies_vector=split_strategies_vector) + split_handler.register_strategy(compute_resharding_cost=False) + setattr(split_node, 'strategies_vector', split_strategies_vector) + getitem_handler = GetItemHandler(node=getitem_node, + device_mesh=device_mesh, + strategies_vector=getitem_strategies_vector) + getitem_handler.register_strategy(compute_resharding_cost=False) + setattr(getitem_node, 'strategies_vector', getitem_strategies_vector) + + # check operation data mapping + mapping = getitem_handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.data is not None + + assert mapping['input'].name == "split" + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == (torch.Size([2, 4, 64, 64]), torch.Size([2, 4, 64, 64])) + + assert mapping['index'].name == "index" + assert isinstance(mapping['index'].data, int) + assert mapping['index'].type == OperationDataType.ARG + + assert mapping['output'].name == "getitem" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([2, 4, 64, 64]) + assert mapping['output'].type == OperationDataType.OUTPUT + + # getitem is a following strategy handler, so the number of strategies is equal to the predecessor node. + assert len(getitem_strategies_vector) == len(split_strategies_vector) + + +if __name__ == '__main__': + test_getitem_from_tensor_handler() + test_getitem_from_tuple_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..ad72c2026b9aea03b83bfdf0db41c915a433618a --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py @@ -0,0 +1,106 @@ +import pytest +import torch +import torch.nn as nn + +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass +from colossalai._analyzer.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.tensor_shard.node_handler.layer_norm_handler import LayerNormModuleHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + + +def check_ln_module_handler(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = nn.Sequential(nn.LayerNorm(16)).cuda() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + input = torch.rand(4, 16).cuda() + # the index of bn node in computation graph + node_index = 1 + # the total number of ln strategies + strategy_number = 4 + # construct input args + input_args = [input] + # construct meta arg names + meta_arg_names = ['input'] + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names) + tracer = ColoTracer(bias_addition_split=True) + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {}) + # return _0 + meta_args = {"input": torch.rand(4, 16).to('meta')} + graph = tracer.trace(model, meta_args=meta_args) + gm = ColoGraphModule(model, graph) + shape_prop_pass(gm, *meta_args.values()) + + ln_mod_node = list(graph.nodes)[1] + strategies_vector = StrategiesVector(ln_mod_node) + + # build handler + handler = LayerNormModuleHandler(node=ln_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['input'].name == "input_1" + assert mapping['input'].data.shape == torch.Size([4, 16]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 16]) + + assert mapping['other'].name == "weight" + assert mapping['other'].data.shape == torch.Size([16]) + assert mapping['other'].type == OperationDataType.PARAM + assert mapping['other'].logical_shape == torch.Size([16]) + + assert mapping['bias'].name == "bias" + assert mapping['bias'].data.shape == torch.Size([16]) + assert mapping['bias'].type == OperationDataType.PARAM + assert mapping['bias'].logical_shape == torch.Size([16]) + + assert mapping['output'].name == "_0" + assert mapping['output'].data.shape == torch.Size([4, 16]) + assert mapping['output'].type == OperationDataType.OUTPUT + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # SR = SR x R + assert '[S0, R] = [S0, R] x [R]' in strategy_name_list + assert '[S1, R] = [S1, R] x [R]' in strategy_name_list + + # RR = RR x R + assert 'RR = RR x R' in strategy_name_list + + # S01R = S01R x R + assert '[S01, R] = [S01, R] x [R]' in strategy_name_list + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_ln_module_handler(): + spawn(check_ln_module_handler, 4) + + +if __name__ == '__main__': + test_ln_module_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..ec695cd8f7b9e26704f28da68862844b6c75d4a4 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py @@ -0,0 +1,327 @@ +import pytest +import torch +import torch.nn as nn + +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass +from colossalai._analyzer.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + OperationData, + OperationDataType, + ShardingStrategy, + StrategiesVector, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing.utils import parameterize +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + + +def check_linear_module_handler(rank, world_size, port, bias, input_shape): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = nn.Sequential(nn.Linear(16, 32, bias=bias)).cuda() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + input = torch.rand(input_shape).cuda() + # the index of linear node in computation graph + node_index = 1 + # strategy number of linear node + if input_shape == (1, 4, 4, 16): + strategy_number = 19 + else: + strategy_number = 24 + # construct input args + input_args = [input] + # construct meta arg names + meta_arg_names = ['input'] + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names) + + tracer = ColoTracer(bias_addition_split=True) + meta_args = {"input": torch.rand(input_shape).cuda()} + graph = tracer.trace(model, meta_args=meta_args) + gm = ColoGraphModule(model, graph) + shape_prop_pass(gm, *meta_args.values()) + + linear_mod_node = list(graph.nodes)[1] + strategies_vector = StrategiesVector(linear_mod_node) + + # build handler + handler = LinearModuleHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['input'].name == "input_1" + assert mapping['input'].data.shape == torch.Size(input_shape) + assert mapping['input'].type == OperationDataType.ARG + input_logical_shape = mapping['input'].data.view(-1, 16).shape + assert mapping['input'].logical_shape == input_logical_shape + + assert mapping['other'].name == "weight" + assert mapping['other'].data.shape == torch.Size([32, 16]) + assert mapping['other'].type == OperationDataType.PARAM + assert mapping['other'].logical_shape == torch.Size([16, 32]) + + if bias: + assert mapping['bias'].name == "bias" + assert mapping['bias'].data.shape == torch.Size([32]) + assert mapping['bias'].type == OperationDataType.PARAM + assert mapping['bias'].logical_shape == torch.Size([32]) + + assert mapping['output'].name == "_0" + output_shape = input_shape[:-1] + (32,) + assert mapping['output'].data.shape == torch.Size(output_shape) + assert mapping['output'].type == OperationDataType.OUTPUT + output_logical_shape = mapping['output'].data.view(-1, 32).shape + assert mapping['output'].logical_shape == torch.Size(output_logical_shape) + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # First dimension cannot be shard if input shape is (1, 4, 4, 16) + if input_shape != (1, 4, 4, 16): + assert 'S1S0 = S1R x RS0_0' in strategy_name_list + assert 'S0S1 = S0R x RS1_0' in strategy_name_list + assert 'S1R = S1S0 x S0R_0' in strategy_name_list + assert 'S0R = S0S1 x S1R_0' in strategy_name_list + assert 'S01R = S01R x RR_0' in strategy_name_list + + # SS = SR x RS + assert 'S0S1 = S0R x RS1_1' in strategy_name_list + assert 'S0S1 = S0R x RS1_2' in strategy_name_list + assert 'S1S0 = S1R x RS0_1' in strategy_name_list + assert 'S1S0 = S1R x RS0_2' in strategy_name_list + + # SR = SS x SR + assert 'S0R = S0S1 x S1R_1' in strategy_name_list + assert 'S0R = S0S1 x S1R_2' in strategy_name_list + assert 'S1R = S1S0 x S0R_1' in strategy_name_list + assert 'S1R = S1S0 x S0R_2' in strategy_name_list + + # RS = RS x SS + assert 'RS0 = RS1 x S1S0' in strategy_name_list + assert 'RS1 = RS0 x S0S1' in strategy_name_list + + # RR = RS x SR + assert 'RR = RS0 x S0R' in strategy_name_list + assert 'RR = RS1 x S1R' in strategy_name_list + + # RS= RR x RS + assert 'RS0 = RR x RS0' in strategy_name_list + assert 'RS1 = RR x RS1' in strategy_name_list + + # S01R = S01R x RR + assert 'S01R = S01R x RR_1' in strategy_name_list + assert 'S01R = S01R x RR_2' in strategy_name_list + + # RR = RS01 x S01R + assert 'RR = RS01 x S01R' in strategy_name_list + + # RS01 = RR x RS01 + assert 'RS01 = RR x RS01' in strategy_name_list + + # RR = RR x RR + assert 'RR = RR x RR' in strategy_name_list + + for strategy in strategies_vector: + strategy: ShardingStrategy + input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') + weight_sharding_spec = strategy.get_sharding_spec_by_name('weight') + output_sharding_spec = strategy.get_sharding_spec_by_name('_0') + + if bias: + bias_sharding_spec = strategy.get_sharding_spec_by_name('bias') + + # make sure the sharding matches across different operation data + assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] + assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1] + assert weight_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[-1] + + if bias: + assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] + + +class LinearModel(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, input, others, bias=None): + x = nn.functional.linear(input, others, bias=bias) + return x + + +def check_linear_function_handler(rank, world_size, port, bias, input_shape): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = LinearModel().cuda() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + input = torch.rand(input_shape).cuda() + other = torch.rand(32, 16).cuda() + # the index of linear node in computation graph + node_index = 2 + # strategy number of linear node + if input_shape == (1, 4, 4, 16): + strategy_number = 19 + else: + strategy_number = 24 + # construct input args + input_args = [input, other] + # construct meta arg names + meta_arg_names = ['input', 'others'] + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names) + + tracer = ColoTracer(bias_addition_split=True) + meta_args = {'input': torch.rand(input_shape).to('meta'), 'others': torch.rand(32, 16).to('meta')} + graph = tracer.trace(model, meta_args=meta_args) + gm = ColoGraphModule(model, graph) + shape_prop_pass(gm, *meta_args.values()) + + if bias: + linear_func_node = list(graph.nodes)[3] + else: + linear_func_node = list(graph.nodes)[2] + strategies_vector = StrategiesVector(linear_func_node) + + # build handler + handler = LinearFunctionHandler(node=linear_func_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # # check operation data mapping + mapping = handler.get_operation_data_mapping() + + assert mapping['input'].name == "input_1" + assert mapping['input'].data.shape == torch.Size(input_shape) + assert mapping['input'].type == OperationDataType.ARG + input_logical_shape = mapping['input'].data.view(-1, 16).shape + assert mapping['input'].logical_shape == torch.Size(input_logical_shape) + + assert mapping['other'].name == "others" + assert mapping['other'].data.shape == torch.Size([32, 16]) + assert mapping['other'].type == OperationDataType.ARG + assert mapping['other'].logical_shape == torch.Size([16, 32]) + + if bias: + assert mapping['bias'].name == "bias" + assert mapping['bias'].data.shape == torch.Size([32]) + assert mapping['bias'].type == OperationDataType.ARG + assert mapping['other'].logical_shape == torch.Size([16, 32]) + + assert mapping['output'].name == "linear" + output_shape = input_shape[:-1] + (32,) + assert mapping['output'].data.shape == torch.Size(output_shape) + assert mapping['output'].type == OperationDataType.OUTPUT + output_logical_shape = mapping['output'].data.view(-1, 32).shape + assert mapping['output'].logical_shape == torch.Size(output_logical_shape) + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # First dimension cannot be shard if input shape is (1, 4, 4, 16) + if input_shape != (1, 4, 4, 16): + assert 'S1S0 = S1R x RS0_0' in strategy_name_list + assert 'S0S1 = S0R x RS1_0' in strategy_name_list + assert 'S1R = S1S0 x S0R_0' in strategy_name_list + assert 'S0R = S0S1 x S1R_0' in strategy_name_list + assert 'S01R = S01R x RR_0' in strategy_name_list + + # SS = SR x RS + assert 'S0S1 = S0R x RS1_1' in strategy_name_list + assert 'S0S1 = S0R x RS1_2' in strategy_name_list + assert 'S1S0 = S1R x RS0_1' in strategy_name_list + assert 'S1S0 = S1R x RS0_2' in strategy_name_list + + # SR = SS x SR + assert 'S0R = S0S1 x S1R_1' in strategy_name_list + assert 'S0R = S0S1 x S1R_2' in strategy_name_list + assert 'S1R = S1S0 x S0R_1' in strategy_name_list + assert 'S1R = S1S0 x S0R_2' in strategy_name_list + + # RS = RS x SS + assert 'RS0 = RS1 x S1S0' in strategy_name_list + assert 'RS1 = RS0 x S0S1' in strategy_name_list + + # RR = RS x SR + assert 'RR = RS0 x S0R' in strategy_name_list + assert 'RR = RS1 x S1R' in strategy_name_list + + # RS= RR x RS + assert 'RS0 = RR x RS0' in strategy_name_list + assert 'RS1 = RR x RS1' in strategy_name_list + + # S01R = S01R x RR + assert 'S01R = S01R x RR_1' in strategy_name_list + assert 'S01R = S01R x RR_2' in strategy_name_list + + # RR = RS01 x S01R + assert 'RR = RS01 x S01R' in strategy_name_list + + # RS01 = RR x RS01 + assert 'RS01 = RR x RS01' in strategy_name_list + + # RR = RR x RR + assert 'RR = RR x RR' in strategy_name_list + + for strategy in strategies_vector: + strategy: ShardingStrategy + input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') + weight_sharding_spec = strategy.get_sharding_spec_by_name('others') + output_sharding_spec = strategy.get_sharding_spec_by_name('linear') + + if bias: + bias_sharding_spec = strategy.get_sharding_spec_by_name('bias') + + # make sure the sharding matches across different operation data + assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] + assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1] + assert weight_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[-1] + + if bias: + assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@parameterize('input_shape', [(1, 4, 4, 16), (4, 4, 4, 16)]) +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_linear_handler(input_shape, bias=False): + spawn( + check_linear_module_handler, + 4, + bias=bias, + input_shape=input_shape, + ) + spawn( + check_linear_function_handler, + 4, + bias=bias, + input_shape=input_shape, + ) + + +if __name__ == '__main__': + test_linear_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..938acd3d1eeacef457cc887ef92cb8624884a7e3 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py @@ -0,0 +1,175 @@ +import pytest +import torch +import torch.nn as nn + +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass +from colossalai._analyzer.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.tensor_shard.node_handler.matmul_handler import ( + MatMulHandler, + MatMulType, + _get_bmm_logical_shape, + get_matmul_type, +) +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + OperationData, + OperationDataType, + ShardingStrategy, + StrategiesVector, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.testing.utils import clear_cache_before_run, parameterize + + +class MatMulModule(nn.Module): + + def forward(self, x1, x2): + return torch.matmul(x1, x2) + + +@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") +@clear_cache_before_run() +@parameterize( + 'tensor_shapes', + [ + [[8], [8]], # dot product + [[4, 8], [8]], # mat-vec product + [[4, 8], [8, 16]], # mat-mat product + [[8], [8, 16]], # mat-mat product + [[8], [4, 8, 16]], # batched mat-mat product with padding + broadcasting + [[4, 8, 16], [16]], # batched mat-mat product with padding + broadcasting + [[4, 8, 16], [16, 32]], # batched mat-mat product with broadcasting + [[4, 8, 16], [1, 16, 32]], # batched mat-mat product with broadcasting + [[8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting + [[4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting + [[1, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting + [[1, 4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting + [[2, 1, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting + [[2, 4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product without broadcasting + ]) +def test_matmul_node_handler(tensor_shapes): + input_shape, other_shape = tensor_shapes + + # get output shape + x1 = torch.rand(*input_shape) + x2 = torch.rand(*other_shape) + output_shape = list(torch.matmul(x1, x2).shape) + + # get matmul type + matmul_type = get_matmul_type(x1.dim(), x2.dim()) + + model = MatMulModule() + + tracer = ColoTracer(bias_addition_split=True) + meta_args = {"x1": x1.to('meta'), 'x2': x2.to('meta')} + graph = tracer.trace(model, meta_args=meta_args) + gm = ColoGraphModule(model, graph) + shape_prop_pass(gm, *meta_args.values()) + physical_mesh_id = torch.arange(0, 4) + + print(graph) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + mod_node = list(graph.nodes)[2] + strategies_vector = StrategiesVector(mod_node) + + # build handler + handler = MatMulHandler(node=mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + logical_input_shape = input_shape + logical_other_shape = other_shape + logical_output_shape = output_shape + if matmul_type == MatMulType.MM and len(input_shape) == 1: + logical_input_shape = [1] + input_shape + elif matmul_type == MatMulType.BMM: + logical_input_shape, logical_other_shape, logical_output_shape = _get_bmm_logical_shape( + input_shape, other_shape, handler.transforms) + else: + logical_input_shape = input_shape + + # check input operation data + assert mapping['input'].name == "x1" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size(input_shape) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size(logical_input_shape) + + # check other operation data + assert mapping['other'].name == "x2" + assert mapping['other'].data.is_meta + assert mapping['other'].data.shape == torch.Size(other_shape) + assert mapping['other'].type == OperationDataType.ARG + assert mapping['other'].logical_shape == torch.Size(logical_other_shape) + + # check output + assert mapping['output'].name == "matmul" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size(output_shape) + assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping['output'].logical_shape == torch.Size(logical_output_shape) + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # ensure there is no duplicate strategy + if matmul_type != MatMulType.BMM: + assert len(set(strategy_name_list)) == len(strategy_name_list), strategy_name_list + + for strategy in strategies_vector: + strategy: ShardingStrategy + input_sharding_spec = strategy.get_sharding_spec_by_name('x1') + other_sharding_spec = strategy.get_sharding_spec_by_name('x2') + output_sharding_spec = strategy.get_sharding_spec_by_name('matmul') + if matmul_type == MatMulType.DOT: + # dot product will produce a scaler + # results should fulfill: + # 1. the input and other operands have the same sharding spec + # 2. the output has no sharding + assert input_sharding_spec.sharding_sequence == other_sharding_spec.sharding_sequence + assert len(output_sharding_spec.sharding_sequence) == 0 + elif matmul_type == MatMulType.MV: + # matrix-vector product should fulfill + # 1. the last dim of the input and other operands should have the same sharding + # 2. the first dim of the input and other should have the same sharding + # 3. the output should have only 1 dim + assert input_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[-1] + assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0] + assert len(output_sharding_spec.sharding_sequence) == 1 + elif matmul_type == MatMulType.MM: + # matrix-matrix multiplication should fulfil + # 1. if input is a 2D tensor, the 1st dim of input and output should have the same sharding + # 2. the input's last dim and the first dim of the other should have the same sharding + # 3. the last dim of the output and other should have the same sharding + # 4. the input and output should have the same number of dims + if len(input_shape) == 2: + assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0] + assert input_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[0] + assert output_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[-1] + assert len(input_sharding_spec.sharding_sequence) == len(output_sharding_spec.sharding_sequence) + elif matmul_type == MatMulType.BMM: + # bmm should fulfil + # 1. of the other tensor is not a 1d tensor, the last dim of other and output have the same sharding + # 2. if the input has more than 2 dim, the second last dim of input and output have the same sharding + # 3. if the other have more than 2 dim, the second last dim of other and the last dim of input should have the same sharding + if len(other_shape) > 1: + assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] + if len(input_shape) > 1: + if len(other_shape) == 1: + assert input_sharding_spec.sharding_sequence[-2] == output_sharding_spec.sharding_sequence[-1] + else: + assert input_sharding_spec.sharding_sequence[-2] == output_sharding_spec.sharding_sequence[-2] + if len(other_shape) > 2: + assert other_sharding_spec.sharding_sequence[-2] == input_sharding_spec.sharding_sequence[-1] + + +if __name__ == '__main__': + test_matmul_node_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..6bff9f9648e2aca20c5c480726bfeb7038041712 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py @@ -0,0 +1,61 @@ +import torch +import torch.nn as nn + +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass +from colossalai._analyzer.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.tensor_shard.node_handler.normal_pooling_handler import NormPoolingHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.testing import clear_cache_before_run, run_on_environment_flag + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@clear_cache_before_run() +def test_norm_pool_handler(): + model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta')) + tracer = ColoTracer(bias_addition_split=True) + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {}) + # return _0 + meta_args = {"input": torch.rand(4, 4, 64, 64).to('meta')} + graph = tracer.trace(model, meta_args=meta_args) + + gm = ColoGraphModule(model, graph) + shape_prop_pass(gm, *meta_args.values()) + physical_mesh_id = torch.arange(0, 4) + + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + conv_mod_node = list(graph.nodes)[1] + strategies_vector = StrategiesVector(conv_mod_node) + + # build handler + handler = NormPoolingHandler(node=conv_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.data is not None + + assert mapping['input'].name == "input_1" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 4, 64, 64]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 4, 64, 64]) + + assert mapping['output'].name == "_0" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([4, 4, 16, 16]) + assert mapping['output'].type == OperationDataType.OUTPUT + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + assert len(strategy_name_list) == 9 + + +if __name__ == '__main__': + test_norm_pool_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..5259455d2179cae56865bba73718d8e818a39c41 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py @@ -0,0 +1,70 @@ +import pytest +import torch +import torch.nn as nn + +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass +from colossalai._analyzer.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import OutputHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.testing import clear_cache_before_run, parameterize + + +class OutputModel(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x): + y = x * 2 + return x, y + + +@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') +@parameterize('output_option', ['distributed', 'replicated']) +@clear_cache_before_run() +def test_output_handler(output_option): + model = OutputModel() + tracer = ColoTracer(bias_addition_split=True) + # graph(): + # %x : torch.Tensor [#users=2] = placeholder[target=x] + # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {}) + # return (x, mul) + meta_args = {'x': torch.rand(4, 4, 64, 64).to('meta')} + graph = tracer.trace(model, meta_args=meta_args) + gm = ColoGraphModule(model, graph) + shape_prop_pass(gm, *meta_args.values()) + physical_mesh_id = torch.arange(0, 4) + + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + output_node = list(graph.nodes)[2] + output_strategies_vector = StrategiesVector(output_node) + + # build handler + otuput_handler = OutputHandler(node=output_node, + device_mesh=device_mesh, + strategies_vector=output_strategies_vector, + output_option=output_option) + + otuput_handler.register_strategy(compute_resharding_cost=False) + # check operation data mapping + mapping = otuput_handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.data is not None + + assert mapping['output'].name == "output" + assert mapping['output'].type == OperationDataType.OUTPUT + strategy_name_list = [val.name for val in otuput_handler.strategies_vector] + if output_option == 'distributed': + assert "Distributed Output" in strategy_name_list + else: + assert "Replica Output" in strategy_name_list + + +if __name__ == '__main__': + test_output_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..f071cd120fb719e832b691a7de37d811a4235708 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py @@ -0,0 +1,339 @@ +from functools import partial + +import pytest +import torch +import torch.nn as nn + +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass +from colossalai._analyzer.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.tensor_shard.node_handler import PermuteHandler, TransposeHandler +from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler +from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + + +class ConvReshapeModel(nn.Module): + + def __init__(self, reshape_dims, call_function): + super().__init__() + self.reshape_dims = reshape_dims + self.call_function = call_function + + def forward(self, input, other): + conv_node = nn.functional.conv2d(input, other, bias=None) + # permute_node = torch.permute(conv_node, self.permute_dims) + if self.call_function == torch.permute: + permute_node = self.call_function(conv_node, self.reshape_dims) + else: + permute_node = self.call_function(conv_node, *self.reshape_dims) + return permute_node + + +class LinearReshapeModel(nn.Module): + + def __init__(self, reshape_dims, call_function): + super().__init__() + self.reshape_dims = reshape_dims + self.call_function = call_function + + def forward(self, input, other): + linear_node = nn.functional.linear(input, other, bias=None) + # permute_node = torch.permute(linear_node, self.tgt_shape) + if self.call_function == torch.permute: + permute_node = self.call_function(linear_node, self.reshape_dims) + else: + permute_node = self.call_function(linear_node, *self.reshape_dims) + return permute_node + + +def check_view_handler(rank, world_size, port, call_function, reshape_dims, model_cls): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + if call_function == torch.permute: + reshape_dims = reshape_dims[0] + elif call_function == torch.transpose: + reshape_dims = reshape_dims[1] + model = model_cls(reshape_dims, call_function).cuda() + + if model_cls.__name__ == 'ConvReshapeModel': + input = torch.rand(8, 8, 66, 66).to('cuda') + other = torch.rand(16, 8, 3, 3).to('cuda') + # index of conv node in computation graph + node_index = 2 + # total number of conv strategies + strategy_number = 16 + if model_cls.__name__ == 'LinearReshapeModel': + input = torch.rand(8, 16, 64, 32).to('cuda') + other = torch.rand(64, 32).to('cuda') + # index of linear node in computation graph + node_index = 2 + # total number of linear strategies + strategy_number = 23 + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input, other], + meta_arg_names=['input', 'other'], + node_type='following') + tracer = ColoTracer(bias_addition_split=True) + if model_cls.__name__ == 'ConvReshapeModel': + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %other : torch.Tensor [#users=1] = placeholder[target=other] + # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {bias: None}) + # %permute : [#users=1] = call_function[target=torch.permute](args = (%conv2d, (0, 2, 1, 3)), kwargs = {}) + # return permute + meta_args = { + 'input': torch.rand(8, 8, 66, 66).to('meta'), + 'other': torch.rand(16, 8, 3, 3).to('meta'), + } + graph = tracer.trace(model, meta_args=meta_args) + + if model_cls.__name__ == 'LinearReshapeModel': + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %other : torch.Tensor [#users=1] = placeholder[target=other] + # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None}) + # %permute : [#users=1] = call_method[target=view](args = (%linear, 32, 4, 32, 32, 4), kwargs = {}) + # return permute + meta_args = { + 'input': torch.rand(8, 16, 64, 32).to('meta'), + 'other': torch.rand(64, 32).to('meta'), + } + graph = tracer.trace(model, meta_args=meta_args) + + gm = ColoGraphModule(model, graph) + shape_prop_pass(gm, *meta_args.values()) + + previous_mod_node = list(graph.nodes)[2] + reshape_node = list(graph.nodes)[3] + view_strategies_vector = StrategiesVector(reshape_node) + previous_strategies_vector = StrategiesVector(previous_mod_node) + + # build handler + if model_cls.__name__ == 'ConvReshapeModel': + + conv_handler = ConvFunctionHandler(node=previous_mod_node, + device_mesh=device_mesh, + strategies_vector=previous_strategies_vector) + conv_handler.register_strategy(compute_resharding_cost=False) + setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + + if model_cls.__name__ == 'LinearReshapeModel': + assert len(previous_strategies_vector) == 0 + linear_handler = LinearFunctionHandler(node=previous_mod_node, + device_mesh=device_mesh, + strategies_vector=previous_strategies_vector) + linear_handler.register_strategy(compute_resharding_cost=False) + setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + + if call_function == torch.permute: + reshape_handler = PermuteHandler(node=reshape_node, + device_mesh=device_mesh, + strategies_vector=view_strategies_vector) + else: + reshape_handler = TransposeHandler(node=reshape_node, + device_mesh=device_mesh, + strategies_vector=view_strategies_vector) + + reshape_handler.register_strategy(compute_resharding_cost=False) + + # check operation data mapping + mapping = reshape_handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.data is not None + + if model_cls.__name__ == 'ConvReshapeModel': + assert mapping['input'].name == "conv2d" + else: + assert mapping['input'].name == "linear" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64]) + + if call_function == torch.permute: + assert mapping['output'].name == "permute" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.permute(torch.rand(8, 16, 64, 64), reshape_dims).shape + assert mapping['output'].type == OperationDataType.OUTPUT + else: + assert mapping['output'].name == "transpose" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.transpose(torch.rand(8, 16, 64, 64), *reshape_dims).shape + assert mapping['output'].type == OperationDataType.OUTPUT + + # reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node. + assert len(view_strategies_vector) == len(previous_strategies_vector) + strategy_name_list = [strategy.name for strategy in view_strategies_vector] + if rank == 0: + for name in strategy_name_list: + print(name) + if model_cls.__name__ == 'ConvReshapeModel': + + if reshape_dims in ((0, 2, 1, 3), (1, 2)): + assert '[S0, S1, R, R] -> [S0, R, S1, R]_0' in strategy_name_list + assert '[S1, S0, R, R] -> [S1, R, S0, R]_1' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R]_2' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R]_3' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R]_4' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R]_5' in strategy_name_list + assert '[R, S1, R, R] -> [R, R, S1, R]_6' in strategy_name_list + assert '[R, S0, R, R] -> [R, R, S0, R]_7' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list + assert '[R, S0, R, R] -> [R, R, S0, R]_10' in strategy_name_list + assert '[R, S1, R, R] -> [R, R, S1, R]_11' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_12' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list + assert '[R, S01, R, R] -> [R, R, S01, R]_15' in strategy_name_list + + if reshape_dims == (2, 0, 1, 3): + assert '[S0, S1, R, R] -> [R, S0, S1, R]_0' in strategy_name_list + assert '[S1, S0, R, R] -> [R, S1, S0, R]_1' in strategy_name_list + assert '[S0, R, R, R] -> [R, S0, R, R]_2' in strategy_name_list + assert '[S1, R, R, R] -> [R, S1, R, R]_3' in strategy_name_list + assert '[S0, R, R, R] -> [R, S0, R, R]_4' in strategy_name_list + assert '[S1, R, R, R] -> [R, S1, R, R]_5' in strategy_name_list + assert '[R, S1, R, R] -> [R, R, S1, R]_6' in strategy_name_list + assert '[R, S0, R, R] -> [R, R, S0, R]_7' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list + assert '[R, S0, R, R] -> [R, R, S0, R]_10' in strategy_name_list + assert '[R, S1, R, R] -> [R, R, S1, R]_11' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_12' in strategy_name_list + assert '[S01, R, R, R] -> [R, S01, R, R]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list + assert '[R, S01, R, R] -> [R, R, S01, R]_15' in strategy_name_list + + if reshape_dims == (1, 3): + assert '[S0, S1, R, R] -> [S0, R, R, S1]_0' in strategy_name_list + assert '[S1, S0, R, R] -> [S1, R, R, S0]_1' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R]_2' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R]_3' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R]_4' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R]_5' in strategy_name_list + assert '[R, S1, R, R] -> [R, R, R, S1]_6' in strategy_name_list + assert '[R, S0, R, R] -> [R, R, R, S0]_7' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list + assert '[R, S0, R, R] -> [R, R, R, S0]_10' in strategy_name_list + assert '[R, S1, R, R] -> [R, R, R, S1]_11' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_12' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list + assert '[R, S01, R, R] -> [R, R, R, S01]_15' in strategy_name_list + + if model_cls.__name__ == 'LinearReshapeModel': + + if reshape_dims == ((0, 2, 1, 3), (1, 2)): + assert '[S0, R, R, S1] -> [S0, R, R, S1]_11' in strategy_name_list + assert '[R, S0, R, S1] -> [R, R, S0, S1]_12' in strategy_name_list + assert '[R, R, S0, S1] -> [R, S0, R, S1]_13' in strategy_name_list + assert '[S1, R, R, S0] -> [S1, R, R, S0]_14' in strategy_name_list + assert '[R, S1, R, S0] -> [R, R, S1, S0]_15' in strategy_name_list + assert '[R, R, S1, S0] -> [R, S1, R, S0]_16' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R]_17' in strategy_name_list + assert '[R, S0, R, R] -> [R, R, S0, R]_18' in strategy_name_list + assert '[R, R, S0, R] -> [R, S0, R, R]_19' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R]_20' in strategy_name_list + assert '[R, S1, R, R] -> [R, R, S1, R]_21' in strategy_name_list + assert '[R, R, S1, R] -> [R, S1, R, R]_22' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R]_0' in strategy_name_list + assert '[R, S01, R, R] -> [R, R, S01, R]_1' in strategy_name_list + assert '[R, R, S01, R] -> [R, S01, R, R]_2' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list + + if reshape_dims == (2, 0, 1, 3): + assert '[S0, R, R, S1] -> [R, S0, R, S1]_11' in strategy_name_list + assert '[R, S0, R, S1] -> [R, R, S0, S1]_12' in strategy_name_list + assert '[R, R, S0, S1] -> [S0, R, R, S1]_13' in strategy_name_list + assert '[S1, R, R, S0] -> [R, S1, R, S0]_14' in strategy_name_list + assert '[R, S1, R, S0] -> [R, R, S1, S0]_15' in strategy_name_list + assert '[R, R, S1, S0] -> [S1, R, R, S0]_16' in strategy_name_list + assert '[S0, R, R, R] -> [R, S0, R, R]_17' in strategy_name_list + assert '[R, S0, R, R] -> [R, R, S0, R]_18' in strategy_name_list + assert '[R, R, S0, R] -> [S0, R, R, R]_19' in strategy_name_list + assert '[S1, R, R, R] -> [R, S1, R, R]_20' in strategy_name_list + assert '[R, S1, R, R] -> [R, R, S1, R]_21' in strategy_name_list + assert '[R, R, S1, R] -> [S1, R, R, R]_22' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list + assert '[S01, R, R, R] -> [R, S01, R, R]_0' in strategy_name_list + assert '[R, S01, R, R] -> [R, R, S01, R]_1' in strategy_name_list + assert '[R, R, S01, R] -> [S01, R, R, R]_2' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list + + if reshape_dims == (1, 3): + assert '[S0, R, R, S1] -> [S0, S1, R, R]_11' in strategy_name_list + assert '[R, S0, R, S1] -> [R, S1, R, S0]_12' in strategy_name_list + assert '[R, R, S0, S1] -> [R, S1, S0, R]_13' in strategy_name_list + assert '[S1, R, R, S0] -> [S1, S0, R, R]_14' in strategy_name_list + assert '[R, S1, R, S0] -> [R, S0, R, S1]_15' in strategy_name_list + assert '[R, R, S1, S0] -> [R, S0, S1, R]_16' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R]_17' in strategy_name_list + assert '[R, S0, R, R] -> [R, R, R, S0]_18' in strategy_name_list + assert '[R, R, S0, R] -> [R, R, S0, R]_19' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R]_20' in strategy_name_list + assert '[R, S1, R, R] -> [R, R, R, S1]_21' in strategy_name_list + assert '[R, R, S1, R] -> [R, R, S1, R]_22' in strategy_name_list + assert '[R, R, R, S1] -> [R, S1, R, R]_10' in strategy_name_list + assert '[R, R, R, S0] -> [R, S0, R, R]_9' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list + assert '[R, R, R, S0] -> [R, S0, R, R]_6' in strategy_name_list + assert '[R, R, R, S1] -> [R, S1, R, R]_5' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R]_0' in strategy_name_list + assert '[R, S01, R, R] -> [R, R, R, S01]_1' in strategy_name_list + assert '[R, R, S01, R] -> [R, R, S01, R]_2' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list + assert '[R, R, R, S01] -> [R, S01, R, R]_4' in strategy_name_list + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +@parameterize('call_function', [torch.permute, torch.transpose]) +@parameterize('reshape_dims', [((0, 2, 1, 3), (1, 2)), ((2, 0, 1, 3), (1, 3))]) +@parameterize('model_cls', [ConvReshapeModel, LinearReshapeModel]) +def test_view_handler(call_function, reshape_dims, model_cls): + spawn( + check_view_handler, + 4, + call_function=call_function, + reshape_dims=reshape_dims, + model_cls=model_cls, + ) + + +if __name__ == '__main__': + test_view_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..6d02b0e0ba7407147fc603cbec9910ff79e58743 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py @@ -0,0 +1,80 @@ +import pytest +import torch +import torch.nn as nn + +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass +from colossalai._analyzer.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.testing import clear_cache_before_run, parameterize + + +class PlaceholderModel(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, input): + return input + + +@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') +@parameterize('placeholder_option', ['distributed', 'replicated']) +@clear_cache_before_run() +def test_placeholder_handler(placeholder_option): + model = PlaceholderModel() + tracer = ColoTracer(bias_addition_split=True) + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # return input_1 + meta_args = { + "input": torch.rand(4, 4, 64, 64).to('meta'), + } + graph = tracer.trace(model, meta_args=meta_args) + gm = ColoGraphModule(model, graph) + shape_prop_pass(gm, *meta_args.values()) + physical_mesh_id = torch.arange(0, 4) + + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + placeholder_node = list(graph.nodes)[0] + placeholder_strategies_vector = StrategiesVector(placeholder_node) + # build handler + placeholder_handler = PlaceholderHandler(node=placeholder_node, + device_mesh=device_mesh, + strategies_vector=placeholder_strategies_vector, + placeholder_option=placeholder_option) + + placeholder_handler.register_strategy(compute_resharding_cost=False) + + # check operation data mapping + mapping = placeholder_handler.get_operation_data_mapping() + + strategy = placeholder_strategies_vector[0] + strategy_sharding_spec = strategy.get_sharding_spec_by_name(mapping['output'].name) + + if placeholder_option == 'distributed': + assert str(strategy_sharding_spec.sharding_sequence) == '[S01, R, R, R]' + else: + assert str(strategy_sharding_spec.sharding_sequence) == '[R, R, R, R]' + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.data is not None + + assert mapping['output'].name == "input_1" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size((4, 4, 64, 64)) + assert mapping['output'].type == OperationDataType.OUTPUT + strategy_name_list = [val.name for val in placeholder_handler.strategies_vector] + if placeholder_option == 'replicated': + assert "Replica Placeholder" in strategy_name_list + else: + assert "Distributed Placeholder" in strategy_name_list + + +if __name__ == '__main__': + test_placeholder_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py new file mode 100644 index 0000000000000000000000000000000000000000..14c364c45fc437c08a69948eea63802184fc04e9 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py @@ -0,0 +1,118 @@ +import torch +import torch.nn as nn + +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass +from colossalai._analyzer.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler +from colossalai.auto_parallel.tensor_shard.options import ShardOption +from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.testing import clear_cache_before_run, run_on_environment_flag + + +class LinearModel(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, input, others, bias=None): + x = nn.functional.linear(input, others, bias=bias) + return x + + +def check_shard_option(shard_option): + model = LinearModel().cuda() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + + tracer = ColoTracer(bias_addition_split=True) + meta_args = {'input': torch.rand(4, 4, 4, 16).to('meta'), 'others': torch.rand(32, 16).to('meta')} + graph = tracer.trace(model, meta_args=meta_args) + gm = ColoGraphModule(model, graph) + shape_prop_pass(gm, *meta_args.values()) + linear_func_node = list(graph.nodes)[2] + strategies_vector = StrategiesVector(linear_func_node) + + # build handler + handler = LinearFunctionHandler(node=linear_func_node, + device_mesh=device_mesh, + strategies_vector=strategies_vector, + shard_option=shard_option) + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + if shard_option == ShardOption.SHARD_LAST_AXIS: + # RR = RS x SR + assert 'RR = RS1 x S1R' in strategy_name_list + + # RS= RR x RS + assert 'RS1 = RR x RS1' in strategy_name_list + + return + + # SS = SR x RS + assert 'S1S0 = S1R x RS0_0' in strategy_name_list + assert 'S0S1 = S0R x RS1_1' in strategy_name_list + assert 'S0S1 = S0R x RS1_2' in strategy_name_list + assert 'S0S1 = S0R x RS1_0' in strategy_name_list + assert 'S1S0 = S1R x RS0_1' in strategy_name_list + assert 'S1S0 = S1R x RS0_2' in strategy_name_list + + # SR = SS x SR + assert 'S0R = S0S1 x S1R_1' in strategy_name_list + assert 'S0R = S0S1 x S1R_2' in strategy_name_list + assert 'S1R = S1S0 x S0R_0' in strategy_name_list + assert 'S0R = S0S1 x S1R_0' in strategy_name_list + assert 'S1R = S1S0 x S0R_1' in strategy_name_list + assert 'S1R = S1S0 x S0R_2' in strategy_name_list + + # RS = RS x SS + assert 'RS0 = RS1 x S1S0' in strategy_name_list + assert 'RS1 = RS0 x S0S1' in strategy_name_list + + # S01R = S01R x RR + assert 'S01R = S01R x RR_0' in strategy_name_list + assert 'S01R = S01R x RR_1' in strategy_name_list + assert 'S01R = S01R x RR_2' in strategy_name_list + + # RR = RS01 x S01R + assert 'RR = RS01 x S01R' in strategy_name_list + + # RS01 = RR x RS01 + assert 'RS01 = RR x RS01' in strategy_name_list + + if shard_option == ShardOption.SHARD: + # RR = RS x SR + assert 'RR = RS0 x S0R' in strategy_name_list + assert 'RR = RS1 x S1R' in strategy_name_list + + # RS= RR x RS + assert 'RS0 = RR x RS0' in strategy_name_list + assert 'RS1 = RR x RS1' in strategy_name_list + + if shard_option == ShardOption.STANDARD: + # RR = RS x SR + assert 'RR = RS0 x S0R' in strategy_name_list + assert 'RR = RS1 x S1R' in strategy_name_list + + # RS= RR x RS + assert 'RS0 = RR x RS0' in strategy_name_list + assert 'RS1 = RR x RS1' in strategy_name_list + + # RR = RR x RR + assert 'RR = RR x RR' in strategy_name_list + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@clear_cache_before_run() +def test_shard_option(): + # for shard_option in [ShardOption.STANDARD, ShardOption.SHARD, ShardOption.FULL_SHARD, ShardOption.SHARD_LAST_AXIS]: + for shard_option in [ShardOption.SHARD_LAST_AXIS]: + check_shard_option(shard_option) + + +if __name__ == '__main__': + test_shard_option() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..75ae0416ef9876795a7966c8f4f1d685583058ce --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py @@ -0,0 +1,178 @@ +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F + +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass +from colossalai._analyzer.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler +from colossalai.auto_parallel.tensor_shard.node_handler.softmax_handler import SoftmaxHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + + +class LinearSplitModel(nn.Module): + + def __init__(self, softmax_dim): + super().__init__() + self.softmax_dim = softmax_dim + + def forward(self, input, other): + linear_node = F.linear(input, other, bias=None) + softmax_node = F.softmax(linear_node, self.softmax_dim) + return softmax_node + + +def check_split_handler(rank, world_size, port, softmax_dim, model_cls): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = model_cls(softmax_dim=softmax_dim).cuda() + + input = torch.rand(8, 16, 64, 32).to('cuda') + other = torch.rand(64, 32).to('cuda') + # index of linear node in computation graph + node_index = 2 + # total number of linear strategies + strategy_number = 23 + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input, other], + meta_arg_names=['input', 'other'], + node_type='following') + tracer = ColoTracer(bias_addition_split=True) + + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %other : torch.Tensor [#users=1] = placeholder[target=other] + # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None}) + # %softmax : [#users=1] = call_method[target=split](args = (%linear,), kwargs = {}) + # return split + meta_args = { + 'input': torch.rand(8, 16, 64, 32).to('meta'), + 'other': torch.rand(64, 32).to('meta'), + } + graph = tracer.trace(model, meta_args=meta_args) + + gm = ColoGraphModule(model, graph) + shape_prop_pass(gm, *meta_args.values()) + + previous_mod_node = list(graph.nodes)[2] + split_node = list(graph.nodes)[3] + split_strategies_vector = StrategiesVector(split_node) + previous_strategies_vector = StrategiesVector(previous_mod_node) + + # build handler + assert len(previous_strategies_vector) == 0 + linear_handler = LinearFunctionHandler(node=previous_mod_node, + device_mesh=device_mesh, + strategies_vector=previous_strategies_vector) + linear_handler.register_strategy(compute_resharding_cost=False) + setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + + softmax_handler = SoftmaxHandler(node=split_node, + device_mesh=device_mesh, + strategies_vector=split_strategies_vector) + + softmax_handler.register_strategy(compute_resharding_cost=False) + + # check operation data mapping + mapping = softmax_handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.data is not None + + assert mapping['input'].name == "linear" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64]) + + assert mapping['softmax_dim'].name == "softmax_dim" + assert mapping['softmax_dim'].data == softmax_dim + assert mapping['softmax_dim'].type == OperationDataType.ARG + + assert mapping['output'].name == "softmax" + assert mapping['output'].data.shape == torch.Size([8, 16, 64, 64]) + assert mapping['output'].logical_shape == torch.Size([8, 16, 64, 64]) + assert mapping['output'].type == OperationDataType.OUTPUT + + # reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node. + assert len(split_strategies_vector) == len(previous_strategies_vector) + strategy_name_list = [strategy.name for strategy in split_strategies_vector] + + if softmax_dim == 0: + assert '[R, R, R, S1] -> [R, R, R, S1]_11' in strategy_name_list + assert '[R, S0, R, S1] -> [R, S0, R, S1]_12' in strategy_name_list + assert '[R, R, S0, S1] -> [R, R, S0, S1]_13' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_14' in strategy_name_list + assert '[R, S1, R, S0] -> [R, S1, R, S0]_15' in strategy_name_list + assert '[R, R, S1, S0] -> [R, R, S1, S0]_16' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_17' in strategy_name_list + assert '[R, S0, R, R] -> [R, S0, R, R]_18' in strategy_name_list + assert '[R, R, S0, R] -> [R, R, S0, R]_19' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_20' in strategy_name_list + assert '[R, S1, R, R] -> [R, S1, R, R]_21' in strategy_name_list + assert '[R, R, S1, R] -> [R, R, S1, R]_22' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_0' in strategy_name_list + assert '[R, S01, R, R] -> [R, S01, R, R]_1' in strategy_name_list + assert '[R, R, S01, R] -> [R, R, S01, R]_2' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list + + if softmax_dim == 1: + assert '[S0, R, R, S1] -> [S0, R, R, S1]_11' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list + assert '[R, R, S0, S1] -> [R, R, S0, S1]_13' in strategy_name_list + assert '[S1, R, R, S0] -> [S1, R, R, S0]_14' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_15' in strategy_name_list + assert '[R, R, S1, S0] -> [R, R, S1, S0]_16' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R]_17' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_18' in strategy_name_list + assert '[R, R, S0, R] -> [R, R, S0, R]_19' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R]_20' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list + assert '[R, R, S1, R] -> [R, R, S1, R]_22' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R]_0' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_1' in strategy_name_list + assert '[R, R, S01, R] -> [R, R, S01, R]_2' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +@parameterize('softmax_dim', [0, 1, 2, 3]) +@parameterize('model_cls', [LinearSplitModel]) +def test_split_handler(softmax_dim, model_cls): + spawn(check_split_handler, 4, softmax_dim=softmax_dim, model_cls=model_cls) + + +if __name__ == '__main__': + test_split_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..f860c629b0a0ca8fb4211e1f7b4efd4d5ccec71c --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py @@ -0,0 +1,260 @@ +import pytest +import torch +import torch.nn as nn + +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass +from colossalai._analyzer.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.tensor_shard.node_handler import SplitHandler +from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler +from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + + +class ConvSplitModel(nn.Module): + + def __init__(self, split_size, split_dim): + super().__init__() + self.split_size = split_size + self.split_dim = split_dim + + def forward(self, input, other): + conv_node = nn.functional.conv2d(input, other, bias=None) + split_node = conv_node.split(self.split_size, dim=self.split_dim) + return split_node + + +class LinearSplitModel(nn.Module): + + def __init__(self, split_size, split_dim): + super().__init__() + self.split_size = split_size + self.split_dim = split_dim + + def forward(self, input, other): + linear_node = nn.functional.linear(input, other, bias=None) + split_node = linear_node.split(self.split_size, dim=self.split_dim) + return split_node + + +def check_split_handler(rank, world_size, port, split_size, split_dim, model_cls): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = model_cls(split_size=split_size, split_dim=split_dim).cuda() + + if model_cls.__name__ == 'ConvSplitModel': + input = torch.rand(8, 8, 66, 66).to('cuda') + other = torch.rand(16, 8, 3, 3).to('cuda') + # index of conv node in computation graph + node_index = 2 + # total number of conv strategies + strategy_number = 16 + if model_cls.__name__ == 'LinearSplitModel': + input = torch.rand(8, 16, 64, 32).to('cuda') + other = torch.rand(64, 32).to('cuda') + # index of linear node in computation graph + node_index = 2 + # total number of linear strategies + strategy_number = 23 + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input, other], + meta_arg_names=['input', 'other'], + node_type='following') + tracer = ColoTracer(bias_addition_split=True) + if model_cls.__name__ == 'ConvSplitModel': + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %other : torch.Tensor [#users=1] = placeholder[target=other] + # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {}) + # %split : [#users=1] = call_method[target=split](args = (%conv2d,), kwargs = {}) + # return split + meta_args = { + 'input': torch.rand(8, 8, 66, 66).to('meta'), + 'other': torch.rand(16, 8, 3, 3).to('meta'), + } + graph = tracer.trace(model, meta_args=meta_args) + + if model_cls.__name__ == 'LinearSplitModel': + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %other : torch.Tensor [#users=1] = placeholder[target=other] + # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None}) + # %split : [#users=1] = call_method[target=split](args = (%linear,), kwargs = {}) + # return split + meta_args = { + 'input': torch.rand(8, 16, 64, 32).to('meta'), + 'other': torch.rand(64, 32).to('meta'), + } + graph = tracer.trace(model, meta_args=meta_args) + + gm = ColoGraphModule(model, graph) + shape_prop_pass(gm, *meta_args.values()) + + previous_mod_node = list(graph.nodes)[2] + split_node = list(graph.nodes)[3] + split_strategies_vector = StrategiesVector(split_node) + previous_strategies_vector = StrategiesVector(previous_mod_node) + + # build handler + if model_cls.__name__ == 'ConvSplitModel': + + conv_handler = ConvFunctionHandler(node=previous_mod_node, + device_mesh=device_mesh, + strategies_vector=previous_strategies_vector) + conv_handler.register_strategy(compute_resharding_cost=False) + setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + + if model_cls.__name__ == 'LinearSplitModel': + assert len(previous_strategies_vector) == 0 + linear_handler = LinearFunctionHandler(node=previous_mod_node, + device_mesh=device_mesh, + strategies_vector=previous_strategies_vector) + linear_handler.register_strategy(compute_resharding_cost=False) + setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + + split_handler = SplitHandler(node=split_node, device_mesh=device_mesh, strategies_vector=split_strategies_vector) + + split_handler.register_strategy(compute_resharding_cost=False) + + # check operation data mapping + mapping = split_handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.data is not None + + if model_cls.__name__ == 'ConvSplitModel': + assert mapping['input'].name == "conv2d" + else: + assert mapping['input'].name == "linear" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64]) + + assert mapping['output'].name == "split" + split_items = torch.empty([8, 16, 64, 64]).split(split_size, split_dim) + assert mapping['output'].logical_shape == tuple([item.shape for item in split_items]) + assert mapping['output'].type == OperationDataType.OUTPUT + + # reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node. + assert len(split_strategies_vector) == len(previous_strategies_vector) + strategy_name_list = [strategy.name for strategy in split_strategies_vector] + + if model_cls.__name__ == 'ConvSplitModel': + + if split_dim == 0: + assert '[R, S1, R, R]_0' in strategy_name_list + assert '[R, S0, R, R]_1' in strategy_name_list + assert '[R, R, R, R]_2' in strategy_name_list + assert '[R, R, R, R]_3' in strategy_name_list + assert '[R, R, R, R]_4' in strategy_name_list + assert '[R, R, R, R]_5' in strategy_name_list + assert '[R, S1, R, R]_6' in strategy_name_list + assert '[R, S0, R, R]_7' in strategy_name_list + assert '[R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R]_9' in strategy_name_list + assert '[R, S0, R, R]_10' in strategy_name_list + assert '[R, S1, R, R]_11' in strategy_name_list + assert '[R, R, R, R]_12' in strategy_name_list + assert '[R, R, R, R]_13' in strategy_name_list + assert '[R, R, R, R]_14' in strategy_name_list + assert '[R, S01, R, R]_15' in strategy_name_list + + if split_dim == 1: + assert '[S0, R, R, R]_0' in strategy_name_list + assert '[S1, R, R, R]_1' in strategy_name_list + assert '[S0, R, R, R]_2' in strategy_name_list + assert '[S1, R, R, R]_3' in strategy_name_list + assert '[S0, R, R, R]_4' in strategy_name_list + assert '[S1, R, R, R]_5' in strategy_name_list + assert '[R, R, R, R]_6' in strategy_name_list + assert '[R, R, R, R]_7' in strategy_name_list + assert '[R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R]_9' in strategy_name_list + assert '[R, R, R, R]_10' in strategy_name_list + assert '[R, R, R, R]_11' in strategy_name_list + assert '[R, R, R, R]_12' in strategy_name_list + assert '[S01, R, R, R]_13' in strategy_name_list + assert '[R, R, R, R]_14' in strategy_name_list + assert '[R, R, R, R]_15' in strategy_name_list + + if model_cls.__name__ == 'LinearSplitModel': + + if split_dim == 0: + assert '[R, R, R, S1]_11' in strategy_name_list + assert '[R, S0, R, S1]_12' in strategy_name_list + assert '[R, R, S0, S1]_13' in strategy_name_list + assert '[R, R, R, S0]_14' in strategy_name_list + assert '[R, S1, R, S0]_15' in strategy_name_list + assert '[R, R, S1, S0]_16' in strategy_name_list + assert '[R, R, R, R]_17' in strategy_name_list + assert '[R, S0, R, R]_18' in strategy_name_list + assert '[R, R, S0, R]_19' in strategy_name_list + assert '[R, R, R, R]_20' in strategy_name_list + assert '[R, S1, R, R]_21' in strategy_name_list + assert '[R, R, S1, R]_22' in strategy_name_list + assert '[R, R, R, S1]_10' in strategy_name_list + assert '[R, R, R, S0]_9' in strategy_name_list + assert '[R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R]_7' in strategy_name_list + assert '[R, R, R, S0]_6' in strategy_name_list + assert '[R, R, R, S1]_5' in strategy_name_list + assert '[R, R, R, R]_0' in strategy_name_list + assert '[R, S01, R, R]_1' in strategy_name_list + assert '[R, R, S01, R]_2' in strategy_name_list + assert '[R, R, R, R]_3' in strategy_name_list + assert '[R, R, R, S01]_4' in strategy_name_list + + if split_dim == 1: + assert '[S0, R, R, S1]_11' in strategy_name_list + assert '[R, R, R, S1]_12' in strategy_name_list + assert '[R, R, S0, S1]_13' in strategy_name_list + assert '[S1, R, R, S0]_14' in strategy_name_list + assert '[R, R, R, S0]_15' in strategy_name_list + assert '[R, R, S1, S0]_16' in strategy_name_list + assert '[S0, R, R, R]_17' in strategy_name_list + assert '[R, R, R, R]_18' in strategy_name_list + assert '[R, R, S0, R]_19' in strategy_name_list + assert '[S1, R, R, R]_20' in strategy_name_list + assert '[R, R, R, R]_21' in strategy_name_list + assert '[R, R, S1, R]_22' in strategy_name_list + assert '[R, R, R, S1]_10' in strategy_name_list + assert '[R, R, R, S0]_9' in strategy_name_list + assert '[R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R]_7' in strategy_name_list + assert '[R, R, R, S0]_6' in strategy_name_list + assert '[R, R, R, S1]_5' in strategy_name_list + assert '[S01, R, R, R]_0' in strategy_name_list + assert '[R, R, R, R]_1' in strategy_name_list + assert '[R, R, S01, R]_2' in strategy_name_list + assert '[R, R, R, R]_3' in strategy_name_list + assert '[R, R, R, S01]_4' in strategy_name_list + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +@parameterize('split_size', [2]) +@parameterize('split_dim', [0, 1, 2]) +@parameterize('model_cls', [ConvSplitModel, LinearSplitModel]) +def test_split_handler(split_size, split_dim, model_cls): + spawn(check_split_handler, 4, split_size=split_size, split_dim=split_dim, model_cls=model_cls) + + +if __name__ == '__main__': + test_split_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..c11291ecac969f4f2a93b12646baa546ad231f13 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py @@ -0,0 +1,230 @@ +import pytest +import torch +import torch.nn as nn + +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass +from colossalai._analyzer.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler +from colossalai.auto_parallel.tensor_shard.node_handler.sum_handler import SumHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + + +class LinearSumModel(nn.Module): + + def __init__(self, sum_dims, keepdim): + super().__init__() + self.sum_dims = sum_dims + self.keepdim = keepdim + + def forward(self, input, other): + linear_node = nn.functional.linear(input, other, bias=None) + if self.sum_dims is not None: + sum_node = torch.sum(linear_node, self.sum_dims, keepdim=self.keepdim) + else: + sum_node = torch.sum(linear_node, keepdim=self.keepdim) + return sum_node + + +def check_sum_handler(rank, world_size, port, sum_dims, keepdim): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = LinearSumModel(sum_dims=sum_dims, keepdim=keepdim).cuda() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + input = torch.rand(8, 16, 64, 32).to('cuda') + other = torch.rand(64, 32).to('cuda') + # index of linear node in computation graph + node_index = 2 + # total number of linear strategies + strategy_number = 24 + + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input, other], + meta_arg_names=['input', 'other'], + node_type='following') + + tracer = ColoTracer(bias_addition_split=True) + + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %other : torch.Tensor [#users=1] = placeholder[target=other] + # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None}) + # %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%linear,), kwargs = {}) + # return sum_1 + meta_args = { + "input": torch.rand(8, 16, 64, 32).to('meta'), + "other": torch.rand(64, 32).to('meta'), + } + graph = tracer.trace(model, meta_args=meta_args) + gm = ColoGraphModule(model, graph) + shape_prop_pass(gm, *meta_args.values()) + + previous_mod_node = list(graph.nodes)[2] + sum_node = list(graph.nodes)[3] + sum_strategies_vector = StrategiesVector(sum_node) + previous_strategies_vector = StrategiesVector(previous_mod_node) + + # build handler + + assert len(previous_strategies_vector) == 0 + linear_handler = LinearFunctionHandler(node=previous_mod_node, + device_mesh=device_mesh, + strategies_vector=previous_strategies_vector) + linear_handler.register_strategy(compute_resharding_cost=False) + setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + + sum_handler = SumHandler(node=sum_node, device_mesh=device_mesh, strategies_vector=sum_strategies_vector) + + sum_handler.register_strategy(compute_resharding_cost=False) + + # sum handler is a following strategy handler, so the number of strategies is equal to the predecessor node. + assert len(sum_strategies_vector) == len(previous_strategies_vector) + strategy_name_list = [strategy.name for strategy in sum_strategies_vector] + + # check operation data mapping + mapping = sum_handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.data is not None + + assert mapping['input'].name == "linear" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64]) + + assert mapping['output'].name == "sum_1" + sum_node_shape = torch.empty([8, 16, 64, 64]).sum(sum_dims, keepdim=keepdim).shape + assert mapping['output'].logical_shape == sum_node_shape + assert mapping['output'].type == OperationDataType.OUTPUT + + # check strategy name + if sum_dims == (0, 2) and keepdim == False: + assert '[R, R, R, R] -> [R, R]_0' in strategy_name_list + assert '[R, S01, R, R] -> [S01, R]_1' in strategy_name_list + assert '[R, R, R, R] -> [R, R]_2' in strategy_name_list + assert '[R, R, R, R] -> [R, R]_3' in strategy_name_list + assert '[R, R, R, S01] -> [R, S01]_4' in strategy_name_list + assert '[R, R, R, S1] -> [R, S1]_5' in strategy_name_list + assert '[R, R, R, S0] -> [R, S0]_6' in strategy_name_list + assert '[R, R, R, R] -> [R, R]_7' in strategy_name_list + assert '[R, R, R, R] -> [R, R]_8' in strategy_name_list + assert '[R, R, R, S0] -> [R, S0]_9' in strategy_name_list + assert '[R, R, R, S1] -> [R, S1]_10' in strategy_name_list + assert '[R, R, R, S1] -> [R, S1]_11' in strategy_name_list + assert '[R, S0, R, S1] -> [S0, S1]_12' in strategy_name_list + assert '[R, R, R, S1] -> [R, S1]_13' in strategy_name_list + assert '[R, R, R, S0] -> [R, S0]_14' in strategy_name_list + assert '[R, S1, R, S0] -> [S1, S0]_15' in strategy_name_list + assert '[R, R, R, S0] -> [R, S0]_16' in strategy_name_list + assert '[R, R, R, R] -> [R, R]_17' in strategy_name_list + assert '[R, S0, R, R] -> [S0, R]_18' in strategy_name_list + assert '[R, R, R, R] -> [R, R]_19' in strategy_name_list + assert '[R, R, R, R] -> [R, R]_20' in strategy_name_list + assert '[R, S1, R, R] -> [S1, R]_21' in strategy_name_list + assert '[R, R, R, R] -> [R, R]_22' in strategy_name_list + assert '[R, R, R, R] -> [R, R]_23' in strategy_name_list + + if sum_dims == (0, 2) and keepdim == True: + assert '[R, R, R, R] -> [R, R, R, R]_0' in strategy_name_list + assert '[R, S01, R, R] -> [R, S01, R, R]_1' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_2' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_11' in strategy_name_list + assert '[R, S0, R, S1] -> [R, S0, R, S1]_12' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_13' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_14' in strategy_name_list + assert '[R, S1, R, S0] -> [R, S1, R, S0]_15' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_17' in strategy_name_list + assert '[R, S0, R, R] -> [R, S0, R, R]_18' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_19' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_20' in strategy_name_list + assert '[R, S1, R, R] -> [R, S1, R, R]_21' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_22' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_23' in strategy_name_list + + if sum_dims == 1 and keepdim == False: + assert '[S01, R, R, R] -> [S01, R, R]_0' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R]_1' in strategy_name_list + assert '[R, R, S01, R] -> [R, S01, R]_2' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R]_3' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, S01]_4' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, S1]_5' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, S0]_6' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R]_7' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R]_8' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, S0]_9' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, S1]_10' in strategy_name_list + assert '[S0, R, R, S1] -> [S0, R, S1]_11' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, S1]_12' in strategy_name_list + assert '[R, R, S0, S1] -> [R, S0, S1]_13' in strategy_name_list + assert '[S1, R, R, S0] -> [S1, R, S0]_14' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, S0]_15' in strategy_name_list + assert '[R, R, S1, S0] -> [R, S1, S0]_16' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R]_17' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R]_18' in strategy_name_list + assert '[R, R, S0, R] -> [R, S0, R]_19' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R]_20' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R]_21' in strategy_name_list + assert '[R, R, S1, R] -> [R, S1, R]_22' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R]_23' in strategy_name_list + + if sum_dims == 1 and keepdim == True: + assert '[S01, R, R, R] -> [S01, R, R, R]_0' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_1' in strategy_name_list + assert '[R, R, S01, R] -> [R, R, S01, R]_2' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list + assert '[S0, R, R, S1] -> [S0, R, R, S1]_11' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list + assert '[R, R, S0, S1] -> [R, R, S0, S1]_13' in strategy_name_list + assert '[S1, R, R, S0] -> [S1, R, R, S0]_14' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_15' in strategy_name_list + assert '[R, R, S1, S0] -> [R, R, S1, S0]_16' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R]_17' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_18' in strategy_name_list + assert '[R, R, S0, R] -> [R, R, S0, R]_19' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R]_20' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list + assert '[R, R, S1, R] -> [R, R, S1, R]_22' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_23' in strategy_name_list + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +@parameterize('sum_dims', [(0, 2), 1]) +@parameterize('keepdim', [False, True]) +def test_sum_handler(sum_dims, keepdim): + spawn(check_sum_handler, 4, sum_dims=sum_dims, keepdim=keepdim) + + +if __name__ == '__main__': + test_sum_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_tensor_constructor.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_tensor_constructor.py new file mode 100644 index 0000000000000000000000000000000000000000..5b6ac051a8ef7944f35bb3e94a1170e3bfd8efc2 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_tensor_constructor.py @@ -0,0 +1,71 @@ +import torch +import torch.nn as nn + +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass +from colossalai._analyzer.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.tensor_shard.node_handler.tensor_constructor_handler import TensorConstructorHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.testing import clear_cache_before_run, run_on_environment_flag + + +class TensorConstructorModel(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x): + arange_node = torch.arange(x.size()[0]) + x = x + arange_node + return x + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@clear_cache_before_run() +def test_where_handler(): + model = TensorConstructorModel() + tracer = ColoTracer(bias_addition_split=True) + # graph(): + # %x : torch.Tensor [#users=2] = placeholder[target=x] + # %size : [#users=1] = call_method[target=size](args = (%x,), kwargs = {}) + # %getitem : [#users=1] = call_function[target=operator.getitem](args = (%size, 0), kwargs = {}) + # %arange : [#users=1] = call_function[target=torch.arange](args = (%getitem,), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%x, %arange), kwargs = {}) + # return add + meta_args = {'x': torch.rand(10).to('meta')} + graph = tracer.trace(model, meta_args=meta_args) + gm = ColoGraphModule(model, graph) + shape_prop_pass(gm, *meta_args.values()) + physical_mesh_id = torch.arange(0, 4) + + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + arange_node = list(graph.nodes)[3] + strategies_vector = StrategiesVector(arange_node) + + # build handler + handler = TensorConstructorHandler(node=arange_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['output'].name == "arange" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([10]) + assert mapping['output'].type == OperationDataType.OUTPUT + + handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + assert 'Replica Tensor Constructor' in strategy_name_list + + +if __name__ == '__main__': + test_where_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..f4e6dafdfd692e5c1ef5a5862c70f61ac31d6994 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py @@ -0,0 +1,89 @@ +import torch +import torch.nn as nn + +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass +from colossalai._analyzer.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler +from colossalai.auto_parallel.tensor_shard.node_handler.unary_elementwise_handler import UnaryElementwiseHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.testing import clear_cache_before_run, run_on_environment_flag + + +class ReLuModel(nn.Module): + + def __init__(self): + super().__init__() + self.act = torch.nn.ReLU() + + def forward(self, input, other): + conv_node = nn.functional.conv2d(input, other) + relu_node = self.act(conv_node) + return relu_node + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@clear_cache_before_run() +def test_elementwise_handler(): + model = ReLuModel() + tracer = ColoTracer(bias_addition_split=True) + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %other : torch.Tensor [#users=1] = placeholder[target=other] + # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {}) + # %act : [#users=1] = call_module[target=act](args = (%conv2d,), kwargs = {}) + # return act + meta_args = { + 'input': torch.rand(4, 4, 64, 64).to('meta'), + 'other': torch.rand(16, 4, 3, 3).to('meta'), + } + graph = tracer.trace(model, meta_args=meta_args) + gm = ColoGraphModule(model, graph) + shape_prop_pass(gm, *meta_args.values()) + physical_mesh_id = torch.arange(0, 4) + + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + conv_mod_node = list(graph.nodes)[2] + relu_mod_node = list(graph.nodes)[3] + relu_strategies_vector = StrategiesVector(relu_mod_node) + conv_strategies_vector = StrategiesVector(conv_mod_node) + + # build handler + conv_handler = ConvFunctionHandler(node=conv_mod_node, + device_mesh=device_mesh, + strategies_vector=conv_strategies_vector) + conv_handler.register_strategy(compute_resharding_cost=False) + setattr(conv_mod_node, 'strategies_vector', conv_strategies_vector) + relu_handler = UnaryElementwiseHandler(node=relu_mod_node, + device_mesh=device_mesh, + strategies_vector=relu_strategies_vector) + + relu_handler.register_strategy(compute_resharding_cost=False) + + # check operation data mapping + mapping = relu_handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.data is not None + + assert mapping['input'].name == "conv2d" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 16, 62, 62]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 16, 62, 62]) + + assert mapping['output'].name == "act" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([4, 16, 62, 62]) + assert mapping['output'].type == OperationDataType.OUTPUT + + # getitem is a following strategy handler, so the number of strategies is equal to the predecessor node. + assert len(relu_strategies_vector) == len(conv_strategies_vector) + + +if __name__ == '__main__': + test_elementwise_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..fbb194d8e0b8d119cec8eb27f09657ee95dbeaaf --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py @@ -0,0 +1,258 @@ +import pytest +import torch +import torch.nn as nn + +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass +from colossalai._analyzer.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.tensor_shard.node_handler import ViewHandler +from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler +from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + + +class ConvViewModel(nn.Module): + + def __init__(self, tgt_shape): + super().__init__() + self.tgt_shape = tgt_shape + + def forward(self, input, other): + conv_node = nn.functional.conv2d(input, other, bias=None) + reshape_node = conv_node.view(*self.tgt_shape) + return reshape_node + + +class LinearViewModel(nn.Module): + + def __init__(self, tgt_shape): + super().__init__() + self.tgt_shape = tgt_shape + + def forward(self, input, other): + linear_node = nn.functional.linear(input, other, bias=None) + reshape_node = linear_node.view(*self.tgt_shape) + return reshape_node + + +def check_view_handler(rank, tgt_shape, model_cls, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = model_cls(tgt_shape).cuda() + + if model_cls.__name__ == 'ConvViewModel': + input = torch.rand(8, 8, 66, 66).to('cuda') + other = torch.rand(16, 8, 3, 3).to('cuda') + # index of conv node in computation graph + node_index = 2 + # total number of conv strategies + strategy_number = 16 + if model_cls.__name__ == 'LinearViewModel': + input = torch.rand(8, 16, 64, 32).to('cuda') + other = torch.rand(64, 32).to('cuda') + # index of linear node in computation graph + node_index = 2 + # total number of linear strategies + strategy_number = 23 + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input, other], + meta_arg_names=['input', 'other'], + node_type='following') + tracer = ColoTracer(bias_addition_split=True) + if model_cls.__name__ == 'ConvViewModel': + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %other : torch.Tensor [#users=1] = placeholder[target=other] + # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {}) + # %view : [#users=1] = call_method[target=view](args = (%conv2d, 2, -1), kwargs = {}) + # return view + meta_args = {'input': torch.rand(8, 8, 66, 66).to('meta'), 'other': torch.rand(16, 8, 3, 3).to('meta')} + graph = tracer.trace(model, meta_args=meta_args) + + if model_cls.__name__ == 'LinearViewModel': + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %other : torch.Tensor [#users=1] = placeholder[target=other] + # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None}) + # %view : [#users=1] = call_method[target=view](args = (%linear, 32, 4, 32, 32, 4), kwargs = {}) + # return view + meta_args = { + 'input': torch.rand(8, 16, 64, 32).to('meta'), + 'other': torch.rand(64, 32).to('meta'), + } + graph = tracer.trace(model, meta_args=meta_args) + + gm = ColoGraphModule(model, graph) + shape_prop_pass(gm, *meta_args.values()) + + previous_mod_node = list(graph.nodes)[2] + view_node = list(graph.nodes)[3] + view_strategies_vector = StrategiesVector(view_node) + previous_strategies_vector = StrategiesVector(previous_mod_node) + + # build handler + if model_cls.__name__ == 'ConvViewModel': + + conv_handler = ConvFunctionHandler(node=previous_mod_node, + device_mesh=device_mesh, + strategies_vector=previous_strategies_vector) + conv_handler.register_strategy(compute_resharding_cost=False) + setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + + if model_cls.__name__ == 'LinearViewModel': + assert len(previous_strategies_vector) == 0 + linear_handler = LinearFunctionHandler(node=previous_mod_node, + device_mesh=device_mesh, + strategies_vector=previous_strategies_vector) + linear_handler.register_strategy(compute_resharding_cost=False) + setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + + view_handler = ViewHandler(node=view_node, device_mesh=device_mesh, strategies_vector=view_strategies_vector) + + view_handler.register_strategy(compute_resharding_cost=False) + + # check operation data mapping + mapping = view_handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.data is not None + + if model_cls.__name__ == 'ConvViewModel': + assert mapping['input'].name == "conv2d" + else: + assert mapping['input'].name == "linear" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64]) + + assert mapping['output'].name == "view" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size(tgt_shape) + assert mapping['output'].type == OperationDataType.OUTPUT + + # reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node. + assert len(view_strategies_vector) == len(previous_strategies_vector) + strategy_name_list = [strategy.name for strategy in view_strategies_vector] + + if model_cls.__name__ == 'ConvViewModel': + + if tgt_shape == (32, 4, 64, 16, 4): + assert '[S0, S1, R, R] -> FULLY REPLICATED_0' in strategy_name_list + assert '[S1, S0, R, R] -> FULLY REPLICATED_1' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R, R]_2' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R, R]_3' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R, R]_4' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R, R]_5' in strategy_name_list + assert '[R, S1, R, R] -> FULLY REPLICATED_6' in strategy_name_list + assert '[R, S0, R, R] -> FULLY REPLICATED_7' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R]_9' in strategy_name_list + assert '[R, S0, R, R] -> FULLY REPLICATED_10' in strategy_name_list + assert '[R, S1, R, R] -> FULLY REPLICATED_11' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R]_12' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R, R]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R]_14' in strategy_name_list + assert '[R, S01, R, R] -> FULLY REPLICATED_15' in strategy_name_list + + if tgt_shape == (8, 4, 4, 64, 16, 4): + assert '[S0, S1, R, R] -> [S0, S1, R, R, R, R]_0' in strategy_name_list + assert '[S1, S0, R, R] -> [S1, S0, R, R, R, R]_1' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R, R, R]_2' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R, R, R]_3' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R, R, R]_4' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R, R, R]_5' in strategy_name_list + assert '[R, S1, R, R] -> [R, S1, R, R, R, R]_6' in strategy_name_list + assert '[R, S0, R, R] -> [R, S0, R, R, R, R]_7' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R, R]_9' in strategy_name_list + assert '[R, S0, R, R] -> [R, S0, R, R, R, R]_10' in strategy_name_list + assert '[R, S1, R, R] -> [R, S1, R, R, R, R]_11' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R, R]_12' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R, R, R]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R, R]_14' in strategy_name_list + assert '[R, S01, R, R] -> [R, S01, R, R, R, R]_15' in strategy_name_list + + if model_cls.__name__ == 'LinearViewModel': + + if tgt_shape == (32, 4, 64, 16, 4): + for strategy in strategy_name_list: + print(strategy) + # print(strategy_name_list) + assert '[S0, R, R, S1] -> [S0, R, R, S1, R]_11' in strategy_name_list + assert '[R, S0, R, S1] -> FULLY REPLICATED_12' in strategy_name_list + assert '[R, R, S0, S1] -> [R, R, S0, S1, R]_13' in strategy_name_list + assert '[S1, R, R, S0] -> [S1, R, R, S0, R]_14' in strategy_name_list + assert '[R, S1, R, S0] -> FULLY REPLICATED_15' in strategy_name_list + assert '[R, R, S1, S0] -> [R, R, S1, S0, R]_16' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R, R]_17' in strategy_name_list + assert '[R, S0, R, R] -> FULLY REPLICATED_18' in strategy_name_list + assert '[R, R, S0, R] -> [R, R, S0, R, R]_19' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R, R]_20' in strategy_name_list + assert '[R, S1, R, R] -> FULLY REPLICATED_21' in strategy_name_list + assert '[R, R, S1, R] -> [R, R, S1, R, R]_22' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1, R]_10' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0, R]_9' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R]_7' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0, R]_6' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1, R]_5' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R, R]_0' in strategy_name_list + assert '[R, S01, R, R] -> FULLY REPLICATED_1' in strategy_name_list + assert '[R, R, S01, R] -> [R, R, S01, R, R]_2' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R]_3' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, R, S01, R]_4' in strategy_name_list + + if tgt_shape == (8, 4, 4, 64, 16, 4): + assert '[S0, R, R, S1] -> [S0, R, R, R, S1, R]_11' in strategy_name_list + assert '[R, S0, R, S1] -> [R, S0, R, R, S1, R]_12' in strategy_name_list + assert '[R, R, S0, S1] -> [R, R, R, S0, S1, R]_13' in strategy_name_list + assert '[S1, R, R, S0] -> [S1, R, R, R, S0, R]_14' in strategy_name_list + assert '[R, S1, R, S0] -> [R, S1, R, R, S0, R]_15' in strategy_name_list + assert '[R, R, S1, S0] -> [R, R, R, S1, S0, R]_16' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R, R, R]_17' in strategy_name_list + assert '[R, S0, R, R] -> [R, S0, R, R, R, R]_18' in strategy_name_list + assert '[R, R, S0, R] -> [R, R, R, S0, R, R]_19' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R, R, R]_20' in strategy_name_list + assert '[R, S1, R, R] -> [R, S1, R, R, R, R]_21' in strategy_name_list + assert '[R, R, S1, R] -> [R, R, R, S1, R, R]_22' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, R, S1, R]_10' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, R, S0, R]_9' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R, R]_7' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, R, S0, R]_6' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, R, S1, R]_5' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R, R, R]_0' in strategy_name_list + assert '[R, S01, R, R] -> [R, S01, R, R, R, R]_1' in strategy_name_list + assert '[R, R, S01, R] -> [R, R, R, S01, R, R]_2' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R, R]_3' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, R, R, S01, R]_4' in strategy_name_list + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +@parameterize('tgt_shape', [(32, 4, 64, 16, 4), (8, 4, 4, 64, 16, 4)]) +@parameterize('model_cls', [ConvViewModel, LinearViewModel]) +def test_view_handler(tgt_shape, model_cls): + spawn(check_view_handler, 4, tgt_shape=tgt_shape, model_cls=model_cls) + + +if __name__ == '__main__': + test_view_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_where_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_where_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..bd7635ac1737424d1f78c2155f2d3308d7d343b0 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_where_handler.py @@ -0,0 +1,92 @@ +import pytest +import torch +import torch.nn as nn + +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass +from colossalai._analyzer.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.tensor_shard.node_handler.where_handler import WhereHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.testing import clear_cache_before_run + + +class ConvModel(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, condition, x, y): + output = torch.where(condition, x, y) + return output + + +@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') +@clear_cache_before_run() +def test_where_handler(): + model = ConvModel() + tracer = ColoTracer(bias_addition_split=True) + # graph(): + # %condition : torch.Tensor [#users=1] = placeholder[target=condition] + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %y : torch.Tensor [#users=1] = placeholder[target=y] + # %where : [#users=1] = call_function[target=torch.where](args = (%condition, %x, %y), kwargs = {}) + # return where + meta_args = { + 'condition': torch.rand(4, 4, 64, 64).to('meta'), + 'x': torch.rand(4, 1, 64, 64).to('meta'), + 'y': torch.rand(1, 4, 64, 64).to('meta') + } + graph = tracer.trace(model, meta_args=meta_args) + gm = ColoGraphModule(model, graph) + shape_prop_pass(gm, *meta_args.values()) + physical_mesh_id = torch.arange(0, 4) + + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + where_node = list(graph.nodes)[3] + strategies_vector = StrategiesVector(where_node) + + # build handler + handler = WhereHandler(node=where_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping, _ = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['condition'].name == "condition" + assert mapping['condition'].data.is_meta + assert mapping['condition'].data.shape == torch.Size([4, 4, 64, 64]) + assert mapping['condition'].type == OperationDataType.ARG + assert mapping['condition'].logical_shape == torch.Size([4, 4, 64, 64]) + + assert mapping['x'].name == "x" + assert mapping['x'].data.is_meta + assert mapping['x'].data.shape == torch.Size([4, 1, 64, 64]) + assert mapping['x'].type == OperationDataType.ARG + assert mapping['x'].logical_shape == torch.Size([4, 4, 64, 64]) + + assert mapping['y'].name == "y" + assert mapping['y'].data.is_meta + assert mapping['y'].data.shape == torch.Size([1, 4, 64, 64]) + assert mapping['y'].type == OperationDataType.ARG + assert mapping['y'].logical_shape == torch.Size([4, 4, 64, 64]) + + assert mapping['output'].name == "where" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([4, 4, 64, 64]) + assert mapping['output'].type == OperationDataType.OUTPUT + + handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + # 4*3 + 4*3/2*2 + 1 + assert len(strategy_name_list) == 25 + + +if __name__ == '__main__': + test_where_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..28a8bbd9a4c11df2a2fbf324f6873c6d2c75ae4f --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py @@ -0,0 +1,192 @@ +import copy +from typing import Dict, List + +import torch +from torch.fx import GraphModule + +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass +from colossalai._analyzer.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass +from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass +from colossalai.auto_parallel.tensor_shard.options import SolverOptions +from colossalai.auto_parallel.tensor_shard.solver import StrategiesConstructor +from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph +from colossalai.auto_parallel.tensor_shard.solver.solver import Solver +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.shape_consistency import to_global +from colossalai.testing.comparison import assert_close + + +def _build_model_to_compare(model: torch.nn.Module, input_args: List[torch.Tensor], + input_kwargs: Dict[str, torch.Tensor], grad_dict: Dict[any, torch.Tensor]): + + model_to_compare = copy.deepcopy(model) + args_to_compare = [] + kwargs_to_compare = {} + for arg_index, input_tensor in enumerate(input_args): + + def wrapper(param, index): + + def hook_fn(grad): + grad_dict[index] = grad + + param.register_hook(hook_fn) + + arg_to_compare = copy.deepcopy(input_tensor) + + # only Tensors of floating point and complex dtype can require gradients + if arg_to_compare.dtype != torch.int64: + arg_to_compare.requires_grad = True + wrapper(arg_to_compare, arg_index) + + args_to_compare.append(arg_to_compare) + + for name, input_kwarg in input_kwargs.items(): + + def wrapper(param, name): + + def hook_fn(grad): + grad_dict[name] = grad + + param.register_hook(hook_fn) + + kwarg_to_compare = copy.deepcopy(input_kwarg) + + # only Tensors of floating point and complex dtype can require gradients + if kwarg_to_compare.dtype != torch.int64: + kwarg_to_compare.requires_grad = True + wrapper(kwarg_to_compare, name) + + kwargs_to_compare[name] = kwarg_to_compare + + return model_to_compare, args_to_compare, kwargs_to_compare + + +def numerical_test_for_node_strategy(model: torch.nn.Module, + device_mesh: DeviceMesh, + node_index: int, + strategy_number: int, + input_args: List[torch.Tensor], + meta_arg_names: List[str], + input_kwargs: Dict[str, torch.Tensor] = {}, + node_type: str = 'normal'): + for strategy_index in range(strategy_number): + print(f'#strategy_index: {strategy_index}') + # We need to copy the model to avoid do backward more than once in same graph + grad_to_compare_dict = {} + grad_to_shard_dict = {} + model_to_compare, args_to_compare, kwargs_to_compare = _build_model_to_compare( + model, input_args, input_kwargs, grad_to_compare_dict) + model_to_shard, args_to_shard, kwargs_to_shard = _build_model_to_compare(model, input_args, input_kwargs, + grad_to_shard_dict) + + tracer = ColoTracer(bias_addition_split=True) + input_sample = {} + for input_arg, meta_arg_name in zip(input_args, meta_arg_names): + input_sample[meta_arg_name] = torch.empty(input_arg.shape, dtype=input_arg.dtype).to('meta') + for meta_kwarg_name, input_kwarg in input_kwargs.items(): + input_sample[meta_kwarg_name] = torch.empty(input_kwarg.shape, dtype=input_kwarg.dtype).to('meta') + graph = tracer.trace(root=model_to_shard, meta_args=input_sample) + gm = ColoGraphModule(model_to_shard, graph, model_to_shard.__class__.__name__) + shape_prop_pass(gm, *input_sample.values()) + + solver_options = SolverOptions() + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + strategies_constructor.build_strategies_and_cost() + target_node = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies + ][node_index] + if node_type == 'normal': + solution_len = len(strategies_constructor.leaf_strategies) + solution = [0] * solution_len + solution[node_index] = strategy_index + elif node_type == 'following': + solution_len = len(strategies_constructor.leaf_strategies) + solution = [0] * solution_len + solution[node_index] = strategy_index + solution[node_index + 1] = strategy_index + else: + node_vector = strategies_constructor.leaf_strategies[node_index] + strategy_to_keep = node_vector[strategy_index] + node_vector = [strategy_to_keep] + # solution construction + cost_graph = CostGraph(strategies_constructor.leaf_strategies) + cost_graph.simplify_graph() + solver = Solver(gm.graph, strategies_constructor, cost_graph, verbose=False) + ret = solver.call_solver_serialized_args() + solution = list(ret[0]) + gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass( + gm, solution, device_mesh, strategies_constructor) + gm = runtime_apply_pass(gm) + gm.recompile() + + # forward result compare + output = gm(*args_to_shard, + sharding_spec_convert_dict=sharding_spec_dict, + origin_node_sharding_spec_dict=origin_spec_dict, + comm_actions_dict=comm_actions_dict, + **kwargs_to_shard) + output_to_compare = model_to_compare(*args_to_compare, **kwargs_to_compare) + assert_close_helper(output, output_to_compare, strategy_index=strategy_index, type='forward output') + + # backward result compare + if isinstance(output, (tuple, list)): + loss = output[0].sum() + loss_to_compare = output_to_compare[0].sum() + else: + loss = output.sum() + loss_to_compare = output_to_compare.sum() + + loss_to_compare.backward() + loss.backward() + for key in grad_to_shard_dict.keys(): + grad_to_shard = grad_to_shard_dict[key] + grad_to_compare = grad_to_compare_dict[key] + assert_close_helper(grad_to_shard, grad_to_compare, strategy_index=strategy_index, type='input grad') + # extract the strategy used in this iter + strategy_in_use = target_node.strategies_vector[strategy_index] + param_to_shard_dict = dict(gm.named_parameters()) + param_to_compare_dict = dict(model_to_compare.named_parameters()) + for name in param_to_shard_dict.keys(): + param_name = name.split('.')[-1] + if node_type == 'normal': + param_sharding_spec = strategy_in_use.get_sharding_spec_by_name(param_name) + else: + if 'weight' in name: + param_sharding_spec = None + + for node in list(graph.nodes): + if 'weight' in node.name: + param_sharding_spec = node.sharding_spec + + elif 'bias' in name: + param_sharding_spec = None + + for node in list(graph.nodes): + if 'bias' in node.name: + param_sharding_spec = node.sharding_spec + + assert param_sharding_spec is not None + grad_sharded = param_to_shard_dict[name].grad + grad_to_compare = param_to_compare_dict[name].grad + global_grad = to_global(grad_sharded, param_sharding_spec) + assert_close_helper(global_grad, grad_to_compare, strategy_index=strategy_index, type='param grad') + + +def assert_close_helper(first: torch.Tensor, + second: torch.Tensor, + rtol: float = 1e-2, + atol: float = 1e-2, + strategy_index: int = -1, + type: str = 'not defined'): + """ + This method is used to check whether the average difference between two tensors is as close as expected. + """ + try: + if isinstance(first, (tuple, list)): + for first_element, second_element in zip(first, second): + assert_close(first_element, second_element, rtol=rtol, atol=atol) + else: + assert_close(first, second, rtol=rtol, atol=atol) + except: + print(f'strategy index {strategy_index} encounter assert_close error on {type}') diff --git a/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py b/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..0d93e4e4052792158729b382f4f679996bcc48e1 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py @@ -0,0 +1,96 @@ +import torch +from torch.fx import GraphModule +from torchvision.models import resnet50 + +from colossalai._analyzer.fx.passes import shape_prop_pass +# from colossalai.fx.tracer.tracer import ColoTracer +from colossalai._analyzer.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP +from colossalai.auto_parallel.tensor_shard.options import SolverOptions +from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.testing import clear_cache_before_run, run_on_environment_flag + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@clear_cache_before_run() +def test_cost_graph(): + physical_mesh_id = torch.arange(0, 8) + mesh_shape = (2, 4) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + shape_consistency_manager = ShapeConsistencyManager() + + tracer = ColoTracer(bias_addition_split=True) + model = resnet50(num_classes=100000) + input_sample = {'x': torch.rand(128, 3, 224, 224).to('meta')} + + graph = tracer.trace(root=model, meta_args=input_sample) + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %conv1 : [#users=1] = call_module[target=conv1](args = (%x,), kwargs = {}) + # %bn1 : [#users=1] = call_module[target=bn1](args = (%conv1,), kwargs = {}) + # %relu : [#users=1] = call_module[target=relu](args = (%bn1,), kwargs = {}) + # %maxpool : [#users=2] = call_module[target=maxpool](args = (%relu,), kwargs = {}) + # %layer1_0_conv1 : [#users=1] = call_module[target=layer1.0.conv1](args = (%maxpool,), kwargs = {}) + # %layer1_0_bn1 : [#users=1] = call_module[target=layer1.0.bn1](args = (%layer1_0_conv1,), kwargs = {}) + # %layer1_0_relu : [#users=1] = call_module[target=layer1.0.relu](args = (%layer1_0_bn1,), kwargs = {}) + # %layer1_0_conv2 : [#users=1] = call_module[target=layer1.0.conv2](args = (%layer1_0_relu,), kwargs = {}) + # %layer1_0_bn2 : [#users=1] = call_module[target=layer1.0.bn2](args = (%layer1_0_conv2,), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%layer1_0_bn2, %maxpool), kwargs = {}) + # %layer1_0_relu_1 : [#users=2] = call_module[target=layer1.0.relu](args = (%add,), kwargs = {}) + # %layer1_1_conv1 : [#users=1] = call_module[target=layer1.1.conv1](args = (%layer1_0_relu_1,), kwargs = {}) + # %layer1_1_bn1 : [#users=1] = call_module[target=layer1.1.bn1](args = (%layer1_1_conv1,), kwargs = {}) + # %layer1_1_relu : [#users=1] = call_module[target=layer1.1.relu](args = (%layer1_1_bn1,), kwargs = {}) + # %layer1_1_conv2 : [#users=1] = call_module[target=layer1.1.conv2](args = (%layer1_1_relu,), kwargs = {}) + # %layer1_1_bn2 : [#users=1] = call_module[target=layer1.1.bn2](args = (%layer1_1_conv2,), kwargs = {}) + # %add_1 : [#users=1] = call_function[target=operator.add](args = (%layer1_1_bn2, %layer1_0_relu_1), kwargs = {}) + # ... + # %avgpool : [#users=1] = call_module[target=avgpool](args = (%layer4_2_relu_1,), kwargs = {}) + # %flatten : [#users=1] = call_function[target=torch.flatten](args = (%avgpool, 1), kwargs = {}) + # %fc : [#users=1] = call_module[target=fc](args = (%flatten,), kwargs = {}) + # return fc + gm = GraphModule(model, graph, model.__class__.__name__) + shape_prop_pass(gm, *input_sample.values()) + gm.recompile() + + solver_options = SolverOptions() + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + strategies_constructor.build_strategies_and_cost() + + cost_graph = CostGraph(strategies_constructor.leaf_strategies) + cost_graph.simplify_graph() + solver = Solver(gm.graph, strategies_constructor, cost_graph) + + ret = solver.call_solver_serialized_args() + print(ret[0]) + print(solver.last_s_val) + strategies_list = solver.last_s_val + + computation_cost = 0 + communication_cost = 0 + communication_cost_bn = 0 + memory_cost = 0 + for index, node in enumerate(graph.nodes): + if node.op == 'call_module': + submod = node.graph.owning_module.get_submodule(node.target) + if type(submod) in BATCHNORM_MODULE_OP: + communication_cost_bn += node.strategies_vector[strategies_list[index]].communication_cost.total + print(node.name, node.strategies_vector[strategies_list[index]].name) + computation_cost += node.strategies_vector[strategies_list[index]].compute_cost.total + communication_cost += node.strategies_vector[strategies_list[index]].communication_cost.total + node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost.total + if isinstance(node_memory_cost, tuple): + node_memory_cost = node_memory_cost[0] + memory_cost += node_memory_cost.activation + node_memory_cost.parameter + + print(f'computation cost is {computation_cost}') + print(f'communication cost is {communication_cost}') + print(f'memory cost is {memory_cost}') + print(f'bn communication cost is {communication_cost_bn}') + + +if __name__ == '__main__': + test_cost_graph() diff --git a/tests/test_autochunk/test_autochunk_alphafold/benchmark_autochunk_alphafold.py b/tests/test_autochunk/test_autochunk_alphafold/benchmark_autochunk_alphafold.py new file mode 100644 index 0000000000000000000000000000000000000000..d07145e48e1f5e86a9c1dcf9e1ed5644bee513aa --- /dev/null +++ b/tests/test_autochunk/test_autochunk_alphafold/benchmark_autochunk_alphafold.py @@ -0,0 +1,140 @@ +import time +from typing import Any, Dict, List + +import torch +import torch.fx + +import colossalai +from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.testing import free_port + +if AUTOCHUNK_AVAILABLE: + from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen + from colossalai.fx.profiler import MetaTensor + from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace + + +def _benchmark_evoformer_stack_gm( + data_args: tuple, + max_memory: int, + get_model: Any, + get_data: Any, +) -> None: + # build model and input + model = get_model().cpu().eval() + meta_args, concrete_args = get_data(*data_args) + if concrete_args is None: + concrete_args = [] + + # trace the meta graph and setup codegen + meta_graph = symbolic_trace( + model, + meta_args={k: v.to(torch.device("meta")) for k, v in meta_args}, + concrete_args={k: v for k, v in concrete_args}, + ) + interp = MetaInfoProp(meta_graph) + meta_tensors = [MetaTensor(i[1], fake_device="cpu") for i in meta_args] + [i[1] for i in concrete_args] + interp.propagate(*meta_tensors) + codegen = AutoChunkCodeGen( + meta_graph, + max_memory=max_memory, + ) + + # trace and recompile + # MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer + graph = ColoTracer().trace( + model, + meta_args={k: v.to(torch.device("meta")) for k, v in meta_args}, + concrete_args={k: v for k, v in concrete_args}, + ) + graph.set_codegen(codegen) + gm = ColoGraphModule(model, graph, ckpt_codegen=False) + gm.recompile() + + # init inputs + inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args] + inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs] + model.cuda() + + # bench + mem = _benchmark_memory(gm, inputs) + speed = _benchmark_speed(gm, inputs) + print("evoformer stack gm, mem: %.2fMB, time: %.4fs" % (mem, speed)) + + +def _benchmark_evoformer_stack_origin( + data_args: tuple, + get_model: Any, + get_data: Any, +) -> None: + # build model and input + model = get_model() + meta_args, concrete_args = get_data(*data_args) + if concrete_args is None: + concrete_args = [] + + # init inputs + inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args] + inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs] + model.cuda() + + # bench + mem = _benchmark_memory(model, inputs) + speed = _benchmark_speed(model, inputs) + print("evoformer stack origin, mem: %.2fMB, time: %.4fs" % (mem, speed)) + return mem + + +def _benchmark_memory(model, inputs): + with torch.no_grad(): + torch.cuda.reset_peak_memory_stats() + now_mem = torch.cuda.memory_allocated() / 1024**2 + model(*inputs) + new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 + return new_max_mem - now_mem + + +def _benchmark_speed(model, inputs, loop=5): + with torch.no_grad(): + for _ in range(loop // 2 + 1): + model(*inputs) + torch.cuda.synchronize() + time1 = time.time() + for _ in range(loop): + model(*inputs) + torch.cuda.synchronize() + time2 = time.time() + return (time2 - time1) / loop + + +def benchmark_evoformer_stack(data_args): + from test_autochunk_evoformer_stack import get_data, get_model + print("\nmsa len: %d, pair len: %d" % (data_args[0], data_args[1])) + max_mem = _benchmark_evoformer_stack_origin(data_args, get_model, get_data) + for ratio in [0.5, 0.4, 0.3, 0.2, 0.1]: + try: + _benchmark_evoformer_stack_gm(data_args, max_mem * ratio, get_model, get_data) + except RuntimeError as e: + if e.args[0] == 'Search failed. Try a larger memory threshold.': + break + except Exception as e: + raise e + _benchmark_evoformer_stack_gm(data_args, None, get_model, get_data) + + +if __name__ == "__main__": + # launch colossalai + colossalai.launch( + config={}, + rank=0, + world_size=1, + host="localhost", + port=free_port(), + backend="nccl", + ) + benchmark_evoformer_stack((256, 256)) + benchmark_evoformer_stack((256, 512)) + benchmark_evoformer_stack((256, 1024)) + benchmark_evoformer_stack((256, 1280)) diff --git a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_alphafold_utils.py b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_alphafold_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..15610e2b50dcbdb24431b7add4cd9fae6a98619e --- /dev/null +++ b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_alphafold_utils.py @@ -0,0 +1,132 @@ +from typing import Any, Dict, List + +import torch +import torch.fx + +import colossalai +from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE +from colossalai.autochunk.utils import flat_list +from colossalai.core import global_context as gpc +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.testing import free_port + +if AUTOCHUNK_AVAILABLE: + from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen + from colossalai.fx.profiler import MetaTensor + from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace + + +def assert_codegen_run( + model: Any, + meta_args: List, + concrete_args: List = None, + max_memory: int = None, + print_mem: bool = False, + print_est_mem: bool = False, + print_progress: bool = False, + print_code: bool = False, +) -> List[Dict]: + if concrete_args is None: + concrete_args = [] + + # trace the meta graph and setup codegen + meta_graph = symbolic_trace( + model, + meta_args={k: v.to(torch.device("meta")) for k, v in meta_args}, + concrete_args={k: v for k, v in concrete_args}, + ) + interp = MetaInfoProp(meta_graph) + meta_tensors = [MetaTensor(i[1], fake_device="cuda:0") for i in meta_args] + [i[1] for i in concrete_args] + interp.propagate(*meta_tensors) + codegen = AutoChunkCodeGen( + meta_graph, + max_memory=max_memory, + print_mem=print_est_mem, + print_progress=print_progress, + ) + chunks = codegen.chunk_infos + + # trace and recompile + # MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer + graph = ColoTracer().trace( + model, + meta_args={k: v.to(torch.device("meta")) for k, v in meta_args}, + concrete_args={k: v for k, v in concrete_args}, + ) + graph.set_codegen(codegen) + gm = ColoGraphModule(model, graph, ckpt_codegen=False) + gm.recompile() + + # assert chunk in code + code = graph.python_code("self").src + if print_code: + print(code) + assert "chunk_size = None; " in code + + # assert result + inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args] + inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs] + model.cuda() + with torch.no_grad(): + if print_mem: + torch.cuda.reset_peak_memory_stats() + now_mem = torch.cuda.memory_allocated() / 1024**2 + out_gm = gm(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs]) + if print_mem: + new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 + print("mem: %.2fMB" % (new_max_mem - now_mem)) + out_model = model(*inputs) + out_gm = flat_list(out_gm) + out_model = flat_list(out_model) + for out_gm_i, out_model_i in zip(out_gm, out_model): + assert torch.allclose(out_gm_i, out_model_i, + atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( + torch.abs(out_gm_i - out_model_i)) + + return chunks + + +def run_test( + rank: int, + data_args: tuple, + max_memory: int, + get_model: Any, + get_data: Any, + print_code: bool = False, + print_mem: bool = False, + print_est_mem: bool = False, + print_progress: bool = False, + get_chunk_target: Any = None, +) -> None: + # launch colossalai + colossalai.launch( + config={}, + rank=rank, + world_size=1, + host="localhost", + port=free_port(), + backend="nccl", + ) + + # build model and input + model = get_model() + meta_args, concrete_args = get_data(*data_args) + chunks = assert_codegen_run( + model, + meta_args=meta_args, + concrete_args=concrete_args, + max_memory=max_memory, + print_code=print_code, + print_mem=print_mem, + print_est_mem=print_est_mem, + print_progress=print_progress, + ) + + if get_chunk_target is not None: + chunk_found = [i["region"] for i in chunks] + chunk_target = get_chunk_target()[max_memory] + assert chunk_found == chunk_target, "found regions %s doesn't equal target regions %s" % ( + str(chunk_found), + str(chunk_target), + ) diff --git a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_block.py b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_block.py new file mode 100644 index 0000000000000000000000000000000000000000..9e4cb7ee9f95b212821809e3ab25906ab82540aa --- /dev/null +++ b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_block.py @@ -0,0 +1,95 @@ +from typing import Dict, List, Tuple + +import pytest +import torch +import torch.fx + +try: + from fastfold.model.nn.evoformer import EvoformerBlock + HAS_REPO = True +except: + HAS_REPO = False + +from test_autochunk_alphafold_utils import run_test + +from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE +from colossalai.testing import clear_cache_before_run, parameterize, spawn + + +def get_model(): + model = EvoformerBlock( + c_m=256, + c_z=128, + c_hidden_msa_att=32, + c_hidden_opm=32, + c_hidden_mul=128, + c_hidden_pair_att=32, + no_heads_msa=8, + no_heads_pair=4, + transition_n=4, + msa_dropout=0.15, + pair_dropout=0.15, + inf=1e4, + eps=1e-4, + is_multimer=False, + ).eval().cuda() + return model + + +def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]: + node = torch.randn(1, msa_len, pair_len, 256).cuda() + node_mask = torch.randn(1, msa_len, pair_len).cuda() + pair = torch.randn(1, pair_len, pair_len, 128).cuda() + pair_mask = torch.randn(1, pair_len, pair_len).cuda() + + meta_args = [ + ("m", node), + ("z", pair), + ("msa_mask", node_mask), + ("pair_mask", pair_mask), + ] + concrete_args = [("chunk_size", None), ("_mask_trans", True)] + return meta_args, concrete_args + + +def get_chunk_target() -> Dict: + return { + None: [(120, 126), (225, 244), (270, 289), (306, 311), (70, 106), (23, 46), (146, 152), (187, 193), (181, 184), + (140, 145), (162, 163), (203, 204)], + 20: [(120, 123), (232, 237), (277, 282), (305, 306)], + 24: [(122, 123)], + } + + +@pytest.mark.skipif( + not (AUTOCHUNK_AVAILABLE and HAS_REPO), + reason="torch version is lower than 1.12.0", +) +@clear_cache_before_run() +@parameterize("max_memory", [None, 20, 24]) +@parameterize("data_args", [(32, 64)]) +def test_evoformer_block(data_args, max_memory): + spawn( + run_test, + 1, + data_args=data_args, + max_memory=max_memory, + get_model=get_model, + get_data=get_data, + get_chunk_target=get_chunk_target, + ) + + +if __name__ == "__main__": + run_test( + rank=0, + data_args=(32, 64), + max_memory=24, + get_model=get_model, + get_data=get_data, + get_chunk_target=get_chunk_target, + print_code=False, + print_mem=False, + print_est_mem=False, + print_progress=False, + ) diff --git a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_stack.py b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_stack.py new file mode 100644 index 0000000000000000000000000000000000000000..6b47033e199f0ac47ef6df90bb4d28730eae1bd9 --- /dev/null +++ b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_stack.py @@ -0,0 +1,87 @@ +from typing import List, Tuple + +import pytest +import torch +import torch.fx + +try: + from fastfold.model.nn.evoformer import EvoformerStack + HAS_REPO = True +except: + HAS_REPO = False + +from test_autochunk_alphafold_utils import run_test + +from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE +from colossalai.testing import clear_cache_before_run, parameterize, spawn + + +def get_model(): + model = EvoformerStack( + c_m=256, + c_z=128, + c_hidden_msa_att=32, + c_hidden_opm=32, + c_hidden_mul=128, + c_hidden_pair_att=32, + c_s=384, + no_heads_msa=8, + no_heads_pair=4, + no_blocks=2, # 48 + transition_n=4, + msa_dropout=0.15, + pair_dropout=0.25, + blocks_per_ckpt=None, + inf=1000000000.0, + eps=1e-08, + clear_cache_between_blocks=False, + is_multimer=False, + ).eval().cuda() + return model + + +def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]: + node = torch.randn(1, msa_len, pair_len, 256).cuda() + node_mask = torch.randn(1, msa_len, pair_len).cuda() + pair = torch.randn(1, pair_len, pair_len, 128).cuda() + pair_mask = torch.randn(1, pair_len, pair_len).cuda() + + meta_args = [ + ("m", node), + ("z", pair), + ("msa_mask", node_mask), + ("pair_mask", pair_mask), + ] + concrete_args = [("chunk_size", None), ("_mask_trans", True)] + return meta_args, concrete_args + + +@pytest.mark.skipif( + not (AUTOCHUNK_AVAILABLE and HAS_REPO), + reason="torch version is lower than 1.12.0", +) +@clear_cache_before_run() +@parameterize("max_memory", [None, 20, 24]) +@parameterize("data_args", [(32, 64)]) # (msa_len, pair_len) +def test_evoformer_stack(data_args, max_memory): + spawn( + run_test, + 1, + data_args=data_args, + max_memory=max_memory, + get_model=get_model, + get_data=get_data, + ) + + +if __name__ == "__main__": + run_test( + rank=0, + data_args=(32, 64), + max_memory=None, + get_model=get_model, + get_data=get_data, + print_code=False, + print_mem=False, + print_progress=False, + ) diff --git a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_extramsa_block.py b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_extramsa_block.py new file mode 100644 index 0000000000000000000000000000000000000000..b4c577c18ee602fcf49724b413c2e323931530f5 --- /dev/null +++ b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_extramsa_block.py @@ -0,0 +1,83 @@ +from typing import Dict, List, Tuple + +import pytest +import torch +import torch.fx + +try: + from fastfold.model.nn.evoformer import ExtraMSABlock + HAS_REPO = True +except: + HAS_REPO = False +from test_autochunk_alphafold_utils import run_test + +from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE +from colossalai.testing import clear_cache_before_run, parameterize, spawn + + +def get_model(): + model = ExtraMSABlock( + c_m=256, + c_z=128, + c_hidden_msa_att=32, + c_hidden_opm=32, + c_hidden_mul=128, + c_hidden_pair_att=32, + no_heads_msa=8, + no_heads_pair=4, + transition_n=4, + msa_dropout=0.15, + pair_dropout=0.15, + inf=1e4, + eps=1e-4, + ckpt=False, + is_multimer=False, + ).eval().cuda() + return model + + +def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]: + node = torch.randn(1, msa_len, pair_len, 256).cuda() + node_mask = torch.randn(1, msa_len, pair_len).cuda() + pair = torch.randn(1, pair_len, pair_len, 128).cuda() + pair_mask = torch.randn(1, pair_len, pair_len).cuda() + + meta_args = [ + ("m", node), + ("z", pair), + ("msa_mask", node_mask), + ("pair_mask", pair_mask), + ] + concrete_args = [("chunk_size", None), ("_chunk_logits", 1024)] + return meta_args, concrete_args + + +@pytest.mark.skipif( + not (AUTOCHUNK_AVAILABLE and HAS_REPO), + reason="torch version is lower than 1.12.0", +) +@clear_cache_before_run() +@parameterize("max_memory", [None, 20, 24]) +@parameterize("data_args", [(32, 64)]) # (msa_len, pair_len) +def test_extramsa_block(data_args, max_memory): + spawn( + run_test, + 1, + data_args=data_args, + max_memory=max_memory, + get_model=get_model, + get_data=get_data, + ) + + +if __name__ == "__main__": + run_test( + rank=0, + data_args=(32, 64), + max_memory=None, + get_model=get_model, + get_data=get_data, + print_code=False, + print_mem=False, + print_progress=False, + ) diff --git a/tests/test_autochunk/test_autochunk_diffuser/benchmark_autochunk_diffuser.py b/tests/test_autochunk/test_autochunk_diffuser/benchmark_autochunk_diffuser.py new file mode 100644 index 0000000000000000000000000000000000000000..6fb7efa7a8fc986b6811fe9958d2365b3fa7a1f4 --- /dev/null +++ b/tests/test_autochunk/test_autochunk_diffuser/benchmark_autochunk_diffuser.py @@ -0,0 +1,147 @@ +import time +from typing import Any, Dict, List + +import torch +import torch.fx + +import colossalai +from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.fx.profiler import parameter_size +from colossalai.utils import free_port + +if AUTOCHUNK_AVAILABLE: + from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen + from colossalai.fx.profiler import MetaTensor + from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace + + +def _benchmark_autochunk_unet_gm( + model: Any, + data: tuple, + max_memory: int = None, +) -> None: + model = model.cuda().eval() + + # build model and input + meta_args, concrete_args = data + if concrete_args is None: + concrete_args = {} + + # trace the meta graph and setup codegen + meta_graph = symbolic_trace( + model, + meta_args={k: v.to(torch.device("meta")) for k, v in meta_args}, + concrete_args={k: v for k, v in concrete_args}, + ) + interp = MetaInfoProp(meta_graph) + meta_tensors = [i[1] for i in meta_args] + [i[1] for i in concrete_args] + meta_tensors = [MetaTensor(i, fake_device="cpu") if isinstance(i, torch.Tensor) else i for i in meta_tensors] + interp.propagate(*meta_tensors) + codegen = AutoChunkCodeGen( + meta_graph, + max_memory=max_memory, + ) + + # trace and recompile + # MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer + graph = ColoTracer().trace( + model.cuda().eval(), + meta_args={k: v.to(torch.device("meta")) for k, v in meta_args}, + concrete_args={k: v for k, v in concrete_args}, + ) + graph.set_codegen(codegen) + gm = ColoGraphModule(model, graph, ckpt_codegen=False) + gm.recompile() + + # init inputs + inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args] + inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs] + model.cuda().eval() + + # bench + para_mem = float(parameter_size(model)) / 1024**2 + act_mem = _benchmark_memory(gm, inputs) + speed = _benchmark_speed(gm, inputs) + print("unet autochunk, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" % + (speed, act_mem, para_mem, act_mem + para_mem)) + + +def _benchmark_autochunk_unet_origin( + model: Any, + data: tuple, +) -> None: + # build model and input + meta_args, concrete_args = data + if concrete_args is None: + concrete_args = {} + + # init inputs + inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args] + inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs] + model.cuda().eval() + + # bench + para_mem = float(parameter_size(model)) / 1024**2 + act_mem = _benchmark_memory(model, inputs) + speed = _benchmark_speed(model, inputs) + print("unet origin, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" % + (speed, act_mem, para_mem, act_mem + para_mem)) + return act_mem + + +def _benchmark_memory(model, inputs): + with torch.no_grad(): + torch.cuda.reset_peak_memory_stats() + now_mem = float(torch.cuda.memory_allocated()) / 1024**2 + model(*inputs) + new_max_mem = float(torch.cuda.max_memory_allocated()) / 1024**2 + return new_max_mem - now_mem + + +def _benchmark_speed(model, inputs, loop=5): + with torch.no_grad(): + for _ in range(loop // 2 + 1): + model(*inputs) + torch.cuda.synchronize() + time1 = time.time() + for _ in range(loop): + model(*inputs) + torch.cuda.synchronize() + time2 = time.time() + return (time2 - time1) / loop + + +def benchmark_autochunk_unet(batch=1, height=448, width=448): + from test_autochunk_unet import UNet2DModel, get_data + model = UNet2DModel() + latent_shape = (batch, 3, height // 7, width // 7) + + print("\nbatch: %d, height: %d, width: %d" % (batch, height, width)) + max_mem = _benchmark_autochunk_unet_origin(model, get_data(latent_shape)) + for ratio in [0.5, 0.4, 0.3, 0.2]: + try: + _benchmark_autochunk_unet_gm(model, get_data(latent_shape), max_mem * ratio) + except RuntimeError as e: + if e.args[0] == 'Search failed. Try a larger memory threshold.': + break + except Exception as e: + raise e + _benchmark_autochunk_unet_gm(model, get_data(latent_shape), None) + + +if __name__ == "__main__": + # launch colossalai + colossalai.launch( + config={}, + rank=0, + world_size=1, + host="localhost", + port=free_port(), + backend="nccl", + ) + benchmark_autochunk_unet(batch=1, height=224 * 3, width=224 * 3) + benchmark_autochunk_unet(batch=1, height=224 * 4, width=224 * 4) + benchmark_autochunk_unet(batch=1, height=224 * 5, width=224 * 5) + benchmark_autochunk_unet(batch=1, height=224 * 6, width=224 * 6) diff --git a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e245f10d4576bed99bc322ba602dfb8e04c315b3 --- /dev/null +++ b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py @@ -0,0 +1,136 @@ +from typing import Any, Dict, List + +import torch +import torch.fx + +import colossalai +from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE +from colossalai.core import global_context as gpc +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.testing import free_port + +if AUTOCHUNK_AVAILABLE: + from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen + from colossalai.fx.profiler import MetaTensor + from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace + + +def assert_codegen_run( + model: Any, + meta_args: List, + concrete_args: List = None, + max_memory: int = None, + print_mem: bool = False, + print_est_mem: bool = False, + print_progress: bool = False, + print_code: bool = False, +) -> List[Dict]: + if concrete_args is None: + concrete_args = [] + model = model() + + # trace the meta graph and setup codegen + meta_graph = symbolic_trace( + model, + meta_args={k: v.to(torch.device("meta")) for k, v in meta_args}, + concrete_args={k: v for k, v in concrete_args}, + ) + model = model.cuda().eval() + interp = MetaInfoProp(meta_graph) + meta_tensors = [MetaTensor(i[1], fake_device="cuda:0") for i in meta_args] + [i[1] for i in concrete_args] + interp.propagate(*meta_tensors) + codegen = AutoChunkCodeGen( + meta_graph, + max_memory=max_memory, + print_mem=print_est_mem, + print_progress=print_progress, + ) + chunks = codegen.chunk_infos + + # trace and recompile + # MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer + graph = ColoTracer().trace( + model.cuda(), + meta_args={k: v.to(torch.device("meta")) for k, v in meta_args}, + concrete_args={k: v for k, v in concrete_args}, + ) + graph.set_codegen(codegen) + gm = ColoGraphModule(model, graph, ckpt_codegen=False) + gm.recompile() + + # assert chunk in code + code = graph.python_code("self").src + if print_code: + print(code) + assert "chunk_size = None; " in code + + # assert result + inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args] + inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs] + model.cuda().eval() + gm.eval() + with torch.no_grad(): + if print_mem: + torch.cuda.reset_peak_memory_stats() + now_mem_gm = torch.cuda.memory_allocated() / 1024**2 + out_gm = gm(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs]) + if print_mem: + max_mem_gm = torch.cuda.max_memory_allocated() / 1024**2 + torch.cuda.reset_peak_memory_stats() + now_mem_ori = torch.cuda.memory_allocated() / 1024**2 + out_model = model(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs]) + if print_mem: + max_mem_ori = torch.cuda.max_memory_allocated() / 1024**2 + print("origin mem: %.2fMB, autochunk mem: %.2fMB" % (max_mem_ori - now_mem_ori, max_mem_gm - now_mem_gm)) + + assert torch.allclose(out_gm["sample"], out_model["sample"], + atol=1e-3), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( + torch.abs(out_gm["sample"] - out_model["sample"])) + + return chunks + + +def run_test( + rank: int, + model: Any, + data: tuple, + max_memory: int, + print_code: bool = False, + print_mem: bool = False, + print_est_mem: bool = False, + print_progress: bool = False, + get_chunk_target: Any = None, +) -> None: + # launch colossalai + colossalai.launch( + config={}, + rank=rank, + world_size=1, + host="localhost", + port=free_port(), + backend="nccl", + ) + + # build model and input + meta_args, concrete_args = data + chunks = assert_codegen_run( + model, + meta_args=meta_args, + concrete_args=concrete_args, + max_memory=max_memory, + print_code=print_code, + print_mem=print_mem, + print_est_mem=print_est_mem, + print_progress=print_progress, + ) + + if get_chunk_target is not None: + chunk_found = [i["region"] for i in chunks] + chunk_target = get_chunk_target()[max_memory] + assert (chunk_found == chunk_target), "found regions %s doesn't equal target regions %s" % ( + str(chunk_found), + str(chunk_target), + ) + + gpc.destroy() diff --git a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..ff0d4a1b53f58a22adc1bf820e25bdcf920728bb --- /dev/null +++ b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py @@ -0,0 +1,63 @@ +from typing import List, Tuple + +import pytest +import torch + +try: + from diffusers import UNet2DModel + MODELS = [UNet2DModel] + HAS_REPO = True +except: + MODELS = [] + HAS_REPO = False + +from test_autochunk_diffuser_utils import run_test + +from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE +from colossalai.testing import clear_cache_before_run, parameterize, spawn + +BATCH_SIZE = 1 +HEIGHT = 448 +WIDTH = 448 +IN_CHANNELS = 3 +LATENTS_SHAPE = (BATCH_SIZE, IN_CHANNELS, HEIGHT // 7, WIDTH // 7) + + +def get_data(shape: tuple) -> Tuple[List, List]: + sample = torch.randn(shape) + meta_args = [ + ("sample", sample), + ] + concrete_args = [("timestep", 50)] + return meta_args, concrete_args + + +@pytest.mark.skipif( + not (AUTOCHUNK_AVAILABLE and HAS_REPO), + reason="torch version is lower than 1.12.0", +) +@clear_cache_before_run() +@parameterize("model", MODELS) +@parameterize("shape", [LATENTS_SHAPE]) +@parameterize("max_memory", [None, 150, 300]) +def test_evoformer_block(model, shape, max_memory): + spawn( + run_test, + 1, + max_memory=max_memory, + model=model, + data=get_data(shape), + ) + + +if __name__ == "__main__": + run_test( + rank=0, + data=get_data(LATENTS_SHAPE), + max_memory=None, + model=UNet2DModel, + print_code=False, + print_mem=True, + print_est_mem=False, + print_progress=False, + ) diff --git a/tests/test_autochunk/test_autochunk_transformer/benchmark_autochunk_transformer.py b/tests/test_autochunk/test_autochunk_transformer/benchmark_autochunk_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..63490aaee7ff0803b7c0aded3fca7d79084f9b27 --- /dev/null +++ b/tests/test_autochunk/test_autochunk_transformer/benchmark_autochunk_transformer.py @@ -0,0 +1,149 @@ +import time +from typing import Any, Dict, List + +import torch +import torch.fx + +import colossalai +from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.fx.profiler import parameter_size +from colossalai.utils import free_port + +if AUTOCHUNK_AVAILABLE: + from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen + from colossalai.fx.profiler import MetaTensor + from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace + + +def _benchmark_autochunk_gpt_gm( + model: Any, + data: tuple, + max_memory: int = None, +) -> None: + model = model.eval().cpu() + + # build model and input + meta_args, concrete_args, sequence = data + if concrete_args is None: + concrete_args = {} + + # trace the meta graph and setup codegen + meta_graph = symbolic_trace( + model, + meta_args={k: v.to(torch.device("meta")) for k, v in meta_args.items()}, + concrete_args={k: v for k, v in concrete_args.items()}, + ) + interp = MetaInfoProp(meta_graph) + meta_tensors = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence] + meta_tensors = [MetaTensor(i, fake_device="cpu") if isinstance(i, torch.Tensor) else i for i in meta_tensors] + interp.propagate(*meta_tensors) + codegen = AutoChunkCodeGen( + meta_graph, + max_memory=max_memory, + ) + + # trace and recompile + # MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer + graph = ColoTracer().trace( + model.cuda().eval(), + meta_args={k: v.to(torch.device("meta")) for k, v in meta_args.items()}, + concrete_args={k: v for k, v in concrete_args.items()}, + ) + graph.set_codegen(codegen) + gm = ColoGraphModule(model, graph, ckpt_codegen=False) + gm.recompile() + + # init inputs + inputs = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence] + inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs] + model.cuda() + + # bench + para_mem = float(parameter_size(model)) / 1024**2 * 6 + act_mem = _benchmark_memory(gm, inputs) + speed = _benchmark_speed(gm, inputs) + print("gpt autochunk, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" % + (speed, act_mem, para_mem, act_mem + para_mem)) + + +def _benchmark_autochunk_gpt_origin( + model: Any, + data: tuple, +) -> None: + # build model and input + meta_args, concrete_args, sequence = data + if concrete_args is None: + concrete_args = {} + + # init inputs + inputs = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence] + inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs] + model.cuda().eval() + + # bench + para_mem = float(parameter_size(model)) / 1024**2 * 6 + act_mem = _benchmark_memory(model, inputs) + speed = _benchmark_speed(model, inputs) + print("gpt origin, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" % + (speed, act_mem, para_mem, act_mem + para_mem)) + return act_mem + + +def _benchmark_memory(model, inputs): + with torch.no_grad(): + torch.cuda.reset_peak_memory_stats() + now_mem = float(torch.cuda.memory_allocated()) / 1024**2 + model(*inputs) + new_max_mem = float(torch.cuda.max_memory_allocated()) / 1024**2 + return new_max_mem - now_mem + + +def _benchmark_speed(model, inputs, loop=5): + with torch.no_grad(): + for _ in range(loop // 2 + 1): + model(*inputs) + torch.cuda.synchronize() + time1 = time.time() + for _ in range(loop): + model(*inputs) + torch.cuda.synchronize() + time2 = time.time() + return (time2 - time1) / loop + + +def benchmark_autochunk_gpt(batch=1, seq=512, n_embd=768, n_head=12): + from test_autochunk_gpt import GPT2Config, GPT2Model, get_data + model = GPT2Model + config = GPT2Config(n_embd=n_embd, n_positions=seq, n_layer=2, n_head=n_head) + model = model(config=config) + shape = [batch, seq] + print("\nbatch: %d, seq: %d, n_embd: %d, n_head: %d" % (batch, seq, n_embd, n_head)) + max_mem = _benchmark_autochunk_gpt_origin(model, get_data(shape)) + for ratio in [0.5, 0.4, 0.3, 0.2]: + try: + _benchmark_autochunk_gpt_gm(model, get_data(shape), max_mem * ratio) + except RuntimeError as e: + if e.args[0] == 'Search failed. Try a larger memory threshold.': + break + except Exception as e: + raise e + _benchmark_autochunk_gpt_gm(model, get_data(shape), None) + + +if __name__ == "__main__": + # launch colossalai + colossalai.launch( + config={}, + rank=0, + world_size=1, + host="localhost", + port=free_port(), + backend="nccl", + ) + benchmark_autochunk_gpt(batch=1, seq=1024, n_embd=768, n_head=12) + benchmark_autochunk_gpt(batch=1, seq=2048, n_embd=768, n_head=12) + benchmark_autochunk_gpt(batch=1, seq=4096, n_embd=768, n_head=12) + benchmark_autochunk_gpt(batch=1, seq=6144, n_embd=768, n_head=12) + benchmark_autochunk_gpt(batch=1, seq=8192, n_embd=768, n_head=12) diff --git a/tests/test_autochunk/test_autochunk_transformer/test_autochunk_gpt.py b/tests/test_autochunk/test_autochunk_transformer/test_autochunk_gpt.py new file mode 100644 index 0000000000000000000000000000000000000000..384706639e108c067839b878cca3bb5ba142b604 --- /dev/null +++ b/tests/test_autochunk/test_autochunk_transformer/test_autochunk_gpt.py @@ -0,0 +1,62 @@ +from typing import List, Tuple + +import pytest +import torch + +try: + from transformers import GPT2Config, GPT2Model + MODELS = [GPT2Model] + HAS_REPO = True +except: + MODELS = [] + HAS_REPO = False + +from test_autochunk_transformer_utils import run_test + +from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE +from colossalai.testing import clear_cache_before_run, parameterize, spawn + +BATCH_SIZE = 1 +SEQ_LENGTH = 512 + + +def get_data(shape: tuple) -> Tuple[List, List]: + input_ids = torch.zeros(shape, dtype=torch.int64) + token_type_ids = torch.zeros(shape, dtype=torch.int64) + attention_mask = torch.ones(shape, dtype=torch.int64) + meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + concrete_args = {"past_key_values": None} + sequence = ["input_ids", "past_key_values", "attention_mask", "token_type_ids"] + return meta_args, concrete_args, sequence + + +@pytest.mark.skipif( + not (AUTOCHUNK_AVAILABLE and HAS_REPO), + reason="torch version is lower than 1.12.0", +) +@clear_cache_before_run() +@parameterize("model", MODELS) +@parameterize("shape", [(BATCH_SIZE, SEQ_LENGTH)]) +@parameterize("max_memory", [None, 6, 8]) +def test_autochunk_gpt(model, shape, max_memory): + spawn( + run_test, + 1, + data=get_data(shape), + max_memory=max_memory, + model=model, + config=GPT2Config(n_embd=96, n_positions=shape[1], n_layer=2, n_head=4), + ) + + +if __name__ == "__main__": + run_test(rank=0, + data=get_data((BATCH_SIZE, SEQ_LENGTH)), + max_memory=None, + model=GPT2Model, + config=GPT2Config(n_embd=96, n_position=SEQ_LENGTH, n_layer=2, n_head=4), + print_code=False, + print_est_mem=False, + print_mem=False, + print_progress=False, + eval_mem=False) diff --git a/tests/test_autochunk/test_autochunk_transformer/test_autochunk_transformer_utils.py b/tests/test_autochunk/test_autochunk_transformer/test_autochunk_transformer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..faba138cd42cae467532587338dda5bf9a6b33fb --- /dev/null +++ b/tests/test_autochunk/test_autochunk_transformer/test_autochunk_transformer_utils.py @@ -0,0 +1,141 @@ +from typing import Any, Dict, List + +import torch +import torch.fx + +import colossalai +from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE +from colossalai.core import global_context as gpc +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.testing import free_port + +if AUTOCHUNK_AVAILABLE: + from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen + from colossalai.fx.profiler import MetaTensor + from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace + + +def assert_codegen_run( + model: Any, + data: tuple, + max_memory: int = None, + print_est_mem: bool = False, + print_mem: bool = False, + print_progress: bool = False, + print_code: bool = False, + eval_mem: bool = False, +) -> List[Dict]: + meta_args, concrete_args, sequence = data + if concrete_args is None: + concrete_args = {} + + # trace the meta graph and setup codegen + meta_graph = symbolic_trace( + model, + meta_args={k: v.to(torch.device("meta")) for k, v in meta_args.items()}, + concrete_args={k: v for k, v in concrete_args.items()}, + ) + interp = MetaInfoProp(meta_graph) + meta_tensors = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence] + meta_tensors = [MetaTensor(i, fake_device="cuda:0") if isinstance(i, torch.Tensor) else i for i in meta_tensors] + interp.propagate(*meta_tensors) + codegen = AutoChunkCodeGen(meta_graph, + max_memory=max_memory, + print_mem=print_est_mem, + print_progress=print_progress, + eval_mem=eval_mem) + chunks = codegen.chunk_infos + + # trace and recompile + # MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer + graph = ColoTracer().trace( + model.cuda(), + meta_args={k: v.to(torch.device("meta")) for k, v in meta_args.items()}, + concrete_args={k: v for k, v in concrete_args.items()}, + ) + graph.set_codegen(codegen) + gm = ColoGraphModule(model, graph, ckpt_codegen=False) + gm.recompile() + + # assert chunk in code + code = graph.python_code("self").src + if print_code: + print(code) + assert "chunk_size = None; " in code + + # assert result + inputs = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence] + inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs] + model.cuda().eval() + gm.eval() + with torch.no_grad(): + if print_mem: + torch.cuda.reset_peak_memory_stats() + now_mem = torch.cuda.memory_allocated() / 1024**2 + out_gm = gm(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs]) + if print_mem: + new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 + print("mem: %.2fMB" % (new_max_mem - now_mem)) + out_model = model(*inputs) + assert_allclose(out_model, out_gm) + return chunks + + +def assert_allclose(out_model: Any, out_gm: Any) -> None: + """ + assert allclose for out + """ + if isinstance(out_model, torch.Tensor): + assert torch.allclose(out_model, out_gm, + atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( + torch.abs(out_model - out_gm)) + elif isinstance(out_model, dict): + for k in out_model.keys(): + assert_allclose(out_model[k], out_gm[k]) + elif isinstance(out_model, tuple) or isinstance(out_model, list) or isinstance(out_model, set): + for i, j in zip(out_model, out_gm): + assert_allclose(i, j) + + +def run_test( + rank: int, + model: Any, + config: Any, + data: tuple, + max_memory: int, + print_code: bool = False, + print_est_mem: bool = False, + print_mem: bool = False, + print_progress: bool = False, + eval_mem: bool = False, + get_chunk_target: Any = None, +) -> None: + model = model(config=config) + # launch colossalai + colossalai.launch( + config={}, + rank=rank, + world_size=1, + host="localhost", + port=free_port(), + backend="nccl", + ) + + # build model and input + chunks = assert_codegen_run(model, + data=data, + max_memory=max_memory, + print_code=print_code, + print_est_mem=print_est_mem, + print_mem=print_mem, + print_progress=print_progress, + eval_mem=eval_mem) + + if get_chunk_target is not None: + chunk_found = [i["region"] for i in chunks] + chunk_target = get_chunk_target()[max_memory] + assert (chunk_found == chunk_target), "found regions %s doesn't equal target regions %s" % ( + str(chunk_found), + str(chunk_target), + ) diff --git a/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit.py b/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..a98aa0e03954e49c9dfd361a1ed885e86a9c0771 --- /dev/null +++ b/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit.py @@ -0,0 +1,53 @@ +from typing import List, Tuple + +import pytest +import torch + +try: + from timm.models.vision_transformer import vit_large_patch16_384 as vit + MODELS = [vit] + HAS_REPO = True +except: + MODELS = [] + HAS_REPO = False + +from test_autochunk_vit_utils import run_test + +from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE +from colossalai.testing import clear_cache_before_run, parameterize, spawn + + +def get_data() -> Tuple[List, List]: + data = torch.rand(1, 3, 384, 384) + meta_args = {'x': data} + return data, meta_args + + +@pytest.mark.skipif( + not (AUTOCHUNK_AVAILABLE and HAS_REPO), + reason="torch version is lower than 1.12.0", +) +@clear_cache_before_run() +@parameterize("model", MODELS) +@parameterize("max_memory", [None, 32, 40]) +def test_evoformer_block(model, max_memory): + spawn( + run_test, + 1, + max_memory=max_memory, + model=model, + data=get_data(), + ) + + +if __name__ == "__main__": + run_test( + rank=0, + data=get_data(), + max_memory=None, + model=vit, + print_code=False, + print_mem=False, + print_est_mem=False, + print_progress=False, + ) diff --git a/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit_utils.py b/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..317606fc478184c5f659797042dec5702c9e488e --- /dev/null +++ b/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit_utils.py @@ -0,0 +1,128 @@ +from typing import Any, Dict, List + +import torch +import torch.fx + +import colossalai +from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE +from colossalai.core import global_context as gpc +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.testing import free_port + +if AUTOCHUNK_AVAILABLE: + from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen + from colossalai.fx.profiler import MetaTensor + from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace + + +def assert_codegen_run( + model: Any, + meta_args: Dict, + data: Any, + max_memory: int = None, + print_mem: bool = False, + print_est_mem: bool = False, + print_progress: bool = False, + print_code: bool = False, +) -> List[Dict]: + model = model() + + # trace the meta graph and setup codegen + meta_graph = symbolic_trace(model, meta_args={k: v.to(torch.device("meta")) for k, v in meta_args.items()}) + model = model.cuda().eval() + interp = MetaInfoProp(meta_graph) + meta_tensors = [MetaTensor(i[1], fake_device="cuda:0") for i in meta_args.items()] + interp.propagate(*meta_tensors) + codegen = AutoChunkCodeGen( + meta_graph, + max_memory=max_memory, + print_mem=print_est_mem, + print_progress=print_progress, + ) + chunks = codegen.chunk_infos + + # trace and recompile + # MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer + graph = ColoTracer().trace( + model.cuda(), + meta_args={k: v.to(torch.device("meta")) for k, v in meta_args.items()}, + ) + graph.set_codegen(codegen) + gm = ColoGraphModule(model, graph, ckpt_codegen=False) + gm.recompile() + + # assert chunk in code + code = graph.python_code("self").src + if print_code: + print(code) + assert "chunk_size = None; " in code + + # assert result + inputs = [data.cuda()] + model.cuda().eval() + gm.eval() + with torch.no_grad(): + if print_mem: + torch.cuda.reset_peak_memory_stats() + now_mem_gm = torch.cuda.memory_allocated() / 1024**2 + out_gm = gm(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs]) + if print_mem: + max_mem_gm = torch.cuda.max_memory_allocated() / 1024**2 + torch.cuda.reset_peak_memory_stats() + now_mem_ori = torch.cuda.memory_allocated() / 1024**2 + out_model = model(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs]) + if print_mem: + max_mem_ori = torch.cuda.max_memory_allocated() / 1024**2 + print("origin mem: %.2fMB, autochunk mem: %.2fMB" % (max_mem_ori - now_mem_ori, max_mem_gm - now_mem_gm)) + + assert torch.allclose(out_gm, out_model, + atol=1e-3), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( + torch.abs(out_gm - out_model)) + + return chunks + + +def run_test( + rank: int, + model: Any, + data: tuple, + max_memory: int, + print_code: bool = False, + print_mem: bool = False, + print_est_mem: bool = False, + print_progress: bool = False, + get_chunk_target: Any = None, +) -> None: + # launch colossalai + colossalai.launch( + config={}, + rank=rank, + world_size=1, + host="localhost", + port=free_port(), + backend="nccl", + ) + + # build model and input + data, meta_args = data + chunks = assert_codegen_run( + model, + meta_args=meta_args, + data=data, + max_memory=max_memory, + print_code=print_code, + print_mem=print_mem, + print_est_mem=print_est_mem, + print_progress=print_progress, + ) + + if get_chunk_target is not None: + chunk_found = [i["region"] for i in chunks] + chunk_target = get_chunk_target()[max_memory] + assert (chunk_found == chunk_target), "found regions %s doesn't equal target regions %s" % ( + str(chunk_found), + str(chunk_target), + ) + + gpc.destroy() diff --git a/tests/test_booster/test_accelerator.py b/tests/test_booster/test_accelerator.py new file mode 100644 index 0000000000000000000000000000000000000000..895c494d0c17f9ad4822154376dcaf045a310ac8 --- /dev/null +++ b/tests/test_booster/test_accelerator.py @@ -0,0 +1,14 @@ +import torch.nn as nn + +from colossalai.booster.accelerator import Accelerator +from colossalai.testing import clear_cache_before_run, parameterize + + +@clear_cache_before_run() +@parameterize('device', ['cpu', 'cuda']) +def test_accelerator(device): + acceleartor = Accelerator(device) + model = nn.Linear(8, 8) + model = acceleartor.configure_model(model) + assert next(model.parameters()).device.type == device + del model, acceleartor diff --git a/tests/test_booster/test_mixed_precision/test_fp16_torch.py b/tests/test_booster/test_mixed_precision/test_fp16_torch.py new file mode 100644 index 0000000000000000000000000000000000000000..963387da262bffea1ee555f8614ed2f60157b3ea --- /dev/null +++ b/tests/test_booster/test_mixed_precision/test_fp16_torch.py @@ -0,0 +1,40 @@ +import torch +from torch.optim import Adam + +import colossalai +from colossalai.booster.mixed_precision import FP16TorchMixedPrecision +from colossalai.testing import rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo + + +def run_torch_amp(rank, world_size, port): + # init dist env + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + sub_model_zoo = model_zoo.get_sub_registry('timm') + for name, (model_fn, data_gen_fn, output_transform_fn, _) in sub_model_zoo.items(): + # dlrm_interactionarch has not parameters, so skip + if name == 'dlrm_interactionarch': + continue + + model = model_fn().cuda() + optimizer = Adam(model.parameters(), lr=1e-3) + criterion = lambda x: x.mean() + data = data_gen_fn() + data = { + k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items() + } + mixed_precision = FP16TorchMixedPrecision() + model, optimizer, criterion = mixed_precision.configure(model, optimizer, criterion) + output = model(**data) + output = output_transform_fn(output) + output_key = list(output.keys())[0] + loss = criterion(output[output_key]) + optimizer.backward(loss) + optimizer.clip_grad_by_norm(1.0) + optimizer.step() + del model, optimizer, criterion, data, output, mixed_precision + + +@rerun_if_address_is_in_use() +def test_torch_ddp_plugin(): + spawn(run_torch_amp, 1) diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py new file mode 100644 index 0000000000000000000000000000000000000000..985d7989fc9dfa7df2523899b7912e6e776d35dc --- /dev/null +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -0,0 +1,157 @@ +from contextlib import nullcontext +from typing import Optional + +import torch +import torch.distributed as dist + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin +from colossalai.fx import is_compatible_with_meta +from colossalai.nn.optimizer import HybridAdam +from colossalai.tensor.colo_parameter import ColoParameter +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils.model.experimental import LazyInitContext +from colossalai.zero import ColoInitContext +from tests.kit.model_zoo import model_zoo + + +def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: + try: + if init_method == 'colo': + ctx = ColoInitContext() + elif init_method == 'lazy': + ctx = LazyInitContext() + else: + ctx = nullcontext() + plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, max_norm=1.0, initial_scale=2**5) + booster = Booster(plugin=plugin) + with ctx: + model = model_fn() + optimizer = HybridAdam(model.parameters(), lr=1e-3) + criterion = lambda x: x.mean() + data = data_gen_fn() + + data = { + k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items() + } + + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + for n, p in model.named_parameters(): + assert isinstance(p, ColoParameter), f'{n} is not a ColoParameter' + + output = model(**data) + output = output_transform_fn(output) + output_key = list(output.keys())[0] + loss = criterion(output[output_key]) + + booster.backward(loss, optimizer) + optimizer.step() + + except Exception as e: + return repr(e) + + +# TODO(ver217): CI does not support lazy now +# @parameterize('init_method', ['lazy', 'none', 'colo']) + + +@parameterize('init_method', ['none']) +def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True): + """check gemini plugin over model zoo + + Args: + early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True. + """ + is_support_meta = is_compatible_with_meta() + if not is_support_meta and init_method == 'lazy': + return + + passed_models = [] + failed_info = {} # (model_name, error) pair + + for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items(): + # These models lead to CUDA error + if name in ('diffusers_auto_encoder_kl', 'diffusers_vq_model', 'diffusers_unet2d_model', 'timm_resmlp', + 'timm_gmixer_12_224', 'timm_gmlp_b16_224', 'timm_mixer_b16_224', 'timm_convnext'): + continue + # These models are not compatible with gemini + if name in [ + 'diffusers_clip_vision_model', 'timm_resnet', 'timm_beit', 'timm_beitv2', 'timm_eca_nfnet', + 'timm_efficientformer', 'timm_hrnet_w18_small', 'timm_nf_ecaresnet101', 'timm_nf_regnet_b0', + 'timm_skresnet18', 'timm_wide_resnet50_2', 'timm_convit', 'timm_dm_nfnet', 'timm_swin_transformer', + 'torchaudio_conformer', 'torchaudio_deepspeech', 'torchaudio_wavernn', 'torchaudio_tacotron', + 'deepfm_interactionarch', 'deepfm_simpledeepfmnn', 'dlrm', 'dlrm_interactionarch', + 'torchvision_googlenet', 'torchvision_inception_v3', 'torchvision_mobilenet_v3_small', + 'torchvision_resnet18', 'torchvision_resnext50_32x4d', 'torchvision_wide_resnet50_2', + 'torchvision_vit_b_16', 'torchvision_convnext_base', 'torchvision_swin_s', 'transformers_albert', + 'transformers_albert_for_pretraining', 'transformers_bert', 'transformers_bert_for_pretraining', + 'transformers_gpt_double_heads', 'torchaudio_hubert_base', 'torchaudio_wav2vec2_base', + 'transformers_t5_for_conditional_generation', 'transformers_t5', 'transformers_t5_encoder_model' + ]: + continue + + if init_method == 'lazy' and name in [ + 'timm_convmixer', 'timm_vision_transformer', 'timm_deit', 'timm_deit3', 'timm_inception_v3', + 'timm_tnt_b_patch16_224', 'timm_rexnet', 'torchvision_densenet121', 'torchvision_efficientnet_b0', + 'torchvision_mobilenet_v2', 'torchvision_mnasnet0_5', 'torchvision_regnet_x_16gf', + 'torchvision_shufflenet_v2_x0_5', 'torchvision_efficientnet_v2_s' + ]: + continue + + err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) + torch.cuda.empty_cache() + + if err is None: + passed_models.append(name) + else: + failed_info[name] = err + if early_stop: + break + + if dist.get_rank() == 0: + print(f'Init method: {init_method}') + print(f'Passed models({len(passed_models)}): {passed_models}\n\n') + print(f'Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n') + assert len(failed_info) == 0, '\n'.join([f'{k}: {v}' for k, v in failed_info.items()]) + + +def check_dataloader_sharding(): + plugin = GeminiPlugin() + + # create a custom dasetset with 0 to 10 + dataset = torch.utils.data.TensorDataset(torch.arange(0, 10)) + train_dataloader = plugin.prepare_train_dataloader(dataset, batch_size=2) + + # get the first batch of data + batch = next(iter(train_dataloader))[0].cuda() + is_rank_0 = dist.get_rank() == 0 + + if is_rank_0: + batch_to_compare = batch.clone() + else: + batch_to_compare = batch + # pass to the rank 1 value to rank 0 + dist.broadcast(batch_to_compare, src=1) + + # compare on rank 0 + if is_rank_0: + assert not torch.equal(batch, + batch_to_compare), 'Same number was found across ranks but expected it to be different' + + +def run_dist(rank, world_size, port, early_stop: bool = True): + # init dist env + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + check_dataloader_sharding() + check_gemini_plugin(early_stop=early_stop) + + +@rerun_if_address_is_in_use() +def test_gemini_plugin(early_stop: bool = True): + spawn(run_dist, 4, early_stop=early_stop) + + +if __name__ == '__main__': + test_gemini_plugin(early_stop=False) diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py new file mode 100644 index 0000000000000000000000000000000000000000..e24196a149172c16300845f156d429338fa7a1b8 --- /dev/null +++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py @@ -0,0 +1,122 @@ +from typing import Optional + +import torch +import torch.distributed as dist + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import LowLevelZeroPlugin +from colossalai.nn.optimizer import HybridAdam +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo + +# These models are not compatible with AMP +_AMP_ERR_MODELS = ['timm_convit', 'dlrm', 'deepfm_interactionarch', 'deepfm_simpledeepfmnn`'] +# These models have no parameters +_LOW_LEVEL_ZERO_ERR_MODELS = ['dlrm_interactionarch'] +# These models will get stuck +_STUCK_MODELS = [ + 'diffusers_vq_model', 'transformers_albert', 'transformers_albert_for_pretraining', 'transformers_bert', + 'transformers_bert_for_pretraining', 'transformers_gpt_double_heads' +] + + +def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: + try: + plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=2**5) + booster = Booster(plugin=plugin) + model = model_fn() + optimizer = HybridAdam(model.parameters(), lr=1e-3) + criterion = lambda x: x.mean() + data = data_gen_fn() + + data = { + k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items() + } + + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + output = model(**data) + output = output_transform_fn(output) + output_key = list(output.keys())[0] + loss = criterion(output[output_key]) + + booster.backward(loss, optimizer) + optimizer.step() + + except Exception as e: + return repr(e) + + +@parameterize('stage', [2]) +def check_low_level_zero_plugin(stage: int, early_stop: bool = True): + """check low level zero plugin over model zoo + + Args: + stage (int), stage of low level zero plugin + early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True. + """ + passed_models = [] + failed_info = {} # (model_name, error) pair + ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS + _STUCK_MODELS + skipped_models = [] + + for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items(): + # FIXME(ver217): fix these models + if name in ignore_models: + skipped_models.append(name) + continue + err = run_fn(stage, model_fn, data_gen_fn, output_transform_fn) + torch.cuda.empty_cache() + + if err is None: + passed_models.append(name) + else: + failed_info[name] = err + if early_stop: + break + + if dist.get_rank() == 0: + print(f'Passed models({len(passed_models)}): {passed_models}\n\n') + print(f'Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n') + print(f'Skipped models({len(skipped_models)}): {skipped_models}\n\n') + assert len(failed_info) == 0, '\n'.join([f'{k}: {v}' for k, v in failed_info.items()]) + + +def check_dataloader_sharding(): + plugin = LowLevelZeroPlugin() + + # create a custom dasetset with 0 to 10 + dataset = torch.utils.data.TensorDataset(torch.arange(0, 10)) + train_dataloader = plugin.prepare_train_dataloader(dataset, batch_size=2) + + # get the first batch of data + batch = next(iter(train_dataloader))[0].cuda() + is_rank_0 = dist.get_rank() == 0 + + if is_rank_0: + batch_to_compare = batch.clone() + else: + batch_to_compare = batch + # pass to the rank 1 value to rank 0 + dist.broadcast(batch_to_compare, src=1) + + # compare on rank 0 + if is_rank_0: + assert not torch.equal(batch, + batch_to_compare), 'Same number was found across ranks but expected it to be different' + + +def run_dist(rank, world_size, port, early_stop: bool = True): + # init dist env + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + check_low_level_zero_plugin(early_stop=early_stop) + + +@rerun_if_address_is_in_use() +def test_low_level_zero_plugin(early_stop: bool = True): + spawn(run_dist, 2, early_stop=early_stop) + + +if __name__ == '__main__': + test_low_level_zero_plugin(early_stop=False) diff --git a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py new file mode 100644 index 0000000000000000000000000000000000000000..5354eae01d40fbf4a0cd740c6568fb296b08f6ef --- /dev/null +++ b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py @@ -0,0 +1,103 @@ +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import SGD + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import TorchDDPPlugin +from colossalai.interface import OptimizerWrapper +from colossalai.testing import rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo + + +def run_fn(model_fn, data_gen_fn, output_transform_fn): + plugin = TorchDDPPlugin() + booster = Booster(plugin=plugin) + model = model_fn() + optimizer = SGD(model.parameters(), lr=1e-3) + criterion = lambda x: x.mean() + data = data_gen_fn() + + data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()} + + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + assert isinstance(model.module, DDP) + assert isinstance(optimizer, OptimizerWrapper) + + output = model(**data) + output = output_transform_fn(output) + output_key = list(output.keys())[0] + loss = criterion(output[output_key]) + + booster.backward(loss, optimizer) + optimizer.clip_grad_by_norm(1.0) + optimizer.step() + + +def check_torch_ddp_plugin(): + for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items(): + if name == 'dlrm_interactionarch': + continue + run_fn(model_fn, data_gen_fn, output_transform_fn) + torch.cuda.empty_cache() + + +def check_dataloader_sharding(): + plugin = TorchDDPPlugin() + + # create a custom dasetset with 0 to 10 + dataset = torch.utils.data.TensorDataset(torch.arange(0, 10)) + train_dataloader = plugin.prepare_train_dataloader(dataset, batch_size=2) + + # get the first batch of data + batch = next(iter(train_dataloader))[0].cuda() + is_rank_0 = dist.get_rank() == 0 + + if is_rank_0: + batch_to_compare = batch.clone() + else: + batch_to_compare = batch + # pass to the rank 1 value to rank 0 + dist.broadcast(batch_to_compare, src=1) + + # compare on rank 0 + if is_rank_0: + assert not torch.equal(batch, + batch_to_compare), 'Same number was found across ranks but expected it to be different' + + +def check_checkpoint_save_and_load(): + model_fn, data_gen_fn, output_transform_fn, _ = model_zoo['timm_resnet'] + + plugin = TorchDDPPlugin() + booster = Booster(plugin=plugin) + + model = model_fn() + optimizer = SGD(model.parameters(), lr=1e-3) + criterion = lambda x: x.mean() + data = data_gen_fn() + + data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()} + + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + output = model(**data) + output = output_transform_fn(output) + output_key = list(output.keys())[0] + loss = criterion(output[output_key]) + + booster.backward(loss, optimizer) + + +def run_dist(rank, world_size, port): + # init dist env + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + check_dataloader_sharding() + check_torch_ddp_plugin() + + +@rerun_if_address_is_in_use() +def test_torch_ddp_plugin(): + spawn(run_dist, 2) diff --git a/tests/test_checkpoint_io/test_general_checkpoint_io.py b/tests/test_checkpoint_io/test_general_checkpoint_io.py new file mode 100644 index 0000000000000000000000000000000000000000..ca5ce10054f7c85d86a9ab82a8141c0edfb7389c --- /dev/null +++ b/tests/test_checkpoint_io/test_general_checkpoint_io.py @@ -0,0 +1,126 @@ +import tempfile +import pytest +import torch +import logging +from torch.optim import Adam +from torchvision.models import resnet18 +from pathlib import Path +import os +import subprocess + +from colossalai.checkpoint_io import GeneralCheckpointIO +from colossalai.testing import clear_cache_before_run, parameterize + +# ======== +# Note: +# 1. due to checkpoint IO can be quite slow if tested with all models, we will only test on resnet for now +# 2. we will test on both sharded and unsharded checkpoints +# 3. implement sharded checkpoint and test it +# ======== + + +@clear_cache_before_run() +@parameterize('use_safetensors', [True, False]) +def test_unsharded_checkpoint(use_safetensors: bool): + # create a model and optimizer + model = resnet18() + optimizer = Adam(model.parameters(), lr=0.001) + + # create test data sample + x = torch.randn(1, 3, 224, 224) + + # run fwd and bwd + y = model(x) + loss = y.sum() + loss.backward() + optimizer.step() + + # create a temp file for checkpoint + if use_safetensors: + suffix = ".safetensors" + else: + suffix = ".bin" + model_ckpt_tempfile = tempfile.NamedTemporaryFile(suffix=suffix) + optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile() + + # save the model and optimizer + ckpt_io = GeneralCheckpointIO() + ckpt_io.save_model(model, model_ckpt_tempfile.name, use_safetensors=use_safetensors) + ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name) + + # create new model + new_model = resnet18() + new_optimizer = Adam(new_model.parameters(), lr=0.001) + + # load the model and optimizer + ckpt_io.load_model(new_model, model_ckpt_tempfile.name) + ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name) + + + # check for model and optimizer state dict recursively + recursive_check(model.state_dict(), new_model.state_dict()) + recursive_check(optimizer.state_dict(), new_optimizer.state_dict()) + +@pytest.mark.parametrize('use_safetensors', [True, False]) +def test_sharded_checkpoint(use_safetensors: bool): + # create a model and optimizer + model = resnet18() + optimizer = Adam(model.parameters(), lr=0.001) + # create test data sample + x = torch.randn(1, 3, 224, 224) + + # run fwd and bwd + y = model(x) + loss = y.sum() + loss.backward() + optimizer.step() + + # create a temp file for checkpoint + if use_safetensors: + suffix = ".safetensors" + SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json" + else: + suffix = ".bin" + WEIGHTS_INDEX_NAME = "model.bin.index.json" + + # model_ckpt_dir = tempfile.TemporaryDirectory(suffix=suffix) + model_ckpt_dir = tempfile.TemporaryDirectory() + optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile() + + # save the model and optimizer + ckpt_io = GeneralCheckpointIO() + + ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=use_safetensors) + ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name, shard=False) + + # create new model + new_model = resnet18() + new_optimizer = Adam(new_model.parameters(), lr=0.001) + + ckpt_io.load_model(new_model, str(model_ckpt_dir.name), strict=True) + ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name) + + # check for model and optimizer state dict recursively + recursive_check(model.state_dict(), new_model.state_dict()) + recursive_check(optimizer.state_dict(), new_optimizer.state_dict()) + + +# do recursive check for the optimizer state dict +# if the value is a dict, compare its values +# if the value is a list, comapre all elements one-by-one +# if the value is a torch.Tensor, use torch.equal +# otherwise use assertEqual +def recursive_check(d1, d2): + for k, v in d1.items(): + if isinstance(v, dict): + recursive_check(v, d2[k]) + elif isinstance(v, list): + for i in range(len(v)): + if isinstance(v[i], torch.Tensor): + assert torch.equal(v[i], d2[k][i]) + else: + assert v[i] == d2[k][i] + elif isinstance(v, torch.Tensor): + assert torch.equal(v, d2[k]) + else: + assert v == d2[k] diff --git a/tests/test_cluster/test_device_mesh_manager.py b/tests/test_cluster/test_device_mesh_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..b42ef1fe0062dc2dec944c321a63bf9812d86a6c --- /dev/null +++ b/tests/test_cluster/test_device_mesh_manager.py @@ -0,0 +1,33 @@ +import torch + +from colossalai.cluster.device_mesh_manager import DeviceMeshInfo, DeviceMeshManager +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import spawn + + +def check_device_mesh_manager(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + device_mesh_manager = DeviceMeshManager() + device_mesh_info_auto = DeviceMeshInfo(physical_ids=[0, 1, 2, 3],) + device_mesh_auto = device_mesh_manager.create_device_mesh('0', device_mesh_info_auto) + assert device_mesh_auto.shape == (2, 2) + assert device_mesh_auto._logical_mesh_id.tolist() == [[0, 1], [2, 3]] + + device_mesh_info_with_shape = DeviceMeshInfo( + physical_ids=[0, 1, 2, 3], + mesh_shape=(2, 2), + ) + device_mesh_with_shape = device_mesh_manager.create_device_mesh('1', device_mesh_info_with_shape) + + assert device_mesh_with_shape.shape == (2, 2) + assert device_mesh_with_shape._logical_mesh_id.tolist() == [[0, 1], [2, 3]] + + +def test_device_mesh_manager(): + spawn(check_device_mesh_manager, 4) + + +if __name__ == '__main__': + test_device_mesh_manager() diff --git a/tests/test_comm/test_boardcast_send_recv_v2.py b/tests/test_comm/test_boardcast_send_recv_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..253f6f21cd804c7d4d1e3d3e62b5ce2253ef271f --- /dev/null +++ b/tests/test_comm/test_boardcast_send_recv_v2.py @@ -0,0 +1,47 @@ +import pytest +import torch + +from colossalai.communication.p2p_v2 import _recv_object, _send_object +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_if_address_is_in_use, spawn + +disable_existing_loggers() +world_size = 4 +CONFIG = dict(parallel=dict(pipeline=world_size)) +torch.manual_seed(123) + + +def check_layer(rank, world_size, port): + disable_existing_loggers() + launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl', verbose=False) + rank = gpc.get_local_rank(ParallelMode.PIPELINE) + + if rank == 0: + obj = [torch.randn(3,)] + _send_object(obj, 1) + + if rank == 1: + _recv_object(0) + + if rank == 2: + _recv_object(3) + + if rank == 3: + obj = [torch.randn(3,)] + _send_object(obj, 2) + + gpc.destroy() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_object_list_p2p(): + spawn(check_layer, world_size) + + +if __name__ == '__main__': + test_object_list_p2p() diff --git a/tests/test_comm/test_comm.py b/tests/test_comm/test_comm.py new file mode 100644 index 0000000000000000000000000000000000000000..747596bd2dedff9a20222669bb611cf96d020724 --- /dev/null +++ b/tests/test_comm/test_comm.py @@ -0,0 +1,71 @@ +import pytest +import torch +import torch.distributed as dist + +from colossalai.communication import all_gather, all_reduce, reduce_scatter +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device + +CONFIG = dict(parallel=dict(data=8, pipeline=1, tensor=dict(mode=None, size=1))) + +SIZE = 8 + + +def check_all_gather(): + tensor = torch.tensor([dist.get_rank() * SIZE + j for j in range(SIZE)]) + tensor = tensor.to(get_current_device()) + print('Before: Rank {0} - {1}'.format(dist.get_rank(), tensor)) + tensor, op = all_gather(tensor, 0, ParallelMode.GLOBAL, async_op=True) + print('After: Rank {0} - {1}'.format(dist.get_rank(), tensor)) + op.wait() + print('Complete: Rank {0} - {1}'.format(dist.get_rank(), tensor)) + torch.cuda.synchronize() + + +def check_reduce_scatter(): + tensor = torch.tensor([dist.get_rank() * SIZE + j for j in range(SIZE)]) + tensor = tensor.to(get_current_device()) + print('Before: Rank {0} - {1}'.format(dist.get_rank(), tensor)) + tensor, op = reduce_scatter(tensor, 0, ParallelMode.GLOBAL, async_op=True) + print('After: Rank {0} - {1}'.format(dist.get_rank(), tensor)) + op.wait() + print('Complete: Rank {0} - {1}'.format(dist.get_rank(), tensor)) + torch.cuda.synchronize() + + +def check_all_reduce(): + tensor = torch.tensor([dist.get_rank() * SIZE + j for j in range(SIZE)]) + tensor = tensor.to(get_current_device()) + print('Before: Rank {0} - {1}'.format(dist.get_rank(), tensor)) + tensor, op = all_reduce(tensor, ParallelMode.GLOBAL, async_op=True) + print('After: Rank {0} - {1}'.format(dist.get_rank(), tensor)) + op.wait() + print('Complete: Rank {0} - {1}'.format(dist.get_rank(), tensor)) + torch.cuda.synchronize() + + +def check_layer(rank, world_size, port): + launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + assert dist.get_rank() == gpc.get_global_rank() + print('Rank {} / {}'.format(dist.get_rank(), dist.get_world_size())) + + check_all_gather() + check_reduce_scatter() + check_all_reduce() + + gpc.destroy() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_comm(): + spawn(check_layer, 4) + + +if __name__ == '__main__': + test_comm() diff --git a/tests/test_comm/test_object_list_p2p.py b/tests/test_comm/test_object_list_p2p.py new file mode 100644 index 0000000000000000000000000000000000000000..e9d7630c154307bb680a4fd85babc012bbaade47 --- /dev/null +++ b/tests/test_comm/test_object_list_p2p.py @@ -0,0 +1,106 @@ +import pytest +import torch + +from colossalai.communication.p2p import ( + recv_backward, + recv_forward, + send_backward, + send_backward_recv_forward, + send_forward, + send_forward_recv_backward, +) +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.testing import rerun_if_address_is_in_use, spawn + +CONFIG = dict(parallel=dict(pipeline=2)) +torch.manual_seed(123) +LIST_LENGTH = 3 +TENSOR_SIZE = torch.Size((3, 3)) +TENSOR_SIZE_LIST = [TENSOR_SIZE for i in range(LIST_LENGTH)] +data = torch.rand(3, 3) +data_list = [torch.rand(3, 3) for i in range(LIST_LENGTH)] +grad = torch.rand(3, 3) +grad_list = [torch.rand(3, 3) for i in range(LIST_LENGTH)] + + +def check_send_recv_forward(): + if gpc.get_local_rank(ParallelMode.PIPELINE) == 0: + device = torch.device('cuda:0') + data_to_send = data.to(device) + data_list_to_send = [] + for data_in_list in data_list: + data_list_to_send.append(data_in_list.to(device)) + send_forward(data_to_send) + send_forward(data_list_to_send) + else: + device = torch.device('cuda:1') + data_recv = recv_forward(TENSOR_SIZE) + data_list_recv = recv_forward(TENSOR_SIZE_LIST) + data_to_check = data.to(device) + assert data_recv.equal(data_to_check) + for data_recv, data_send in zip(data_list_recv, data_list): + data_to_check = data_send.to(device) + assert data_recv.equal(data_to_check) + + +def check_send_recv_backward(): + if gpc.get_local_rank(ParallelMode.PIPELINE) == 0: + device = torch.device('cuda:0') + grad_recv = recv_backward(TENSOR_SIZE) + grad_list_recv = recv_backward(TENSOR_SIZE_LIST) + grad_to_check = grad.to(device) + assert grad_recv.equal(grad_to_check) + for grad_recv, grad_send in zip(grad_list_recv, grad_list): + grad_to_check = grad_send.to(device) + assert grad_recv.equal(grad_to_check) + else: + device = torch.device('cuda:1') + grad_to_send = grad.to(device) + grad_list_to_send = [] + for grad_in_list in grad_list: + grad_list_to_send.append(grad_in_list.to(device)) + send_backward(grad_to_send) + send_backward(grad_list_to_send) + + +def check_send_recv_forward_backward(): + if gpc.get_local_rank(ParallelMode.PIPELINE) == 0: + device = torch.device('cuda:0') + data_list_to_send = [] + for data_in_list in data_list: + data_list_to_send.append(data_in_list.to(device)) + grad_list_recv = send_forward_recv_backward(data_list_to_send, TENSOR_SIZE_LIST) + + for grad_recv, grad_send in zip(grad_list_recv, grad_list): + grad_to_check = grad_send.to(device) + assert grad_recv.equal(grad_to_check) + else: + device = torch.device('cuda:1') + grad_list_to_send = [] + for grad_in_list in grad_list: + grad_list_to_send.append(grad_in_list.to(device)) + data_list_recv = send_backward_recv_forward(grad_list_to_send, TENSOR_SIZE_LIST) + for data_recv, data_send in zip(data_list_recv, data_list): + data_to_check = data_send.to(device) + assert data_recv.equal(data_to_check) + + +def check_layer(rank, world_size, port): + launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + check_send_recv_forward() + check_send_recv_backward() + check_send_recv_forward_backward() + gpc.destroy() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_object_list_p2p(): + spawn(check_layer, 2) + + +if __name__ == '__main__': + test_object_list_p2p() diff --git a/tests/test_comm/test_object_list_p2p_v2.py b/tests/test_comm/test_object_list_p2p_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..cae38385b6e17ddd8ebbd7e8ad7fbb9d3041ef24 --- /dev/null +++ b/tests/test_comm/test_object_list_p2p_v2.py @@ -0,0 +1,125 @@ +import pytest +import torch + +from colossalai.communication.p2p_v2 import recv_backward, recv_forward, send_backward, send_forward +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_if_address_is_in_use, spawn + +disable_existing_loggers() + +# config +world_size = 4 +CONFIG = dict(parallel=dict(pipeline=4)) +torch.manual_seed(123) +use_scatter_gather_tensors = False + +# data +torch.manual_seed(123) +LIST_LENGTH = 3 +TENSOR_SIZE = torch.Size((3, 3)) +TENSOR_SIZE_LIST = [TENSOR_SIZE for i in range(LIST_LENGTH)] +data = torch.rand(3, 3) +data_list = [torch.rand(3, 3) for i in range(LIST_LENGTH)] +grad = torch.rand(3, 3) +grad_list = [torch.rand(3, 3) for i in range(LIST_LENGTH)] + + +def check_send_recv_forward(): + disable_existing_loggers() + local_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + + if local_rank == 0: + device = torch.device('cuda:0') + data_to_send = data.to(device) + data_list_to_send = [] + for data_in_list in data_list: + data_list_to_send.append(data_in_list.to(device)) + + send_forward(data_to_send, scatter_gather_tensors=use_scatter_gather_tensors) + send_forward(data_list_to_send, scatter_gather_tensors=use_scatter_gather_tensors) + + elif local_rank == 1: + device = torch.device('cuda:1') + + data_recv = recv_forward(TENSOR_SIZE, scatter_gather_tensors=use_scatter_gather_tensors) + data_list_recv = recv_forward(TENSOR_SIZE_LIST, scatter_gather_tensors=use_scatter_gather_tensors) + + data_to_check = data.to(device) + + assert data_recv.equal(data_to_check) + + for data_recv, data_send in zip(data_list_recv, data_list): + data_to_check = data_send.to(device) + data_recv = data_recv.to(device) + assert data_recv.equal(data_to_check) + + +def check_send_recv_backward(): + disable_existing_loggers() + if gpc.get_local_rank(ParallelMode.PIPELINE) == 0: + device = torch.device('cuda:0') + grad_recv = recv_backward(TENSOR_SIZE) + grad_list_recv = recv_backward(TENSOR_SIZE_LIST) + + grad_to_check = grad.to(device) + grad_recv = grad_recv[0].to(device) + + assert grad_recv.equal(grad_to_check) + for grad_recv, grad_send in zip(grad_list_recv, grad_list): + grad_recv = grad_recv.to(device) + grad_to_check = grad_send.to(device) + assert grad_recv.equal(grad_to_check) + else: + device = torch.device('cuda:1') + grad_to_send = grad.to(device) + grad_list_to_send = [] + for grad_in_list in grad_list: + grad_list_to_send.append(grad_in_list.to(device)) + send_backward(grad_to_send) + send_backward(grad_list_to_send) + + +def check_small_pipeline(): + disable_existing_loggers() + # make sure the rank is 4 + assert gpc.world_size == 4, "make sure to set world size to 4 to start the training process" + local_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + if local_rank == 0: + obj = [1, torch.randn(2, 2).cuda(), None] + send_forward(obj) + elif local_rank == 1: + obj = recv_forward() + send_forward(obj) + elif local_rank == 2: + obj = recv_forward() + send_forward(obj) + elif local_rank == 3: + obj = recv_forward() + else: + pass + + +def check_layer(rank, world_size, port): + disable_existing_loggers() + launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + disable_existing_loggers() + # check_send_recv_forward() + check_small_pipeline() + + gpc.destroy() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_object_list_p2p(): + spawn(check_layer, world_size) + + +if __name__ == '__main__': + disable_existing_loggers() + test_object_list_p2p() diff --git a/tests/test_config/sample_config.py b/tests/test_config/sample_config.py new file mode 100644 index 0000000000000000000000000000000000000000..08ca108281b9c0700fef4ecb2c14416ccbabfd9f --- /dev/null +++ b/tests/test_config/sample_config.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +train_data = dict( + dataset=dict( + type='CIFAR10Dataset', + root='/path/to/data', + download=True, + transform_pipeline=[ + dict(type='RandomResizedCrop', size=224), + dict(type='RandomHorizontalFlip'), + dict(type='ToTensor'), + dict(type='Normalize', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) + ] + ), + dataloader=dict( + batch_size=64, + pin_memory=True, + num_workers=4, + sampler=dict( + type='DataParallelSampler', + shuffle=True, + ) + ) +) diff --git a/tests/test_config/test_load_config.py b/tests/test_config/test_load_config.py new file mode 100644 index 0000000000000000000000000000000000000000..550af2a4ae81656b4da0c159ce5a04bfdbb891cc --- /dev/null +++ b/tests/test_config/test_load_config.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from pathlib import Path + +import pytest + +from colossalai.context.config import Config + + +@pytest.mark.cpu +def test_load_config(): + filename = Path(__file__).parent.joinpath('sample_config.py') + config = Config.from_file(filename) + + assert config.train_data, 'cannot access train data as attribute' + assert config.train_data.dataset, 'cannot access grandchild attribute' + assert isinstance(config.train_data.dataset.transform_pipeline[0], dict), \ + f'expected attribute transform_pipeline elements to be a dict, but found {type(config.train_data.dataset.transform_pipeline)}' diff --git a/tests/test_context/configs/parallel_2d_init.py b/tests/test_context/configs/parallel_2d_init.py new file mode 100644 index 0000000000000000000000000000000000000000..6af884450ad0fee42d86fd1ad7ee950d576dd7da --- /dev/null +++ b/tests/test_context/configs/parallel_2d_init.py @@ -0,0 +1,10 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +parallel = dict( + pipeline=dict(size=2), + tensor=dict( + size=4, + mode='2d' + ) +) diff --git a/tests/test_context/configs/parallel_2p5d_init.py b/tests/test_context/configs/parallel_2p5d_init.py new file mode 100644 index 0000000000000000000000000000000000000000..c2d896d383e26d1530bd05d4127dfdafec57d826 --- /dev/null +++ b/tests/test_context/configs/parallel_2p5d_init.py @@ -0,0 +1,11 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +parallel = dict( + pipeline=dict(size=2), + tensor=dict( + size=8, + depth=2, + mode='2.5d' + ) +) diff --git a/tests/test_context/configs/parallel_3d_init.py b/tests/test_context/configs/parallel_3d_init.py new file mode 100644 index 0000000000000000000000000000000000000000..0ec724f8bb4f2513457568eaeb221727e4da2ff1 --- /dev/null +++ b/tests/test_context/configs/parallel_3d_init.py @@ -0,0 +1,10 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +parallel = dict( + pipeline=dict(size=2), + tensor=dict( + size=8, + mode='3d' + ) +) diff --git a/tests/test_context/test_hybrid_parallel.py b/tests/test_context/test_hybrid_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..9f26a5af53ce6d13fc0d00e4c4d8b949b1172b17 --- /dev/null +++ b/tests/test_context/test_hybrid_parallel.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from pathlib import Path + +import pytest +import torch + +from colossalai import launch +from colossalai.context import reset_seeds +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.global_variables import tensor_parallel_env as tp_env +from colossalai.testing import free_port, rerun_if_address_is_in_use, spawn + +CONFIG_PATH_LIST = list(Path(__file__).parent.glob('configs/*.py')) + + +def check_data_parallel_rank(rank): + global_world_size = gpc.get_world_size(ParallelMode.GLOBAL) + mp_size = gpc.get_world_size(ParallelMode.MODEL) + num_dp_groups = global_world_size // mp_size + dp_local_rank = gpc.get_local_rank(ParallelMode.DATA) + + assert gpc.get_world_size(ParallelMode.DATA) == num_dp_groups + + for group_idx in range(num_dp_groups): + ranks_in_dp_group = range(group_idx * mp_size, (group_idx + 1) * mp_size) + if rank in ranks_in_dp_group: + assert dp_local_rank == group_idx + + +def check_pipeline_parallel_rank(rank): + mp_world_size = gpc.get_world_size(ParallelMode.MODEL) + tp_world_size = gpc.get_world_size(ParallelMode.TENSOR) + num_pipeline_stage = mp_world_size // tp_world_size + pipeline_local_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + + for stage_idx in range(num_pipeline_stage): + ranks_in_current_stage = range(stage_idx * tp_world_size, (stage_idx + 1) * tp_world_size) + if rank in ranks_in_current_stage: + assert stage_idx == pipeline_local_rank + + +def check_model_parallel_rank(rank): + mp_size = gpc.get_world_size(ParallelMode.MODEL) + rank_within_mp_group = rank % mp_size + mp_local_rank = gpc.get_local_rank(ParallelMode.MODEL) + assert rank_within_mp_group == mp_local_rank + + +def check_tensor_parallel_rank(rank): + if tp_env.mode == '2d': + check_2d_tensor_parallel_rank(rank) + elif tp_env == '2.5d': + check_2p5d_tensor_parallel_rank(rank) + elif tp_env == '3d': + check_3d_tensor_parallel_rank(rank) + + +def get_tp_info(): + global_world_size = gpc.get_world_size(ParallelMode.GLOBAL) + tp_world_size = gpc.get_world_size(ParallelMode.TENSOR) + num_tp_groups = global_world_size // tp_world_size + tp_local_rank = gpc.get_local_rank(ParallelMode.TENSOR) + return tp_local_rank, tp_world_size, num_tp_groups + + +def check_2d_tensor_parallel_rank(rank): + tp_local_rank, tp_world_size, num_tp_groups = get_tp_info() + + for group_id in range(num_tp_groups): + ranks_in_current_tp_group = range(group_id * tp_world_size, (group_id + 1) * tp_world_size) + + if rank in ranks_in_current_tp_group: + col_local_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + row_local_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + + assert col_local_rank == tp_local_rank // tp_env.summa_dim + assert row_local_rank == tp_local_rank % tp_env.summa_dim + + +def check_2p5d_tensor_parallel_rank(rank): + tp_local_rank, tp_world_size, num_tp_groups = get_tp_info() + + for group_id in range(num_tp_groups): + ranks_in_current_tp_group = range(group_id * tp_world_size, (group_id + 1) * tp_world_size) + + if rank in ranks_in_current_tp_group: + rp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + cp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + dp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + xp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_XZ) + + assert rp_rank == tp_local_rank % tp_env.summa_dim + assert cp_rank == tp_local_rank // tp_env.tesseract_dim + assert dp_rank == tp_local_rank // (tp_env.summa_dim**2) + assert xp_rank == tp_local_rank // tp_env.summa_dim + + +def check_3d_tensor_parallel_rank(rank): + tp_local_rank, tp_world_size, num_tp_groups = get_tp_info() + + for group_id in range(num_tp_groups): + ranks_in_current_tp_group = range(group_id * tp_world_size, (group_id + 1) * tp_world_size) + + if rank in ranks_in_current_tp_group: + ip_rank = gpc.get_local_rank(ParallelMode.PARALLEL_3D_INPUT) + wp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT) + op_rank = gpc.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT) + + assert ip_rank == tp_local_rank % tp_env.depth_3d + assert wp_rank == tp_local_rank // tp_env.depth_3d + assert op_rank == tp_local_rank // (tp_env.depth_3d**2) + + +def init_context(config_path, rank, world_size, backend, port, host): + dist_args = dict(config=config_path, + rank=rank, + world_size=world_size, + backend=backend, + port=port, + host=host, + verbose=True) + launch(**dist_args) + + check_tensor_parallel_rank(rank) + check_data_parallel_rank(rank) + check_pipeline_parallel_rank(rank) + check_model_parallel_rank(rank) + gpc.destroy() + torch.cuda.empty_cache() + + +def run_dist(rank, world_size, port, backend, port_list, host): + for config_path, current_port in zip(CONFIG_PATH_LIST, port_list): + init_context(config_path=config_path, + rank=rank, + world_size=world_size, + backend=backend, + port=current_port, + host=host) + reset_seeds() + + +@pytest.mark.cpu +@rerun_if_address_is_in_use() +def test_context(): + """ + As no computation or communication is done, we can run this test on CPU. + """ + world_size = 32 + port_list = [] + + for _ in range(len(CONFIG_PATH_LIST)): + while True: + port = free_port() + if port not in port_list: + port_list.append(port) + break + + spawn(run_dist, world_size, backend='gloo', port_list=port_list, host='localhost') + + +if __name__ == '__main__': + test_context() diff --git a/tests/test_data/test_cifar10_dataset.py b/tests/test_data/test_cifar10_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..4b9ca61d9f1796353f48fd96ca2d8b1cbacfa17d --- /dev/null +++ b/tests/test_data/test_cifar10_dataset.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import os +from pathlib import Path + +import pytest +from torchvision import transforms, datasets +from torch.utils.data import DataLoader + + +@pytest.mark.cpu +def test_cifar10_dataset(): + # build transform + transform_pipeline = [transforms.ToTensor()] + transform_pipeline = transforms.Compose(transform_pipeline) + + # build dataset + dataset = datasets.CIFAR10(root=Path(os.environ['DATA']), train=True, download=True, transform=transform_pipeline) + + # build dataloader + dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True, num_workers=2) + data_iter = iter(dataloader) + img, label = data_iter.next() + + +if __name__ == '__main__': + test_cifar10_dataset() diff --git a/tests/test_data/test_data_parallel_sampler.py b/tests/test_data/test_data_parallel_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..2ad3fd696c39179d9eaec2a331925ea4f5ab3bf1 --- /dev/null +++ b/tests/test_data/test_data_parallel_sampler.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import os +from pathlib import Path + +import pytest +import torch +import torch.distributed as dist +from torchvision import datasets, transforms + +import colossalai +from colossalai.context import Config, ParallelMode +from colossalai.core import global_context as gpc +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import get_dataloader + +CONFIG = Config(dict( + parallel=dict( + pipeline=dict(size=1), + tensor=dict(size=1, mode=None), + ), + seed=1024, +)) + + +def run_data_sampler(rank, world_size, port): + dist_args = dict(config=CONFIG, rank=rank, world_size=world_size, backend='gloo', port=port, host='localhost') + colossalai.launch(**dist_args) + print('finished initialization') + + # build dataset + transform_pipeline = [transforms.ToTensor()] + transform_pipeline = transforms.Compose(transform_pipeline) + dataset = datasets.CIFAR10(root=Path(os.environ['DATA']), train=True, download=True, transform=transform_pipeline) + + # build dataloader + dataloader = get_dataloader(dataset, batch_size=8, add_sampler=True) + + data_iter = iter(dataloader) + img, label = data_iter.next() + img = img[0] + + if gpc.get_local_rank(ParallelMode.DATA) != 0: + img_to_compare = img.clone() + else: + img_to_compare = img + dist.broadcast(img_to_compare, src=0, group=gpc.get_group(ParallelMode.DATA)) + + if gpc.get_local_rank(ParallelMode.DATA) != 0: + assert not torch.equal( + img, img_to_compare), 'Same image was distributed across ranks but expected it to be different' + torch.cuda.empty_cache() + + +@pytest.mark.cpu +@rerun_if_address_is_in_use() +def test_data_sampler(): + spawn(run_data_sampler, 4) + + +if __name__ == '__main__': + test_data_sampler() diff --git a/tests/test_data/test_deterministic_dataloader.py b/tests/test_data/test_deterministic_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..239e79dff7d85f81f71802cbbdedbae6789b869d --- /dev/null +++ b/tests/test_data/test_deterministic_dataloader.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import os +from pathlib import Path + +import pytest +import torch +import torch.distributed as dist +from torchvision import datasets, transforms + +import colossalai +from colossalai.context import Config, ParallelMode +from colossalai.core import global_context as gpc +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import get_dataloader + +CONFIG = Config( + dict( + train_data=dict( + dataset=dict( + type='CIFAR10', + root=Path(os.environ['DATA']), + train=True, + download=True, + ), + dataloader=dict(num_workers=2, batch_size=2, shuffle=True), + ), + parallel=dict( + pipeline=dict(size=1), + tensor=dict(size=1, mode=None), + ), + seed=1024, + )) + + +def run_data_sampler(rank, world_size, port): + dist_args = dict(config=CONFIG, rank=rank, world_size=world_size, backend='gloo', port=port, host='localhost') + colossalai.launch(**dist_args) + + # build dataset + transform_pipeline = [transforms.ToTensor(), transforms.RandomCrop(size=32, padding=4)] + transform_pipeline = transforms.Compose(transform_pipeline) + dataset = datasets.CIFAR10(root=Path(os.environ['DATA']), train=True, download=True, transform=transform_pipeline) + + # build dataloader + dataloader = get_dataloader(dataset, batch_size=8, add_sampler=False) + + data_iter = iter(dataloader) + img, label = data_iter.next() + img = img[0] + + if gpc.get_local_rank(ParallelMode.DATA) != 0: + img_to_compare = img.clone() + else: + img_to_compare = img + dist.broadcast(img_to_compare, src=0, group=gpc.get_group(ParallelMode.DATA)) + + if gpc.get_local_rank(ParallelMode.DATA) != 0: + # this is without sampler + # this should be false if data parallel sampler to given to the dataloader + assert torch.equal(img, + img_to_compare), 'Same image was distributed across ranks and expected it to be the same' + torch.cuda.empty_cache() + + +@pytest.mark.cpu +@rerun_if_address_is_in_use() +def test_data_sampler(): + spawn(run_data_sampler, 4) + + +if __name__ == '__main__': + test_data_sampler() diff --git a/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py b/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..4d63592f12b0321ee3f48c0f289871cb213d6a17 --- /dev/null +++ b/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py @@ -0,0 +1,100 @@ +import os +from pathlib import Path + +import pytest +import torch +from torchvision import transforms +from torchvision.datasets import CIFAR10 + +import colossalai +from colossalai.amp import AMP_TYPE +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.nn import CrossEntropyLoss +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.pipeline.pipelinable import PipelinableContext +from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn +from colossalai.trainer import Trainer, hooks +from colossalai.utils import get_dataloader + +BATCH_SIZE = 4 +NUM_EPOCHS = 60 +WARMUP_EPOCHS = 5 +CONFIG = dict(NUM_MICRO_BATCHES=2, + parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')), + fp16=dict(mode=AMP_TYPE.NAIVE), + gradient_accumulation=2) + + +def run_trainer(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + logger = get_dist_logger() + + # get logger + logger = get_dist_logger() + + pipelinable = PipelinableContext() + try: + from titans.model.vit import vit_tiny_patch4_32 + except ImportError: + logger.warning('skip the test_cifar_with_data_pipeline_tensor test because titan is not installed') + logger.warning('please install titan from https://github.com/hpcaitech/Titans') + return + with pipelinable: + model = vit_tiny_patch4_32() + pipelinable.to_layer_list() + pipelinable.policy = "uniform" + model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE)) + + # craete dataloaders + root = Path(os.environ['DATA']) + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4, pad_if_needed=True), + transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + train_dataset = CIFAR10(root=root, train=True, download=True, transform=transform_train) + train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE, pin_memory=True) + + # create loss function + criterion = CrossEntropyLoss(label_smoothing=0.1) + + # create optimizer + optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0) + + # create lr scheduler + lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=NUM_EPOCHS, warmup_steps=WARMUP_EPOCHS) + + # intiailize + engine, train_dataloader, *_ = colossalai.initialize(model=model, + optimizer=optimizer, + criterion=criterion, + train_dataloader=train_dataloader) + + logger = get_dist_logger() + + trainer = Trainer(engine=engine, logger=logger) + + hook_list = [ + hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False), + ] + + trainer.fit(train_dataloader=train_dataloader, + epochs=NUM_EPOCHS, + max_steps=2, + hooks=hook_list, + display_progress=True) + + +@pytest.mark.dist +@skip_if_not_enough_gpus(min_gpus=8) +@rerun_if_address_is_in_use() +def test_hybrid_parallel(): + spawn(run_trainer, 8) + + +if __name__ == '__main__': + test_hybrid_parallel() diff --git a/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor_v2.py b/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..67d2ba5f5d987606586d5af0d17e058f9fb7c8b5 --- /dev/null +++ b/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor_v2.py @@ -0,0 +1,104 @@ +import os +from pathlib import Path + +import pytest +import torch +from torchvision import transforms +from torchvision.datasets import CIFAR10 + +import colossalai +from colossalai.amp import AMP_TYPE +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.engine.schedule._pipeline_schedule_v2 import PipelineScheduleV2 +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn import CrossEntropyLoss +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.pipeline.pipelinable import PipelinableContext +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.trainer import Trainer, hooks +from colossalai.utils import get_dataloader + +disable_existing_loggers() +BATCH_SIZE = 4 +NUM_EPOCHS = 10 +WARMUP_EPOCHS = 5 +CONFIG = dict(NUM_MICRO_BATCHES=2, + parallel=dict(pipeline=2, tensor=dict(size=1, mode='1d')), + fp16=dict(mode=AMP_TYPE.NAIVE), + gradient_accumulation=2) + + +def run_trainer(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + disable_existing_loggers() + # get logger + logger = get_dist_logger() + + pipelinable = PipelinableContext() + try: + from titans.model.vit import vit_tiny_patch4_32 + except ImportError: + logger.warning('skip the test_cifar_with_data_pipeline_tensor test because titan is not installed') + logger.warning('please install titan from https://github.com/hpcaitech/Titans') + return + with pipelinable: + model = vit_tiny_patch4_32() + pipelinable.to_layer_list() + pipelinable.policy = "uniform" + model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE)) + + # craete dataloaders + root = Path(os.environ['DATA']) + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4, pad_if_needed=True), + transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + train_dataset = CIFAR10(root=root, train=True, download=True, transform=transform_train) + train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE, pin_memory=True) + + # create loss function + criterion = CrossEntropyLoss(label_smoothing=0.1) + + # create optimizer + optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0) + + # create lr scheduler + lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=NUM_EPOCHS, warmup_steps=WARMUP_EPOCHS) + + # intiailize + engine, train_dataloader, *_ = colossalai.initialize(model=model, + optimizer=optimizer, + criterion=criterion, + train_dataloader=train_dataloader) + + engine._schedule = PipelineScheduleV2(num_microbatches=gpc.config.NUM_MICRO_BATCHES) + + logger = get_dist_logger() + + trainer = Trainer(engine=engine, logger=logger) + + hook_list = [ + hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False), + ] + + trainer.fit(train_dataloader=train_dataloader, + max_steps=2, + epochs=NUM_EPOCHS, + hooks=hook_list, + display_progress=True) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_hybrid_parallel(): + spawn(run_trainer, 2) + disable_existing_loggers() + + +if __name__ == '__main__': + test_hybrid_parallel() diff --git a/tests/test_ddp/test_ddp_ignore_params.py b/tests/test_ddp/test_ddp_ignore_params.py new file mode 100644 index 0000000000000000000000000000000000000000..39efcd41a1d46e13238fad7e27a5743665bd10f9 --- /dev/null +++ b/tests/test_ddp/test_ddp_ignore_params.py @@ -0,0 +1,92 @@ +import os +import random +from typing import Callable, Type + +import numpy as np +import pytest +import torch +import torch.distributed as dist + +import colossalai +from colossalai.nn.parallel import ColoDDP +from colossalai.tensor import ProcessGroup +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils.cuda import get_current_device +from colossalai.zero import ColoInitContext, ZeroDDP +from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.zero.gemini.gemini_mgr import GeminiManager + + +def set_seed(seed): + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + + +def init_ddp(module: torch.nn.Module) -> ColoDDP: + pg = ProcessGroup() + return ColoDDP(module, process_group=pg) + + +def init_ddpv2(module: torch.nn.Module) -> ZeroDDP: + chunk_config, *_ = search_chunk_configuration(module, 4, 1024) + chunk_manager = ChunkManager(chunk_config) + gemini_manager = GeminiManager('cuda', chunk_manager) + return ZeroDDP(module, gemini_manager) + + +class Net(torch.nn.Module): + + def __init__(self) -> None: + super().__init__() + self.fc1 = torch.nn.Linear(3, 3, bias=False) + self.fc2 = torch.nn.Linear(3, 1, bias=False) + + def forward(self, x): + return self.fc2(self.fc1(x)) + + +def run_fwd_bwd(ddp_cls: Type[ColoDDP], init_ddp_func: Callable[[torch.nn.Module], ColoDDP]): + with ColoInitContext(device=get_current_device()): + model = Net().cuda() + w1 = model.fc1.weight + w2 = model.fc2.weight + ddp_cls.set_params_to_ignore([w2]) + model = init_ddp_func(model) + x = torch.rand(2, 3, device=get_current_device()) + logits = model(x) + loss = torch.sum(logits) + model.backward(loss) + + if ddp_cls is ZeroDDP: + w1s_grad = w1 + else: + w1s_grad = w1.grad + + w1_grads = [torch.empty_like(w1) for _ in range(dist.get_world_size())] + dist.all_gather(w1_grads, w1s_grad) + assert torch.equal(w1_grads[0], w1_grads[1]) + w2_grads = [torch.empty_like(w2) for _ in range(dist.get_world_size())] + dist.all_gather(w2_grads, w2.grad) + assert not torch.equal(w2_grads[0], w2_grads[1]) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + set_seed(dist.get_rank()) + run_fwd_bwd(ColoDDP, init_ddp) + run_fwd_bwd(ZeroDDP, init_ddpv2) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [2]) +@rerun_if_address_is_in_use() +def test_ddp_ignore_params(world_size): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_ddp_ignore_params(2) diff --git a/tests/test_ddp/test_ddp_state_dict.py b/tests/test_ddp/test_ddp_state_dict.py new file mode 100644 index 0000000000000000000000000000000000000000..54f89f972765ecab174a8370ef7024b714fae8b5 --- /dev/null +++ b/tests/test_ddp/test_ddp_state_dict.py @@ -0,0 +1,67 @@ +from collections import OrderedDict + +import pytest +import torch + +import colossalai +from colossalai.nn.parallel import ColoDDP +from colossalai.tensor import ColoParameter, ProcessGroup +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils.cuda import get_current_device +from colossalai.zero import ColoInitContext +from tests.components_to_test.registry import non_distributed_component_funcs + + +def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDict): + for (k1, t1), (k2, t2) in zip(state_dict.items(), other_state_dict.items()): + assert k1 == k2 + + if t1.device != t2.device: + temp_t2 = t2.to(t1.device) + else: + temp_t2 = t2 + + assert torch.equal(t1, temp_t2), "\t{}\n\t{}".format(t1, temp_t2) + + +def init_ddp(module: torch.nn.Module) -> ColoDDP: + pg = ProcessGroup() + return ColoDDP(module, process_group=pg) + + +def run_ddp_state_dict(): + get_components_func = non_distributed_component_funcs.get_callable('gpt2') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + torch_model = model_builder().cuda() + with ColoInitContext(device=get_current_device()): + model = model_builder() + model = init_ddp(model) + torch_state_dict = torch_model.state_dict() + + for param in model.parameters(): + if isinstance(param, ColoParameter): + assert param.get_process_group() is not None + model.load_state_dict(torch_state_dict) + + for param in model.parameters(): + if isinstance(param, ColoParameter): + assert param.get_process_group() is not None + + state_dict = model.state_dict() + check_state_dict_equal(torch_state_dict, state_dict) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_ddp_state_dict() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2]) +@rerun_if_address_is_in_use() +def test_state_dict(world_size): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_state_dict(2) diff --git a/tests/test_ddp/test_reducer.py b/tests/test_ddp/test_reducer.py new file mode 100644 index 0000000000000000000000000000000000000000..e8d3a112c938d7c883439314cf3e735dec2e8af3 --- /dev/null +++ b/tests/test_ddp/test_reducer.py @@ -0,0 +1,47 @@ +from functools import partial + +import pytest +import torch +import torch.distributed as dist +from torch.distributed.distributed_c10d import _get_default_group + +import colossalai +from colossalai.nn.parallel.reducer import Reducer +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils.cuda import get_current_device + +REDUCE_CNT = 0 + + +def check_eq(grad, grad_clone): + global REDUCE_CNT + print(f'Rank{dist.get_rank()} check {REDUCE_CNT}') + REDUCE_CNT += 1 + assert torch.allclose(grad, grad_clone) + + +def run_reducer(): + grads = [torch.rand(64, i + 1, device=get_current_device()) for i in range(10)] + grads_clone = [g.clone().detach() for g in grads] + for g in grads: + dist.all_reduce(g) + reducer = Reducer(bucket_size_mb=1) + for g, g_clone in zip(grads, grads_clone): + reducer.all_reduce_async(g_clone, _get_default_group(), partial(check_eq, g)) + reducer.flush() + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_reducer() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2]) +@rerun_if_address_is_in_use() +def test_reducer(world_size): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_reducer(2) diff --git a/tests/test_device/test_alpha_beta.py b/tests/test_device/test_alpha_beta.py new file mode 100644 index 0000000000000000000000000000000000000000..ab933ed57d0d2630f5ebd7dad6b86afc1677f6a5 --- /dev/null +++ b/tests/test_device/test_alpha_beta.py @@ -0,0 +1,27 @@ +import pytest + +from colossalai.device import AlphaBetaProfiler +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + + +def check_alpha_beta(rank, world_size, port, physical_devices): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + profiler = AlphaBetaProfiler(physical_devices) + ab_dict = profiler.profile_ab() + for _, (alpha, beta) in ab_dict.items(): + assert alpha > 0 and alpha < 1e-4 and beta > 0 and beta < 1e-10 + + +@pytest.mark.skip(reason="Skip because assertion fails for CI devices") +@pytest.mark.dist +@parameterize('physical_devices', [[0, 1, 2, 3], [0, 3]]) +@rerun_if_address_is_in_use() +def test_profile_alpha_beta(physical_devices): + spawn(check_alpha_beta, 4, physical_devices=physical_devices) + + +if __name__ == '__main__': + test_profile_alpha_beta() diff --git a/tests/test_device/test_device_mesh.py b/tests/test_device/test_device_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..3be057b3a98bba70c3e162d7148fb9c66ee84792 --- /dev/null +++ b/tests/test_device/test_device_mesh.py @@ -0,0 +1,21 @@ +from colossalai.device.device_mesh import DeviceMesh +import torch + + +def test_device_mesh(): + physical_mesh_id = torch.arange(0, 16).reshape(2, 8) + mesh_shape = (4, 4) + # [[0, 1, 2, 3], + # [4, 5, 6, 7], + # [8, 9, 10,11], + # [12,13,14,15]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + assert device_mesh.convert_map[5] == [1, 1] + assert device_mesh.convert_map[11] == [2, 3] + assert device_mesh.global_rank_to_process_groups_with_logical_rank(0)[0] == [[0, 0], [1, 0], [2, 0], [3, 0]] + assert device_mesh.global_rank_to_process_groups_with_logical_rank(2)[1] == [[0, 0], [0, 1], [0, 2], [0, 3]] + assert device_mesh.global_rank_to_process_groups_with_global_rank(2)[1] == [0, 1, 2, 3] + + +if __name__ == '__main__': + test_device_mesh() diff --git a/tests/test_device/test_extract_alpha_beta.py b/tests/test_device/test_extract_alpha_beta.py new file mode 100644 index 0000000000000000000000000000000000000000..52604b9c6a4909d437ef55f446619d49681ccebe --- /dev/null +++ b/tests/test_device/test_extract_alpha_beta.py @@ -0,0 +1,30 @@ +import pytest + +from colossalai.device import AlphaBetaProfiler +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + + +def check_extract_alpha_beta(rank, world_size, port, physical_devices): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + profiler = AlphaBetaProfiler(physical_devices) + + mesh_alpha, mesh_beta = profiler.extract_alpha_beta_for_device_mesh() + for alpha in mesh_alpha: + assert alpha > 0 and alpha < 1e-3 + for beta in mesh_beta: + assert beta > 0 and beta < 1e-10 + + +@pytest.mark.skip(reason="Skip because assertion may fail for CI devices") +@pytest.mark.dist +@parameterize('physical_devices', [[0, 1, 2, 3], [0, 3]]) +@rerun_if_address_is_in_use() +def test_profile_alpha_beta(physical_devices): + spawn(check_extract_alpha_beta, 4, physical_devices=physical_devices) + + +if __name__ == '__main__': + test_profile_alpha_beta() diff --git a/tests/test_device/test_init_logical_pg.py b/tests/test_device/test_init_logical_pg.py new file mode 100644 index 0000000000000000000000000000000000000000..2b7060c4846aef61d7680e4884fe0f87526f8795 --- /dev/null +++ b/tests/test_device/test_init_logical_pg.py @@ -0,0 +1,44 @@ +import pytest +import torch +import torch.distributed as dist +from torch.distributed import ReduceOp + +from colossalai.core import global_context as gpc +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +def check_layer(rank, world_size, port): + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + physical_mesh_id = torch.arange(0, 4) + assert rank == gpc.get_global_rank() + + tensor_to_check = torch.tensor([2, 2, 2, 2]).cuda() + mesh_shape = (2, 2) + # [[0, 1, + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + logical_pg_dict = {0: [[0, 2], [1, 3]], 1: [[0, 1], [2, 3]]} + logical_process_groups = device_mesh.process_groups_dict + + for mesh_dim, pgs in logical_pg_dict.items(): + for index, pg in enumerate(pgs): + if rank in pg: + tensor = torch.ones(4).cuda() + group = logical_process_groups[mesh_dim][index][1] + dist.all_reduce(tensor, op=ReduceOp.SUM, group=group) + assert tensor.equal(tensor_to_check) + + gpc.destroy() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_logical_pg(): + spawn(check_layer, 4) + + +if __name__ == '__main__': + test_logical_pg() diff --git a/tests/test_device/test_search_logical_device_mesh.py b/tests/test_device/test_search_logical_device_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..b22a76eabc2fec34ec7d9922a9cd988c58c576aa --- /dev/null +++ b/tests/test_device/test_search_logical_device_mesh.py @@ -0,0 +1,30 @@ +import pytest + +from colossalai.device import AlphaBetaProfiler +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + + +def check_alpha_beta(rank, world_size, port, physical_devices): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + profiler = AlphaBetaProfiler(physical_devices) + best_logical_mesh = profiler.search_best_logical_mesh() + + if physical_devices == [0, 1, 2, 3]: + assert best_logical_mesh == [[0, 1], [2, 3]] + elif physical_devices == [0, 3]: + assert best_logical_mesh == [[0, 3]] + + +@pytest.mark.skip(reason="Skip because assertion may fail for CI devices") +@pytest.mark.dist +@parameterize('physical_devices', [[0, 1, 2, 3], [0, 3]]) +@rerun_if_address_is_in_use() +def test_profile_alpha_beta(physical_devices): + spawn(check_alpha_beta, 4, physical_devices=physical_devices) + + +if __name__ == '__main__': + test_profile_alpha_beta() diff --git a/tests/test_engine/test_engine.py b/tests/test_engine/test_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..62493cf3712dd46eba2bde0dc3d34531725230e3 --- /dev/null +++ b/tests/test_engine/test_engine.py @@ -0,0 +1,62 @@ +import pytest + +import colossalai +from colossalai.amp import AMP_TYPE +from colossalai.core import global_context as gpc +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from tests.components_to_test.registry import non_distributed_component_funcs + +CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)), + fp16=dict(mode=None), + clip_grad_norm=1.0) + + +@parameterize('model_name', ['repeated_computed_layers', 'resnet18', 'repeated_computed_layers']) +@parameterize('amp_mode', [AMP_TYPE.APEX, AMP_TYPE.TORCH, AMP_TYPE.NAIVE, None]) +def run_train(model_name, amp_mode): + # FIXME: test bert + get_components_func = non_distributed_component_funcs.get_callable(model_name) + gpc.config.fp16['mode'] = amp_mode + model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() + + model = model_builder(checkpoint=False) + engine, train_dataloader, *args = colossalai.initialize(model=model, + optimizer=optimizer_class(model.parameters(), lr=1e-3), + criterion=criterion, + train_dataloader=train_dataloader) + + try: + engine.train() + for data, label in train_dataloader: + engine.zero_grad() + data = data.cuda() + label = label.cuda() + if criterion: + output = engine(data) + loss = engine.criterion(output, label) + else: + loss = engine(data, label) + engine.backward(loss) + engine.step() + break + except IndexError: + # if using apex amp, NetWithRepeatedlyComputedLayers will raise an index out of range issue + # the following check fails in apex + # if cached_x.grad_fn.next_functions[1][0].variable is not x: + pass + + +def run_engine(rank, world_size, port): + # init dist env + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_train() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_engine(): + spawn(run_engine, 2) + + +if __name__ == '__main__': + test_engine() diff --git a/tests/test_engine/test_gradient_accumluation.py b/tests/test_engine/test_gradient_accumluation.py new file mode 100644 index 0000000000000000000000000000000000000000..7783827c7c44c929a4a66667ffdf24ffc2d1aa8f --- /dev/null +++ b/tests/test_engine/test_gradient_accumluation.py @@ -0,0 +1,95 @@ +import os +from pathlib import Path + +import pytest +import torch +import torch.nn as nn +from torch.optim import Adam +from torchvision import transforms +from torchvision.datasets import CIFAR10 +from torchvision.models import resnet18 + +import colossalai +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import get_dataloader + +# Config +BATCH_SIZE = 2 +NUM_CLASSES = 10 + +CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)), + clip_grad_norm=1.0, + gradient_accumulation=4) + + +def run_no_pipeline(rank, world_size, port): + + # init dist env + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + # build model + model = resnet18(num_classes=10) + + # build dataloaders + train_dataset = CIFAR10(root=Path(os.environ['DATA']), + download=True, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) + ])) + train_dataloader = get_dataloader(dataset=train_dataset, + shuffle=True, + batch_size=BATCH_SIZE, + pin_memory=True, + drop_last=True) + + # build optimizer + optimizer = Adam(model.parameters(), lr=0.001) + criterion = nn.CrossEntropyLoss() + + engine, train_dataloader, *args = colossalai.initialize(model=model, + optimizer=optimizer, + criterion=criterion, + train_dataloader=train_dataloader) + logger = get_dist_logger() + rank = torch.distributed.get_rank() + param_track = [] + grad_track = [] + next(model.parameters()).retain_grad() + + engine.train() + step = 0 + for img, label in train_dataloader: + engine.zero_grad() + img = img.cuda() + label = label.cuda() + output = engine(img) + loss = engine.criterion(output, label) + engine.backward(loss) + engine.step() + + # check + param_track.append(next(model.parameters())[0].clone()) + grad_track.append(next(model.parameters()).grad[0].clone()) + step += 1 + if step == CONFIG['gradient_accumulation']: + break + + assert not torch.all(grad_track[0] == grad_track[-1]), 'grad should be different in different iterations' + assert torch.all(param_track[0] == param_track[1]) and not torch.all(param_track[0] == param_track[-1]), \ + 'param should be the same in the first few iterations and only changed in the last iteration' + + gpc.destroy() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_engine(): + spawn(run_no_pipeline, 4) + + +if __name__ == '__main__': + test_engine() diff --git a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py new file mode 100644 index 0000000000000000000000000000000000000000..ab483f7e47a3e0821b36f06e6ecfad6f56e90a83 --- /dev/null +++ b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py @@ -0,0 +1,182 @@ +import pytest +import torch +import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint + +import colossalai +from colossalai.core import global_context as gpc +from colossalai.fx import ColoTracer +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.testing import rerun_if_address_is_in_use, spawn + +try: + from colossalai.fx.codegen import ActivationCheckpointCodeGen + with_codegen = True +except: + # fall back to older pytorch version + from colossalai.fx.codegen import python_code_with_activation_checkpoint + with_codegen = False + + +class MLP(torch.nn.Module): + + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(4, 4) + self.linear2 = torch.nn.Linear(4, 4) + + def forward(self, x): + return self.linear1(x), self.linear2(x) + + +class relu(torch.nn.Module): + + def __init__(self) -> None: + super().__init__() + self.relu = torch.nn.ReLU(inplace=True) + + def forward(self, x): + return self.relu(x) + + +class MyModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.mlp1 = MLP() + self.relu = relu() + self.linear2 = torch.nn.Linear(4, 4) + + def ckpt2(self, x): + return F.relu(x, inplace=True) + + def ckpt3(self, x, y): + return self.linear2(x) + self.linear2(y) + + def forward(self, x, y): + y1, y2 = checkpoint(self.mlp1, x) + y3 = checkpoint(self.relu, x) + + y4 = checkpoint(self.ckpt2, y) + y5 = checkpoint(self.ckpt3, y, y4) + y6 = self.linear2(y4) + return y1 + y2 + y3 + y4 + y5 + y6 + + +def _run_act_ckpt_codegen(rank, world_size, port): + # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + # build model and run forward + model = MyModule() + data1 = torch.rand(4, 4) + data2 = torch.rand(4, 4) + + # copy model to cuda + model = model.to(device="cuda") + data1 = data1.to(device="cuda") + data2 = data2.to(device="cuda") + + non_fx_out = model(data1, data2) + + # trace the module and replace codegen + tracer = ColoTracer(trace_act_ckpt=True) + graph = tracer.trace(model) + codegen = ActivationCheckpointCodeGen() + graph.set_codegen(codegen) + + # check ops are annotated with ckpt + # also annotate the selected node for offloading + ckpt_nodes = ['mlp1_linear1', 'mlp1_linear2', 'relu_relu', 'relu'] + offload_starts = ['mlp1_linear1'] + for node in graph.nodes: + if node.name in ckpt_nodes: + assert 'activation_checkpoint' in node.meta + + # annotate the selected node for offload + if node.name in offload_starts: + node.meta['activation_offload'] = True + + gm = ColoGraphModule(model, graph) + gm.recompile() + + # assert checkpoint function will be generated and + # the offload option is correct + code = graph.python_code('self').src + assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, x, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, y, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_3, False, y, relu, use_reentrant=True)' in code + + # recompile and verify the outputs are consistent + fx_out = gm(data1, data2) + assert torch.equal(non_fx_out, fx_out) + + gpc.destroy() + + +@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') +@rerun_if_address_is_in_use() +def test_act_ckpt_codegen(): + spawn(_run_act_ckpt_codegen, 1) + + +def _run_act_ckpt_python_code_torch11(rank, world_size, port): + # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + # build model and run forward + model = MyModule() + data1 = torch.rand(4, 4) + data2 = torch.rand(4, 4) + + # copy model to cuda + data1 = data1.to(device="cuda") + data2 = data2.to(device="cuda") + + non_fx_out = model(data1, data2) + + # trace the module and replace codegen + tracer = ColoTracer(trace_act_ckpt=True) + graph = tracer.trace(model) + + # replace a bound method of an object + graph._python_code = python_code_with_activation_checkpoint.__get__(graph) + + # check ops are annotated with ckpt + ckpt_nodes = ['mlp1_linear1', 'mlp1_linear2', 'relu_relu', 'relu'] + offload_starts = ['mlp1_linear1'] + for node in graph.nodes: + if node.name in ckpt_nodes: + assert 'activation_checkpoint' in node.meta + + # annotate the selected node for offload + if node.name in offload_starts: + node.meta['activation_offload'] = True + + gm = ColoGraphModule(model, graph) + gm.recompile() + # assert checkpoint function will be generated and + # the offload option is correct + code = graph.python_code('self').src + assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, x, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, y, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_3, False, y, relu, use_reentrant=True)' in code + + # recompile and verify the outputs are consistent + fx_out = gm(data1, data2) + assert torch.equal(non_fx_out, fx_out) + + gpc.destroy() + + +@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') +@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done") +@rerun_if_address_is_in_use() +def test_act_ckpt_python_code_torch11(): + spawn(_run_act_ckpt_python_code_torch11, 1) + + +if __name__ == '__main__': + _run_act_ckpt_codegen(rank=0) diff --git a/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py b/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py new file mode 100644 index 0000000000000000000000000000000000000000..9064023d4f68299e0520ba1f0321a0d91b239527 --- /dev/null +++ b/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py @@ -0,0 +1,151 @@ +import pytest +import torch + +import colossalai +from colossalai.core import global_context as gpc +from colossalai.fx import ColoTracer +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.testing import rerun_if_address_is_in_use, spawn + +try: + from colossalai.fx.codegen import ActivationCheckpointCodeGen + with_codegen = True +except: + # fall back to older pytorch version + from colossalai.fx.codegen import python_code_with_activation_checkpoint + with_codegen = False + + +class MyModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(4, 4) + self.linear2 = torch.nn.Linear(4, 4) + self.linear3 = torch.nn.Linear(4, 4) + self.linear4 = torch.nn.Linear(4, 4) + self.linear5 = torch.nn.Linear(4, 4) + self.linear6 = torch.nn.Linear(4, 4) + + def forward(self, x): + return self.linear6(self.linear5(self.linear4(self.linear3(self.linear2(self.linear1(x)))))) + + +def _run_act_ckpt_codegen(rank, world_size, port): + # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + # build model and run forward + model = MyModule() + data1 = torch.rand(4, 4) + + # copy model to cuda + model = model.to(device="cuda") + data1 = data1.to(device="cuda") + + non_fx_out = model(data1) + + # trace the module and replace codegen + tracer = ColoTracer(trace_act_ckpt=True) + graph = tracer.trace(model) + codegen = ActivationCheckpointCodeGen() + graph.set_codegen(codegen) + + # annotate nested checkpoint + for node in graph.nodes: + if node.name == "linear1": + node.meta['activation_checkpoint'] = [0, 0, 0] + continue + if node.name == "linear2": + node.meta['activation_checkpoint'] = [0, 0, None] + if node.name == "linear3": + node.meta['activation_checkpoint'] = [0, 0, 1] + if node.name == "linear4": + node.meta['activation_checkpoint'] = [0, 1, None] + if node.name == "linear5": + node.meta['activation_checkpoint'] = 1 + gm = ColoGraphModule(model, graph) + gm.recompile() + + # assert checkpoint function will be generated and + code = graph.python_code('self').src + assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0, False, x, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_1, False, linear3, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_0, False, x, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_1, False, linear2, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, x, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, linear4, use_reentrant=False)' in code + + # recompile and verify the outputs are consistent + fx_out = gm(data1) + assert torch.equal(non_fx_out, fx_out) + + gpc.destroy() + + +@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') +def test_act_ckpt_codegen(): + spawn(_run_act_ckpt_codegen, 1) + + +def _run_act_ckpt_python_code_torch11(rank, world_size, port): + # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + # build model and run forward + model = MyModule() + data1 = torch.rand(4, 4) + + # copy model to cuda + model = model.to(device="cuda") + data1 = data1.to(device="cuda") + + non_fx_out = model(data1) + + # trace the module and replace codegen + tracer = ColoTracer(trace_act_ckpt=True) + graph = tracer.trace(model) + codegen = ActivationCheckpointCodeGen() + graph.set_codegen(codegen) + + # annotate nested checkpoint + for node in graph.nodes: + if node.name == "linear1": + node.meta['activation_checkpoint'] = [0, 0, 0] + continue + if node.name == "linear2": + node.meta['activation_checkpoint'] = [0, 0, None] + if node.name == "linear3": + node.meta['activation_checkpoint'] = [0, 0, 1] + if node.name == "linear4": + node.meta['activation_checkpoint'] = [0, 1, None] + if node.name == "linear5": + node.meta['activation_checkpoint'] = 1 + gm = ColoGraphModule(model, graph) + gm.recompile() + + # assert checkpoint function will be generated and + code = graph.python_code('self').src + assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0, False, x, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_1, False, linear3, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_0, False, x, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_1, False, linear2, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, x, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, linear4, use_reentrant=False)' in code + + # recompile and verify the outputs are consistent + fx_out = gm(data1) + assert torch.equal(non_fx_out, fx_out) + + gpc.destroy() + + +@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') +@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done") +@rerun_if_address_is_in_use() +def test_act_ckpt_python_code_torch11(): + spawn(_run_act_ckpt_python_code_torch11, 1) + + +if __name__ == '__main__': + _run_act_ckpt_codegen(rank=0) diff --git a/tests/test_fx/test_codegen/test_offload_codegen.py b/tests/test_fx/test_codegen/test_offload_codegen.py new file mode 100644 index 0000000000000000000000000000000000000000..96e88eb92b3319bf8cd54c9a0175d6e9bf5f5617 --- /dev/null +++ b/tests/test_fx/test_codegen/test_offload_codegen.py @@ -0,0 +1,179 @@ +import copy + +import pytest +import torch +from torch.fx import GraphModule + +import colossalai +from colossalai.core import global_context as gpc +from colossalai.fx import ColoTracer +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.testing import rerun_if_address_is_in_use, spawn + +try: + from colossalai.fx.codegen import ActivationCheckpointCodeGen + with_codegen = True +except: + # fall back to older pytorch version + from colossalai.fx.codegen import python_code_with_activation_checkpoint + with_codegen = False + + +class MyNet(torch.nn.Module): + + def __init__(self) -> None: + super().__init__() + self.linear0 = torch.nn.Linear(4, 4) + self.linear1 = torch.nn.Linear(4, 4) + self.linear2 = torch.nn.Linear(4, 4) + self.linear3 = torch.nn.Linear(4, 4) + self.linear4 = torch.nn.Linear(4, 4) + self.linear5 = torch.nn.Linear(4, 4) + self.linear6 = torch.nn.Linear(4, 4) + + def forward(self, x): + x = self.linear0(x) + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + x = self.linear4(x) + x = self.linear5(x) + x = self.linear6(x) + return x + + +def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule) -> bool: + for m_p, gm_p in zip(m.parameters(), gm.parameters()): + if not torch.allclose(m_p.grad, gm_p.grad): + return False + return True + + +def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, data: torch.Tensor): + + # test forward + non_fx_out = model(data) + fx_out = gm(data) + assert torch.equal(non_fx_out, fx_out), "fx_out doesn't comply with original output" + + # test barckward + loss0 = non_fx_out.sum() + loss0.backward() + loss1 = fx_out.sum() + loss1.backward() + assert _is_all_gradient_close(model, gm), "gm doesn't have the same gradient as original one" + + +def _run_offload_codegen(rank, world_size, port): + # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + # build model and input + model = MyNet().cuda() + data = torch.rand(4, 4).cuda() + + # trace the module and replace codegen + tracer = ColoTracer(trace_act_ckpt=True) + graph = tracer.trace(model) + codegen = ActivationCheckpointCodeGen() + graph.set_codegen(codegen) + + # annotate the activation offload part + # also annotate the activation_checkpoint so we could test both types + # of input offload + for node in graph.nodes: + if node.name == "linear0": + node.meta['activation_offload'] = [0, True, False] + if node.name == "linear1": + node.meta['activation_offload'] = [0, True, False] + if node.name == "linear2": + node.meta['activation_offload'] = [1, True, True] + if node.name == "linear4": + node.meta['activation_offload'] = [2, False, True] + if node.name == "linear5": + node.meta['activation_checkpoint'] = [0] + node.meta['activation_offload'] = True + + gm = ColoGraphModule(copy.deepcopy(model), graph) + gm.recompile() + + # assert we have all the components + code = graph.python_code("self").src + assert "def pack_hook_input(self, x):" in code and \ + "def unpack_hook(self, packed):" in code and \ + "def pack_hook_no_input(self, x):" in code and \ + "setattr(x, 'offload', True)" in code and \ + "setattr(linear3, 'offload', False)" in code and \ + "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code and \ + "with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code and \ + "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code and \ + "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" in code + + _test_fwd_and_bwd(model, gm, data) + gpc.destroy() + + +@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') +@rerun_if_address_is_in_use() +def test_act_ckpt_codegen(): + spawn(_run_offload_codegen, 1) + + +def _run_offload_codegen_torch11(rank, world_size, port): + # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + # build model and input + model = MyNet().cuda() + data = torch.rand(4, 4).cuda() + + # trace the module and replace codegen + tracer = ColoTracer(trace_act_ckpt=True) + graph = tracer.trace(model) + + # replace a bound method of an object + graph._python_code = python_code_with_activation_checkpoint.__get__(graph) + + # annotate the activation offload part + # also annotate the activation_checkpoint so we could test both types + # of input offload + for node in graph.nodes: + if node.name == "linear0": + node.meta['activation_offload'] = [0, True, False] + if node.name == "linear1": + node.meta['activation_offload'] = [0, True, False] + if node.name == "linear2": + node.meta['activation_offload'] = [1, True, True] + if node.name == "linear4": + node.meta['activation_offload'] = [2, False, True] + if node.name == "linear5": + node.meta['activation_checkpoint'] = [0] + node.meta['activation_offload'] = True + + gm = ColoGraphModule(copy.deepcopy(model), graph) + gm.recompile() + + # assert we have all the components + code = graph.python_code("self").src + assert "def pack_hook_input(self, x):" in code and \ + "def unpack_hook(self, packed):" in code and \ + "def pack_hook_no_input(self, x):" in code and \ + "setattr(x, 'offload', True)" in code and \ + "setattr(linear3, 'offload', False)" in code and \ + "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code and \ + "with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code and \ + "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code and \ + "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" in code + + _test_fwd_and_bwd(model, gm, data) + gpc.destroy() + + +@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not implemented") +@rerun_if_address_is_in_use() +def test_act_ckpt_python_code_torch11(): + spawn(_run_offload_codegen_torch11, 1) + + +if __name__ == "__main__": + _run_offload_codegen(0) diff --git a/tests/test_fx/test_coloproxy.py b/tests/test_fx/test_coloproxy.py new file mode 100644 index 0000000000000000000000000000000000000000..96cf5198da10a5eb3d0d595e73d9b796547afe86 --- /dev/null +++ b/tests/test_fx/test_coloproxy.py @@ -0,0 +1,51 @@ +import pytest +import torch +import torch.nn as nn +from torch.fx import GraphModule + +from colossalai.fx.proxy import ColoProxy +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.testing import clear_cache_before_run + + +class Conv1D(nn.Module): + + def __init__(self, nf, nx): + super().__init__() + self.nf = nf + w = torch.empty(nx, nf) + nn.init.normal_(w, std=0.02) + self.weight = nn.Parameter(w) + self.bias = nn.Parameter(torch.zeros(nf)) + + def forward(self, x): + size_out = x.shape[:-1] + (self.nf,) + x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) + x = x.view(size_out) + return x + + +@clear_cache_before_run() +def test_coloproxy(): + + tracer = ColoTracer() + model = Conv1D(3, 3) + input_sample = {'x': torch.rand(3, 3).to('meta')} + + graph = tracer.trace(root=model, meta_args=input_sample) + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + node = list(gm.graph.nodes)[0] + + proxy = ColoProxy(node=node, tracer=tracer) + proxy.meta_data = torch.empty(4, 2, device='meta') + + assert len(proxy) == 4 + assert proxy.shape[0] == 4 and proxy.shape[1] == 2 + assert proxy.dim() == 2 + assert proxy.dtype == torch.float32 + assert proxy.size(0) == 4 + + +if __name__ == '__main__': + test_coloproxy() diff --git a/tests/test_fx/test_comm_size_compute.py b/tests/test_fx/test_comm_size_compute.py new file mode 100644 index 0000000000000000000000000000000000000000..d3daadd714064b38112d5af976e9a32725faddf7 --- /dev/null +++ b/tests/test_fx/test_comm_size_compute.py @@ -0,0 +1,53 @@ +import torch +from torch.fx import symbolic_trace + +from colossalai.fx._compatibility import is_compatible_with_meta +from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, uniform_split_pass +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.fx.passes.utils import get_comm_size +from colossalai.testing import clear_cache_before_run + +is_compatible = is_compatible_with_meta() +if is_compatible: + from colossalai.fx.profiler import MetaTensor + +MODEL_DIM = 16 +BATCH_SIZE = 8 +PIPELINE_SIZE = 2 + + +class MLP(torch.nn.Module): + + def __init__(self, dim: int): + super().__init__() + self.linear1 = torch.nn.Linear(dim, dim) + self.linear2 = torch.nn.Linear(dim, dim) + self.linear3 = torch.nn.Linear(dim, dim) + self.linear4 = torch.nn.Linear(dim, dim) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + x = self.linear4(x) + return x + + +@clear_cache_before_run() +def test_comm_size_compute(): + model = MLP(MODEL_DIM) + input_sample = torch.rand(BATCH_SIZE, MODEL_DIM, device='meta') + gm = symbolic_trace(model) + if is_compatible: + input_sample = MetaTensor(input_sample, fake_device=next(gm.parameters()).device) + MetaInfoProp(gm).run(input_sample) + annotated_model = uniform_split_pass(gm, PIPELINE_SIZE) + split_model, split_submodules = split_with_split_nodes_pass(annotated_model) + submodule_list = list(split_model.children()) + comm_size = get_comm_size(submodule_list[0], submodule_list[1]) + # the shape of tensor send from partition 0 to partition 1 is (8, 16) + assert comm_size == 128 + + +if __name__ == '__main__': + test_comm_size_compute() diff --git a/tests/test_fx/test_graph_manipulation.py b/tests/test_fx/test_graph_manipulation.py new file mode 100644 index 0000000000000000000000000000000000000000..175b69dd96fe53204ed1263daab40b923ed54206 --- /dev/null +++ b/tests/test_fx/test_graph_manipulation.py @@ -0,0 +1,53 @@ +import torch +from torch.fx import GraphModule + +import colossalai +from colossalai.fx import ColoTracer +from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata +from colossalai.fx.passes.utils import assign_bfs_level_to_nodes, get_leaf, get_top +from colossalai.testing import clear_cache_before_run + + +class MLP(torch.nn.Module): + + def __init__(self, dim: int): + super().__init__() + self.linear1 = torch.nn.Linear(dim, dim) + self.linear2 = torch.nn.Linear(dim, dim) + self.linear3 = torch.nn.Linear(dim, dim) + self.linear4 = torch.nn.Linear(dim, dim) + self.linear5 = torch.nn.Linear(dim, dim) + + def forward(self, x): + l1 = self.linear1(x) + l2 = self.linear2(x) + l3 = self.linear3(l1) + l4 = self.linear4(l2) + l5 = self.linear5(l3) + return l4, l5 + + +@clear_cache_before_run() +def test_graph_manipulation(): + model = MLP(4) + tracer = ColoTracer() + graph = tracer.trace(model) + nodes = list(graph.nodes) + x, l1, l2, l3, l4, l5, output = nodes + + leaf_nodes = set(get_leaf(graph)) + top_nodes = set(get_top(graph)) + compare_dict = {x: None, l1: 0, l2: 0, l3: 1, l4: 1, l5: 2, output: None} + assign_bfs_level_to_nodes(graph) + + assert leaf_nodes == set([l4, l5]) + assert top_nodes == set([l1, l2]) + for node in graph.nodes: + if node.op in ('placeholder', 'output'): + assert not hasattr(node, 'bfs_level') + else: + assert node.bfs_level == compare_dict[node] + + +if __name__ == '__main__': + test_graph_manipulation() diff --git a/tests/test_fx/test_meta/test_aten.py b/tests/test_fx/test_meta/test_aten.py new file mode 100644 index 0000000000000000000000000000000000000000..e490522dbf15677f5af4fa5b8c04520338101585 --- /dev/null +++ b/tests/test_fx/test_meta/test_aten.py @@ -0,0 +1,84 @@ +from typing import Any, Callable, Union + +import pytest +import torch +import torch.nn as nn + +from colossalai.fx._compatibility import is_compatible_with_meta +from colossalai.testing import clear_cache_before_run + +if is_compatible_with_meta(): + from colossalai.fx.profiler import MetaTensor + +aten = torch.ops.aten + +registered_meta = { + ('aten.convolution.default', True): [ # (aten ops, requires_backward) + (nn.Conv1d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4)), + (nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4, 4)), + (nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4, 4, 4)), + (nn.ConvTranspose1d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4)), + (nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, + dilation=2), torch.rand(2, 3, 4, 4)), + (nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, + dilation=2), torch.rand(2, 3, 4, 4, 4)), + ], + ('aten.native_batch_norm.default', True): [ + (nn.BatchNorm1d(4), torch.rand(2, 4)), + (nn.BatchNorm2d(4), torch.rand(1, 4, 4, 4)), + (nn.BatchNorm3d(4), torch.rand(1, 4, 4, 4, 4)), + ], + ('aten.native_layer_norm.default', True): [(nn.LayerNorm(4), torch.rand(1, 2, 3, 4)),], + ('aten.avg_pool1d.default', True): [ + (nn.MaxPool1d(3, stride=2), torch.rand(4, 5, 5)), + (nn.AvgPool1d(3, stride=2), torch.rand(4, 5, 5)), + (nn.AdaptiveMaxPool1d(3), torch.rand(4, 5, 5)), + (nn.AdaptiveAvgPool1d(3), torch.rand(4, 5, 5)), + ], + ('aten.avg_pool2d.default', True): [ + (nn.MaxPool2d((3, 2), stride=(2, 1)), torch.rand(2, 4, 5, 5)), + (nn.AvgPool2d((3, 2), stride=(2, 1)), torch.rand(2, 4, 5, 5)), + (nn.AdaptiveMaxPool2d((3, 2)), torch.rand(2, 4, 5, 5)), + (nn.AdaptiveAvgPool2d((3, 2)), torch.rand(2, 4, 5, 5)), + ], + ('aten.relu.default', True): [ + (nn.ReLU(), torch.rand(4, 3, 1, 2)), + (nn.LeakyReLU(), torch.rand(4, 3, 1, 2)), + (nn.SiLU(), torch.rand(4, 3, 1, 2)), + (nn.GELU(), torch.rand(4, 3, 1, 2)), + (nn.ELU(), torch.rand(4, 3, 1, 2)), + (nn.Sigmoid(), torch.rand(4, 3, 1, 2)), + (nn.Tanh(), torch.rand(4, 3, 1, 2)), + (nn.Hardswish(), torch.rand(4, 3, 1, 2)), + ] +} + + +def compare_all(tensor: torch.Tensor, meta_tensor: torch.Tensor) -> Any: + assert tensor.shape == meta_tensor.shape, f'the shape of tensor ({tensor.shape}) and meta tensor ({meta_tensor.shape}) does not match.' + assert tensor.dtype == meta_tensor.dtype, f'the dtype of tensor ({tensor.dtype}) and meta tensor ({meta_tensor.dtype}) does not match.' + assert tensor.stride() == meta_tensor.stride( + ), f'the stride of tensor ({tensor.stride()}) and meta tensor ({meta_tensor.stride()}) does not match.' + + +def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_backward=False) -> Any: + x.requires_grad = requires_backward + meta_x = MetaTensor(x) + x_out, meta_out = f(x), f(meta_x) + compare_all(x_out, meta_out) + if requires_backward: + x_out.sum().backward() + meta_out.sum().backward() + compare_all(x.grad, meta_x.grad) + + +@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0') +@clear_cache_before_run() +def test_meta_aten(): + for (aten_op, requires_backward), v in registered_meta.items(): + for f, x in v: + run_and_compare(f, x, requires_backward) + + +if __name__ == '__main__': + test_meta_aten() diff --git a/tests/test_fx/test_meta/test_backward.py b/tests/test_fx/test_meta/test_backward.py new file mode 100644 index 0000000000000000000000000000000000000000..7aed6fd4597b2286bbcda9627c7dd1ba0b482593 --- /dev/null +++ b/tests/test_fx/test_meta/test_backward.py @@ -0,0 +1,53 @@ +import pytest +import timm.models as tmm +import torch +import torchvision.models as tm + +from colossalai.fx._compatibility import is_compatible_with_meta + +if is_compatible_with_meta(): + from colossalai.fx.profiler import MetaTensor + +from colossalai.testing import clear_cache_before_run + +tm_models = [ + tm.vgg11, + tm.resnet18, + tm.densenet121, + tm.mobilenet_v3_small, + tm.resnext50_32x4d, + tm.wide_resnet50_2, + tm.regnet_x_16gf, + tm.mnasnet0_5, + tm.efficientnet_b0, +] + +tmm_models = [ + tmm.resnest.resnest50d, tmm.beit.beit_base_patch16_224, tmm.cait.cait_s24_224, tmm.efficientnet.efficientnetv2_m, + tmm.resmlp_12_224, tmm.vision_transformer.vit_base_patch16_224, tmm.deit_base_distilled_patch16_224, + tmm.convnext.convnext_base, tmm.vgg.vgg11, tmm.dpn.dpn68, tmm.densenet.densenet121, tmm.rexnet.rexnet_100, + tmm.swin_transformer.swin_base_patch4_window7_224 +] + + +@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0') +@clear_cache_before_run() +def test_torchvision_models(): + for m in tm_models: + model = m() + data = torch.rand(100000, 3, 224, 224, device='meta') + model(MetaTensor(data, fake_device=torch.device('cpu'))).sum().backward() + + +@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0') +@clear_cache_before_run() +def test_timm_models(): + for m in tmm_models: + model = m() + data = torch.rand(100000, 3, 224, 224, device='meta') + model(MetaTensor(data, fake_device=torch.device('cpu'))).sum().backward() + + +if __name__ == '__main__': + test_torchvision_models() + test_timm_models() diff --git a/tests/test_fx/test_meta/test_meta_trace.py b/tests/test_fx/test_meta/test_meta_trace.py new file mode 100644 index 0000000000000000000000000000000000000000..61614f8a66234260f54f8286bcfe628d55452eab --- /dev/null +++ b/tests/test_fx/test_meta/test_meta_trace.py @@ -0,0 +1,53 @@ +import pytest +import timm.models as tmm +import torch +import torchvision.models as tm + +from colossalai.fx._compatibility import is_compatible_with_meta + +if is_compatible_with_meta(): + from colossalai.fx import meta_trace + +from colossalai.testing import clear_cache_before_run + +tm_models = [ + tm.vgg11, + tm.resnet18, + tm.densenet121, + tm.mobilenet_v3_small, + tm.resnext50_32x4d, + tm.wide_resnet50_2, + tm.regnet_x_16gf, + tm.mnasnet0_5, + tm.efficientnet_b0, +] + +tmm_models = [ + tmm.resnest.resnest50d, tmm.beit.beit_base_patch16_224, tmm.cait.cait_s24_224, tmm.efficientnet.efficientnetv2_m, + tmm.resmlp_12_224, tmm.vision_transformer.vit_base_patch16_224, tmm.deit_base_distilled_patch16_224, + tmm.convnext.convnext_base, tmm.vgg.vgg11, tmm.dpn.dpn68, tmm.densenet.densenet121, tmm.rexnet.rexnet_100, + tmm.swin_transformer.swin_base_patch4_window7_224 +] + + +@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0') +@clear_cache_before_run() +def test_torchvision_models_trace(): + for m in tm_models: + model = m() + data = torch.rand(1000, 3, 224, 224, device='meta') + graph = meta_trace(model, torch.device('cpu'), data) + + +@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0') +@clear_cache_before_run() +def test_timm_models_trace(): + for m in tmm_models: + model = m() + data = torch.rand(1000, 3, 224, 224, device='meta') + graph = meta_trace(model, torch.device('cpu'), data) + + +if __name__ == '__main__': + test_torchvision_models_trace() + test_timm_models_trace() diff --git a/tests/test_fx/test_meta_info_prop.py b/tests/test_fx/test_meta_info_prop.py new file mode 100644 index 0000000000000000000000000000000000000000..a12512696a730459b80a9debc4cfa43695be8515 --- /dev/null +++ b/tests/test_fx/test_meta_info_prop.py @@ -0,0 +1,40 @@ +import torch +from torch.fx import symbolic_trace + +from colossalai.fx._compatibility import is_compatible_with_meta +from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata +from colossalai.testing import clear_cache_before_run + +if is_compatible_with_meta(): + from colossalai.fx.profiler import MetaTensor + +BATCH_SIZE = 2 +DIM_IN = 4 +DIM_OUT = 16 + + +def meta_check(meta_info_spec: TensorMetadata, orig_tensor: torch.Tensor): + assert meta_info_spec.shape == orig_tensor.shape + assert meta_info_spec.dtype == orig_tensor.dtype + assert meta_info_spec.stride == orig_tensor.stride() + assert meta_info_spec.numel == orig_tensor.numel() + + +@clear_cache_before_run() +def test_meta_info_prop(): + model = torch.nn.Linear(DIM_IN, DIM_OUT) + input_sample = torch.rand(BATCH_SIZE, DIM_IN, device='meta') + if is_compatible_with_meta(): + input_sample = MetaTensor(input_sample, fake_device='cpu') + orig_output = model(input_sample) + gm = symbolic_trace(model) + MetaInfoProp(gm).run(input_sample) + for node in gm.graph.nodes: + if node.op == 'placeholder': + meta_check(node.meta['tensor_meta'], input_sample) + if node.op == 'output': + meta_check(node.meta['tensor_meta'], orig_output) + + +if __name__ == '__main__': + test_meta_info_prop() diff --git a/tests/test_fx/test_parallel_1d.py b/tests/test_fx/test_parallel_1d.py new file mode 100644 index 0000000000000000000000000000000000000000..1044be7db1f4cbab519d8a90a70f8f3deb3016a2 --- /dev/null +++ b/tests/test_fx/test_parallel_1d.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import pytest +import torch +from torch.fx import symbolic_trace + +from colossalai.core import global_context as gpc +from colossalai.fx.passes import column_shard_linear_pass +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn + + +class MLP(torch.nn.Module): + + def __init__(self, dim: int): + super().__init__() + self.linear1 = torch.nn.Linear(dim, dim) + self.linear2 = torch.nn.Linear(dim, dim) + self.linear3 = torch.nn.Linear(dim, dim) + self.linear4 = torch.nn.Linear(dim, dim) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + x = self.linear4(x) + return x + + +CONFIG = dict(parallel=dict(tensor=dict(mode='1d', size=2))) + + +def check_layer(rank, world_size, port): + disable_existing_loggers() + launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + input_tensor = torch.rand(2, 16).cuda() + model = MLP(16).cuda() + symbolic_traced = symbolic_trace(model) + output = model(input_tensor) + splitted_gm = column_shard_linear_pass(symbolic_traced) + new_output = splitted_gm(input_tensor) + + assert output.equal(new_output) + + gpc.destroy() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@clear_cache_before_run() +@rerun_if_address_is_in_use() +def test_1d(): + spawn(check_layer, 2) + + +if __name__ == '__main__': + test_1d() diff --git a/tests/test_fx/test_pipeline/test_hf_model/hf_utils.py b/tests/test_fx/test_pipeline/test_hf_model/hf_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3afc6c97e2bb69778b5fce5668861cba3d3582eb --- /dev/null +++ b/tests/test_fx/test_pipeline/test_hf_model/hf_utils.py @@ -0,0 +1,69 @@ +import torch +from torch.fx import symbolic_trace +from torch.fx import GraphModule +from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass +from colossalai.fx import ColoTracer +import inspect +import random +import numpy as np + +MANUAL_SEED = 0 +random.seed(MANUAL_SEED) +np.random.seed(MANUAL_SEED) +torch.manual_seed(MANUAL_SEED) + + +def split_model_and_compare_output(model, data_gen): + model.eval() + + # generate input sample + kwargs = data_gen() + + # get origin output and rng state + cpu_rng_state = torch.get_rng_state() + output = model(**kwargs) + + # tracing model + tracer = ColoTracer() + try: + meta_args = {k: v.to('meta') for k, v in kwargs.items()} + graph = tracer.trace(root=model, meta_args=meta_args) + except Exception as e: + raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}") + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + + # apply transform passes + annotated_model = balanced_split_pass(gm, 2) + split_model, split_submodules = split_with_split_nodes_pass(annotated_model) + + # get split model + model_part0 = list(split_model.children())[0] + model_part1 = list(split_model.children())[1] + + # set rng state and compute output of split model + torch.set_rng_state(cpu_rng_state) + output_part0 = model_part0(**kwargs) + sig = inspect.signature(model_part1.forward) + if isinstance(output_part0, torch.Tensor): + output_part1 = model_part1(output_part0) + else: + if len(output_part0) > len(sig.parameters): + output_part0 = output_part0[:len(sig.parameters)] + output_part1 = model_part1(*output_part0) + + # get output tensor from HFOutput datastructure + if 'logits' in output: + output_to_compare = output['logits'] + elif 'prediction_logits' in output: + output_to_compare = output['prediction_logits'] + else: + output_to_compare = output['last_hidden_state'] + + # compare output + if isinstance(output_part1, torch.Tensor): + assert output_to_compare.equal(output_part1) + elif isinstance(output_part1, (tuple, list)): + assert output_to_compare.equal(output_part1[0]) + else: + assert False diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_albert.py b/tests/test_fx/test_pipeline/test_hf_model/test_albert.py new file mode 100644 index 0000000000000000000000000000000000000000..6ef861bdefbe105299e31e8f06f8d1eb464fb3c2 --- /dev/null +++ b/tests/test_fx/test_pipeline/test_hf_model/test_albert.py @@ -0,0 +1,40 @@ +import pytest +import torch +import transformers +from hf_utils import split_model_and_compare_output + +BATCH_SIZE = 2 +SEQ_LENGHT = 16 + + +@pytest.mark.skip('balance split v2 is not ready') +def test_single_sentence_albert(): + MODEL_LIST = [ + transformers.AlbertModel, + transformers.AlbertForPreTraining, + transformers.AlbertForMaskedLM, + transformers.AlbertForSequenceClassification, + transformers.AlbertForTokenClassification, + ] + + config = transformers.AlbertConfig(vocab_size=100, + embedding_size=128, + hidden_size=128, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=256) + + def data_gen(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + return meta_args + + for model_cls in MODEL_LIST: + model = model_cls(config=config) + split_model_and_compare_output(model, data_gen) + + +if __name__ == '__main__': + test_single_sentence_albert() diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_bert.py b/tests/test_fx/test_pipeline/test_hf_model/test_bert.py new file mode 100644 index 0000000000000000000000000000000000000000..a7550413fac8f6514e1428b8d7b3999608d84bbc --- /dev/null +++ b/tests/test_fx/test_pipeline/test_hf_model/test_bert.py @@ -0,0 +1,40 @@ +import pytest +import torch +import transformers +from hf_utils import split_model_and_compare_output + +BATCH_SIZE = 2 +SEQ_LENGHT = 16 + + +@pytest.mark.skip('balance split v2 is not ready') +def test_single_sentence_bert(): + MODEL_LIST = [ + transformers.BertModel, + transformers.BertForPreTraining, + transformers.BertLMHeadModel, + transformers.BertForMaskedLM, + transformers.BertForSequenceClassification, + transformers.BertForTokenClassification, + ] + + config = transformers.BertConfig(vocab_size=100, + hidden_size=128, + num_hidden_layers=4, + num_attention_heads=4, + intermediate_size=256) + + def data_gen(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + return meta_args + + for model_cls in MODEL_LIST: + model = model_cls(config=config) + split_model_and_compare_output(model, data_gen) + + +if __name__ == '__main__': + test_single_sentence_bert() diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_gpt.py b/tests/test_fx/test_pipeline/test_hf_model/test_gpt.py new file mode 100644 index 0000000000000000000000000000000000000000..6181c5c0706a97d895fb4d51d012809a2a5336a4 --- /dev/null +++ b/tests/test_fx/test_pipeline/test_hf_model/test_gpt.py @@ -0,0 +1,36 @@ +import pytest +import torch +import transformers +from hf_utils import split_model_and_compare_output + +BATCH_SIZE = 64 +SEQ_LENGHT = 16 +NUM_EPOCHS = 2 +NUM_CHUNKS = 1 + + +@pytest.mark.skip('balance split v2 is not ready') +def test_gpt(): + MODEL_LIST = [ + transformers.GPT2Model, + transformers.GPT2LMHeadModel, + transformers.GPT2DoubleHeadsModel, + transformers.GPT2ForTokenClassification, + # transformers.GPT2ForSequenceClassification, # not supported yet + ] + config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=8) + + def data_gen(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + return kwargs + + for model_cls in MODEL_LIST: + model = model_cls(config=config) + split_model_and_compare_output(model, data_gen) + + +if __name__ == '__main__': + test_gpt() diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_opt.py b/tests/test_fx/test_pipeline/test_hf_model/test_opt.py new file mode 100644 index 0000000000000000000000000000000000000000..1a9b36be82bd9e9fb29580b6fe6807dbb1eb032a --- /dev/null +++ b/tests/test_fx/test_pipeline/test_hf_model/test_opt.py @@ -0,0 +1,31 @@ +import pytest +import torch +import transformers +from hf_utils import split_model_and_compare_output + +BATCH_SIZE = 1 +SEQ_LENGHT = 16 + + +@pytest.mark.skip('balance split v2 is not ready') +def test_opt(): + MODEL_LIST = [ + transformers.OPTModel, + transformers.OPTForCausalLM, + ] + + config = transformers.OPTConfig(vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4) + + def data_gen(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) + return kwargs + + for model_cls in MODEL_LIST: + model = model_cls(config=config) + split_model_and_compare_output(model, data_gen) + + +if __name__ == '__main__': + test_opt() diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_t5.py b/tests/test_fx/test_pipeline/test_hf_model/test_t5.py new file mode 100644 index 0000000000000000000000000000000000000000..16d0163746b318dec6be47529b9f9c55d23d8a7e --- /dev/null +++ b/tests/test_fx/test_pipeline/test_hf_model/test_t5.py @@ -0,0 +1,43 @@ +import pytest +import torch +import transformers +from hf_utils import split_model_and_compare_output + +BATCH_SIZE = 1 +SEQ_LENGHT = 16 + + +@pytest.mark.skip('balance split v2 is not ready') +def test_t5(): + MODEL_LIST = [ + transformers.T5Model, + transformers.T5ForConditionalGeneration, + transformers.T5EncoderModel, + ] + + config = transformers.T5Config(vocab_size=100, d_model=128, num_layers=2) + + def data_gen(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + decoder_input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + kwargs = dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + return kwargs + + def data_gen_for_encoder_only(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + kwargs = dict(input_ids=input_ids) + return kwargs + + for model_cls in MODEL_LIST: + model = model_cls(config=config) + + if isinstance(model, transformers.T5EncoderModel): + data_gen_func = data_gen_for_encoder_only + else: + data_gen_func = data_gen + + split_model_and_compare_output(model, data_gen_func) + + +if __name__ == '__main__': + test_t5() diff --git a/tests/test_fx/test_pipeline/test_timm_model/test_timm.py b/tests/test_fx/test_pipeline/test_timm_model/test_timm.py new file mode 100644 index 0000000000000000000000000000000000000000..6fb1f6f4bb237db355e54d7b57730b3e9a6990a5 --- /dev/null +++ b/tests/test_fx/test_pipeline/test_timm_model/test_timm.py @@ -0,0 +1,48 @@ +import pytest +import timm.models as tm +import torch +from timm_utils import split_model_and_compare_output + + +@pytest.mark.skip('balance split v2 is not ready') +def test_timm_models_without_control_flow(): + + MODEL_LIST = [ + tm.resnest.resnest50d, + tm.beit.beit_base_patch16_224, + tm.cait.cait_s24_224, + tm.convmixer.convmixer_768_32, + tm.efficientnet.efficientnetv2_m, + tm.resmlp_12_224, + tm.vision_transformer.vit_base_patch16_224, + tm.deit_base_distilled_patch16_224, + ] + + data = torch.rand(2, 3, 224, 224) + + for model_cls in MODEL_LIST: + model = model_cls() + split_model_and_compare_output(model, data) + + +@pytest.mark.skip('balance split v2 is not ready') +def test_timm_models_with_control_flow(): + torch.backends.cudnn.deterministic = True + + MODEL_LIST_WITH_CONTROL_FLOW = [ + tm.convnext.convnext_base, tm.vgg.vgg11, tm.dpn.dpn68, tm.densenet.densenet121, tm.rexnet.rexnet_100, + tm.swin_transformer.swin_base_patch4_window7_224 + ] + + data = torch.rand(2, 3, 224, 224) + + meta_args = {'x': data.to('meta')} + + for model_cls in MODEL_LIST_WITH_CONTROL_FLOW: + model = model_cls() + split_model_and_compare_output(model, data, meta_args) + + +if __name__ == '__main__': + test_timm_models_without_control_flow() + test_timm_models_with_control_flow() diff --git a/tests/test_fx/test_pipeline/test_timm_model/timm_utils.py b/tests/test_fx/test_pipeline/test_timm_model/timm_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..aa870e5c7a659c5a9c09eae5a9a5d5ec3b77220f --- /dev/null +++ b/tests/test_fx/test_pipeline/test_timm_model/timm_utils.py @@ -0,0 +1,51 @@ +import torch +from torch.fx import symbolic_trace +from torch.fx import GraphModule +from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass +from colossalai.fx import ColoTracer +import inspect +import random +import numpy as np + +MANUAL_SEED = 0 +random.seed(MANUAL_SEED) +np.random.seed(MANUAL_SEED) +torch.manual_seed(MANUAL_SEED) +torch.backends.cudnn.deterministic = True + + +def split_model_and_compare_output(model, data, meta_args=None): + model.eval() + + # get origin output and rng state + cpu_rng_state = torch.get_rng_state() + output = model(data) + + # tracing model + tracer = ColoTracer() + try: + graph = tracer.trace(root=model, meta_args=meta_args) + except Exception as e: + raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}") + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + + # apply transform passes + annotated_model = balanced_split_pass(gm, 2) + split_model, split_submodules = split_with_split_nodes_pass(annotated_model) + + # get split model + model_part0 = list(split_model.children())[0] + model_part1 = list(split_model.children())[1] + + # set rng state and compute output of split model + torch.set_rng_state(cpu_rng_state) + output_part0 = model_part0(data) + sig = inspect.signature(model_part1.forward) + if isinstance(output_part0, torch.Tensor): + output_part1 = model_part1(output_part0) + else: + if len(output_part0) > len(sig.parameters): + output_part0 = output_part0[:len(sig.parameters)] + output_part1 = model_part1(*output_part0) + assert output.equal(output_part1) diff --git a/tests/test_fx/test_pipeline/test_topo/test_topo.py b/tests/test_fx/test_pipeline/test_topo/test_topo.py new file mode 100644 index 0000000000000000000000000000000000000000..16da56250dc3a43fffec4534119f6a1de586a0e7 --- /dev/null +++ b/tests/test_fx/test_pipeline/test_topo/test_topo.py @@ -0,0 +1,49 @@ +import pytest +import torch +import transformers +from topo_utils import MLP, check_topo, split_model_and_get_DAG + +BATCH_SIZE = 1 +SEQ_LENGHT = 16 + + +@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') +def test_opt(): + MODEL_LIST = [ + MLP, + transformers.OPTModel, + ] + + CONFIGS = [ + { + 'dim': 10, + 'layers': 12 + }, + transformers.OPTConfig(vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4), + ] + + def data_gen_MLP(): + x = torch.zeros((16, 10)) + kwargs = dict(x=x) + return kwargs + + def data_gen_OPT(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) + return kwargs + + DATAGEN = [ + data_gen_MLP, + data_gen_OPT, + ] + + for i, model_cls in enumerate(MODEL_LIST): + model = model_cls(config=CONFIGS[i]) + top_mod, topo = split_model_and_get_DAG(model, DATAGEN[i]) + # print(f'{top_mod=}\n----\n{topo=}') + check_topo(top_mod, topo) + + +if __name__ == '__main__': + test_opt() diff --git a/tests/test_fx/test_pipeline/test_topo/topo_utils.py b/tests/test_fx/test_pipeline/test_topo/topo_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..55dd65201acd9ddd06d3becc84326ba97b1a7943 --- /dev/null +++ b/tests/test_fx/test_pipeline/test_topo/topo_utils.py @@ -0,0 +1,92 @@ +import torch +from torch.fx import GraphModule +from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass +from colossalai.fx import ColoTracer +from colossalai.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo +from colossalai.pipeline.middleware.adaptor import get_fx_topology +import random +import numpy as np + +MANUAL_SEED = 0 +random.seed(MANUAL_SEED) +np.random.seed(MANUAL_SEED) +torch.manual_seed(MANUAL_SEED) + +class MLP(torch.nn.Module): + def __init__(self, config={}): + super().__init__() + dim = config['dim'] + layers = config['layers'] + self.layers = torch.nn.ModuleList() + + for _ in range(layers): + self.layers.append(torch.nn.Linear(dim, dim, bias=False)) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + +def split_model_and_get_DAG(model, data_gen): + model.eval() + + # generate input sample + kwargs = data_gen() + + # tracing model + tracer = ColoTracer() + try: + meta_args = {k: v.to('meta') for k, v in kwargs.items()} + graph = tracer.trace(root=model, meta_args=meta_args) + except Exception as e: + raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}") + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + + # apply transform passes + annotated_model = balanced_split_pass(gm, 2) + top_module, split_submodules = split_with_split_nodes_pass(annotated_model) + + topo = get_fx_topology(top_module) + for submodule in split_submodules: + if isinstance(submodule, torch.fx.GraphModule): + setattr(submodule, '_topo', topo) + + return top_module, split_submodules[0]._topo + +def check_input(top_module, input_partition: Partition): + partition_output = input_partition.get_output_vals() + arg_pos = 0 + for node in top_module.graph.nodes: + if node.op == 'placeholder': + cur_checkee = partition_output[arg_pos] + to_partition_and_offset = cur_checkee.get() + assert len(to_partition_and_offset) == len(node.users.keys()) + arg_pos += 1 + + assert arg_pos == len(partition_output) + +def check_submod(top_module, part_id, mid_partition: Partition): + partition_input = mid_partition.get_input_vals() + partition_output = mid_partition.get_output_vals() + + cnt = 1 + cur_node = None + for node in top_module.graph.nodes: + if node.name.startswith('submod'): + cnt += 1 + if cnt == part_id: + cur_node = node + break + + assert len(partition_input) == len(cur_node.args) + assert len(partition_output) == len(cur_node.users) + +def check_topo(top_module, topo: Topo): + input_partition = topo.get_input_partition() + mid_partitions = topo.get_mid_partitions() + + check_input(top_module, input_partition) + for part_id, submod in mid_partitions.items(): + check_submod(top_module, part_id, submod) + \ No newline at end of file diff --git a/tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py b/tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py new file mode 100644 index 0000000000000000000000000000000000000000..5d47be2c7bea01323cf1938cfa98ffe1175271de --- /dev/null +++ b/tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py @@ -0,0 +1,66 @@ +import inspect +import random + +import numpy as np +import pytest +import torch +import torchvision +import torchvision.models as tm +from packaging import version +from torch.fx import GraphModule + +from colossalai.fx import ColoTracer +from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass + +MANUAL_SEED = 0 +random.seed(MANUAL_SEED) +np.random.seed(MANUAL_SEED) +torch.manual_seed(MANUAL_SEED) +torch.backends.cudnn.deterministic = True + + +@pytest.mark.skip('balance split v2 is not ready') +def test_torchvision_models(): + MODEL_LIST = [ + tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2, + tm.regnet_x_16gf, tm.efficientnet_b0, tm.mnasnet0_5 + ] + + if version.parse(torchvision.__version__) >= version.parse('0.12.0'): + MODEL_LIST.extend([tm.vit_b_16, tm.convnext_small]) + + tracer = ColoTracer() + data = torch.rand(2, 3, 224, 224) + + for model_cls in MODEL_LIST: + model = model_cls() + model.eval() + cpu_rng_state = torch.get_rng_state() + output = model(data) + graph = tracer.trace(root=model) + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + + # apply transform passes + annotated_model = balanced_split_pass(gm, 2) + split_model, split_submodules = split_with_split_nodes_pass(annotated_model) + + # get split model + model_part0 = list(split_model.children())[0] + model_part1 = list(split_model.children())[1] + + # set rng state and compute output of split model + torch.set_rng_state(cpu_rng_state) + output_part0 = model_part0(data) + sig = inspect.signature(model_part1.forward) + if isinstance(output_part0, torch.Tensor): + output_part1 = model_part1(output_part0) + else: + if len(output_part0) > len(sig.parameters): + output_part0 = output_part0[:len(sig.parameters)] + output_part1 = model_part1(*output_part0) + assert output.equal(output_part1) + + +if __name__ == '__main__': + test_torchvision_models() diff --git a/tests/test_fx/test_pipeline_passes.py b/tests/test_fx/test_pipeline_passes.py new file mode 100644 index 0000000000000000000000000000000000000000..1078dac9db7cb4c2e00191a6fc33121f397f89e6 --- /dev/null +++ b/tests/test_fx/test_pipeline_passes.py @@ -0,0 +1,57 @@ +import pytest +import torch +import torch.nn as nn +from torch.fx import symbolic_trace + +import colossalai +import colossalai.nn as col_nn +from colossalai.fx.passes.adding_split_node_pass import ( + balanced_split_pass, + balanced_split_pass_v2, + split_with_split_nodes_pass, + uniform_split_pass, +) +from colossalai.testing import clear_cache_before_run + +MODEL_DIM = 16 +BATCH_SIZE = 8 +PIPELINE_SIZE = 2 + + +class MLP(torch.nn.Module): + + def __init__(self, dim: int): + super().__init__() + self.linear1 = torch.nn.Linear(dim, dim) + self.linear2 = torch.nn.Linear(dim, dim) + self.linear3 = torch.nn.Linear(dim, dim) + self.linear4 = torch.nn.Linear(dim, dim) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + x = self.linear4(x) + return x + + +def pipeline_pass_test_helper(model, data, pass_func): + origin_output = model(data) + symbolic_traced = symbolic_trace(model) + annotated_model = pass_func(symbolic_traced, PIPELINE_SIZE) + split_model, split_submodules = split_with_split_nodes_pass(annotated_model) + output = split_model(data) + assert output.equal(origin_output) + + +@clear_cache_before_run() +def test_pipeline_passes(): + model = MLP(MODEL_DIM) + data = torch.rand(BATCH_SIZE, MODEL_DIM) + pipeline_pass_test_helper(model, data, balanced_split_pass) + pipeline_pass_test_helper(model, data, balanced_split_pass_v2) + pipeline_pass_test_helper(model, data, uniform_split_pass) + + +if __name__ == '__main__': + test_pipeline_passes() diff --git a/tests/test_fx/test_profiler/gpt_utils.py b/tests/test_fx/test_profiler/gpt_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..aec32268484fbcc3b362207a02b8a022bf660642 --- /dev/null +++ b/tests/test_fx/test_profiler/gpt_utils.py @@ -0,0 +1,50 @@ +import torch +import torch.nn as nn +from transformers import GPT2Config, GPT2LMHeadModel + + +class GPTLMModel(nn.Module): + + def __init__(self, + hidden_size=768, + num_layers=12, + num_attention_heads=12, + max_seq_len=1024, + vocab_size=50257, + checkpoint=False): + super().__init__() + self.checkpoint = checkpoint + self.model = GPT2LMHeadModel( + GPT2Config(n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size)) + if checkpoint: + self.model.gradient_checkpointing_enable() + + def forward(self, input_ids, attention_mask): + # Only return lm_logits + return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0] + + +class GPTLMLoss(nn.Module): + + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, logits, labels): + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + +def gpt2_medium(checkpoint=False): + return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint) + + +def gpt2_xl(checkpoint=False): + return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32, checkpoint=checkpoint) diff --git a/tests/test_fx/test_profiler/test_profiler_meta_info_prop.py b/tests/test_fx/test_profiler/test_profiler_meta_info_prop.py new file mode 100644 index 0000000000000000000000000000000000000000..b5a6bbe8bf181c92f5fb8bf18779925889bc35fb --- /dev/null +++ b/tests/test_fx/test_profiler/test_profiler_meta_info_prop.py @@ -0,0 +1,184 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.fx +import torchvision.models as tm +from gpt_utils import gpt2_medium, gpt2_xl +from torch.fx import symbolic_trace + +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta, parameter_size +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.testing import clear_cache_before_run, run_on_environment_flag + +if is_compatible_with_meta(): + from colossalai.fx.profiler import MetaTensor + +TM_BATCH_SIZE = 64 +GPT_BATCH_SIZE = 8 +NUM_STEPS = 5 + + +def extract_forward_mem(gm: torch.fx.GraphModule): + node_size = 0 + param_size = 0 + for node in gm.graph.nodes: + node_size += calculate_fwd_tmp(node) + node_size += calculate_fwd_out(node) + param_size = parameter_size(gm) + return (node_size + param_size) / 1024**2, param_size / 1024**2 + + +def extract_forward_flops(gm: torch.fx.GraphModule): + fwd_flop = 0 + bwd_flop = 0 + for node in gm.graph.nodes: + fwd_flop += node.meta.get('fwd_flop', 0) + bwd_flop += node.meta.get('bwd_flop', 0) + return fwd_flop, bwd_flop + + +def gen_tm_data(batch_size: int, shape: Tuple[int, int, int], device='cuda'): + data = torch.rand(batch_size, *shape, device=device) + label = torch.empty(batch_size, dtype=torch.long, device=device).random_(1000) + return data, label + + +def gen_gpt_data(batch_size, seq_len, vocab_size, device='cpu'): + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) + attention_mask = torch.ones_like(input_ids, device=device) + return input_ids, attention_mask + + +def run_tm_forward(gm: torch.fx.GraphModule): + torch.cuda.reset_peak_memory_stats() + forward_mem = -torch.cuda.memory_allocated(device="cuda:0") / 1024**2 + param_mem = -torch.cuda.memory_allocated(device="cuda:0") / 1024**2 + gm.cuda() + param_mem += torch.cuda.memory_allocated(device="cuda:0") / 1024**2 + gm.train() + for n in range(NUM_STEPS): + torch.cuda.reset_peak_memory_stats() + data, _ = gen_tm_data(TM_BATCH_SIZE, (3, 224, 224)) + + # If we need to dive deep into the memory usage by + # inspecting `saved_tensor_hooks` + + # ===================================================== + # fwd_mem = 0 + # cache = set() + # def pack(x): + # if isinstance(x, torch.Tensor): + # nonlocal fwd_mem, cache + # if x.data_ptr() not in cache: + # fwd_mem += activation_size(x) + # cache.add(x.data_ptr()) + # return x + # def unpack(x): + # return x + # + # with torch.autograd.graph.saved_tensors_hooks(pack, unpack): + # output = gm(data) + # print(f'Memory estimation by saved_tensor_hooks: {fwd_mem / 1024**2}') + # ===================================================== + + output = gm(data) + forward_mem += torch.cuda.memory_allocated(device="cuda:0") / 1024**2 / NUM_STEPS + del output + return forward_mem, param_mem + + +def run_gpt_forward(gm: torch.fx.GraphModule): + torch.cuda.reset_peak_memory_stats() + forward_mem = -torch.cuda.memory_allocated(device="cuda:0") / 1024**2 + param_mem = -torch.cuda.memory_allocated(device="cuda:0") / 1024**2 + gm.cuda() + param_mem += torch.cuda.memory_allocated(device="cuda:0") / 1024**2 + for n in range(NUM_STEPS): + torch.cuda.reset_peak_memory_stats() + data, mask = gen_gpt_data(GPT_BATCH_SIZE, 1024, 50257, device='cuda:0') + + # If we need to dive deep into the memory usage by + # inspecting `saved_tensor_hooks` + + # ===================================================== + # fwd_mem = 0 + # cache = set() + # def pack(x): + # if isinstance(x, torch.Tensor): + # nonlocal fwd_mem, cache + # if x.data_ptr() not in cache: + # fwd_mem += activation_size(x) + # cache.add(x.data_ptr()) + # return x + # def unpack(x): + # return x + # + # with torch.autograd.graph.saved_tensors_hooks(pack, unpack): + # output = gm(data, mask) + # print(f'Memory estimation by saved_tensor_hooks: {fwd_mem / 1024**2}') + # ===================================================== + + output = gm(data, mask) + forward_mem += torch.cuda.memory_allocated(device="cuda:0") / 1024**2 / NUM_STEPS + del output + return forward_mem, param_mem + + +@run_on_environment_flag(name='FX_PROFILER') +@clear_cache_before_run() +def test_meta_info_prop(): + for m in [ + tm.alexnet, tm.resnet18, tm.resnet34, tm.resnet50, tm.resnet101, tm.resnet152, tm.densenet121, + tm.densenet161, tm.densenet169, tm.densenet201, tm.convnext_tiny, tm.convnext_small, tm.convnext_base, + tm.convnext_large, tm.wide_resnet50_2, tm.wide_resnet101_2, tm.regnet_x_16gf, tm.mnasnet0_5, + tm.efficientnet_b0, tm.shufflenet_v2_x0_5, tm.shufflenet_v2_x1_0, tm.shufflenet_v2_x1_5, + tm.shufflenet_v2_x2_0, tm.mobilenet_v2, tm.mobilenet_v3_small, tm.mobilenet_v3_large, tm.resnext50_32x4d, + tm.resnext101_32x8d, tm.resnext101_64x4d, tm.vit_b_16, tm.vit_b_32, tm.vit_h_14, tm.vit_l_16, tm.vit_l_32, + tm.vgg11, tm.vgg11_bn, tm.vgg13, tm.vgg13_bn, tm.vgg16, tm.vgg16_bn, tm.vgg19, tm.vgg19_bn + ]: + model = m().cuda() + model.train() + data = MetaTensor(torch.rand(int(TM_BATCH_SIZE), 3, 224, 224, device='meta'), fake_device='cuda:0') + gm = symbolic_trace(model) + interp = MetaInfoProp(gm) + interp.propagate(data) + gm.cpu() + + meta_forward_mem, meta_param_mem = extract_forward_mem(gm) + fwd_flop, bwd_flop = extract_forward_flops(gm) + concrete_forward_mem, concrete_param_mem = run_tm_forward(gm) + + print( + f'|{m.__name__}|{meta_forward_mem:.3f} MB|{meta_param_mem:.3f} MB|{concrete_forward_mem:.3f} MB|{concrete_param_mem:.3f} MB|fwd_flop={fwd_flop / 1e9:.3f}GFLOPs|bwd_flop={bwd_flop / 1e9:.3f}GFLOPs|' + ) + del model, gm + + +@run_on_environment_flag(name='FX_PROFILER') +@clear_cache_before_run() +def test_gpt_meta_info_prop(): + for m in [gpt2_medium]: + model = m().cuda() + model.train() + data, mask = gen_gpt_data(GPT_BATCH_SIZE, 1024, 50257, device='meta') + graph = ColoTracer().trace(model, meta_args={'input_ids': data, 'attention_mask': mask}) + gm = torch.fx.GraphModule(model, graph) + interp = MetaInfoProp(gm) + interp.propagate(MetaTensor(data, fake_device='cuda:0'), MetaTensor(mask, fake_device='cuda:0')) + model.cpu() + + fwd_flop, bwd_flop = extract_forward_flops(gm) + + concrete_forward_mem, concrete_param_mem = run_gpt_forward(gm) + meta_forward_mem, meta_param_mem = extract_forward_mem(gm) + + print( + f'|{m.__name__}|{meta_forward_mem:.3f} MB|{meta_param_mem:.3f} MB|{concrete_forward_mem:.3f} MB|{concrete_param_mem:.3f} MB|fwd_flop={fwd_flop / 1e9:.3f}GFLOPs|bwd_flop={bwd_flop / 1e9:.3f}GFLOPs|' + ) + del model, gm + + +if __name__ == '__main__': + test_meta_info_prop() + test_gpt_meta_info_prop() diff --git a/tests/test_fx/test_tracer/test_activation_checkpoint_annotation.py b/tests/test_fx/test_tracer/test_activation_checkpoint_annotation.py new file mode 100644 index 0000000000000000000000000000000000000000..632ab8c097505d6e805219527c7d51ecc6ba69d6 --- /dev/null +++ b/tests/test_fx/test_tracer/test_activation_checkpoint_annotation.py @@ -0,0 +1,65 @@ +import torch +import torch.nn as nn +from torch.fx import GraphModule +from torch.utils.checkpoint import checkpoint + +from colossalai.fx import ColoTracer +from colossalai.testing import clear_cache_before_run + + +class MLP(torch.nn.Module): + + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(4, 4) + self.linear2 = torch.nn.Linear(4, 4) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +# Simple module for demonstration +class MyModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.mlp_1 = MLP() + self.mlp_2 = MLP() + self.output = torch.nn.Linear(4, 4) + + def forward(self, x): + x = checkpoint(self.mlp_1, x) + x = checkpoint(self.mlp_2, x) + x = self.output(x) + return x + + +@clear_cache_before_run() +def test_activation_checkpoint_annotation(): + module = MyModule() + + # test tracing with activation checkpoint + tracer = ColoTracer(trace_act_ckpt=True) + graph = tracer.trace(module) + gm = GraphModule(module, graph) + + for node in gm.graph.nodes: + if node.name in ['mlp_1_linear1', 'mlp_1_linear2']: + assert node.meta.get('activation_checkpoint', -1) == 0 + + for node in gm.graph.nodes: + if node.name in ['mlp_2_linear1', 'mlp_2_linear2']: + assert node.meta.get('activation_checkpoint', -1) == 1 + + tracer = ColoTracer(trace_act_ckpt=False) + graph = tracer.trace(module) + gm = GraphModule(module, graph) + + for node in gm.graph.nodes: + assert not hasattr(node, 'activation_checkpoint') + + +if __name__ == '__main__': + test_activation_checkpoint_annotation() diff --git a/tests/test_fx/test_tracer/test_bias_addition_module.py b/tests/test_fx/test_tracer/test_bias_addition_module.py new file mode 100644 index 0000000000000000000000000000000000000000..2f88d8c784e8144e557747bc35e86b74c96e7478 --- /dev/null +++ b/tests/test_fx/test_tracer/test_bias_addition_module.py @@ -0,0 +1,117 @@ +import torch + +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.testing import clear_cache_before_run + + +class LinearModel(torch.nn.Module): + + def __init__(self, in_features, out_features): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features) + + def forward(self, x): + x = self.linear(x) + x = x * 2 + + return x + + +class ConvModel(torch.nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, bias=True): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + bias=bias) + + def forward(self, x): + x = self.conv(x) + x = x * 2 + + return x + + +@clear_cache_before_run() +def test_linear_module(): + model = LinearModel(3, 6) + tracer = ColoTracer() + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %linear_weight : [#users=1] = get_attr[target=linear.weight] + # %linear_bias : [#users=1] = get_attr[target=linear.bias] + # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%x, %linear_weight), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%linear, %linear_bias), kwargs = {}) + # %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {}) + # return mul + graph = tracer.trace(root=model, meta_args={'x': torch.rand(3, 3).to('meta')}) + # def forward(self, x : torch.Tensor): + # linear_weight = self.linear.weight + # linear_bias = self.linear.bias + # linear = torch._C._nn.linear(x, linear_weight); x = linear_weight = None + # add = linear + linear_bias; linear = linear_bias = None + # mul = add * 2; add = None + # return mul + gm = ColoGraphModule(model, graph) + gm.recompile() + node_list = list(graph.nodes) + for node in node_list: + if node.op == 'output': + continue + assert hasattr(node, '_meta_data') + weight_node = node_list[1] + bias_node = node_list[2] + linear_node = node_list[3] + add_node = node_list[4] + assert weight_node._meta_data.shape == (6, 3) + assert bias_node._meta_data.shape == (6,) + assert linear_node._meta_data.shape == (3, 6) + assert add_node._meta_data.shape == (3, 6) + + +@clear_cache_before_run() +def test_conv_module(): + model = ConvModel(3, 6, 2) + tracer = ColoTracer() + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %conv_weight : [#users=1] = get_attr[target=conv.weight] + # %conv_bias : [#users=1] = get_attr[target=conv.bias] + # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%x, %conv_weight), kwargs = {}) + # %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {}) + # %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {}) + # return mul + graph = tracer.trace(root=model, meta_args={'x': torch.rand(4, 3, 64, 64).to('meta')}) + # def forward(self, x : torch.Tensor): + # conv_weight = self.conv.weight + # conv_bias = self.conv.bias + # conv2d = torch.conv2d(x, conv_weight); x = conv_weight = None + # view = conv_bias.view([1, -1, 1, 1]); conv_bias = None + # add = conv2d + view; conv2d = view = None + # mul = add * 2; add = None + # return mul + gm = ColoGraphModule(model, graph) + + gm.recompile() + node_list = list(graph.nodes) + for node in node_list: + if node.op == 'output': + continue + assert hasattr(node, '_meta_data') + weight_node = node_list[1] + bias_node = node_list[2] + conv_node = node_list[3] + view_node = node_list[4] + add_node = node_list[5] + assert weight_node._meta_data.shape == (6, 3, 2, 2) + assert bias_node._meta_data.shape == (6,) + assert conv_node._meta_data.shape == (4, 6, 63, 63) + assert view_node._meta_data.shape == (6, 1, 1) + assert add_node._meta_data.shape == (4, 6, 63, 63) + + +if __name__ == '__main__': + test_linear_module() + test_conv_module() diff --git a/tests/test_fx/test_tracer/test_control_flow.py b/tests/test_fx/test_tracer/test_control_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..820729dadb3ed21aad6f0230e3094f839a14d50f --- /dev/null +++ b/tests/test_fx/test_tracer/test_control_flow.py @@ -0,0 +1,60 @@ +import torch +import torch.nn as nn +from torch.fx import GraphModule + +from colossalai.fx import ColoTracer as Tracer +from colossalai.testing import clear_cache_before_run + + +class ControlFlowModel(nn.Module): + + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(10, 10) + self.linear2 = nn.Linear(10, 10) + + def forward(self, x, y): + x1 = self.linear1(x) + y1 = self.linear2(y) + + if x1.dim() == 2: + return x1 + y1 + else: + return x1 - y1 + + +@clear_cache_before_run() +def test_control_flow(): + model = ControlFlowModel() + tracer = Tracer() + graph_branch_true = tracer.trace(model, + meta_args={ + 'x': torch.rand(4, 10, device='meta'), + 'y': torch.rand(4, 10, device='meta') + }) + graph_branch_false = tracer.trace(model, + meta_args={ + 'x': torch.rand(10, device='meta'), + 'y': torch.rand(4, 10, device='meta') + }) + + gm_branch_true = GraphModule(model, graph_branch_true, model.__class__.__name__) + gm_branch_false = GraphModule(model, graph_branch_false, model.__class__.__name__) + gm_branch_true.recompile() + gm_branch_false.recompile() + + # test the true branch + x = torch.rand(4, 10) + y = torch.rand(4, 10) + assert torch.all(model(x, y) == gm_branch_true(x, y)) + assert torch.all(gm_branch_false(x, y) != gm_branch_true(x, y)) + + # test the true branch + x = torch.rand(10) + y = torch.rand(4, 10) + assert torch.all(model(x, y) == gm_branch_false(x, y)) + assert torch.all(gm_branch_false(x, y) != gm_branch_true(x, y)) + + +if __name__ == '__main__': + test_control_flow() diff --git a/tests/test_fx/test_tracer/test_functional_conv.py b/tests/test_fx/test_tracer/test_functional_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..a552e905223d5ce2b1c9e7208242edb565b7e882 --- /dev/null +++ b/tests/test_fx/test_tracer/test_functional_conv.py @@ -0,0 +1,51 @@ +import torch +from torch.nn import functional as F + +from colossalai.fx.tracer.meta_patch import patched_function +from colossalai.testing import clear_cache_before_run + + +@clear_cache_before_run() +def test_conv(): + # test F.conv_1d + data_1d = torch.rand(3, 16, 10) + weight_1d = torch.rand(3, 16, 3) + out_1d = F.conv1d(data_1d, weight_1d) + patched_out_1d = patched_function.torch_nn_functional_conv1d(data_1d, weight_1d) + assert out_1d.shape == patched_out_1d.shape + + # test F.conv_transpose1d + weight_1d = torch.transpose(weight_1d, 0, 1) + out_transpose_1d = F.conv_transpose1d(data_1d, weight_1d) + patched_out_transpose_1d = patched_function.torch_nn_functional_convtranspose1d(data_1d, weight_1d) + assert out_transpose_1d.shape == patched_out_transpose_1d.shape + + # test F.conv2d + data_2d = torch.rand(3, 16, 10, 10) + weight_2d = torch.rand(3, 16, 3, 3) + out_2d = F.conv2d(data_2d, weight_2d) + patched_out_2d = patched_function.torch_nn_functional_conv2d(data_2d, weight_2d) + assert out_2d.shape == patched_out_2d.shape + + # test F.conv_transpose2d + weight_2d = torch.transpose(weight_2d, 0, 1) + out_transpose_2d = F.conv_transpose2d(data_2d, weight_2d) + patched_out_transpose_2d = patched_function.torch_nn_functional_convtranspose2d(data_2d, weight_2d) + assert out_transpose_2d.shape == patched_out_transpose_2d.shape + + # test F.conv3d + data_3d = torch.rand(3, 16, 10, 10, 10) + weight_3d = torch.rand(3, 16, 3, 3, 3) + out_3d = F.conv3d(data_3d, weight_3d) + patched_out_3d = patched_function.torch_nn_functional_conv3d(data_3d, weight_3d) + assert out_3d.shape == patched_out_3d.shape + + # test F.conv_transpose3d + weight_3d = torch.transpose(weight_3d, 0, 1) + out_transpose_3d = F.conv_transpose3d(data_3d, weight_3d) + patched_out_transpose_3d = patched_function.torch_nn_functional_convtranspose3d(data_3d, weight_3d) + assert out_transpose_3d.shape == patched_out_transpose_3d.shape + + +if __name__ == '__main__': + test_conv() diff --git a/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py b/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7a4bf131ae36875aa9ce69ba00a08e75741bf1d2 --- /dev/null +++ b/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py @@ -0,0 +1,31 @@ +import torch +from numpy import isin +from torch.fx import GraphModule +from torch.utils._pytree import tree_flatten + +# from colossalai.fx import symbolic_trace +from colossalai._analyzer.fx import symbolic_trace + + +def trace_model_and_compare_output(model, data_gen): + # must turn on eval mode to ensure the output is consistent + model.eval() + + try: + kwargs = data_gen() + meta_args = {k: v.to('meta') for k, v in kwargs.items()} + gm = symbolic_trace(model, meta_args=meta_args) + except Exception as e: + raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}") + + # run forward + inputs = data_gen() + non_fx_out = model(**inputs) + fx_out = gm(**inputs) + + # check output + for k in non_fx_out.keys(): + if torch.is_tensor(fx_out[k]): + assert torch.equal( + fx_out[k], non_fx_out[k] + ), f'{model.__class__.__name__} has incorrect output {k}, expect {non_fx_out[k]}, but got {fx_out[k]}' diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py new file mode 100644 index 0000000000000000000000000000000000000000..f4d681221191ca2084f184ec0bacea0749356c74 --- /dev/null +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py @@ -0,0 +1,24 @@ +import pytest +import torch +from hf_tracer_utils import trace_model_and_compare_output +from packaging import version + +from colossalai.testing import clear_cache_before_run +from tests.kit.model_zoo import model_zoo + +BATCH_SIZE = 2 +SEQ_LENGTH = 16 + + +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@clear_cache_before_run() +def test_albert(): + sub_registry = model_zoo.get_sub_registry('transformers_albert') + + for name, (model_fn, data_gen_fn, _, _) in sub_registry.items(): + model = model_fn() + trace_model_and_compare_output(model, data_gen_fn) + + +if __name__ == '__main__': + test_albert() diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py new file mode 100644 index 0000000000000000000000000000000000000000..a833bb30c056b772dce2b4cbb158be37d254ed0e --- /dev/null +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py @@ -0,0 +1,21 @@ +import pytest +import torch +from hf_tracer_utils import trace_model_and_compare_output +from packaging import version + +from colossalai.testing import clear_cache_before_run +from tests.kit.model_zoo import model_zoo + + +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@clear_cache_before_run() +def test_bert(): + sub_registry = model_zoo.get_sub_registry('transformers_bert') + + for name, (model_fn, data_gen_fn, _, _) in sub_registry.items(): + model = model_fn() + trace_model_and_compare_output(model, data_gen_fn) + + +if __name__ == '__main__': + test_bert() diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py new file mode 100644 index 0000000000000000000000000000000000000000..0cbea82e083a80159bd4603b6995e98ca86dc20a --- /dev/null +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py @@ -0,0 +1,72 @@ +import pytest +import torch + +from colossalai.fx import symbolic_trace +from colossalai.testing import clear_cache_before_run +from colossalai.testing.random import seed_all +from tests.kit.model_zoo import model_zoo + + +def assert_dict(da, db, assert_fn): + assert len(da) == len(db) + for k, v in da.items(): + assert k in db + if not torch.is_tensor(v): + continue + u = db.get(k) + assert_fn(u, v) + + +def trace_and_compare(model_cls, data, output_fn): + model = model_cls() + model.eval() + + concrete_args = {k: v for k, v in data.items() if not torch.is_tensor(v)} + meta_args = {k: v.to('meta') for k, v in data.items() if torch.is_tensor(v)} + gm = symbolic_trace(model, concrete_args=concrete_args, meta_args=meta_args) + + # run forward + with torch.no_grad(): + fx_out = gm(**data) + non_fx_out = model(**data) + + # compare output + transformed_fx_out = output_fn(fx_out) + transformed_non_fx_out = output_fn(non_fx_out) + + def assert_fn(ta, tb): + assert torch.equal(ta, tb) + + assert_dict(transformed_fx_out, transformed_non_fx_out, assert_fn) + + +@pytest.mark.skip(reason='cannot pass this test yet') +@clear_cache_before_run() +def test_diffusers(): + seed_all(9091, cuda_deterministic=True) + + sub_model_zoo = model_zoo.get_sub_registry('diffusers') + + for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items(): + data = data_gen_fn() + trace_and_compare(model_fn, data, output_transform_fn) + torch.cuda.synchronize() + print(f"{name:40s} √") + + +@clear_cache_before_run() +def test_torch_diffusers(): + seed_all(65535, cuda_deterministic=True) + + sub_model_zoo = model_zoo.get_sub_registry('diffusers') + + for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items(): + data = data_gen_fn() + model = model_fn() + output = model(**data) + torch.cuda.synchronize() + print(f"{name:40s} √") + + +if __name__ == "__main__": + test_torch_diffusers() diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py new file mode 100644 index 0000000000000000000000000000000000000000..67107469d8bb898edfb487cc63bdac0cbfe9b024 --- /dev/null +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py @@ -0,0 +1,28 @@ +import pytest +import torch +from hf_tracer_utils import trace_model_and_compare_output +from packaging import version + +from colossalai.testing import clear_cache_before_run +from tests.kit.model_zoo import model_zoo + + +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@clear_cache_before_run() +def test_gpt(): + sub_registry = model_zoo.get_sub_registry('transformers_gpt') + + for name, (model_fn, data_gen_fn, _, _) in sub_registry.items(): + model = model_fn() + + # TODO: support the following models + # 1. GPT2DoubleHeadsModel + # as they are not supported, let's skip them + if model.__class__.__name__ in ['GPT2DoubleHeadsModel']: + continue + + trace_model_and_compare_output(model, data_gen_fn) + + +if __name__ == '__main__': + test_gpt() diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py new file mode 100644 index 0000000000000000000000000000000000000000..369545b03de1b4bf0746d46b9928ed68be9688d4 --- /dev/null +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py @@ -0,0 +1,21 @@ +import pytest +import torch +from hf_tracer_utils import trace_model_and_compare_output +from packaging import version + +from colossalai.testing import clear_cache_before_run +from tests.kit.model_zoo import model_zoo + + +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@clear_cache_before_run() +def test_opt(): + sub_registry = model_zoo.get_sub_registry('transformers_opt') + + for name, (model_fn, data_gen_fn, _, _) in sub_registry.items(): + model = model_fn() + trace_model_and_compare_output(model, data_gen_fn) + + +if __name__ == '__main__': + test_opt() diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py new file mode 100644 index 0000000000000000000000000000000000000000..811cf3b214303ae62569158d1321b1a99ff6d76e --- /dev/null +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py @@ -0,0 +1,21 @@ +import pytest +import torch +from hf_tracer_utils import trace_model_and_compare_output +from packaging import version + +from colossalai.testing import clear_cache_before_run +from tests.kit.model_zoo import model_zoo + + +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@clear_cache_before_run() +def test_t5(): + sub_registry = model_zoo.get_sub_registry('transformers_t5') + + for name, (model_fn, data_gen_fn, _, _) in sub_registry.items(): + model = model_fn() + trace_model_and_compare_output(model, data_gen_fn) + + +if __name__ == '__main__': + test_t5() diff --git a/tests/test_fx/test_tracer/test_patched_module.py b/tests/test_fx/test_tracer/test_patched_module.py new file mode 100644 index 0000000000000000000000000000000000000000..ef778e21801a84bd4e55ca08937db67bffaca5c1 --- /dev/null +++ b/tests/test_fx/test_tracer/test_patched_module.py @@ -0,0 +1,498 @@ +import torch + +from colossalai.fx.tracer.meta_patch import patched_module +from colossalai.testing import clear_cache_before_run + + +def _run(data, module, patch_fn): + try: + if isinstance(data, dict): + output = patch_fn(module, **data) + if isinstance(data, tuple) or isinstance(data, list): + output = patch_fn(module, *data) + else: + output = patch_fn(module, data) + return output + except Exception as e: + return e + + +def _assert_output_shape(data, module, patch_fn, expect_exception, output_shape): + output = _run(data, module, patch_fn) + + if expect_exception: + assert isinstance(output, AssertionError) + else: + assert not isinstance(output, Exception) + if isinstance(output, tuple): + for item, shape in zip(output, output_shape): + assert item.is_meta + assert item.shape == shape + else: + assert output.is_meta + assert output.shape == output_shape + + +@clear_cache_before_run() +def test_linear(): + # test linear patch can produce the meta output with correct shape + data = torch.rand(2, 4, device='meta') + module = torch.nn.Linear(4, 2) + _assert_output_shape(data, module, patched_module.torch_nn_linear, False, torch.Size([2, 2])) + + # test if the linear patch can catch exception when dimension does not match + data = torch.rand(2, 2, device='meta') + _assert_output_shape(data, module, patched_module.torch_nn_linear, True, None) + + +@clear_cache_before_run() +def test_rnn(): + # test rnn patch can produce the meta output with correct shape + data = (torch.randn(5, 3, 10), torch.randn(2, 3, 20)) + module = torch.nn.RNN(10, 20, 2) + output, hn = module(*data) + meta_data = (torch.randn(5, 3, 10).to('meta'), torch.randn(2, 3, 20).to('meta')) + _assert_output_shape(meta_data, module, patched_module.torch_nn_rnn, False, (output.shape, hn.shape)) + + # test if the rnn patch can catch exception when dimension does not match + data = (torch.randn(5, 3, 10), torch.randn(2, 3, 20)) + module = torch.nn.RNN(10, 20, 2) + output, hn = module(*data) + meta_data = (torch.randn(5, 3, 1).to('meta'), torch.randn(2, 3, 20).to('meta')) + _assert_output_shape(meta_data, module, patched_module.torch_nn_rnn, True, None) + + +@clear_cache_before_run() +def test_embedding(): + data = torch.rand(2, 4, device='meta') + + # test layernorm + ln = torch.nn.LayerNorm(4) + _assert_output_shape(data, ln, patched_module.torch_nn_normalize, False, data.shape) + + # test group norm + gn = torch.nn.GroupNorm(4, num_channels=8) + _assert_output_shape(data, gn, patched_module.torch_nn_normalize, False, data.shape) + + # test batch norm 1d + bn1d = torch.nn.BatchNorm1d(4) + data = torch.rand(2, 4, device='meta') + _assert_output_shape(data=data, + module=bn1d, + patch_fn=patched_module.torch_nn_normalize, + expect_exception=False, + output_shape=data.shape) + + data = torch.rand(2, 4, device='meta') + _assert_output_shape(data=data, + module=bn1d, + patch_fn=patched_module.torch_nn_normalize, + expect_exception=False, + output_shape=data.shape) + + data = torch.rand(2, 3, 4, device='meta') + _assert_output_shape(data=data, + module=bn1d, + patch_fn=patched_module.torch_nn_normalize, + expect_exception=False, + output_shape=data.shape) + + data = torch.rand(1, 2, 3, 4, device='meta') + _assert_output_shape(data=data, + module=bn1d, + patch_fn=patched_module.torch_nn_normalize, + expect_exception=True, + output_shape=None) + + # test batch norm 2d + bn2d = torch.nn.BatchNorm2d(4) + + data = torch.rand(1, 2, 3, 4, device='meta') + _assert_output_shape(data=data, + module=bn2d, + patch_fn=patched_module.torch_nn_normalize, + expect_exception=False, + output_shape=data.shape) + + data = torch.rand(2, 3, 4, device='meta') + _assert_output_shape(data=data, + module=bn2d, + patch_fn=patched_module.torch_nn_normalize, + expect_exception=True, + output_shape=None) + + # # test batch size 3d + bn3d = torch.nn.BatchNorm3d(4) + + data = torch.rand(1, 1, 2, 3, 4, device='meta') + _assert_output_shape(data=data, + module=bn3d, + patch_fn=patched_module.torch_nn_normalize, + expect_exception=False, + output_shape=data.shape) + + data = torch.rand(1, 2, 3, 4, device='meta') + _assert_output_shape(data=data, + module=bn3d, + patch_fn=patched_module.torch_nn_normalize, + expect_exception=True, + output_shape=None) + + +@clear_cache_before_run() +def test_conv1d(): + # test conv 1d + data = torch.rand(2, 3, 4) + + conv1d = torch.nn.Conv1d(in_channels=3, out_channels=4, kernel_size=2) + materialized_output = conv1d(data) + meta_data = data.to('meta') + _assert_output_shape(data=meta_data, + module=conv1d, + patch_fn=patched_module.torch_nn_conv1d, + expect_exception=False, + output_shape=materialized_output.shape) + + conv1d = torch.nn.Conv1d(in_channels=3, out_channels=4, kernel_size=2, padding=1) + materialized_output = conv1d(data) + meta_data = data.to('meta') + _assert_output_shape(data=meta_data, + module=conv1d, + patch_fn=patched_module.torch_nn_conv1d, + expect_exception=False, + output_shape=materialized_output.shape) + + conv1d = torch.nn.Conv1d(in_channels=3, + out_channels=4, + kernel_size=2, + padding=1, + dilation=2, + padding_mode='reflect') + materialized_output = conv1d(data) + meta_data = data.to('meta') + _assert_output_shape(data=meta_data, + module=conv1d, + patch_fn=patched_module.torch_nn_conv1d, + expect_exception=False, + output_shape=materialized_output.shape) + + +def test_conv2d(): + # test conv 2d + data = torch.rand(2, 3, 4, 4) + conv2d = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2) + materialized_output = conv2d(data) + _assert_output_shape(data=data, + module=conv2d, + patch_fn=patched_module.torch_nn_conv2d, + expect_exception=False, + output_shape=materialized_output.shape) + + conv2d = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding=1) + materialized_output = conv2d(data) + _assert_output_shape(data=data, + module=conv2d, + patch_fn=patched_module.torch_nn_conv2d, + expect_exception=False, + output_shape=materialized_output.shape) + + conv2d = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2) + materialized_output = conv2d(data) + _assert_output_shape(data=data, + module=conv2d, + patch_fn=patched_module.torch_nn_conv2d, + expect_exception=False, + output_shape=materialized_output.shape) + + conv2d = torch.nn.Conv2d(in_channels=3, + out_channels=4, + kernel_size=2, + padding=1, + dilation=2, + padding_mode='reflect') + materialized_output = conv2d(data) + _assert_output_shape(data=data, + module=conv2d, + patch_fn=patched_module.torch_nn_conv2d, + expect_exception=False, + output_shape=materialized_output.shape) + + +@clear_cache_before_run() +def test_conv3d(): + # test conv 3d + data = torch.rand(2, 3, 4, 4, 4) + conv3d = torch.nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2) + materialized_output = conv3d(data) + _assert_output_shape(data=data, + module=conv3d, + patch_fn=patched_module.torch_nn_conv3d, + expect_exception=False, + output_shape=materialized_output.shape) + + conv3d = torch.nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2, padding=1) + materialized_output = conv3d(data) + _assert_output_shape(data=data, + module=conv3d, + patch_fn=patched_module.torch_nn_conv3d, + expect_exception=False, + output_shape=materialized_output.shape) + + conv3d = torch.nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2) + materialized_output = conv3d(data) + _assert_output_shape(data=data, + module=conv3d, + patch_fn=patched_module.torch_nn_conv3d, + expect_exception=False, + output_shape=materialized_output.shape) + + conv3d = torch.nn.Conv3d(in_channels=3, + out_channels=4, + kernel_size=2, + padding=1, + dilation=2, + padding_mode='reflect') + materialized_output = conv3d(data) + _assert_output_shape(data=data, + module=conv3d, + patch_fn=patched_module.torch_nn_conv3d, + expect_exception=False, + output_shape=materialized_output.shape) + + +@clear_cache_before_run() +def test_conv_transpose1d(): + # test conv transpose1d + data = torch.rand(2, 3, 4) + + convtrans1d = torch.nn.ConvTranspose1d(in_channels=3, out_channels=4, kernel_size=2) + materialized_output = convtrans1d(data) + meta_data = data.to('meta') + _assert_output_shape(data=meta_data, + module=convtrans1d, + patch_fn=patched_module.torch_nn_convtranspose1d, + expect_exception=False, + output_shape=materialized_output.shape) + + convtrans1d = torch.nn.ConvTranspose1d(in_channels=3, out_channels=4, kernel_size=2, padding=1) + materialized_output = convtrans1d(data) + meta_data = data.to('meta') + _assert_output_shape(data=meta_data, + module=convtrans1d, + patch_fn=patched_module.torch_nn_convtranspose1d, + expect_exception=False, + output_shape=materialized_output.shape) + + +@clear_cache_before_run() +def test_conv_transpose2d(): + # test conv transpose2d + data = torch.rand(2, 3, 4, 4) + + convtrans2d = torch.nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2) + materialized_output = convtrans2d(data) + meta_data = data.to('meta') + _assert_output_shape(data=meta_data, + module=convtrans2d, + patch_fn=patched_module.torch_nn_convtranspose2d, + expect_exception=False, + output_shape=materialized_output.shape) + + convtrans2d = torch.nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=1) + materialized_output = convtrans2d(data) + meta_data = data.to('meta') + _assert_output_shape(data=meta_data, + module=convtrans2d, + patch_fn=patched_module.torch_nn_convtranspose2d, + expect_exception=False, + output_shape=materialized_output.shape) + + +@clear_cache_before_run() +def test_conv_transpose3d(): + # test conv transpose2d + data = torch.rand(2, 3, 4, 4, 4) + + convtrans3d = torch.nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2) + materialized_output = convtrans3d(data) + meta_data = data.to('meta') + _assert_output_shape(data=meta_data, + module=convtrans3d, + patch_fn=patched_module.torch_nn_convtranspose3d, + expect_exception=False, + output_shape=materialized_output.shape) + + convtrans3d = torch.nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2, padding=1) + materialized_output = convtrans3d(data) + meta_data = data.to('meta') + _assert_output_shape(data=meta_data, + module=convtrans3d, + patch_fn=patched_module.torch_nn_convtranspose3d, + expect_exception=False, + output_shape=materialized_output.shape) + + +@clear_cache_before_run() +def test_pool1d(): + combinations = [[torch.nn.MaxPool1d, patched_module.torch_nn_maxpool1d], + [torch.nn.AvgPool1d, patched_module.torch_nn_avgpool1d]] + + for (layer_cls, patch_func) in combinations: + pooler = layer_cls(kernel_size=3) + + data = torch.rand(2, 3, 4) + materialized_output = pooler(data) + _assert_output_shape(data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=materialized_output.shape) + + data = torch.rand(2, 4) + materialized_output = pooler(data) + _assert_output_shape(data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=materialized_output.shape) + + data = torch.rand(2, 3, 4, 4) + _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) + + +@clear_cache_before_run() +def test_pool2d(): + combinations = [[torch.nn.MaxPool2d, patched_module.torch_nn_maxpool2d], + [torch.nn.AvgPool2d, patched_module.torch_nn_avgpool2d]] + + for (layer_cls, patch_func) in combinations: + pooler = layer_cls(kernel_size=3) + + # test max pool 3d + data = torch.rand(2, 3, 4, 4) + materialized_output = pooler(data) + _assert_output_shape(data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=materialized_output.shape) + + # test max pool 3d + data = torch.rand(2, 4, 4) + materialized_output = pooler(data) + _assert_output_shape(data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=materialized_output.shape) + + # test max pool 3d + data = torch.rand(2, 3, 4, 4, 4) + _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) + + +@clear_cache_before_run() +def test_pool3d(): + combinations = [[torch.nn.MaxPool3d, patched_module.torch_nn_maxpool3d], + [torch.nn.AvgPool3d, patched_module.torch_nn_avgpool3d]] + + for (layer_cls, patch_func) in combinations: + pooler = layer_cls(kernel_size=3) + + # test max pool 3d + data = torch.rand(2, 3, 4, 4, 4) + materialized_output = pooler(data) + _assert_output_shape(data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=materialized_output.shape) + + # test max pool 3d + data = torch.rand(2, 4, 4, 4) + materialized_output = pooler(data) + _assert_output_shape(data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=materialized_output.shape) + + # test max pool 3d + data = torch.rand(2, 3, 4) + _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) + + +# adapative pooling is different from other pooling, so test it individually +@clear_cache_before_run() +def test_adaptive_pooling_1d(): + pooler = torch.nn.AdaptiveAvgPool1d(output_size=3) + patch_func = patched_module.torch_nn_adapative_pooling_1d + + data = torch.rand(3, 4) + output = pooler(data) + _assert_output_shape(data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=output.shape) + + data = torch.rand(2, 3, 4) + output = pooler(data) + _assert_output_shape(data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=output.shape) + + data = torch.rand(2, 3, 4, 5) + _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) + + +@clear_cache_before_run() +def test_adaptive_pooling_2d(): + pooler = torch.nn.AdaptiveAvgPool2d(output_size=3) + patch_func = patched_module.torch_nn_adapative_pooling_2d + + data = torch.rand(3, 4) + _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) + + data = torch.rand(2, 3, 4) + output = pooler(data) + _assert_output_shape(data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=output.shape) + + data = torch.rand(2, 3, 4, 5) + output = pooler(data) + _assert_output_shape(data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=output.shape) + + +@clear_cache_before_run() +def test_adaptive_pooling_3d(): + pooler = torch.nn.AdaptiveAvgPool3d(output_size=3) + patch_func = patched_module.torch_nn_adapative_pooling_3d + + data = torch.rand(3, 4, 5) + _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) + + data = torch.rand(2, 3, 4, 5) + output = pooler(data) + _assert_output_shape(data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=output.shape) + + data = torch.rand(2, 3, 4, 5, 6) + output = pooler(data) + _assert_output_shape(data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=output.shape) diff --git a/tests/test_fx/test_tracer/test_patched_op.py b/tests/test_fx/test_tracer/test_patched_op.py new file mode 100644 index 0000000000000000000000000000000000000000..e0c5f560c49e0d4d0807fb0b42e9ede960704b7d --- /dev/null +++ b/tests/test_fx/test_tracer/test_patched_op.py @@ -0,0 +1,87 @@ +from functools import partial + +import torch + +from colossalai.fx.tracer.meta_patch import patched_function +from colossalai.testing import clear_cache_before_run + + +def _run(data, patch_fn): + try: + output = patch_fn(data) + return output + except Exception as e: + return e + + +def _assert_output_shape(data, patch_fn, expect_exception, output_shape): + output = _run(data, patch_fn) + + if expect_exception: + assert isinstance(output, AssertionError) + else: + assert not isinstance(output, Exception) + assert output.is_meta + assert output.shape == output_shape + + +@clear_cache_before_run() +def test_repeat_interleave(): + patch_fn = patched_function.torch_repeat_interleave + + # examples from https://pytorch.org/docs/stable/generated/torch.repeat_interleave.html + data = torch.tensor([1, 2, 3]) + materialized_output = torch.repeat_interleave(data, repeats=2) + repeat_interleave = partial(patch_fn, repeats=2) + meta_data = data.to('meta') + _assert_output_shape(data=meta_data, + patch_fn=repeat_interleave, + expect_exception=False, + output_shape=materialized_output.shape) + + data = torch.tensor([[1, 2], [3, 4]]) + materialized_output = torch.repeat_interleave(data, repeats=3, dim=1) + repeat_interleave = partial(patch_fn, repeats=3, dim=1) + meta_data = data.to('meta') + _assert_output_shape(data=meta_data, + patch_fn=repeat_interleave, + expect_exception=False, + output_shape=materialized_output.shape) + + data = torch.tensor([[1, 2], [3, 4]]) + materialized_output = torch.repeat_interleave(data, repeats=torch.tensor([1, 2]), dim=-1) + repeat_interleave = partial(patch_fn, repeats=torch.tensor([1, 2]), dim=-1) + meta_data = data.to('meta') + _assert_output_shape(data=meta_data, + patch_fn=repeat_interleave, + expect_exception=False, + output_shape=materialized_output.shape) + + data = torch.tensor([[1, 2], [3, 4]]) + materialized_output = torch.repeat_interleave(data, repeats=torch.tensor([1, 2]), dim=0) + repeat_interleave = partial(patch_fn, repeats=[1, 2], dim=0) + meta_data = data.to('meta') + _assert_output_shape(data=meta_data, + patch_fn=repeat_interleave, + expect_exception=True, + output_shape=materialized_output.shape) + + +@clear_cache_before_run() +def test_torch_max(): + data = torch.rand(4, 3) + out = torch.max(data) + patched_out = patched_function.torch_max(data) + assert out.shape == patched_out.shape + + data = torch.rand(4, 3, 2) + out, idx = torch.max(data, dim=1) + patched_out, patched_idx = patched_function.torch_max(data, dim=1) + assert out.shape == patched_out.shape + assert idx.shape == patched_idx.shape + + data = torch.rand(4, 3, 2) + out, idx = torch.max(data, dim=1, keepdim=True) + patched_out, patched_idx = patched_function.torch_max(data, dim=1, keepdim=True) + assert out.shape == patched_out.shape + assert idx.shape == patched_idx.shape diff --git a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py new file mode 100644 index 0000000000000000000000000000000000000000..aa14f514c7d6d86bb8c384200777d685fa0df93a --- /dev/null +++ b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py @@ -0,0 +1,64 @@ +import pytest +import torch +from packaging import version + +from colossalai._analyzer.fx import symbolic_trace +from colossalai.testing import clear_cache_before_run +from tests.kit.model_zoo import model_zoo + + +def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None): + # trace + model = model_cls() + + # convert to eval for inference + # it is important to set it to eval mode before tracing + # without this statement, the torch.nn.functional.batch_norm will always be in training mode + model.eval() + + # TODO: support the following models + # 1. ConViT + # 2. NormFreeNet + # as they are not supported, let's skip them + if model.__class__.__name__ in ['ConViT', 'NormFreeNet']: + return + + gm = symbolic_trace(model, meta_args=meta_args) + + # run forward + with torch.no_grad(): + fx_out = gm(**data) + non_fx_out = model(**data) + + # compare output + transformed_fx_out = output_transform_fn(fx_out) + transformed_non_fx_out = output_transform_fn(non_fx_out) + + assert len(transformed_fx_out) == len(transformed_non_fx_out) + + for key in transformed_fx_out.keys(): + fx_output_val = transformed_fx_out[key] + non_fx_output_val = transformed_non_fx_out[key] + assert torch.allclose(fx_output_val, non_fx_output_val, atol=1e-5), \ + f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}' + + +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@clear_cache_before_run() +def test_timm_models(): + torch.backends.cudnn.deterministic = True + + sub_model_zoo = model_zoo.get_sub_registry('timm') + + for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items(): + data = data_gen_fn() + if attribute is not None and attribute.has_control_flow: + meta_args = {k: v.to('meta') for k, v in data.items()} + else: + meta_args = None + + trace_and_compare(model_fn, data, output_transform_fn, meta_args) + + +if __name__ == '__main__': + test_timm_models() diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py new file mode 100644 index 0000000000000000000000000000000000000000..eafcaca10b1d72c37750a8b2f87a9c0ec4487a7f --- /dev/null +++ b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py @@ -0,0 +1,24 @@ +import pytest +import torch +from packaging import version +from torchaudio_utils import trace_and_compare + +from colossalai.testing import clear_cache_before_run +from tests.kit.model_zoo import model_zoo + + +# We cannot handle the tensors constructed with constant during forward, such as ``torch.empty(0).to(device=Proxy.device)`` +# TODO: We could handle this case by hijacking torch.Tensor.to function. +@pytest.mark.skip +@clear_cache_before_run() +def test_torchaudio_models(): + torch.backends.cudnn.deterministic = True + + sub_model_zoo = model_zoo.get_sub_registry('torchaudio') + + for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items(): + model = model_fn() + trace_and_compare(model, + data_gen_fn, + output_transform_fn, + need_meta=(attribute is not None and attribute.has_control_flow)) diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py b/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..239f38680cec5f0659985416260992935645cb10 --- /dev/null +++ b/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py @@ -0,0 +1,28 @@ +import torch + +from colossalai._analyzer.fx import symbolic_trace + + +def trace_and_compare(model, data_gen, output_transform_fn, need_meta=False, need_concrete=False): + data = data_gen() + concrete_args = data if need_concrete else {} + meta_args = {k: v.to('meta') for k, v in data.items()} if need_meta else {} + + model.eval() + + gm = symbolic_trace(model, concrete_args=concrete_args, meta_args=meta_args) + + with torch.no_grad(): + non_fx_out = model(**data) + fx_out = gm(**data) + + # compare output + transformed_fx_out = output_transform_fn(fx_out) + transformed_non_fx_out = output_transform_fn(non_fx_out) + + assert len(transformed_fx_out) == len(transformed_non_fx_out) + + for key, fx_output_val in transformed_fx_out.items(): + non_fx_output_val = transformed_non_fx_out[key] + assert torch.allclose(fx_output_val, non_fx_output_val, atol=1e-5), \ + f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}' diff --git a/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py b/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py new file mode 100644 index 0000000000000000000000000000000000000000..df02568c00496a0b76c686ac7e3a551960914584 --- /dev/null +++ b/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py @@ -0,0 +1,67 @@ +import pytest +import torch + +from colossalai._analyzer.fx import symbolic_trace +from colossalai.testing import clear_cache_before_run +from tests.kit.model_zoo import model_zoo + +BATCH = 2 +SHAPE = 10 + + +def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None): + # trace + model = model_cls() + + # convert to eval for inference + # it is important to set it to eval mode before tracing + # without this statement, the torch.nn.functional.batch_norm will always be in training mode + model.eval() + + gm = symbolic_trace(model, meta_args=meta_args) + gm.eval() + # run forward + with torch.no_grad(): + fx_out = gm(**data) + non_fx_out = model(**data) + + # compare output + transformed_fx_out = output_transform_fn(fx_out) + transformed_non_fx_out = output_transform_fn(non_fx_out) + + assert len(transformed_fx_out) == len(transformed_non_fx_out) + if torch.is_tensor(fx_out): + assert torch.allclose( + fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + else: + assert torch.allclose( + fx_out.values(), + non_fx_out.values()), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + for key in transformed_fx_out.keys(): + fx_output_val = transformed_fx_out[key] + non_fx_output_val = transformed_non_fx_out[key] + if torch.is_tensor(fx_output_val): + assert torch.allclose(fx_output_val, non_fx_output_val, atol=1e-5), \ + f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}' + else: + assert torch.allclose(fx_output_val.values(), non_fx_output_val.values() + ), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + + +@clear_cache_before_run() +def test_torchrec_deepfm_models(): + deepfm_models = model_zoo.get_sub_registry('deepfm') + torch.backends.cudnn.deterministic = True + + for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in deepfm_models.items(): + data = data_gen_fn() + if attribute is not None and attribute.has_control_flow: + meta_args = {k: v.to('meta') for k, v in data.items()} + else: + meta_args = None + + trace_and_compare(model_fn, data, output_transform_fn, meta_args) + + +if __name__ == "__main__": + test_torchrec_deepfm_models() diff --git a/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py b/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py new file mode 100644 index 0000000000000000000000000000000000000000..9776452be9c8ea220a3fc2651f5aa4a5a3f8b2ad --- /dev/null +++ b/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py @@ -0,0 +1,73 @@ +import pytest +import torch + +from colossalai._analyzer.fx import symbolic_trace +from colossalai.testing import clear_cache_before_run +from tests.kit.model_zoo import model_zoo + +BATCH = 2 +SHAPE = 10 + + +def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None): + # trace + model = model_cls() + + # convert to eval for inference + # it is important to set it to eval mode before tracing + # without this statement, the torch.nn.functional.batch_norm will always be in training mode + model.eval() + + gm = symbolic_trace(model, meta_args=meta_args) + gm.eval() + # run forward + with torch.no_grad(): + fx_out = gm(**data) + non_fx_out = model(**data) + + # compare output + transformed_fx_out = output_transform_fn(fx_out) + transformed_non_fx_out = output_transform_fn(non_fx_out) + + assert len(transformed_fx_out) == len(transformed_non_fx_out) + if torch.is_tensor(fx_out): + assert torch.allclose( + fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + else: + assert torch.allclose( + fx_out.values(), + non_fx_out.values()), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + for key in transformed_fx_out.keys(): + fx_output_val = transformed_fx_out[key] + non_fx_output_val = transformed_non_fx_out[key] + if torch.is_tensor(fx_output_val): + assert torch.allclose(fx_output_val, non_fx_output_val, atol=1e-5), \ + f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}' + else: + assert torch.allclose(fx_output_val.values(), non_fx_output_val.values() + ), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + + +@clear_cache_before_run() +def test_torchrec_dlrm_models(): + torch.backends.cudnn.deterministic = True + dlrm_models = model_zoo.get_sub_registry('dlrm') + + for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in dlrm_models.items(): + data = data_gen_fn() + + # dlrm_interactionarch is not supported + # TODO(FrankLeeeee): support this model + if name == 'dlrm_interactionarch': + continue + + if attribute is not None and attribute.has_control_flow: + meta_args = {k: v.to('meta') for k, v in data.items()} + else: + meta_args = None + + trace_and_compare(model_fn, data, output_transform_fn, meta_args) + + +if __name__ == "__main__": + test_torchrec_dlrm_models() diff --git a/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py b/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py new file mode 100644 index 0000000000000000000000000000000000000000..bd259475ae5a51311932f6bb61e89d1c55f538cf --- /dev/null +++ b/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py @@ -0,0 +1,46 @@ +import torch + +from colossalai._analyzer.fx import symbolic_trace +from colossalai.testing import clear_cache_before_run +from tests.kit.model_zoo import model_zoo + + +@clear_cache_before_run() +def test_torchvision_models(): + torch.backends.cudnn.deterministic = True + tv_sub_registry = model_zoo.get_sub_registry('torchvision') + + for name, (model_fn, data_gen_fn, output_transform_fn, model_attribute) in tv_sub_registry.items(): + data = data_gen_fn() + + if model_attribute is not None and model_attribute.has_stochastic_depth_prob: + model = model_fn(stochastic_depth_prob=0) + else: + model = model_fn() + + gm = symbolic_trace(model) + + model.eval() + gm.eval() + + try: + with torch.no_grad(): + fx_out = gm(**data) + non_fx_out = model(**data) + transformed_out = output_transform_fn(fx_out) + transformed_non_fx_out = output_transform_fn(non_fx_out) + + assert len(transformed_out) == len(transformed_non_fx_out) + + for key in transformed_out.keys(): + fx_val = transformed_out[key] + non_fx_val = transformed_non_fx_out[key] + assert torch.allclose( + fx_val, + non_fx_val), f'{model.__class__.__name__} has inconsistent outputs, {fx_val} vs {non_fx_val}' + except Exception as e: + print(name, e) + + +if __name__ == '__main__': + test_torchvision_models() diff --git a/tests/test_layers/test_1d/checks_1d/__init__.py b/tests/test_layers/test_1d/checks_1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/test_layers/test_1d/checks_1d/check_layer_1d.py b/tests/test_layers/test_1d/checks_1d/check_layer_1d.py new file mode 100644 index 0000000000000000000000000000000000000000..668b8a334800753d9848347a8d88cba66b605de6 --- /dev/null +++ b/tests/test_layers/test_1d/checks_1d/check_layer_1d.py @@ -0,0 +1,552 @@ +import torch +import torch.distributed as dist +from torch.nn import Parameter + +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.global_variables import tensor_parallel_env as env +from colossalai.nn import ( + Classifier1D, + Embedding1D, + Linear1D_Col, + Linear1D_Row, + VanillaClassifier, + VocabParallelClassifier1D, + VocabParallelCrossEntropyLoss1D, + VocabParallelEmbedding1D, +) +from colossalai.utils import get_current_device, print_rank_0 + +from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal + + +def check_linear_col(): + device = get_current_device() + dtype = torch.float32 + INPUT_SIZE = HIDDEN_SIZE + OUTPUT_SIZE = 2 * HIDDEN_SIZE + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + layer = Linear1D_Col(INPUT_SIZE, OUTPUT_SIZE) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + dist.broadcast(A_master, src=0) + A = A_master.clone() + A.requires_grad = True + + W_shape = (OUTPUT_SIZE, INPUT_SIZE) + W_master = torch.randn(W_shape, dtype=dtype, device=device) + dist.broadcast(W_master, src=0) + W = torch.chunk(W_master, DEPTH, dim=0)[i] + W = W.clone() + W.requires_grad = True + + B_shape = (OUTPUT_SIZE) + B_master = torch.randn(B_shape, dtype=dtype, device=device) + dist.broadcast(B_master, src=0) + B = torch.chunk(B_master, DEPTH, dim=0)[i] + B = B.clone() + B.requires_grad = True + + layer.weight = Parameter(W) + layer.bias = Parameter(B) + out = layer(A) + + A_master = A_master.clone() + A_master.requires_grad = True + W_master = W_master.clone() + W_master.requires_grad = True + B_master = B_master.clone() + B_master.requires_grad = True + C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master + C = torch.chunk(C_master, DEPTH, dim=-1)[i] + + check_equal(out, C) + print_rank_0('linear_col forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + dist.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=-1)[i] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + A_grad = A_master.grad + check_equal(A_grad, A.grad) + + W_grad = W_master.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i] + check_equal(W_grad, layer.weight.grad) + + B_grad = B_master.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] + check_equal(B_grad, layer.bias.grad) + + print_rank_0('linear_col backward: pass') + + +def check_linear_row(): + device = get_current_device() + dtype = torch.float32 + INPUT_SIZE = HIDDEN_SIZE + OUTPUT_SIZE = 2 * HIDDEN_SIZE + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + layer = Linear1D_Row(OUTPUT_SIZE, INPUT_SIZE) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, OUTPUT_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + dist.broadcast(A_master, src=0) + A = torch.chunk(A_master, DEPTH, dim=-1)[i] + A = A.clone() + A.requires_grad = True + + W_shape = (INPUT_SIZE, OUTPUT_SIZE) + W_master = torch.randn(W_shape, dtype=dtype, device=device) + dist.broadcast(W_master, src=0) + W = torch.chunk(W_master, DEPTH, dim=-1)[i] + W = W.clone() + W.requires_grad = True + + B_shape = (INPUT_SIZE) + B_master = torch.randn(B_shape, dtype=dtype, device=device) + dist.broadcast(B_master, src=0) + B = B_master.clone() + B.requires_grad = True + + layer.weight = Parameter(W) + layer.bias = Parameter(B) + out = layer(A) + + A_master = A_master.clone() + A_master.requires_grad = True + W_master = W_master.clone() + W_master.requires_grad = True + B_master = B_master.clone() + B_master.requires_grad = True + C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master + C = C_master.clone() + + check_equal(out, C) + print_rank_0('linear_row forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + dist.broadcast(grad_master, src=0) + grad = grad_master.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[i] + check_equal(A_grad, A.grad) + + W_grad = W_master.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i] + check_equal(W_grad, layer.weight.grad) + + B_grad = B_master.grad + check_equal(B_grad, layer.bias.grad) + + print_rank_0('linear_row backward: pass') + + +def check_embed(): + device = get_current_device() + dtype = torch.float32 + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + embed = Embedding1D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=-1)[i] + embed.weight.data.copy_(weight) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = embed(A) + + A_master = A_master.clone() + C_master = embed_master(A_master) + C = C_master.clone() + check_equal(out, C) + print_rank_0('embed forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = grad_master.clone() + out.backward(grad) + grad_master = grad_master.clone() + C_master.backward(grad_master) + + B_grad = embed_master.weight.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i] + check_equal(B_grad, embed.weight.grad) + print_rank_0('embed backward: pass') + + +def check_vocab_parallel_embed(): + device = get_current_device() + dtype = torch.float32 + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + embed = VocabParallelEmbedding1D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=0)[i] + embed.weight.data.copy_(weight) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = embed(A) + + A_master = A_master.clone() + C_master = embed_master(A_master) + C = C_master.clone() + check_equal(out, C) + print_rank_0('vocab parallel embed forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = grad_master.clone() + out.backward(grad) + grad_master = grad_master.clone() + C_master.backward(grad_master) + + B_grad = embed_master.weight.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] + check_equal(B_grad, embed.weight.grad) + print_rank_0('vocab parallel embed backward: pass') + + +def check_classifier_no_given_weight(): + device = get_current_device() + dtype = torch.float32 + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + env.parallel_input_1d = False + parallel_input_1d = env.parallel_input_1d + layer = Classifier1D(HIDDEN_SIZE, NUM_CLASSES, bias=True) + layer.to(dtype).to(device) + + layer_master = VanillaClassifier(HIDDEN_SIZE, NUM_CLASSES, bias=True) + layer_master = layer_master.to(dtype).to(device) + + W_master = layer_master.weight.data + dist.broadcast(W_master, src=0) + W = torch.chunk(W_master, DEPTH, dim=-1)[i] + layer.weight.data.copy_(W) + B_master = layer_master.bias.data + dist.broadcast(B_master, src=0) + B = B_master.clone() + layer.bias.data.copy_(B) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + dist.broadcast(A_master, src=0) + if parallel_input_1d: + A = torch.chunk(A_master, DEPTH, dim=-1)[i] + A = A.clone() + else: + A = A_master.clone() + A.requires_grad = True + + out = layer(A) + + A_master = A_master.clone() + A_master.requires_grad = True + C_master = layer_master(A_master) + C = C_master.clone() + + check_equal(out, C) + print_rank_0('classifier (no given weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + dist.broadcast(grad_master, src=0) + grad = grad_master.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + A_grad = A_master.grad + if parallel_input_1d: + A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[i] + check_equal(A_grad, A.grad) + + W_grad = layer_master.weight.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i] + check_equal(W_grad, layer.weight.grad) + + B_grad = layer_master.bias.grad + check_equal(B_grad, layer.bias.grad) + + print_rank_0('classifier (no given weight) backward: pass') + + +def check_vocab_parallel_classifier_no_given_weight(): + device = get_current_device() + dtype = torch.float32 + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + layer = VocabParallelClassifier1D(HIDDEN_SIZE, VOCAB_SIZE, bias=True) + layer.to(dtype).to(device) + + layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, bias=True) + layer_master = layer_master.to(dtype).to(device) + + W_master = layer_master.weight.data + dist.broadcast(W_master, src=0) + W = torch.chunk(W_master, DEPTH, dim=0)[i] + layer.weight.data.copy_(W) + B_master = layer_master.bias.data + dist.broadcast(B_master, src=0) + B = torch.chunk(B_master, DEPTH, dim=0)[i] + layer.bias.data.copy_(B) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + dist.broadcast(A_master, src=0) + A = A_master.clone() + A.requires_grad = True + + out = layer(A) + + A_master = A_master.clone() + A_master.requires_grad = True + C_master = layer_master(A_master) + C = torch.chunk(C_master, DEPTH, dim=-1)[i] + + check_equal(out, C) + print_rank_0('vocab parallel classifier (no given weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + dist.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=-1)[i] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + A_grad = A_master.grad + check_equal(A_grad, A.grad) + + W_grad = layer_master.weight.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i] + check_equal(W_grad, layer.weight.grad) + + B_grad = layer_master.bias.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] + check_equal(B_grad, layer.bias.grad) + + print_rank_0('vocab parallel classifier (no given weight) backward: pass') + + +def check_classifier_given_embed_weight(): + device = get_current_device() + dtype = torch.float32 + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + embed = Embedding1D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=-1)[i] + embed.weight.data.copy_(weight) + + env.parallel_input_1d = False + layer = Classifier1D(HIDDEN_SIZE, NUM_CLASSES, weight=embed.weight, bias=False) + layer.to(dtype).to(device) + + layer_master = VanillaClassifier(HIDDEN_SIZE, NUM_CLASSES, weight=embed_master.weight, bias=False) + layer_master = layer_master.to(dtype).to(device) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = layer(embed(A)) + + A_master = A_master.clone() + C_master = layer_master(embed_master(A_master)) + C = C_master.clone() + check_equal(out, C) + print_rank_0('classifier (given embed weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + dist.broadcast(grad_master, src=0) + grad = grad_master.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + W_grad = embed_master.weight.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i] + check_equal(W_grad, embed.weight.grad) + + print_rank_0('classifier (given embed weight) backward: pass') + + +def check_vocab_parallel_classifier_given_embed_weight(): + device = get_current_device() + dtype = torch.float32 + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + embed = VocabParallelEmbedding1D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=0)[i] + embed.weight.data.copy_(weight) + + env.parallel_input_1d = False + layer = VocabParallelClassifier1D(HIDDEN_SIZE, NUM_CLASSES, weight=embed.weight, bias=False) + layer.to(dtype).to(device) + + layer_master = VanillaClassifier(HIDDEN_SIZE, NUM_CLASSES, weight=embed_master.weight, bias=False) + layer_master = layer_master.to(dtype).to(device) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = layer(embed(A)) + + A_master = A_master.clone() + C_master = layer_master(embed_master(A_master)) + C = torch.chunk(C_master, DEPTH, dim=-1)[i] + check_equal(out, C) + print_rank_0('vocab parallel classifier (given embed weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + dist.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=-1)[i] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + W_grad = embed_master.weight.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i] + check_equal(W_grad, embed.weight.grad) + + print_rank_0('vocab parallel classifier (given embed weight) backward: pass') + + +def check_vocab_parallel_loss(): + device = get_current_device() + dtype = torch.float32 + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + criterion = VocabParallelCrossEntropyLoss1D() + criterion_master = torch.nn.CrossEntropyLoss() + + out_shape = (BATCH_SIZE, SEQ_LENGTH, NUM_CLASSES) + out_master = torch.randn(out_shape, dtype=dtype, device=device) + target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, SEQ_LENGTH), dtype=torch.long, device=device) + torch.distributed.broadcast(out_master, src=0) + torch.distributed.broadcast(target_master, src=0) + out = torch.chunk(out_master, DEPTH, dim=-1)[i] + out = out.clone() + out.requires_grad = True + + loss = criterion(out, target_master) + + out_master = out_master.clone() + out_master.requires_grad = True + loss_master = criterion_master(out_master, target_master) + check_equal(loss, loss_master) + print_rank_0('vocab parallel loss forward: pass') + + loss.backward() + loss_master.backward() + + out_grad = out_master.grad + out_grad = torch.chunk(out_grad, DEPTH, dim=-1)[i] + check_equal(out_grad, out.grad) + print_rank_0('vocab parallel loss backward: pass') + + +@torch.no_grad() +def check_linear_row_stream_inference(): + device = get_current_device() + dtype = torch.float32 + INPUT_SIZE = HIDDEN_SIZE + OUTPUT_SIZE = 2 * HIDDEN_SIZE + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + stream_chunk_num = 4 + assert HIDDEN_SIZE % stream_chunk_num == 0 + layer = Linear1D_Row(OUTPUT_SIZE, INPUT_SIZE, stream_chunk_num=stream_chunk_num) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, OUTPUT_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + dist.broadcast(A_master, src=0) + A = torch.chunk(A_master, DEPTH, dim=-1)[i] + A = A.clone() + + W_shape = (INPUT_SIZE, OUTPUT_SIZE) + W_master = torch.randn(W_shape, dtype=dtype, device=device) + dist.broadcast(W_master, src=0) + W = torch.chunk(W_master, DEPTH, dim=-1)[i] + W = W.clone() + + B_shape = (INPUT_SIZE) + B_master = torch.randn(B_shape, dtype=dtype, device=device) + dist.broadcast(B_master, src=0) + B = B_master.clone() + + layer.weight = Parameter(W) + layer.bias = Parameter(B) + layer.chunk_weight() + layer.eval() + + out = layer(A) + + A_master = A_master.clone() + W_master = W_master.clone() + B_master = B_master.clone() + C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master + C = C_master.clone() + + check_equal(out, C) + print_rank_0('linear_row forward: pass') diff --git a/tests/test_layers/test_1d/checks_1d/common.py b/tests/test_layers/test_1d/checks_1d/common.py new file mode 100644 index 0000000000000000000000000000000000000000..8b7b28613d223bfb0ab249dee01869446a08e406 --- /dev/null +++ b/tests/test_layers/test_1d/checks_1d/common.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch + +DEPTH = 4 +BATCH_SIZE = 8 +SEQ_LENGTH = 8 +IMG_SIZE = 16 +HIDDEN_SIZE = 8 +NUM_CLASSES = 8 +VOCAB_SIZE = 16 + +def check_equal(A, B): + assert torch.allclose(A, B, rtol=1e-3, atol=1e-1) == True diff --git a/tests/test_layers/test_1d/test_1d.py b/tests/test_layers/test_1d/test_1d.py new file mode 100644 index 0000000000000000000000000000000000000000..89151254247522e4b34dfae1155fca9dc73d0572 --- /dev/null +++ b/tests/test_layers/test_1d/test_1d.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import pytest +import torch +from checks_1d.check_layer_1d import * + +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_if_address_is_in_use, spawn + +CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='1d')),) + + +def check_layer(rank, world_size, port): + disable_existing_loggers() + launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + check_linear_col() + check_linear_row() + check_embed() + check_vocab_parallel_embed() + check_classifier_no_given_weight() + check_vocab_parallel_classifier_no_given_weight() + check_classifier_given_embed_weight() + check_vocab_parallel_classifier_given_embed_weight() + check_vocab_parallel_loss() + + check_linear_row_stream_inference() + + gpc.destroy() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_1d(): + spawn(check_layer, 4) + + +if __name__ == '__main__': + test_1d() diff --git a/tests/test_layers/test_2d/checks_2d/__init__.py b/tests/test_layers/test_2d/checks_2d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/test_layers/test_2d/checks_2d/check_layer_2d.py b/tests/test_layers/test_2d/checks_2d/check_layer_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..e030e473a36311250bd0310b94261a46b77d561a --- /dev/null +++ b/tests/test_layers/test_2d/checks_2d/check_layer_2d.py @@ -0,0 +1,741 @@ +import torch +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.nn import (Classifier2D, CrossEntropyLoss2D, Embedding2D, LayerNorm2D, Linear2D, PatchEmbedding2D, + VanillaClassifier, VanillaPatchEmbedding, VocabParallelClassifier2D, + VocabParallelCrossEntropyLoss2D, VocabParallelEmbedding2D) +from colossalai.utils import get_current_device, print_rank_0 + +from .common import (BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal) + + +def check_linear(): + device = get_current_device() + dtype = torch.float32 + INPUT_SIZE = HIDDEN_SIZE + OUTPUT_SIZE = HIDDEN_SIZE + + j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + + layer = Linear2D(INPUT_SIZE, OUTPUT_SIZE) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + torch.distributed.broadcast(A_master, src=0) + A = torch.chunk(A_master, DEPTH, dim=0)[i] + A = torch.chunk(A, DEPTH, dim=-1)[j] + A = A.clone() + A.requires_grad = True + + W_shape = (INPUT_SIZE, OUTPUT_SIZE) + W_master = torch.randn(W_shape, dtype=dtype, device=device) + torch.distributed.broadcast(W_master, src=0) + W = torch.chunk(W_master, DEPTH, dim=0)[i] + W = torch.chunk(W, DEPTH, dim=-1)[j] + W = W.clone() + W.requires_grad = True + + B_shape = (OUTPUT_SIZE) + B_master = torch.randn(B_shape, dtype=dtype, device=device) + torch.distributed.broadcast(B_master, src=0) + B = torch.chunk(B_master, DEPTH, dim=-1)[j] + B = torch.chunk(B, DEPTH, dim=-1)[i] + B = B.clone() + B.requires_grad = True + + layer.weight.data.copy_(W) + layer.bias.data.copy_(B) + out = layer(A) + + A_master = A_master.clone() + A_master.requires_grad = True + W_master = W_master.clone() + W_master.requires_grad = True + B_master = B_master.clone() + B_master.requires_grad = True + C_master = torch.matmul(A_master, W_master) + B_master + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[j] + + check_equal(out, C) + print_rank_0('linear forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[j] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] + A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j] + check_equal(A_grad, A.grad) + + W_grad = W_master.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i] + W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j] + check_equal(W_grad, layer.weight.grad) + + B_grad = B_master.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j] + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i] + # if i == 0: + check_equal(B_grad, layer.bias.grad) + + print_rank_0('linear backward: pass') + + +def check_layernorm(): + device = get_current_device() + dtype = torch.float32 + INPUT_SIZE = HIDDEN_SIZE + EPS = 1e-12 + + j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + + layernorm = LayerNorm2D(INPUT_SIZE) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + torch.distributed.broadcast(A_master, src=0) + A = torch.chunk(A_master, DEPTH, dim=0)[i] + A = torch.chunk(A, DEPTH, dim=-1)[j] + A = A.clone() + A.requires_grad = True + + out = layernorm(A) + + A_master = A_master.clone() + A_master.requires_grad = True + E_master = torch.sum(A_master, dim=-1, keepdim=True) + E_master /= INPUT_SIZE + V_master = torch.sum(A_master * A_master, dim=-1, keepdim=True) + V_master /= INPUT_SIZE + V_master = V_master - E_master * E_master + V_master = 1.0 / torch.sqrt(V_master + EPS) + C_master = (A_master - E_master) * V_master + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[j] + + check_equal(out, C) + print_rank_0('layer norm forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[j] + out.backward(grad) + + C_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] + A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j] + check_equal(A_grad, A.grad) + print_rank_0('layer norm backward: pass') + + +def check_embed(): + device = get_current_device() + dtype = torch.float32 + j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + + embed = Embedding2D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=-1)[j] + weight = torch.chunk(weight, DEPTH, dim=-1)[i] + embed.weight.data.copy_(weight) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = embed(A) + + A_master = A_master.clone() + C_master = embed_master(A_master) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[j] + check_equal(out, C) + print_rank_0('embed forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[j] + grad = grad.clone() + out.backward(grad) + grad_master = grad_master.clone() + C_master.backward(grad_master) + + B_grad = embed_master.weight.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j] + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i] + check_equal(B_grad, embed.weight.grad) + print_rank_0('embed backward: pass') + + +def check_patch_embed(): + device = get_current_device() + dtype = torch.float32 + j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + + layer = PatchEmbedding2D(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype) + torch.nn.init.ones_(layer.cls_token) + torch.nn.init.ones_(layer.pos_embed) + layer = layer.to(device) + + layer_master = VanillaPatchEmbedding(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype) + torch.nn.init.ones_(layer_master.cls_token) + torch.nn.init.ones_(layer_master.pos_embed) + layer_master = layer_master.to(device) + + proj_weight_master = layer_master.weight.data + torch.distributed.broadcast(proj_weight_master, src=0) + proj_weight = torch.chunk(proj_weight_master, DEPTH, dim=0)[j] + proj_weight = torch.chunk(proj_weight, DEPTH, dim=0)[i] + layer.weight.data.copy_(proj_weight) + proj_bias_master = layer_master.bias.data + torch.distributed.broadcast(proj_bias_master, src=0) + proj_bias = torch.chunk(proj_bias_master, DEPTH, dim=0)[j] + proj_bias = torch.chunk(proj_bias, DEPTH, dim=0)[i] + layer.bias.data.copy_(proj_bias) + + A_shape = (BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = layer(A) + + A_master = A_master.clone() + C_master = layer_master(A_master) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[j] + check_equal(out, C) + print_rank_0('patch embed forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[j] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + cls_grad_master = layer_master.cls_token.grad + cls_grad = torch.chunk(cls_grad_master, DEPTH, dim=-1)[j] + cls_grad = torch.chunk(cls_grad, DEPTH, dim=-1)[i] + check_equal(cls_grad, layer.cls_token.grad) + + pos_grad_master = layer_master.pos_embed.grad + pos_grad = torch.chunk(pos_grad_master, DEPTH, dim=-1)[j] + pos_grad = torch.chunk(pos_grad, DEPTH, dim=-1)[i] + check_equal(pos_grad, layer.pos_embed.grad) + + B_grad = layer_master.weight.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j] + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] + check_equal(B_grad, layer.weight.grad) + + bias_grad = layer_master.bias.grad + bias_grad = torch.chunk(bias_grad, DEPTH)[j] + bias_grad = torch.chunk(bias_grad, DEPTH)[i] + check_equal(bias_grad, layer.bias.grad) + print_rank_0('patch embed backward: pass') + + +def check_vocab_parallel_embed(): + device = get_current_device() + dtype = torch.float32 + j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + + embed = VocabParallelEmbedding2D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=-1)[j] + weight = torch.chunk(weight, DEPTH, dim=0)[i] + embed.weight.data.copy_(weight) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = embed(A) + + A_master = A_master.clone() + C_master = embed_master(A_master) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[j] + check_equal(out, C) + print_rank_0('vocab parallel embed forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[j] + grad = grad.clone() + out.backward(grad) + grad_master = grad_master.clone() + C_master.backward(grad_master) + + B_grad = embed_master.weight.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j] + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] + check_equal(B_grad, embed.weight.grad) + print_rank_0('vocab parallel embed backward: pass') + + +def check_classifier_no_given_weight(): + device = get_current_device() + dtype = torch.float32 + INPUT_SIZE = HIDDEN_SIZE + OUTPUT_SIZE = NUM_CLASSES + + j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + + layer = Classifier2D(INPUT_SIZE, OUTPUT_SIZE) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) + A_master = torch.randint(5, A_shape, dtype=dtype, device=device) + torch.distributed.broadcast(A_master, src=0) + A = torch.chunk(A_master, DEPTH, dim=0)[i] + A = torch.chunk(A, DEPTH, dim=-1)[j] + A = A.clone() + A.requires_grad = True + + W_shape = (OUTPUT_SIZE, INPUT_SIZE) + W_master = torch.randint(5, W_shape, dtype=dtype, device=device) + torch.distributed.broadcast(W_master, src=0) + W = torch.chunk(W_master, DEPTH, dim=-1)[j] + W = torch.chunk(W, DEPTH, dim=-1)[i] + W = W.clone() + layer.weight.data.copy_(W) + # W.requires_grad = True + + B_shape = (OUTPUT_SIZE, ) + B_master = torch.randint(5, B_shape, dtype=dtype, device=device) + torch.distributed.broadcast(B_master, src=0) + # B = torch.chunk(B_master, DEPTH, dim=0)[j] + B = B_master.clone() + layer.bias.data.copy_(B) + + out = layer(A) + + A_master = A_master.clone() + A_master.requires_grad = True + W_master = W_master.clone() + W_master.requires_grad = True + B_master = B_master.clone() + B_master.requires_grad = True + C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master + C = torch.chunk(C_master, DEPTH, dim=0)[i] + # C = torch.chunk(C, DEPTH, dim=-1)[j] + + check_equal(out, C) + print_rank_0('classifier (no given weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + # grad = torch.chunk(grad, DEPTH, dim=-1)[j] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] + A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j] + check_equal(A_grad, A.grad) + + W_grad = W_master.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j] + W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i] + check_equal(W_grad, layer.weight.grad) + + B_grad = B_master.grad + # B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j] + # if i == 0: + check_equal(B_grad, layer.bias.grad) + + print_rank_0('classifier (no given weight) backward: pass') + + +def check_vocab_parallel_classifier_no_given_weight(): + device = get_current_device() + dtype = torch.float32 + + j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + + layer = VocabParallelClassifier2D(HIDDEN_SIZE, VOCAB_SIZE, bias=True) + layer = layer.to(dtype).to(device) + + layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, bias=True) + layer_master = layer_master.to(dtype).to(device) + + weight_master = layer_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=0)[i] + weight = torch.chunk(weight, DEPTH, dim=-1)[j] + layer.weight.data.copy_(weight) + bias_master = layer_master.bias.data + torch.distributed.broadcast(bias_master, src=0) + bias = torch.chunk(bias_master, DEPTH)[j] + bias = torch.chunk(bias, DEPTH)[i] + layer.bias.data.copy_(bias) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + torch.distributed.broadcast(A_master, src=0) + A = torch.chunk(A_master, DEPTH, dim=0)[i] + A = torch.chunk(A, DEPTH, dim=-1)[j] + A = A.clone() + A.requires_grad = True + out = layer(A) + + A_master = A_master.clone() + A_master.requires_grad = True + C_master = layer_master(A_master) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[j] + check_equal(out, C) + print_rank_0('vocab parallel classifier (no given weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[j] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] + A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j] + check_equal(A_grad, A.grad) + + W_grad = layer_master.weight.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i] + W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j] + check_equal(W_grad, layer.weight.grad) + + B_grad = layer_master.bias.grad + B_grad = torch.chunk(B_grad, DEPTH)[j] + B_grad = torch.chunk(B_grad, DEPTH)[i] + check_equal(B_grad, layer.bias.grad) + print_rank_0('vocab parallel classifier (no given weight) backward: pass') + + +def check_classifier_given_embed_weight(): + device = get_current_device() + dtype = torch.float32 + + j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + + embed = Embedding2D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=-1)[j] + weight = torch.chunk(weight, DEPTH, dim=-1)[i] + embed.weight.data.copy_(weight) + + layer = Classifier2D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False) + layer = layer.to(dtype).to(device) + layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False) + layer_master = layer_master.to(dtype).to(device) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = layer(embed(A)) + + A_master = A_master.clone() + C_master = layer_master(embed_master(A_master)) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + check_equal(out, C) + print_rank_0('classifier (given embed weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + W_grad = embed_master.weight.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j] + W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i] + check_equal(W_grad, embed.weight.grad) + print_rank_0('classifier (given embed weight) backward: pass') + + +def check_vocab_parallel_classifier_given_embed_weight(): + device = get_current_device() + dtype = torch.float32 + + j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + + embed = VocabParallelEmbedding2D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=-1)[j] + weight = torch.chunk(weight, DEPTH, dim=0)[i] + embed.weight.data.copy_(weight) + + layer = VocabParallelClassifier2D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False) + layer = layer.to(dtype).to(device) + layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False) + layer_master = layer_master.to(dtype).to(device) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = layer(embed(A)) + + A_master = A_master.clone() + C_master = layer_master(embed_master(A_master)) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[j] + check_equal(out, C) + print_rank_0('vocab parallel classifier (given embed weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[j] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + W_grad = embed_master.weight.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j] + W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i] + check_equal(W_grad, embed.weight.grad) + print_rank_0('vocab parallel classifier (given embed weight) backward: pass') + + +def check_loss(): + device = get_current_device() + dtype = torch.float32 + + j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + + criterion = CrossEntropyLoss2D() + criterion_master = torch.nn.CrossEntropyLoss() + + out_shape = (BATCH_SIZE, NUM_CLASSES) + out_master = torch.randn(out_shape, dtype=dtype, device=device) + target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device) + torch.distributed.broadcast(out_master, src=0) + torch.distributed.broadcast(target_master, src=0) + out = torch.chunk(out_master, DEPTH, dim=0)[i] + out = out.clone() + out.requires_grad = True + loss = criterion(out, target_master) + + out_master = out_master.clone() + out_master.requires_grad = True + loss_master = criterion_master(out_master, target_master) + check_equal(loss, loss_master) + print_rank_0('cross entropy loss forward: pass') + + loss.backward() + loss_master.backward() + + out_grad = out_master.grad + out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i] + check_equal(out_grad, out.grad) + print_rank_0('cross entropy loss backward: pass') + + +def check_vocab_parallel_loss(): + device = get_current_device() + dtype = torch.float32 + + j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + + criterion = VocabParallelCrossEntropyLoss2D() + criterion_master = torch.nn.CrossEntropyLoss() + + out_shape = (BATCH_SIZE, NUM_CLASSES) + out_master = torch.randn(out_shape, dtype=dtype, device=device) + target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device) + torch.distributed.broadcast(out_master, src=0) + torch.distributed.broadcast(target_master, src=0) + out = torch.chunk(out_master, DEPTH, dim=0)[i] + out = torch.chunk(out, DEPTH, dim=-1)[j] + out = out.clone() + out.requires_grad = True + loss = criterion(out, target_master) + + out_master = out_master.clone() + out_master.requires_grad = True + loss_master = criterion_master(out_master, target_master) + check_equal(loss, loss_master) + print_rank_0('vocab parallel cross entropy loss forward: pass') + + loss.backward() + loss_master.backward() + + out_grad = out_master.grad + out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i] + out_grad = torch.chunk(out_grad, DEPTH, dim=-1)[j] + check_equal(out_grad, out.grad) + print_rank_0('vocab parallel cross entropy loss backward: pass') + + +# def check_attention(): +# device = get_current_device() +# dtype = torch.float32 +# INPUT_SIZE = HIDDEN_SIZE +# NUM_ATTENTION_HEADS = 2 + +# j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) +# i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + +# layer = TransformerSelfAttention2D( +# HIDDEN_SIZE, +# NUM_ATTENTION_HEADS, +# attention_dropout_prob=0.5, +# hidden_dropout_prob=0.5, +# ) + +# A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) +# A_master = torch.randn(A_shape, dtype=dtype, device=device) +# torch.distributed.broadcast(A_master, src=0) +# A = torch.chunk(A_master, DEPTH, dim=0)[i] +# A = torch.chunk(A, DEPTH, dim=-1)[j] +# A = A.clone() +# A.requires_grad = True + +# mask_shape = (BATCH_SIZE // DEPTH, NUM_ATTENTION_HEADS // DEPTH, SEQ_LENGTH, SEQ_LENGTH) +# attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device) + +# out = layer(A, attention_mask) +# assert out.shape == (BATCH_SIZE // DEPTH, SEQ_LENGTH, INPUT_SIZE // DEPTH) +# print_rank_0('self attention forward: pass') + +# grad_shape = out.shape +# grad = torch.randn(grad_shape, dtype=dtype, device=device) + +# out.backward(grad) +# assert A.grad.shape == A.shape +# print_rank_0('self attention backward: pass') + +# def check_mlp(): +# device = get_current_device() +# dtype = torch.float32 +# INPUT_SIZE = HIDDEN_SIZE + +# j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) +# i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + +# layer = TransformerMLP2D( +# HIDDEN_SIZE, +# dropout_prob=0.5, +# act_func='gelu', +# ) + +# A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) +# A_master = torch.randn(A_shape, dtype=dtype, device=device) +# torch.distributed.broadcast(A_master, src=0) +# A = torch.chunk(A_master, DEPTH, dim=0)[i] +# A = torch.chunk(A, DEPTH, dim=-1)[j] +# A = A.clone() +# A.requires_grad = True + +# out = layer(A) +# assert out.shape == (BATCH_SIZE // DEPTH, SEQ_LENGTH, INPUT_SIZE // DEPTH) +# print_rank_0('mlp forward: pass') + +# grad_shape = out.shape +# grad = torch.randn(grad_shape, dtype=dtype, device=device) + +# out.backward(grad) +# assert A.grad.shape == A.shape +# print_rank_0('mlp backward: pass') + +# def check_transformerlayer(): +# device = get_current_device() +# dtype = torch.float32 +# INPUT_SIZE = HIDDEN_SIZE +# NUM_ATTENTION_HEADS = 2 + +# j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) +# i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + +# layer = TransformerLayer2D(HIDDEN_SIZE, +# NUM_ATTENTION_HEADS, +# act_func='gelu', +# attention_dropout_prob=0.5, +# hidden_dropout_prob=0.5) + +# A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) +# A_master = torch.randn(A_shape, dtype=dtype, device=device) +# torch.distributed.broadcast(A_master, src=0) +# A = torch.chunk(A_master, DEPTH, dim=0)[i] +# A = torch.chunk(A, DEPTH, dim=-1)[j] +# A = A.clone() +# A.requires_grad = True + +# mask_shape = (BATCH_SIZE // DEPTH, NUM_ATTENTION_HEADS // DEPTH, SEQ_LENGTH, SEQ_LENGTH) +# attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device) + +# out = layer(A, attention_mask) +# assert out.shape == (BATCH_SIZE // DEPTH, SEQ_LENGTH, INPUT_SIZE // DEPTH) +# print_rank_0('transformerlayer forward: pass') + +# grad_shape = out.shape +# grad = torch.randn(grad_shape, dtype=dtype, device=device) + +# out.backward(grad) +# assert A.grad.shape == A.shape +# print_rank_0('transformerlayer backward: pass') diff --git a/tests/test_layers/test_2d/checks_2d/check_operation_2d.py b/tests/test_layers/test_2d/checks_2d/check_operation_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..a5e37b1ec3097b29b55a4744111131ef3bfdea44 --- /dev/null +++ b/tests/test_layers/test_2d/checks_2d/check_operation_2d.py @@ -0,0 +1,213 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch + +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.nn.layer.parallel_2d._operation import Matmul_AB_2D, Matmul_ABT_2D, Matmul_ATB_2D +from colossalai.utils import get_current_device +from colossalai.utils import print_rank_0 +from .common import check_equal, BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE, DEPTH + + +def check_AB(): + data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) + pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank( + ParallelMode.PIPELINE) + pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( + ParallelMode.PIPELINE) + tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) + + dtype = torch.float + j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device()) + torch.distributed.broadcast(A_master, src=0) + A = torch.chunk(A_master, DEPTH, dim=0)[i] + A = torch.chunk(A, DEPTH, dim=-1)[j] + A = A.clone() + A.requires_grad = True + + B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE) + B_master = torch.randn(B_shape, dtype=dtype, device=get_current_device()) + torch.distributed.broadcast(B_master, src=0) + B = torch.chunk(B_master, DEPTH, dim=0)[i] + B = torch.chunk(B, DEPTH, dim=-1)[j] + B = B.clone() + B.requires_grad = True + + out_shape = (BATCH_SIZE // DEPTH, SEQ_LENGTH, 4 * HIDDEN_SIZE // DEPTH) + + out = Matmul_AB_2D.apply(A, B, DEPTH, out_shape, i, j, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, + data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size) + + C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE) + A_master = A_master.clone() + A_master.requires_grad = True + B_master = B_master.clone() + B_master.requires_grad = True + C_master = torch.matmul(A_master, B_master) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[j] + # check forward correctness + check_equal(out, C) + print_rank_0('AB forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[j] + + out.backward(grad) + + C_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] + A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j] + # check backward correctness + check_equal(A_grad, A.grad) + + B_grad = B_master.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j] + # check backward correctness + check_equal(B_grad, B.grad) + print_rank_0('AB backward: pass') + + +def check_ABT(): + data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) + pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank( + ParallelMode.PIPELINE) + pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( + ParallelMode.PIPELINE) + tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) + + dtype = torch.float + device = get_current_device() + + j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + + C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE) + C_master = torch.randn(C_shape, dtype=dtype, device=device) + torch.distributed.broadcast(C_master, src=0) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[j] + C = C.clone() + C.requires_grad = True + + B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE) + B_master = torch.randn(B_shape, dtype=dtype, device=device) + torch.distributed.broadcast(B_master, src=0) + B = torch.chunk(B_master, DEPTH, dim=0)[i] + B = torch.chunk(B, DEPTH, dim=-1)[j] + B = B.clone() + B.requires_grad = True + + out = Matmul_ABT_2D.apply(C, B, DEPTH, (BATCH_SIZE // DEPTH, SEQ_LENGTH, HIDDEN_SIZE // DEPTH), i, j, + ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, data_parallel_rank, + pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + C_master = C_master.clone() + C_master.requires_grad = True + B_master = B_master.clone() + B_master.requires_grad = True + A_master = torch.matmul(C_master, B_master.transpose(0, 1)) + A = torch.chunk(A_master, DEPTH, dim=0)[i] + A = torch.chunk(A, DEPTH, dim=-1)[j] + check_equal(out, A) + print_rank_0('ABT forward: pass') + + grad_shape = A_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[j] + + # backward + out.backward(grad) + + A_master.backward(grad_master) + C_grad = C_master.grad + C_grad = torch.chunk(C_grad, DEPTH, dim=0)[i] + C_grad = torch.chunk(C_grad, DEPTH, dim=-1)[j] + check_equal(C_grad, C.grad) + + B_grad = B_master.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j] + check_equal(B_grad, B.grad) + print_rank_0('ABT backward: pass') + + +def check_ATB(): + data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) + pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank( + ParallelMode.PIPELINE) + pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( + ParallelMode.PIPELINE) + tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) + + device = get_current_device() + dtype = torch.float + + j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + torch.distributed.broadcast(A_master, src=0) + A = torch.chunk(A_master, DEPTH, dim=0)[i] + A = torch.chunk(A, DEPTH, dim=-1)[j] + A = A.clone() + A.requires_grad = True + + C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE) + C_master = torch.randn(C_shape, dtype=dtype, device=device) + torch.distributed.broadcast(C_master, src=0) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[j] + C = C.clone() + C.requires_grad = True + + out = Matmul_ATB_2D.apply(A, C, DEPTH, (HIDDEN_SIZE // DEPTH, 4 * HIDDEN_SIZE // DEPTH), i, j, + ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, data_parallel_rank, + pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size) + + B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE) + A_master = A_master.clone() + A_master.requires_grad = True + C_master = C_master.clone() + C_master.requires_grad = True + B_master = torch.matmul( + A_master.view(-1, A_master.shape[-1]).transpose(0, 1), C_master.view(-1, C_master.shape[-1])) + B = torch.chunk(B_master, DEPTH, dim=0)[i] + B = torch.chunk(B, DEPTH, dim=-1)[j] + check_equal(out, B) + print_rank_0('ATB forward: pass') + + grad_shape = B_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[j] + + out.backward(grad) + + B_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] + A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j] + check_equal(A_grad, A.grad) + + C_grad = C_master.grad + C_grad = torch.chunk(C_grad, DEPTH, dim=0)[i] + C_grad = torch.chunk(C_grad, DEPTH, dim=-1)[j] + check_equal(C_grad, C.grad) + print_rank_0('ATB backward: pass') diff --git a/tests/test_layers/test_2d/checks_2d/common.py b/tests/test_layers/test_2d/checks_2d/common.py new file mode 100644 index 0000000000000000000000000000000000000000..8c855c18bc26c7e06507f2b180d87cb0cae8b67f --- /dev/null +++ b/tests/test_layers/test_2d/checks_2d/common.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch + +DEPTH = 2 +BATCH_SIZE = 8 +SEQ_LENGTH = 8 +HIDDEN_SIZE = 8 +NUM_CLASSES = 8 +VOCAB_SIZE = 16 +IMG_SIZE = 16 + + +def check_equal(A, B): + assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) diff --git a/tests/test_layers/test_2d/test_2d.py b/tests/test_layers/test_2d/test_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..bcea5ce7b25dbad1fa9b84c38651feb347ead0cf --- /dev/null +++ b/tests/test_layers/test_2d/test_2d.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import pytest +import torch +from checks_2d.check_layer_2d import ( + check_classifier_given_embed_weight, + check_classifier_no_given_weight, + check_embed, + check_layernorm, + check_linear, + check_loss, + check_patch_embed, + check_vocab_parallel_classifier_given_embed_weight, + check_vocab_parallel_classifier_no_given_weight, + check_vocab_parallel_embed, + check_vocab_parallel_loss, +) +from checks_2d.check_operation_2d import check_AB, check_ABT, check_ATB + +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_if_address_is_in_use, spawn + +CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='2d')),) + + +def check_operations(): + check_AB() + check_ABT() + check_ATB() + + +def check_layer(): + check_linear() + check_layernorm() + check_embed() + check_patch_embed() + check_vocab_parallel_embed() + check_classifier_no_given_weight() + check_vocab_parallel_classifier_no_given_weight() + check_classifier_given_embed_weight() + check_vocab_parallel_classifier_given_embed_weight() + check_loss() + check_vocab_parallel_loss() + + +def check_layer_and_operation(rank, world_size, port): + disable_existing_loggers() + launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + torch.backends.cudnn.deterministic = True + # check_operations() + check_layer() + gpc.destroy() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_2d(): + spawn(check_layer_and_operation, 4) + + +if __name__ == '__main__': + test_2d() diff --git a/tests/test_layers/test_2p5d/checks_2p5d/__init__.py b/tests/test_layers/test_2p5d/checks_2p5d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py b/tests/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py new file mode 100644 index 0000000000000000000000000000000000000000..a8f551093b1ef782f8bef64d4241c1400aa6bdde --- /dev/null +++ b/tests/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py @@ -0,0 +1,754 @@ +import torch +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.nn import (Classifier2p5D, CrossEntropyLoss2p5D, Embedding2p5D, LayerNorm2p5D, Linear2p5D, + PatchEmbedding2p5D, VanillaClassifier, VanillaPatchEmbedding, VocabParallelClassifier2p5D, + VocabParallelCrossEntropyLoss2p5D, VocabParallelEmbedding2p5D) +from colossalai.utils import get_current_device, print_rank_0 +from torch.nn import Parameter + +from .common import * + + +def check_linear(): + device = get_current_device() + dtype = torch.float32 + INPUT_SIZE = HIDDEN_SIZE + OUTPUT_SIZE = 2 * HIDDEN_SIZE + + i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + + layer = Linear2p5D(INPUT_SIZE, OUTPUT_SIZE, dtype=dtype, skip_bias_add=False) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + torch.distributed.broadcast(A_master, src=0) + A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i] + A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j] + A = A.clone() + A.requires_grad = True + + W_shape = (INPUT_SIZE, OUTPUT_SIZE) + W_master = torch.randn(W_shape, dtype=dtype, device=device) + torch.distributed.broadcast(W_master, src=0) + W = torch.chunk(W_master, TESSERACT_DIM, dim=0)[i] + W = torch.chunk(W, TESSERACT_DIM, dim=-1)[j] + W = W.clone() + W.requires_grad = True + + B_shape = (OUTPUT_SIZE) + B_master = torch.randn(B_shape, dtype=dtype, device=device) + torch.distributed.broadcast(B_master, src=0) + B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[j] + B = B.clone() + B.requires_grad = True + + layer.weight = Parameter(W) + layer.bias = Parameter(B) + out = layer(A) + bias = layer.bias + + A_master = A_master.clone() + A_master.requires_grad = True + W_master = W_master.clone() + W_master.requires_grad = True + B_master = B_master.clone() + B_master.requires_grad = True + C_master = torch.matmul(A_master, W_master) + B_master + C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] + C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] + + check_equal(out, C) + print_rank_0('linear forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] + grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i] + A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j] + check_equal(A_grad, A.grad) + + W_grad = W_master.grad + W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=0)[i] + W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j] + check_equal(W_grad, layer.weight.grad) + + B_grad = B_master.grad + B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[j] + if i == 0: + check_equal(B_grad, layer.bias.grad) + + print_rank_0('linear backward: pass') + + +def check_layernorm(): + device = get_current_device() + dtype = torch.float32 + INPUT_SIZE = HIDDEN_SIZE + EPS = 1e-12 + + i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + + layernorm = LayerNorm2p5D(INPUT_SIZE, dtype=dtype) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + torch.distributed.broadcast(A_master, src=0) + A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i] + A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j] + A = A.clone() + A.requires_grad = True + + out = layernorm(A) + + A_master = A_master.clone() + A_master.requires_grad = True + E_master = torch.sum(A_master, dim=-1, keepdim=True) + E_master /= INPUT_SIZE + V_master = torch.sum(A_master * A_master, dim=-1, keepdim=True) + V_master /= INPUT_SIZE + V_master = V_master - E_master * E_master + V_master = 1.0 / torch.sqrt(V_master + EPS) + C_master = (A_master - E_master) * V_master + C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] + C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] + + check_equal(out, C) + print_rank_0('layer norm forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] + grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] + out.backward(grad) + + C_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i] + A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j] + check_equal(A_grad, A.grad) + print_rank_0('layer norm backward: pass') + + +def check_embed(): + device = get_current_device() + dtype = torch.float32 + i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + + embed = Embedding2p5D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, TESSERACT_DIM, dim=-1)[j] + weight = torch.chunk(weight, TESSERACT_DIM, dim=-1)[i] + embed.weight.data.copy_(weight) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = embed(A) + + A_master = A_master.clone() + C_master = embed_master(A_master) + C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] + C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] + check_equal(out, C) + print_rank_0('embed forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] + grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] + grad = grad.clone() + out.backward(grad) + grad_master = grad_master.clone() + C_master.backward(grad_master) + + B_grad = embed_master.weight.grad + B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j] + B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[i] + check_equal(B_grad, embed.weight.grad) + print_rank_0('embed backward: pass') + + +def check_patch_embed(): + device = get_current_device() + dtype = torch.float32 + i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + + layer = PatchEmbedding2p5D(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype) + torch.nn.init.ones_(layer.cls_token) + torch.nn.init.ones_(layer.pos_embed) + layer = layer.to(device) + + layer_master = VanillaPatchEmbedding(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype) + torch.nn.init.ones_(layer_master.cls_token) + torch.nn.init.ones_(layer_master.pos_embed) + layer_master = layer_master.to(device) + + proj_weight_master = layer_master.weight.data + torch.distributed.broadcast(proj_weight_master, src=0) + proj_weight = torch.chunk(proj_weight_master, TESSERACT_DIM, dim=0)[j] + proj_weight = torch.chunk(proj_weight, TESSERACT_DIM, dim=0)[i] + layer.weight.data.copy_(proj_weight) + proj_bias_master = layer_master.bias.data + torch.distributed.broadcast(proj_bias_master, src=0) + proj_bias = torch.chunk(proj_bias_master, TESSERACT_DIM, dim=0)[j] + proj_bias = torch.chunk(proj_bias, TESSERACT_DIM, dim=0)[i] + layer.bias.data.copy_(proj_bias) + + A_shape = (BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = layer(A) + + A_master = A_master.clone() + C_master = layer_master(A_master) + C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] + C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] + check_equal(out, C) + print_rank_0('patch embed forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] + grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + cls_grad_master = layer_master.cls_token.grad + cls_grad = torch.chunk(cls_grad_master, TESSERACT_DIM, dim=-1)[j] + cls_grad = torch.chunk(cls_grad, TESSERACT_DIM, dim=-1)[i] + check_equal(cls_grad, layer.cls_token.grad) + + pos_grad_master = layer_master.pos_embed.grad + pos_grad = torch.chunk(pos_grad_master, TESSERACT_DIM, dim=-1)[j] + pos_grad = torch.chunk(pos_grad, TESSERACT_DIM, dim=-1)[i] + check_equal(pos_grad, layer.pos_embed.grad) + + B_grad = layer_master.weight.grad + B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[j] + B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[i] + check_equal(B_grad, layer.weight.grad) + + bias_grad = layer_master.bias.grad + bias_grad = torch.chunk(bias_grad, TESSERACT_DIM)[j] + bias_grad = torch.chunk(bias_grad, TESSERACT_DIM)[i] + check_equal(bias_grad, layer.bias.grad) + print_rank_0('patch embed backward: pass') + + +def check_vocab_parallel_embed(): + device = get_current_device() + dtype = torch.float32 + i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + + embed = VocabParallelEmbedding2p5D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, TESSERACT_DIM, dim=-1)[j] + weight = torch.chunk(weight, TESSERACT_DIM, dim=0)[i] + embed.weight.data.copy_(weight) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = embed(A) + + A_master = A_master.clone() + C_master = embed_master(A_master) + C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] + C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] + check_equal(out, C) + print_rank_0('vocab parallel embed forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] + grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] + grad = grad.clone() + out.backward(grad) + grad_master = grad_master.clone() + C_master.backward(grad_master) + + B_grad = embed_master.weight.grad + B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j] + B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[i] + check_equal(B_grad, embed.weight.grad) + print_rank_0('vocab parallel embed backward: pass') + + +def check_classifier_no_given_weight(): + device = get_current_device() + dtype = torch.float32 + INPUT_SIZE = HIDDEN_SIZE + OUTPUT_SIZE = NUM_CLASSES + + j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + + layer = Classifier2p5D(INPUT_SIZE, OUTPUT_SIZE) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) + A_master = torch.randint(5, A_shape, dtype=dtype, device=device) + torch.distributed.broadcast(A_master, src=0) + A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i] + A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j] + A = A.clone() + A.requires_grad = True + + W_shape = (OUTPUT_SIZE, INPUT_SIZE) + W_master = torch.randint(5, W_shape, dtype=dtype, device=device) + torch.distributed.broadcast(W_master, src=0) + # W = torch.chunk(W_master, TESSERACT_DIM, dim=-1)[j] + W = torch.chunk(W_master, TESSERACT_DIM, dim=-1)[j] + W = torch.chunk(W, TESSERACT_DIM, dim=-1)[i] + W = W.clone() + layer.weight.data.copy_(W) + # W.requires_grad = True + + B_shape = (OUTPUT_SIZE, ) + B_master = torch.randint(5, B_shape, dtype=dtype, device=device) + torch.distributed.broadcast(B_master, src=0) + # B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[j] + B = B_master.clone() + layer.bias.data.copy_(B) + + out = layer(A) + + A_master = A_master.clone() + A_master.requires_grad = True + W_master = W_master.clone() + W_master.requires_grad = True + B_master = B_master.clone() + B_master.requires_grad = True + C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master + C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] + # C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] + + check_equal(out, C) + print_rank_0('classifier (no given weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] + # grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i] + A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j] + check_equal(A_grad, A.grad) + + W_grad = W_master.grad + W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j] + W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[i] + check_equal(W_grad, layer.weight.grad) + + B_grad = B_master.grad + # B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[j] + # if i == 0: + check_equal(B_grad, layer.bias.grad) + + print_rank_0('classifier (no given weight) backward: pass') + + +def check_vocab_parallel_classifier_no_given_weight(): + device = get_current_device() + dtype = torch.float32 + i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + + layer = VocabParallelClassifier2p5D(HIDDEN_SIZE, VOCAB_SIZE, bias=True) + layer = layer.to(dtype).to(device) + + layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, bias=True) + layer_master = layer_master.to(dtype).to(device) + + weight_master = layer_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, TESSERACT_DIM, dim=0)[i] + weight = torch.chunk(weight, TESSERACT_DIM, dim=-1)[j] + layer.weight.data.copy_(weight) + bias_master = layer_master.bias.data + torch.distributed.broadcast(bias_master, src=0) + bias = torch.chunk(bias_master, TESSERACT_DIM)[j] + layer.bias.data.copy_(bias) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + torch.distributed.broadcast(A_master, src=0) + A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i] + A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j] + A = A.clone() + A.requires_grad = True + out = layer(A) + + A_master = A_master.clone() + A_master.requires_grad = True + C_master = layer_master(A_master) + C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] + C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] + check_equal(out, C) + print_rank_0('vocab parallel classifier (no given weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] + grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i] + A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j] + check_equal(A_grad, A.grad) + + W_grad = layer_master.weight.grad + W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=0)[i] + W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j] + check_equal(W_grad, layer.weight.grad) + + B_grad = layer_master.bias.grad + B_grad = torch.chunk(B_grad, TESSERACT_DIM)[j] + if i == 0: + check_equal(B_grad, layer.bias.grad) + print_rank_0('vocab parallel classifier (no given weight) backward: pass') + + +def check_classifier_given_embed_weight(): + device = get_current_device() + dtype = torch.float32 + i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + + embed = Embedding2p5D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, TESSERACT_DIM, dim=-1)[j] + weight = torch.chunk(weight, TESSERACT_DIM, dim=-1)[i] + embed.weight.data.copy_(weight) + + layer = Classifier2p5D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False) + layer = layer.to(dtype).to(device) + layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False) + layer_master = layer_master.to(dtype).to(device) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = layer(embed(A)) + + A_master = A_master.clone() + C_master = layer_master(embed_master(A_master)) + C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] + check_equal(out, C) + print_rank_0('classifier (given embed weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + W_grad = embed_master.weight.grad + W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j] + W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[i] + check_equal(W_grad, embed.weight.grad) + print_rank_0('classifier (given embed weight) backward: pass') + + +def check_vocab_parallel_classifier_given_embed_weight(): + device = get_current_device() + dtype = torch.float32 + i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + + embed = VocabParallelEmbedding2p5D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, TESSERACT_DIM, dim=-1)[j] + weight = torch.chunk(weight, TESSERACT_DIM, dim=0)[i] + embed.weight.data.copy_(weight) + + layer = VocabParallelClassifier2p5D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False) + layer = layer.to(dtype).to(device) + layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False) + layer_master = layer_master.to(dtype).to(device) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = layer(embed(A)) + + A_master = A_master.clone() + C_master = layer_master(embed_master(A_master)) + C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] + C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] + check_equal(out, C) + print_rank_0('vocab parallel classifier (given embed weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] + grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + W_grad = embed_master.weight.grad + W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j] + W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=0)[i] + check_equal(W_grad, embed.weight.grad) + print_rank_0('vocab parallel classifier (given embed weight) backward: pass') + + +def check_loss(): + device = get_current_device() + dtype = torch.float32 + i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + + criterion = CrossEntropyLoss2p5D() + criterion_master = torch.nn.CrossEntropyLoss() + + out_shape = (BATCH_SIZE, NUM_CLASSES) + out_master = torch.randn(out_shape, dtype=dtype, device=device) + target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device) + torch.distributed.broadcast(out_master, src=0) + torch.distributed.broadcast(target_master, src=0) + out = torch.chunk(out_master, TESSERACT_DIM, dim=0)[i] + out = out.clone() + out.requires_grad = True + loss = criterion(out, target_master) + + out_master = out_master.clone() + out_master.requires_grad = True + loss_master = criterion_master(out_master, target_master) + check_equal(loss, loss_master) + print_rank_0('cross entropy loss forward: pass') + + loss.backward() + loss_master.backward() + + out_grad = out_master.grad + out_grad = torch.chunk(out_grad, TESSERACT_DIM, dim=0)[i] + check_equal(out_grad, out.grad) + print_rank_0('cross entropy loss backward: pass') + + +def check_vocab_parallel_loss(): + device = get_current_device() + dtype = torch.float32 + i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + + criterion = VocabParallelCrossEntropyLoss2p5D() + criterion_master = torch.nn.CrossEntropyLoss() + + out_shape = (BATCH_SIZE, NUM_CLASSES) + out_master = torch.randn(out_shape, dtype=dtype, device=device) + target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device) + torch.distributed.broadcast(out_master, src=0) + torch.distributed.broadcast(target_master, src=0) + out = torch.chunk(out_master, TESSERACT_DIM, dim=0)[i] + out = torch.chunk(out, TESSERACT_DIM, dim=-1)[j] + out = out.clone() + out.requires_grad = True + loss = criterion(out, target_master) + + out_master = out_master.clone() + out_master.requires_grad = True + loss_master = criterion_master(out_master, target_master) + check_equal(loss, loss_master) + print_rank_0('vocab parallel cross entropy loss forward: pass') + + loss.backward() + loss_master.backward() + + out_grad = out_master.grad + out_grad = torch.chunk(out_grad, TESSERACT_DIM, dim=0)[i] + out_grad = torch.chunk(out_grad, TESSERACT_DIM, dim=-1)[j] + check_equal(out_grad, out.grad) + print_rank_0('vocab parallel cross entropy loss backward: pass') + + +# def check_attention(): +# device = get_current_device() +# dtype = torch.float32 +# INPUT_SIZE = HIDDEN_SIZE +# NUM_ATTENTION_HEADS = 2 + +# i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) +# j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) +# k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + +# layer = TransformerSelfAttention2p5D( +# HIDDEN_SIZE, NUM_ATTENTION_HEADS, +# attention_dropout_prob=0.5, +# hidden_dropout_prob=0.5, +# dtype=dtype, +# ) + +# A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) +# A_master = torch.randn(A_shape, dtype=dtype, device=device) +# torch.distributed.broadcast(A_master, src=0) +# A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i] +# A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j] +# A = A.clone() +# A.requires_grad = True + +# mask_shape = (BATCH_SIZE // TESSERACT_DIM, NUM_ATTENTION_HEADS // TESSERACT_DIM, SEQ_LENGTH, SEQ_LENGTH) +# attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device) + +# out = layer(A, attention_mask) +# assert out.shape == (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, INPUT_SIZE // TESSERACT_DIM) +# print_rank_0('self attention forward: pass') + +# grad_shape = out.shape +# grad = torch.randn(grad_shape, dtype=dtype, device=device) + +# out.backward(grad) +# assert A.grad.shape == A.shape +# print_rank_0('self attention backward: pass') + +# def check_mlp(): +# device = get_current_device() +# dtype = torch.float32 +# INPUT_SIZE = HIDDEN_SIZE + +# i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) +# j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) +# k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + +# layer = TransformerMLP2p5D( +# HIDDEN_SIZE, +# mlp_ratio=1, +# dropout_prob=0.5, +# act_func='gelu', +# dtype=dtype, +# ) + +# A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) +# A_master = torch.randn(A_shape, dtype=dtype, device=device) +# torch.distributed.broadcast(A_master, src=0) +# A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i] +# A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j] +# A = A.clone() +# A.requires_grad = True + +# out = layer(A) +# assert out.shape == (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, INPUT_SIZE // TESSERACT_DIM) +# print_rank_0('mlp forward: pass') + +# grad_shape = out.shape +# grad = torch.randn(grad_shape, dtype=dtype, device=device) + +# out.backward(grad) +# assert A.grad.shape == A.shape +# print_rank_0('mlp backward: pass') + +# def check_transformerlayer(): +# device = get_current_device() +# dtype = torch.float32 +# INPUT_SIZE = HIDDEN_SIZE +# NUM_ATTENTION_HEADS = 2 + +# i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) +# j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) +# k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + +# layer = TransformerLayer2p5D( +# HIDDEN_SIZE, +# NUM_ATTENTION_HEADS, +# act_func='gelu', +# attention_dropout_prob=0.5, +# hidden_dropout_prob=0.5, +# dtype=dtype, +# ) + +# A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) +# A_master = torch.randn(A_shape, dtype=dtype, device=device) +# torch.distributed.broadcast(A_master, src=0) +# A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i] +# A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j] +# A = A.clone() +# A.requires_grad = True + +# mask_shape = (BATCH_SIZE // TESSERACT_DIM, NUM_ATTENTION_HEADS // TESSERACT_DIM, SEQ_LENGTH, SEQ_LENGTH) +# attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device) + +# out = layer(A, attention_mask) +# assert out.shape == (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, INPUT_SIZE // TESSERACT_DIM) +# print_rank_0('transformerlayer forward: pass') + +# grad_shape = out.shape +# grad = torch.randn(grad_shape, dtype=dtype, device=device) + +# out.backward(grad) +# assert A.grad.shape == A.shape +# print_rank_0('transformerlayer backward: pass') diff --git a/tests/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py b/tests/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py new file mode 100644 index 0000000000000000000000000000000000000000..d0c3b02fccba589507bf8e2af25846767636c734 --- /dev/null +++ b/tests/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py @@ -0,0 +1,216 @@ +import torch + +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.nn.layer.parallel_2p5d._operation import Matmul_AB_2p5D, Matmul_ABT_2p5D, \ + Matmul_ATB_2p5D +from colossalai.utils import get_current_device +from colossalai.utils import print_rank_0 +from .common import * + + +def check_AB(): + data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) + pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank( + ParallelMode.PIPELINE) + pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( + ParallelMode.PIPELINE) + tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) + + dtype = torch.float + i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device()) + torch.distributed.broadcast(A_master, src=0) + A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i] + A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j] + A = A.clone() + A.requires_grad = True + + B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE) + B_master = torch.randn(B_shape, dtype=dtype, device=get_current_device()) + torch.distributed.broadcast(B_master, src=0) + B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[i] + B = torch.chunk(B, TESSERACT_DIM, dim=-1)[j] + B = B.clone() + B.requires_grad = True + + out_shape = (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, 4 * HIDDEN_SIZE // TESSERACT_DIM) + out = Matmul_AB_2p5D.apply(A, B, TESSERACT_DIM, out_shape, i, j, k, ParallelMode.PARALLEL_2P5D_ROW, + ParallelMode.PARALLEL_2P5D_COL, data_parallel_rank, pipeline_parallel_rank, + pipeline_parallel_size, tensor_parallel_size) + + C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE) + A_master = A_master.clone() + A_master.requires_grad = True + B_master = B_master.clone() + B_master.requires_grad = True + C_master = torch.matmul(A_master, B_master) + C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] + C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] + # check forward correctness + check_equal(out, C) + print_rank_0('AB forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] + grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] + + out.backward(grad) + + C_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i] + A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j] + # check backward correctness + check_equal(A_grad, A.grad) + + B_grad = B_master.grad + B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[i] + B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j] + # check backward correctness + check_equal(B_grad, B.grad) + print_rank_0('AB backward: pass') + + +def check_ABT(): + data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) + pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank( + ParallelMode.PIPELINE) + pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( + ParallelMode.PIPELINE) + tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) + + dtype = torch.float + device = get_current_device() + + i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + + C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE) + C_master = torch.randn(C_shape, dtype=dtype, device=device) + torch.distributed.broadcast(C_master, src=0) + C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] + C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] + C = C.clone() + C.requires_grad = True + + B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE) + B_master = torch.randn(B_shape, dtype=dtype, device=device) + torch.distributed.broadcast(B_master, src=0) + B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[i] + B = torch.chunk(B, TESSERACT_DIM, dim=-1)[j] + B = B.clone() + B.requires_grad = True + + out = Matmul_ABT_2p5D.apply(C, B, TESSERACT_DIM, + (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, HIDDEN_SIZE // TESSERACT_DIM), i, j, k, + ParallelMode.PARALLEL_2P5D_ROW, ParallelMode.PARALLEL_2P5D_COL, data_parallel_rank, + pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + C_master = C_master.clone() + C_master.requires_grad = True + B_master = B_master.clone() + B_master.requires_grad = True + A_master = torch.matmul(C_master, B_master.transpose(0, 1)) + A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i] + A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j] + check_equal(out, A) + print_rank_0('ABT forward: pass') + + grad_shape = A_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] + grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] + + # backward + out.backward(grad) + + A_master.backward(grad_master) + C_grad = C_master.grad + C_grad = torch.chunk(C_grad, TESSERACT_DIM, dim=0)[i] + C_grad = torch.chunk(C_grad, TESSERACT_DIM, dim=-1)[j] + check_equal(C_grad, C.grad) + + B_grad = B_master.grad + B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[i] + B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j] + check_equal(B_grad, B.grad) + print_rank_0('ABT backward: pass') + + +def check_ATB(): + data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) + pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank( + ParallelMode.PIPELINE) + pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( + ParallelMode.PIPELINE) + tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) + + device = get_current_device() + dtype = torch.float + + i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + torch.distributed.broadcast(A_master, src=0) + A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i] + A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j] + A = A.clone() + A.requires_grad = True + + C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE) + C_master = torch.randn(C_shape, dtype=dtype, device=device) + torch.distributed.broadcast(C_master, src=0) + C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] + C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] + C = C.clone() + C.requires_grad = True + + out = Matmul_ATB_2p5D.apply(A, C, TESSERACT_DIM, (HIDDEN_SIZE // TESSERACT_DIM, 4 * HIDDEN_SIZE // TESSERACT_DIM), + i, j, k, ParallelMode.PARALLEL_2P5D_ROW, ParallelMode.PARALLEL_2P5D_COL, + data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, + tensor_parallel_size) + + B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE) + A_master = A_master.clone() + A_master.requires_grad = True + C_master = C_master.clone() + C_master.requires_grad = True + B_master = torch.matmul( + A_master.view(-1, A_master.shape[-1]).transpose(0, 1), C_master.view(-1, C_master.shape[-1])) + B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[i] + B = torch.chunk(B, TESSERACT_DIM, dim=-1)[j] + check_equal(out, B) + print_rank_0('ATB forward: pass') + + grad_shape = B_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] + grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] + + out.backward(grad) + + B_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i] + A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j] + check_equal(A_grad, A.grad) + + C_grad = C_master.grad + C_grad = torch.chunk(C_grad, TESSERACT_DIM, dim=0)[i] + C_grad = torch.chunk(C_grad, TESSERACT_DIM, dim=-1)[j] + check_equal(C_grad, C.grad) + print_rank_0('ATB backward: pass') diff --git a/tests/test_layers/test_2p5d/checks_2p5d/common.py b/tests/test_layers/test_2p5d/checks_2p5d/common.py new file mode 100644 index 0000000000000000000000000000000000000000..aff85f109666d7cdf9e65173eda851368c39694c --- /dev/null +++ b/tests/test_layers/test_2p5d/checks_2p5d/common.py @@ -0,0 +1,14 @@ +import torch + +TESSERACT_DIM = 2 +TESSERACT_DEP = 2 +BATCH_SIZE = 8 +SEQ_LENGTH = 8 +HIDDEN_SIZE = 8 +NUM_CLASSES = 8 +VOCAB_SIZE = 16 +IMG_SIZE = 16 + + +def check_equal(A, B): + assert torch.allclose(A, B, rtol=1e-5, atol=1e-2) \ No newline at end of file diff --git a/tests/test_layers/test_2p5d/test_2p5d.py b/tests/test_layers/test_2p5d/test_2p5d.py new file mode 100644 index 0000000000000000000000000000000000000000..373d834d003286a0f40cd8e4063fa1e4782cdb9c --- /dev/null +++ b/tests/test_layers/test_2p5d/test_2p5d.py @@ -0,0 +1,57 @@ +import pytest +import torch +from checks_2p5d.check_layer_2p5d import * +from checks_2p5d.check_operation_2p5d import check_AB, check_ABT, check_ATB + +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_if_address_is_in_use, spawn + +CONFIG = dict(parallel=dict( + pipeline=dict(size=1), + tensor=dict(size=4, mode='2.5d', depth=1), +),) + + +def check_operations(): + check_AB() + check_ABT() + check_ATB() + + +def check_layer(): + check_linear() + check_layernorm() + check_embed() + check_patch_embed() + check_vocab_parallel_embed() + check_classifier_no_given_weight() + check_vocab_parallel_classifier_no_given_weight() + check_classifier_given_embed_weight() + check_vocab_parallel_classifier_given_embed_weight() + check_loss() + check_vocab_parallel_loss() + + +def check_layer_and_operation(rank, world_size, port): + disable_existing_loggers() + launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + torch.backends.cudnn.deterministic = True + check_operations() + check_layer() + gpc.destroy() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_2p5d(): + spawn(check_layer_and_operation, 4) + + +if __name__ == '__main__': + test_2p5d() diff --git a/tests/test_layers/test_3d/checks_3d/__init__.py b/tests/test_layers/test_3d/checks_3d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/test_layers/test_3d/checks_3d/check_layer_3d.py b/tests/test_layers/test_3d/checks_3d/check_layer_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..e946a1f5912d69cebcbe64bc3c317fc2c9774207 --- /dev/null +++ b/tests/test_layers/test_3d/checks_3d/check_layer_3d.py @@ -0,0 +1,875 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import time + +import torch + +from colossalai.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D +from colossalai.core import global_context +from colossalai.logging import get_dist_logger +from colossalai.nn import ( + Classifier3D, + CrossEntropyLoss3D, + Embedding3D, + LayerNorm3D, + Linear3D, + PatchEmbedding3D, + VanillaClassifier, + VanillaPatchEmbedding, + VocabParallelClassifier3D, + VocabParallelCrossEntropyLoss3D, + VocabParallelEmbedding3D, +) +from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env +from colossalai.utils import get_current_device, print_rank_0 + +from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal + + +def check_linear(): + rank = torch.distributed.get_rank() + logger = get_dist_logger() + device = get_current_device() + INPUT_SIZE = HIDDEN_SIZE + OUTPUT_SIZE = 2 * HIDDEN_SIZE + + input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + + j = global_context.get_local_rank(input_parallel_mode) + i = global_context.get_local_rank(weight_parallel_mode) + k = global_context.get_local_rank(output_parallel_mode) + + layer = Linear3D(INPUT_SIZE, OUTPUT_SIZE, bias=True) + layer = layer.to(device) + layer_master = torch.nn.Linear(INPUT_SIZE, OUTPUT_SIZE) + layer_master = layer_master.to(device) + + weight_master = layer_master.weight.data.transpose(0, 1).contiguous() + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=0)[k] + weight = torch.chunk(weight, DEPTH, dim=-1)[j] + weight = torch.chunk(weight, DEPTH, dim=0)[i] + layer.weight.data.copy_(weight) + bias_master = layer_master.bias.data + torch.distributed.broadcast(bias_master, src=0) + bias = torch.chunk(bias_master, DEPTH)[j] + layer.bias.data.copy_(bias) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) + A_master = torch.randn(A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = torch.chunk(A_master, DEPTH, dim=0)[i] + A = torch.chunk(A, DEPTH, dim=-1)[k] + A = torch.chunk(A, DEPTH, dim=0)[j] + A = A.clone() + A.requires_grad = True + + fwd_start = time.time() + out = layer(A) + torch.cuda.synchronize() + fwd_end = time.time() + print_rank_0( + 'linear forward: {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger) + A_master = A_master.clone() + A_master.requires_grad = True + C_master = layer_master(A_master) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[j] + C = torch.chunk(C, DEPTH, dim=0)[k] + logger.info('Rank {} linear forward: {}'.format(rank, check_equal(out, C))) + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, device=get_current_device()) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[j] + grad = torch.chunk(grad, DEPTH, dim=0)[k] + + bwd_start = time.time() + out.backward(grad) + torch.cuda.synchronize() + bwd_end = time.time() + print_rank_0('linear backward: {:.3f} s'.format(bwd_end - bwd_start), logger) + + C_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] + A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k] + A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j] + logger.info('Rank {} linear backward (input_grad): {}'.format(rank, check_equal(A_grad, A.grad))) + + B_grad = layer_master.weight.grad.transpose(0, 1) + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k] + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j] + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] + logger.info('Rank {} linear backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad))) + + bias_grad = layer_master.bias.grad + bias_grad = torch.chunk(bias_grad, DEPTH)[j] + logger.info('Rank {} linear backward (bias_grad): {}'.format(rank, check_equal(bias_grad, layer.bias.grad))) + + return fwd_end - fwd_start, bwd_end - bwd_start + + +def check_layernorm(): + rank = torch.distributed.get_rank() + logger = get_dist_logger() + device = get_current_device() + INPUT_SIZE = HIDDEN_SIZE + + input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + + j = global_context.get_local_rank(input_parallel_mode) + i = global_context.get_local_rank(weight_parallel_mode) + k = global_context.get_local_rank(output_parallel_mode) + + norm = LayerNorm3D(INPUT_SIZE, eps=1e-6) + norm = norm.to(device) + norm_master = torch.nn.LayerNorm(INPUT_SIZE, eps=1e-6) + norm_master = norm_master.to(device) + + weight_master = norm_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH)[k] + norm.weight.data.copy_(weight) + bias_master = norm_master.bias.data + torch.distributed.broadcast(bias_master, src=0) + bias = torch.chunk(bias_master, DEPTH)[k] + norm.bias.data.copy_(bias) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) + A_master = torch.randn(A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = torch.chunk(A_master, DEPTH, dim=0)[i] + A = torch.chunk(A, DEPTH, dim=-1)[k] + A = torch.chunk(A, DEPTH, dim=0)[j] + A = A.clone() + A.requires_grad = True + + fwd_start = time.time() + out = norm(A) + torch.cuda.synchronize() + fwd_end = time.time() + print_rank_0( + 'layer norm forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape), + fwd_end - fwd_start), logger) + + A_master = A_master.clone() + A_master.requires_grad = True + C_master = norm_master(A_master) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[k] + C = torch.chunk(C, DEPTH, dim=0)[j] + logger.info('Rank {} layernorm forward: {}'.format(rank, check_equal(out, C))) + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[k] + grad = torch.chunk(grad, DEPTH, dim=0)[j] + + bwd_start = time.time() + out.backward(grad) + torch.cuda.synchronize() + bwd_end = time.time() + print_rank_0('layer norm backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger) + + C_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] + A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k] + A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j] + logger.info('Rank {} layernorm backward (input_grad): {}'.format(rank, check_equal(A_grad, A.grad))) + + bias_grad = norm_master.weight.grad + bias_grad = torch.chunk(bias_grad, DEPTH)[k] + logger.info('Rank {} layernorm backward (weight_grad): {}'.format(rank, check_equal(bias_grad, norm.weight.grad))) + + bias_grad = norm_master.bias.grad + bias_grad = torch.chunk(bias_grad, DEPTH)[k] + logger.info('Rank {} layernorm backward (bias_grad): {}'.format(rank, check_equal(bias_grad, norm.bias.grad))) + + return fwd_end - fwd_start, bwd_end - bwd_start + + +def check_classifier_no_given_weight(): + rank = torch.distributed.get_rank() + logger = get_dist_logger() + device = get_current_device() + INPUT_SIZE = HIDDEN_SIZE + + input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + + j = global_context.get_local_rank(input_parallel_mode) + i = global_context.get_local_rank(weight_parallel_mode) + k = global_context.get_local_rank(output_parallel_mode) + + layer = Classifier3D(INPUT_SIZE, NUM_CLASSES, bias=True) + layer = layer.to(device) + + layer_master = VanillaClassifier(INPUT_SIZE, NUM_CLASSES, bias=True) + layer_master = layer_master.to(device) + + weight_master = layer_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=-1)[k] + layer.weight.data.copy_(weight) + bias_master = layer_master.bias.data + torch.distributed.broadcast(bias_master, src=0) + layer.bias.data.copy_(bias_master) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) + A_master = torch.randn(A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = torch.chunk(A_master, DEPTH, dim=0)[i] + A = torch.chunk(A, DEPTH, dim=-1)[k] + A = torch.chunk(A, DEPTH, dim=0)[j] + A = A.clone() + A.requires_grad = True + + fwd_start = time.time() + out = layer(A) + torch.cuda.synchronize() + fwd_end = time.time() + print_rank_0( + 'classifier (no given weight) forward: pass | {0} --> {1} | {2:.3f} s'.format( + tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger) + A_master = A_master.clone() + A_master.requires_grad = True + C_master = layer_master(A_master) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=0)[j] + logger.info('Rank {} classifier (no given weight) forward: {}'.format(rank, check_equal(out, C))) + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, device=get_current_device()) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=0)[j] + grad = grad.clone() + + bwd_start = time.time() + out.backward(grad) + torch.cuda.synchronize() + bwd_end = time.time() + print_rank_0('classifier (no given weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] + A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k] + A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j] + logger.info('Rank {} classifier (no given weight) backward (input_grad): {}'.format( + rank, check_equal(A_grad, A.grad))) + + B_grad = layer_master.weight.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] + if j == k: + logger.info('Rank {} classifier (no given weight) backward (weight_grad): {}'.format( + rank, check_equal(B_grad, layer.weight.grad))) + else: + logger.info('Rank {} classifier (no given weight) backward (weight_grad): {}'.format( + rank, layer.weight.grad is None)) + + bias_grad = layer_master.bias.grad + logger.info('Rank {} classifier (no given weight) backward (bias_grad): {}'.format( + rank, check_equal(bias_grad, layer.bias.grad))) + + return fwd_end - fwd_start, bwd_end - bwd_start + + +def check_vocab_parallel_classifier_no_given_weight(): + rank = torch.distributed.get_rank() + logger = get_dist_logger() + device = get_current_device() + INPUT_SIZE = HIDDEN_SIZE + + input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + + j = global_context.get_local_rank(input_parallel_mode) + i = global_context.get_local_rank(weight_parallel_mode) + k = global_context.get_local_rank(output_parallel_mode) + + layer = VocabParallelClassifier3D(INPUT_SIZE, VOCAB_SIZE, bias=True) + layer = layer.to(device) + + layer_master = VanillaClassifier(INPUT_SIZE, VOCAB_SIZE, bias=True) + layer_master = layer_master.to(device) + + weight_master = layer_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=0)[j] + weight = torch.chunk(weight, DEPTH, dim=0)[i] + weight = torch.chunk(weight, DEPTH, dim=-1)[k] + layer.weight.data.copy_(weight) + bias_master = layer_master.bias.data + torch.distributed.broadcast(bias_master, src=0) + bias = torch.chunk(bias_master, DEPTH)[j] + layer.bias.data.copy_(bias) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) + A_master = torch.randn(A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = torch.chunk(A_master, DEPTH, dim=0)[i] + A = torch.chunk(A, DEPTH, dim=-1)[k] + A = torch.chunk(A, DEPTH, dim=0)[j] + A = A.clone() + A.requires_grad = True + + fwd_start = time.time() + out = layer(A) + torch.cuda.synchronize() + fwd_end = time.time() + print_rank_0( + 'vocab parallel classifier (no given weight) forward: pass | {0} --> {1} | {2:.3f} s'.format( + tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger) + A_master = A_master.clone() + A_master.requires_grad = True + C_master = layer_master(A_master) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[j] + C = torch.chunk(C, DEPTH, dim=0)[k] + logger.info('Rank {} vocab parallel classifier (no given weight) forward: {}'.format(rank, check_equal(out, C))) + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[j] + grad = torch.chunk(grad, DEPTH, dim=0)[k] + grad = grad.clone() + + bwd_start = time.time() + out.backward(grad) + torch.cuda.synchronize() + bwd_end = time.time() + print_rank_0('vocab parallel classifier (no given weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start), + logger) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] + A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k] + A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j] + logger.info('Rank {} vocab parallel classifier (no given weight) backward (input_grad): {}'.format( + rank, check_equal(A_grad, A.grad))) + + B_grad = layer_master.weight.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j] + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] + logger.info('Rank {} vocab parallel classifier (no given weight) backward (weight_grad): {}'.format( + rank, check_equal(B_grad, layer.weight.grad))) + + bias_grad = layer_master.bias.grad + bias_grad = torch.chunk(bias_grad, DEPTH)[j] + logger.info('Rank {} vocab parallel classifier (no given weight) backward (bias_grad): {}'.format( + rank, check_equal(bias_grad, layer.bias.grad))) + + return fwd_end - fwd_start, bwd_end - bwd_start + + +def check_classifier_given_embed_weight(): + rank = torch.distributed.get_rank() + logger = get_dist_logger() + device = get_current_device() + dtype = torch.float32 + + input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + + j = global_context.get_local_rank(input_parallel_mode) + i = global_context.get_local_rank(weight_parallel_mode) + k = global_context.get_local_rank(output_parallel_mode) + + embed = Embedding3D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=-1)[k] + embed.weight.data.copy_(weight) + + layer = Classifier3D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False) + layer = layer.to(dtype).to(device) + + layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False) + layer_master = layer_master.to(dtype).to(device) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + + fwd_start = time.time() + out = layer(embed(A)) + torch.cuda.synchronize() + fwd_end = time.time() + print_rank_0( + 'classifier (given embed weight) forward: pass | {0} --> {1} | {2:.3f} s'.format( + tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger) + A_master = A_master.clone() + C_master = layer_master(embed_master(A_master)) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=0)[j] + logger.info('Rank {} classifier (given embed weight) forward: {}'.format(rank, check_equal(out, C))) + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=0)[j] + grad = grad.clone() + + bwd_start = time.time() + out.backward(grad) + torch.cuda.synchronize() + bwd_end = time.time() + print_rank_0('classifier (given embed weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + B_grad = embed_master.weight.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] + if j == k: + logger.info('Rank {} classifier (given embed weight) backward (weight_grad): {}'.format( + rank, check_equal(B_grad, embed.weight.grad))) + else: + logger.info('Rank {} classifier (given embed weight) backward (weight_grad): {}'.format( + rank, embed.weight.grad is None)) + + return fwd_end - fwd_start, bwd_end - bwd_start + + +def check_vocab_parallel_classifier_given_embed_weight(): + rank = torch.distributed.get_rank() + logger = get_dist_logger() + device = get_current_device() + + input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + + j = global_context.get_local_rank(input_parallel_mode) + i = global_context.get_local_rank(weight_parallel_mode) + k = global_context.get_local_rank(output_parallel_mode) + + embed = VocabParallelEmbedding3D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(device) + + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=0)[j] + weight = torch.chunk(weight, DEPTH, dim=0)[i] + weight = torch.chunk(weight, DEPTH, dim=-1)[k] + embed.weight.data.copy_(weight) + + layer = VocabParallelClassifier3D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False) + layer = layer.to(device) + + layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False) + layer_master = layer_master.to(device) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + + fwd_start = time.time() + out = layer(embed(A)) + torch.cuda.synchronize() + fwd_end = time.time() + print_rank_0( + 'vocab parallel classifier (given embed weight) forward: pass | {0} --> {1} | {2:.3f} s'.format( + tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger) + A_master = A_master.clone() + C_master = layer_master(embed_master(A_master)) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[j] + C = torch.chunk(C, DEPTH, dim=0)[k] + logger.info('Rank {} vocab parallel classifier (given embed weight) forward: {}'.format(rank, check_equal(out, C))) + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[j] + grad = torch.chunk(grad, DEPTH, dim=0)[k] + grad = grad.clone() + + bwd_start = time.time() + out.backward(grad) + torch.cuda.synchronize() + bwd_end = time.time() + print_rank_0('vocab parallel classifier (given embed weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start), + logger) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + B_grad = embed_master.weight.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j] + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] + logger.info('Rank {} vocab parallel embed backward (weight_grad): {}'.format(rank, + check_equal(B_grad, + embed.weight.grad))) + + return fwd_end - fwd_start, bwd_end - bwd_start + + +def check_patch_embed(): + rank = torch.distributed.get_rank() + device = get_current_device() + logger = get_dist_logger() + dtype = torch.float32 + + input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + + j = global_context.get_local_rank(input_parallel_mode) + i = global_context.get_local_rank(weight_parallel_mode) + k = global_context.get_local_rank(output_parallel_mode) + + layer = PatchEmbedding3D(IMG_SIZE, 4, 3, HIDDEN_SIZE) + torch.nn.init.ones_(layer.cls_token) + torch.nn.init.ones_(layer.pos_embed) + layer = layer.to(device) + + layer_master = VanillaPatchEmbedding(IMG_SIZE, 4, 3, HIDDEN_SIZE) + torch.nn.init.ones_(layer_master.cls_token) + torch.nn.init.ones_(layer_master.pos_embed) + layer_master = layer_master.to(device) + + proj_weight_master = layer_master.weight.data + torch.distributed.broadcast(proj_weight_master, src=0) + proj_weight = torch.chunk(proj_weight_master, DEPTH, dim=0)[k] + layer.weight.data.copy_(proj_weight) + proj_bias_master = layer_master.bias.data + torch.distributed.broadcast(proj_bias_master, src=0) + proj_bias = torch.chunk(proj_bias_master, DEPTH)[k] + layer.bias.data.copy_(proj_bias) + + A_shape = (BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE) + A_master = torch.randn(A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + + fwd_start = time.time() + out = layer(A) + torch.cuda.synchronize() + fwd_end = time.time() + print_rank_0( + 'patch embed forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape), + fwd_end - fwd_start), logger) + + A_master = A_master.clone() + C_master = layer_master(A_master) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[k] + C = torch.chunk(C, DEPTH, dim=0)[j] + logger.info('Rank {} patch embed forward: {}'.format(rank, check_equal(out, C))) + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[k] + grad = torch.chunk(grad, DEPTH, dim=0)[j] + grad = grad.clone() + + bwd_start = time.time() + out.backward(grad) + torch.cuda.synchronize() + bwd_end = time.time() + print_rank_0('patch embed backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + cls_grad_master = layer_master.cls_token.grad + cls_grad = torch.chunk(cls_grad_master, DEPTH, dim=-1)[k] + logger.info('Rank {} patch embed backward (cls_grad): {}'.format(rank, check_equal(cls_grad, layer.cls_token.grad))) + + pos_grad_master = layer_master.pos_embed.grad + pos_grad = torch.chunk(pos_grad_master, DEPTH, dim=-1)[k] + logger.info('Rank {} patch embed backward (pos_embed_grad): {}'.format(rank, + check_equal(pos_grad, layer.pos_embed.grad))) + + B_grad = layer_master.weight.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k] + logger.info('Rank {} patch embed backward (proj_weight_grad): {}'.format(rank, + check_equal(B_grad, layer.weight.grad))) + + bias_grad = layer_master.bias.grad + bias_grad = torch.chunk(bias_grad, DEPTH)[k] + logger.info('Rank {} patch embed backward (proj_bias_grad): {}'.format(rank, + check_equal(bias_grad, layer.bias.grad))) + + return fwd_end - fwd_start, bwd_end - bwd_start + + +def check_embed(): + rank = torch.distributed.get_rank() + device = get_current_device() + logger = get_dist_logger() + dtype = torch.float32 + + input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + + j = global_context.get_local_rank(input_parallel_mode) + i = global_context.get_local_rank(weight_parallel_mode) + k = global_context.get_local_rank(output_parallel_mode) + + layer = Embedding3D(VOCAB_SIZE, HIDDEN_SIZE) + layer = layer.to(device) + layer_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + layer_master = layer_master.to(device) + + weight_master = layer_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=-1)[k] + layer.weight.data.copy_(weight) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + + fwd_start = time.time() + out = layer(A) + torch.cuda.synchronize() + fwd_end = time.time() + logger.info('embed forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape), + fwd_end - fwd_start), + ranks=[0]) + + A_master = A_master.clone() + C_master = layer_master(A_master) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[k] + C = torch.chunk(C, DEPTH, dim=0)[j] + logger.info('Rank {} embed forward: {}'.format(rank, check_equal(out, C))) + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[k] + grad = torch.chunk(grad, DEPTH, dim=0)[j] + grad = grad.clone() + bwd_start = time.time() + out.backward(grad) + torch.cuda.synchronize() + bwd_end = time.time() + logger.info('embed backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0]) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + B_grad = layer_master.weight.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] + logger.info('Rank {} embed backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad))) + + return fwd_end - fwd_start, bwd_end - bwd_start + + +def check_vocab_parallel_embed(): + rank = torch.distributed.get_rank() + device = get_current_device() + logger = get_dist_logger() + dtype = torch.float32 + + input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + + j = global_context.get_local_rank(input_parallel_mode) + i = global_context.get_local_rank(weight_parallel_mode) + k = global_context.get_local_rank(output_parallel_mode) + + layer = VocabParallelEmbedding3D(VOCAB_SIZE, HIDDEN_SIZE) + layer = layer.to(device) + layer_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + layer_master = layer_master.to(device) + + weight_master = layer_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=0)[j] + weight = torch.chunk(weight, DEPTH, dim=0)[i] + weight = torch.chunk(weight, DEPTH, dim=-1)[k] + layer.weight.data.copy_(weight) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + + fwd_start = time.time() + out = layer(A) + torch.cuda.synchronize() + fwd_end = time.time() + logger.info('vocab parallel embed forward: pass | {0} --> {1} | {2:.3f} s'.format( + tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), + ranks=[0]) + + A_master = A_master.clone() + C_master = layer_master(A_master) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[k] + C = torch.chunk(C, DEPTH, dim=0)[j] + logger.info('Rank {} vocab parallel embed forward: {}'.format(rank, check_equal(out, C))) + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[k] + grad = torch.chunk(grad, DEPTH, dim=0)[j] + grad = grad.clone() + bwd_start = time.time() + out.backward(grad) + torch.cuda.synchronize() + bwd_end = time.time() + logger.info('vocab parallel embed backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0]) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + B_grad = layer_master.weight.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j] + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] + logger.info('Rank {} vocab parallel embed backward (weight_grad): {}'.format(rank, + check_equal(B_grad, + layer.weight.grad))) + + return fwd_end - fwd_start, bwd_end - bwd_start + + +def check_loss(): + rank = torch.distributed.get_rank() + logger = get_dist_logger() + device = get_current_device() + + input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + + j = global_context.get_local_rank(input_parallel_mode) + i = global_context.get_local_rank(weight_parallel_mode) + + criterion = CrossEntropyLoss3D() + criterion_master = torch.nn.CrossEntropyLoss() + + out_shape = (BATCH_SIZE, NUM_CLASSES) + out_master = torch.randn(out_shape, device=device) + target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device) + torch.distributed.broadcast(out_master, src=0) + torch.distributed.broadcast(target_master, src=0) + out = torch.chunk(out_master, DEPTH, dim=0)[i] + out = torch.chunk(out, DEPTH, dim=0)[j] + out = out.clone() + out.requires_grad = True + + fwd_start = time.time() + loss = criterion(out, target_master) + fwd_end = time.time() + logger.info('cross entropy loss forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(out.shape), tuple(loss.shape), + fwd_end - fwd_start), + ranks=[0]) + + out_master = out_master.clone() + out_master.requires_grad = True + loss_master = criterion_master(out_master, target_master) + logger.info('Rank {} cross entropy loss forward: {}'.format(rank, check_equal(loss, loss_master))) + + bwd_start = time.time() + loss.backward() + bwd_end = time.time() + logger.info('cross entropy loss backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0]) + + loss_master.backward() + out_grad = out_master.grad + out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i] + out_grad = torch.chunk(out_grad, DEPTH, dim=0)[j] + logger.info('Rank {} cross entropy loss backward: {}'.format(rank, check_equal(out_grad, out.grad))) + + return fwd_end - fwd_start, bwd_end - bwd_start + + +def check_vocab_parallel_loss(): + rank = torch.distributed.get_rank() + logger = get_dist_logger() + device = get_current_device() + dtype = torch.float32 + + input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + + j = global_context.get_local_rank(input_parallel_mode) + i = global_context.get_local_rank(weight_parallel_mode) + k = global_context.get_local_rank(output_parallel_mode) + + criterion = VocabParallelCrossEntropyLoss3D() + criterion_master = torch.nn.CrossEntropyLoss() + + out_shape = (BATCH_SIZE, NUM_CLASSES) + out_master = torch.randn(out_shape, device=device) + target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device) + torch.distributed.broadcast(out_master, src=0) + torch.distributed.broadcast(target_master, src=0) + out = torch.chunk(out_master, DEPTH, dim=0)[i] + out = torch.chunk(out, DEPTH, dim=-1)[k] + out = torch.chunk(out, DEPTH, dim=0)[j] + out = out.clone() + out.requires_grad = True + + fwd_start = time.time() + loss = criterion(out, target_master) + fwd_end = time.time() + logger.info('vocab parallel cross entropy loss forward: pass | {0} --> {1} | {2:.3f} s'.format( + tuple(out.shape), tuple(loss.shape), fwd_end - fwd_start), + ranks=[0]) + + out_master = out_master.clone() + out_master.requires_grad = True + loss_master = criterion_master(out_master, target_master) + logger.info('Rank {} vocab parallel cross entropy loss forward: {}'.format(rank, check_equal(loss, loss_master))) + + bwd_start = time.time() + loss.backward() + bwd_end = time.time() + logger.info('vocab parallel cross entropy loss backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0]) + + loss_master.backward() + out_grad = out_master.grad + out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i] + out_grad = torch.chunk(out_grad, DEPTH, dim=-1)[k] + out_grad = torch.chunk(out_grad, DEPTH, dim=0)[j] + logger.info('Rank {} vocab parallel cross entropy loss backward: {}'.format(rank, check_equal(out_grad, out.grad))) + + return fwd_end - fwd_start, bwd_end - bwd_start diff --git a/tests/test_layers/test_3d/checks_3d/common.py b/tests/test_layers/test_3d/checks_3d/common.py new file mode 100644 index 0000000000000000000000000000000000000000..afb19c4745cc72cd66a4d5a11239ed6f10f68d14 --- /dev/null +++ b/tests/test_layers/test_3d/checks_3d/common.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch + +DEPTH = 2 +BATCH_SIZE = 8 +SEQ_LENGTH = 8 +HIDDEN_SIZE = 8 +NUM_CLASSES = 8 +NUM_BLOCKS = 2 +IMG_SIZE = 16 +VOCAB_SIZE = 16 + + +def check_equal(A, B): + eq = torch.allclose(A, B, rtol=1e-3, atol=1e-2) + assert eq, f"\nA = {A}\nB = {B}" + return eq \ No newline at end of file diff --git a/tests/test_layers/test_3d/test_3d.py b/tests/test_layers/test_3d/test_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..fde71a4a0d26c172b45c6d0ff48497d55174b893 --- /dev/null +++ b/tests/test_layers/test_3d/test_3d.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- +import pytest +import torch +from checks_3d.check_layer_3d import ( + check_classifier_no_given_weight, + check_embed, + check_layernorm, + check_linear, + check_loss, + check_patch_embed, + check_vocab_parallel_classifier_given_embed_weight, + check_vocab_parallel_classifier_no_given_weight, + check_vocab_parallel_embed, + check_vocab_parallel_loss, +) + +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn + +CONFIG = dict( + parallel=dict( + pipeline=1, + tensor=dict(mode='3d', size=8), + ), + seed=42, +) + + +def check_layer(): + check_linear() + check_layernorm() + check_classifier_no_given_weight() + check_vocab_parallel_classifier_no_given_weight() + check_vocab_parallel_classifier_given_embed_weight() + check_embed() + check_patch_embed() + check_vocab_parallel_embed() + check_loss() + check_vocab_parallel_loss() + + +def check_layer_and_operation(rank, world_size, port): + disable_existing_loggers() + launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + torch.backends.cudnn.deterministic = True + check_layer() + gpc.destroy() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@skip_if_not_enough_gpus(min_gpus=8) +@rerun_if_address_is_in_use() +def test_3d(): + spawn(check_layer_and_operation, 8) + + +if __name__ == '__main__': + test_3d() diff --git a/tests/test_layers/test_cache_embedding.py b/tests/test_layers/test_cache_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..22d4f02a48d726d8f1892f5cb99f57eca863490d --- /dev/null +++ b/tests/test_layers/test_cache_embedding.py @@ -0,0 +1,377 @@ +import random +from typing import List + +import numpy as np +import pytest +import torch + +import colossalai +from colossalai.nn.parallel.layers import ( + CachedEmbeddingBag, + CachedParamMgr, + EvictionStrategy, + ParallelCachedEmbeddingBag, + ParallelCachedEmbeddingBagTablewise, + TablewiseEmbeddingBagConfig, +) +from colossalai.tensor import ColoTensor, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn + +NUM_EMBED, EMBED_DIM = 10, 8 +BATCH_SIZE = 8 + + +def set_seed(seed): + """ + To achieve reproducible results, it's necessary to fix random seeds + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + +def synthesize_1d_sparse_feature( + batch_size, + num_embed, + device, +): + indices_in_batch = batch_size * 2 + indices = torch.randint(low=0, high=num_embed, size=(indices_in_batch,), device=device, dtype=torch.long) + offsets = torch.from_numpy( + np.array([ + 0, *np.sort(np.random.randint(low=0, high=indices_in_batch, size=(indices_in_batch - 1,))), indices_in_batch + ])).to(device).long() + return indices, offsets + + +@pytest.mark.skip +@clear_cache_before_run() +def test_cachemgr(): + model = torch.nn.EmbeddingBag(10000, 128) + # 10 chunks, 5 in cuda + mgr = CachedParamMgr(model.weight.detach(), 5) + assert mgr.cuda_row_num == 5 + + mgr._admit(1) + assert not mgr._chunk_in_cuda(2) + assert mgr._chunk_in_cuda(1) + + # print(mgr.cached_chunk_table) + mgr._admit(8) + + # now 3 chunk is available + assert mgr.cuda_available_chunk_num == 3 + + mgr._evict() + assert mgr.cuda_available_chunk_num == 4 + + mgr._prepare_rows_on_cuda(torch.tensor([9, 6, 5], dtype=torch.long, device=0)) + mgr._prepare_rows_on_cuda(torch.tensor([3, 4, 5], dtype=torch.long, device=0)) + # print(mgr.cached_chunk_table) + # mgr.print_comm_stats() + + mgr.flush() + assert mgr.cuda_available_chunk_num == 5 + + +@clear_cache_before_run() +def test_reorder_with_freq(): + num_embed = 100 + chunk_size = 1 + num_chunk = 5 + + idx_map = torch.randint(10000, size=(num_embed,)) + sorted_idx = torch.argsort(idx_map, descending=True).tolist() + chunkid, offset_in_chunk = [], [] + for i in range(num_embed): + idx = sorted_idx.index(i) + chunkid.append(idx // chunk_size) + offset_in_chunk.append(idx % chunk_size) + + dev = torch.device('cuda') + chunkid = torch.tensor(chunkid, dtype=torch.long, device=dev) + offset_in_chunk = torch.tensor(offset_in_chunk, dtype=torch.long, device=dev) + + weight = torch.rand(num_embed, 2) + mgr = CachedParamMgr(weight, num_chunk) + + mgr.reorder(idx_map) + + indices = mgr.idx_map.index_select(0, torch.arange(num_embed, dtype=torch.long, device=dev)) + mgr_chunk_id = torch.div(indices, chunk_size, rounding_mode='floor') + mgr_offsets = torch.remainder(indices, chunk_size) + assert torch.allclose(chunkid, mgr_chunk_id), f"chunk id: {chunkid}, mgr: {mgr_chunk_id}" + assert torch.allclose(offset_in_chunk, mgr_offsets), \ + f"offset in chunk: {offset_in_chunk}, mgr: {mgr_offsets}" + + +@clear_cache_before_run() +@parameterize('use_LFU', [True, False]) +def test_freq_aware_embed(use_LFU: bool): + device = torch.device('cuda', 0) + evict_strategy = EvictionStrategy.LFU if use_LFU else EvictionStrategy.DATASET + model = CachedEmbeddingBag(NUM_EMBED, + EMBED_DIM, + mode='mean', + include_last_offset=True, + cache_ratio=min(BATCH_SIZE * 2 / NUM_EMBED, 1.0), + ids_freq_mapping=None, + evict_strategy=evict_strategy).to(device) + + assert model.weight.shape[0] == NUM_EMBED + ref_model = torch.nn.EmbeddingBag.from_pretrained(model.weight.detach().to(device), + mode='mean', + include_last_offset=True, + freeze=False) + + assert torch.allclose(ref_model.weight.detach(), model.weight.detach().to(device)) + + optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) + ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-3) + + for i in range(5): + indices, offsets = synthesize_1d_sparse_feature(BATCH_SIZE, NUM_EMBED, device) + res = model(indices, offsets) + ref_res = ref_model(indices, offsets) + assert torch.allclose(res, ref_res), f"model result: {res}, reference: {ref_res}" + + grad = torch.rand_like(res) + # comparing gradient here is nontrivial + res.backward(grad) + ref_res.backward(grad) + optimizer.step() + optimizer.zero_grad() + + ref_optimizer.step() + ref_optimizer.zero_grad() + + model.cache_weight_mgr.flush() + model_weight = model.weight.detach().to(device) + ref_weight = ref_model.weight.detach() + assert torch.allclose(model_weight, ref_weight), \ + f"model weight: {model_weight[10:18, :8]}, reference: {ref_weight[10:18, :8]}" + + +@clear_cache_before_run() +@parameterize('init_freq', [True, False]) +def test_lfu_strategy(init_freq: bool): + # minimal test to check behavior + Bag = CachedEmbeddingBag(5, + 5, + cache_ratio=3 / 5, + buffer_size=0, + pin_weight=True, + ids_freq_mapping=[4, 2, 1, 3, 1] if init_freq else None, + warmup_ratio=1.0, + evict_strategy=EvictionStrategy.LFU) + + # print('cached_idx_map: ', Bag.cache_weight_mgr.cached_idx_map) + offsets = torch.tensor([0], device="cuda:0") + + # prepare frequency learning info: + Bag.forward(torch.tensor([2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([1, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0], device="cuda:0"), offsets) + + # check strategy + Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([3], device="cuda:0"), offsets) # miss, evict 1 + Bag.forward(torch.tensor([2], device="cuda:0"), offsets) # hit + Bag.forward(torch.tensor([4], device="cuda:0"), offsets) # miss, evict 3 + Bag.forward(torch.tensor([2], device="cuda:0"), offsets) # hit + Bag.forward(torch.tensor([0], device="cuda:0"), offsets) # hit + + assert torch.allclose(torch.Tensor(Bag.cache_weight_mgr.num_hits_history[-6:]), torch.Tensor([3, 0, 1, 0, 1, 1])), \ + "LFU strategy behavior failed" + + +def gather_tensor(tensor, rank, world_size): + gather_list = [] + if rank == 0: + gather_list = [torch.empty_like(tensor) for _ in range(world_size)] + + torch.distributed.gather(tensor, gather_list, dst=0) + return gather_list + + +def run_parallel_freq_aware_embed_tablewise(rank, world_size): + if world_size != 2: + return + device = torch.device('cuda', torch.cuda.current_device()) + + # initialize weight + # 3 feature tables. idx: 0~5, 6~10, 11~17 + weight_tables = torch.rand(18, 5) + weight_table1 = weight_tables[0:6] + weight_table2 = weight_tables[6:11] + weight_table3 = weight_tables[11:18] + embedding_bag_config_list: List[TablewiseEmbeddingBagConfig] = [] + embedding_bag_config_list.append( + TablewiseEmbeddingBagConfig(num_embeddings=6, + cuda_row_num=4, + assigned_rank=0, + initial_weight=weight_table1.clone().detach().cpu())) + embedding_bag_config_list.append( + TablewiseEmbeddingBagConfig(num_embeddings=5, + cuda_row_num=4, + assigned_rank=0, + initial_weight=weight_table2.clone().detach().cpu())) + embedding_bag_config_list.append( + TablewiseEmbeddingBagConfig(num_embeddings=7, + cuda_row_num=4, + assigned_rank=1, + initial_weight=weight_table3.clone().detach().cpu())) + if rank == 0: + _weight = torch.cat([weight_table1, weight_table2], 0) + else: + _weight = weight_table3 + model = ParallelCachedEmbeddingBagTablewise( + embedding_bag_config_list, + embedding_dim=5, + _weight=_weight, + include_last_offset=True, + cache_ratio=0.5, + buffer_size=0, + evict_strategy=EvictionStrategy.LFU, + ) + # explain + ''' + batch feature 1 feature 2 feature 3 + input0 [1,2,3] [6,7] [] + input1 [] [9] [13,15] + input2 [1,5] [6,8] [11] + ↑ ↑ ↑ + rank 0 rank 0 rank 1 + in KJT format + ''' + res = model(torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device), + torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device), + already_split_along_rank=False) + optimizer = torch.optim.SGD(model.parameters(), lr=1e-2) + rand_grad = torch.rand(3, 5 * 3, dtype=res.dtype, device=res.device) + if rank == 0: + fake_grad = rand_grad[0:2] + else: + fake_grad = rand_grad[2:] + res.backward(fake_grad) + optimizer.step() + optimizer.zero_grad() + + # check correctness + if rank == 0: + ref_model = torch.nn.EmbeddingBag.from_pretrained(weight_tables.detach().clone(), + include_last_offset=True, + freeze=False).to(device) + ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-2) + ref_fake_grad = torch.cat(rand_grad.split(5, 1), 0) + ref_res = ref_model(torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device), + torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device)) + ref_res.backward(ref_fake_grad) + ref_optimizer.step() + ref_optimizer.zero_grad() + + model.cache_weight_mgr.flush() + recover_weight = model.cache_weight_mgr.weight.to(device) + ref_weight = ref_model.weight.detach()[:11] + assert torch.allclose(recover_weight, ref_weight), f"{recover_weight - ref_weight}" + + +def run_parallel_freq_aware_embed_columnwise(rank, world_size): + device = torch.device('cuda', torch.cuda.current_device()) + + num_embed = 100 + embed_dim = 16 + batch_size = 4 + + set_seed(4321) + weight = torch.rand(num_embed, embed_dim) + coloweight = ColoTensor(weight.clone().detach().cpu(), spec=None) + + # initialize the tensor spec for the embedding weight parameter, + # which is an ColoParameter. + coloweight.set_process_group(ProcessGroup(tp_degree=world_size)) + coloweight.set_tensor_spec(ShardSpec(dims=[-1], num_partitions=[world_size]), ComputeSpec(ComputePattern.TP1D)) + + model = ParallelCachedEmbeddingBag.from_pretrained( + coloweight, + include_last_offset=True, + freeze=False, + cache_ratio=batch_size * 2 / num_embed, + ) + + assert model.cache_weight_mgr.weight.device.type == 'cpu' + assert model.cache_weight_mgr.cuda_cached_weight.requires_grad + weight_in_rank = torch.tensor_split(weight, world_size, -1)[rank] + print(f"model weight: {model.cache_weight_mgr.weight.shape}, ref weight: {weight_in_rank.shape}") + assert torch.allclose(weight_in_rank, + model.cache_weight_mgr.weight.detach()), f"{weight_in_rank - model.cache_weight_mgr.weight}" + + optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) + + if rank == 0: + ref_model = torch.nn.EmbeddingBag.from_pretrained(weight.detach().clone(), + include_last_offset=True, + freeze=False).to(device) + ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-3) + + set_seed(4321) + for i in range(5): + indices, offsets = synthesize_1d_sparse_feature(batch_size, num_embed, device) + res = model(indices, offsets) + + grad = torch.rand(batch_size * 2, embed_dim, dtype=res.dtype, device=res.device) + grad_in_rank = torch.tensor_split(grad, world_size, 0)[rank] + res.backward(grad_in_rank) + + optimizer.step() + optimizer.zero_grad() + + res_list = gather_tensor(res.detach(), rank, world_size) + + if rank == 0: + ref_res = ref_model(indices, offsets) + recover_res = torch.cat(res_list, dim=0) + + assert torch.allclose(ref_res, recover_res) + + ref_res.backward(grad) + ref_optimizer.step() + ref_optimizer.zero_grad() + + model.cache_weight_mgr.flush() + weight_list = gather_tensor(model.cache_weight_mgr.weight.detach().cuda(), rank, world_size) + if rank == 0: + recover_weight = torch.cat(weight_list, dim=1) + assert torch.allclose(recover_weight, ref_model.weight.detach()), f"{recover_weight - ref_model.weight}" + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + # run_parallel_freq_aware_embed_columnwise(rank, world_size) + run_parallel_freq_aware_embed_tablewise(rank, world_size) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_parallel_freq_aware_embed(world_size): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + # test_freq_aware_embed(True) + test_parallel_freq_aware_embed(2) + # test_lfu_strategy(False) diff --git a/tests/test_layers/test_sequence/checks_seq/__init__.py b/tests/test_layers/test_sequence/checks_seq/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/test_layers/test_sequence/checks_seq/check_layer_seq.py b/tests/test_layers/test_sequence/checks_seq/check_layer_seq.py new file mode 100644 index 0000000000000000000000000000000000000000..2b7b999d43731ae5b5cd3f7cb87eecc6c1585fc4 --- /dev/null +++ b/tests/test_layers/test_sequence/checks_seq/check_layer_seq.py @@ -0,0 +1,21 @@ +import torch + +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.nn import TransformerSelfAttentionRing +from colossalai.utils import get_current_device + + +def check_selfattention(): + WORLD_SIZE = gpc.get_world_size(ParallelMode.SEQUENCE) + SUB_SEQ_LENGTH = 8 + BATCH = 4 + HIDDEN_SIZE = 16 + + layer = TransformerSelfAttentionRing(16, 8, 8, 0.1) + layer = layer.to(get_current_device()) + + hidden_states = torch.rand(SUB_SEQ_LENGTH, BATCH, HIDDEN_SIZE).to(get_current_device()) + attention_mask = torch.randint(low=0, high=2, + size=(BATCH, 1, 1, 1, SUB_SEQ_LENGTH * WORLD_SIZE)).to(get_current_device()) + out = layer(hidden_states, attention_mask) diff --git a/tests/test_layers/test_sequence/test_sequence.py b/tests/test_layers/test_sequence/test_sequence.py new file mode 100644 index 0000000000000000000000000000000000000000..aac192d7eff0b2df668a937c4eafdf9e11404f51 --- /dev/null +++ b/tests/test_layers/test_sequence/test_sequence.py @@ -0,0 +1,138 @@ +import pytest +import torch +import torch.distributed as dist + +import colossalai +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.testing import rerun_if_address_is_in_use, spawn + +CONFIG = dict(parallel=dict(tensor=dict(size=4, mode='sequence'))) + + +def check_ring_qk(rank, world_size): + # params + batch_size = 4 + num_heads = 4 + seq_length = 32 + attention_head_size = 32 + sub_seq_length = seq_length // world_size + + # create master tensors + q = torch.rand(batch_size * num_heads, seq_length, attention_head_size).cuda() + k = torch.rand(batch_size * num_heads, seq_length, attention_head_size).cuda() + dist.broadcast(q, src=0, group=gpc.get_group(ParallelMode.SEQUENCE)) + dist.broadcast(k, src=0, group=gpc.get_group(ParallelMode.SEQUENCE)) + + # create distributed tensors + sub_q = q.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous() + sub_k = k.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous() + + # set autograd attributes + q.requires_grad = True + k.requires_grad = True + q.retain_grad() + k.retain_grad() + sub_q.requires_grad = True + sub_k.requires_grad = True + sub_q.retain_grad() + sub_k.retain_grad() + + # compute master attention scores + a = torch.matmul(q, k.transpose(2, 1)) + + # compute distributed attention scores + ring_qk = colossalai.nn.layer.parallel_sequence.RingQK.apply + sub_a = ring_qk(sub_q, sub_k, batch_size, num_heads, sub_seq_length) + + # check master and distributed attetion scores + sub_master_a = a[:, rank * sub_seq_length:(rank + 1) * sub_seq_length] + assert torch.allclose(sub_a, sub_master_a, rtol=1e-5, atol=1e-2) + + # run master backward + a.retain_grad() + a.mean().backward() + + # run distributed backward + partial_master_a_grad = a.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length] + torch.autograd.backward(sub_a, partial_master_a_grad) + + # check master and distributed grads + partial_master_q_grad = q.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length] + assert torch.allclose(sub_q.grad, partial_master_q_grad, rtol=1e-5, atol=1e-2), \ + 'attention score cannot match' + + +def check_ring_av(rank, world_size): + # params + batch_size = 4 + num_heads = 4 + seq_length = 16 + attention_head_size = 32 + sub_seq_length = seq_length // world_size + + # create master tensors + a = torch.rand(batch_size * num_heads, seq_length, seq_length).cuda() + v = torch.rand(batch_size * num_heads, seq_length, attention_head_size).cuda() + dist.broadcast(a, src=0, group=gpc.get_group(ParallelMode.SEQUENCE)) + dist.broadcast(v, src=0, group=gpc.get_group(ParallelMode.SEQUENCE)) + + # create distributed tensors + sub_a = a.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous() + sub_v = v.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous() + + # set autograd attributes + a.requires_grad = True + v.requires_grad = True + a.retain_grad() + v.retain_grad() + sub_a.requires_grad = True + sub_v.requires_grad = True + sub_a.retain_grad() + sub_v.retain_grad() + + # compute master attention scores + out = torch.matmul(a, v) + + # compute distributed attention scores + ring_av = colossalai.nn.layer.parallel_sequence.RingAV.apply + sub_out = ring_av(sub_a, sub_v, batch_size, num_heads, attention_head_size, sub_seq_length) + + # print(f'master output shape: {out.shape}, partial output shape: {sub_out.shape}') + + # check master and distributed output + sub_master_out = out[:, rank * sub_seq_length:(rank + 1) * sub_seq_length] + assert torch.allclose(sub_out, sub_master_out, rtol=1e-5, atol=1e-2) + + # # run master backward + out.retain_grad() + out.mean().backward() + + # # run distributed backward + partial_master_out_grad = out.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length] + torch.autograd.backward(sub_out, partial_master_out_grad) + + # # check master and distributed grads + partial_master_a_grad = a.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length] + assert torch.allclose(sub_a.grad, partial_master_a_grad, rtol=1e-5, atol=1e-2), \ + 'attention output cannot match' + + +def run_test(rank, world_size, port): + colossalai.launch(rank=rank, world_size=world_size, config=CONFIG, host='localhost', port=port) + + # check_ring_qk(rank, world_size) + check_ring_av(rank, world_size) + + gpc.destroy() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_sequence(): + spawn(run_test, 4) + + +if __name__ == '__main__': + test_sequence() diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_moe/test_grad_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..e7002a75f3f7b1d7dd5eff4fac0eba0d5de9f95d --- /dev/null +++ b/tests/test_moe/test_grad_handler.py @@ -0,0 +1,71 @@ +import pytest +import torch +import torch.distributed as dist +import torch.nn as nn + +import colossalai +from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.engine.gradient_handler import MoeGradientHandler +from colossalai.nn.layer.moe import Experts, MoeLayer, Top1Router, UniformNoiseGenerator +from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device +from colossalai.utils.moe import sync_moe_model_param + +BATCH_SIZE = 4 +DIM = 16 +CONFIG = dict() + + +def run_test(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + expert_module = nn.Linear + expert_factor = dict(in_features=DIM, out_features=DIM, device=get_current_device()) + + MOE_CONTEXT.setup(42) # MOE initialization + noisy_func = UniformNoiseGenerator() + router = Top1Router(noisy_func=noisy_func) + num_experts_list = [1, 2, 4] + layer_list = [] + for num_experts in num_experts_list: + exp = Experts(expert_module, num_experts, **expert_factor) + moe_layer = MoeLayer(DIM, num_experts, router, exp) + layer_list.append(moe_layer) + + model = nn.ModuleList(layer_list) + model = model.to(get_current_device()) + sync_moe_model_param(model) + + dist_dict = MOE_CONTEXT.parallel_info_dict + assert_equal_in_group(layer_list[0].experts.experts[0].weight.data, dist_dict[1].dp_group) + assert_equal_in_group(layer_list[1].experts.experts[0].weight.data, dist_dict[2].dp_group) + # MoE model synchronization passed + + grad_handler = MoeGradientHandler(model, 0) + + rank = dist.get_rank() + torch.cuda.manual_seed(78 + rank) + data = torch.randn(BATCH_SIZE, DIM, device=get_current_device()) + grad = torch.randn_like(data) + + MOE_CONTEXT.reset_loss() + for layer in layer_list: + data, _ = layer(data) + data.backward(grad) + grad_handler.handle_gradient() + + assert_equal_in_group(layer_list[0].experts.experts[0].weight.grad, dist_dict[1].dp_group) + assert_equal_in_group(layer_list[0].experts.experts[0].bias.grad, dist_dict[1].dp_group) + + assert_equal_in_group(layer_list[1].experts.experts[0].weight.grad, dist_dict[2].dp_group) + assert_equal_in_group(layer_list[1].experts.experts[0].bias.grad, dist_dict[2].dp_group) + # MoE grad handler test passed + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_grad_handler(): + spawn(run_test, 4) + + +if __name__ == '__main__': + test_grad_handler() diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..ad9a172b72aa479dd85b50483b79c4b7b22fa023 --- /dev/null +++ b/tests/test_moe/test_kernel.py @@ -0,0 +1,96 @@ +import pytest +import torch +import torch.nn as nn + +import colossalai +from colossalai.context import ParallelMode +from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.core import global_context as gpc +from colossalai.nn.layer.moe import Experts, MoeLayer, Top1Router, Top2Router +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device + +BATCH_SIZE = 16 +NUM_EXPERTS = 4 +CONFIG = dict() + + +def check_equal(tensor_a, tensor_b, atol=1e-06): + assert torch.allclose(tensor_a, tensor_b, rtol=0, atol=atol) is True + + +def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32, router=Top2Router): + # Here we do not need TF32, since it brings absolute error on results + torch.backends.cuda.matmul.allow_tf32 = False + + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + local_rank = gpc.get_local_rank(ParallelMode.GLOBAL) + + MOE_CONTEXT.setup(42) # MOE environment initialization + MOE_CONTEXT.reset_loss() + torch.manual_seed(rs + local_rank) # set each process has different random seed + + # get randomized data + tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True) + + expert_module = nn.Linear + expert_factor = dict(in_features=hidden_size, out_features=hidden_size, device=get_current_device()) + expert = Experts(expert_module, NUM_EXPERTS, **expert_factor) + layer = MoeLayer(hidden_size, NUM_EXPERTS, router(capacity_factor_train=1.0), expert) + layer = layer.to(get_current_device()) + if data_type == torch.float16: + layer = layer.half() + + # use matrix multiplication instead of COL_MOE_KERNL in MOE dispatch and combine + layer.use_kernel = False + old_out, _ = layer(tokens) + ech = old_out.shape + grad = torch.randn(ech, device=get_current_device()) + old_out.backward(grad) # get gradient + + # save all results + o_tk_grad = tokens.grad.data.clone() + o_gt_grad = layer.gate_weight.grad.data.clone() + + # reset all gradients + tokens.grad.zero_() + layer.gate_weight.grad.zero_() + + layer.use_kernel = True + new_out, _ = layer(tokens) # get ouputs through colossal kernel + + if data_type == torch.float32: + check_equal(old_out, new_out) + else: + check_equal(old_out, new_out, 1e-2) + # forward function passed + + new_out.backward(grad) # get new type gradient + n_tk_grad = tokens.grad.data.clone() + n_gt_grad = layer.gate_weight.grad.data.clone() + + if data_type == torch.float32: + check_equal(o_tk_grad, n_tk_grad) + else: + check_equal(o_tk_grad, o_tk_grad, 1e-2) + # tokens gradient is correct + + if data_type == torch.float32: + check_equal(o_gt_grad, n_gt_grad, 5e-05) + else: + check_equal(o_gt_grad, n_gt_grad, 2e-01) + # bias gradient is correct + + +@pytest.mark.dist +@pytest.mark.parametrize("rs", [131]) +@pytest.mark.parametrize("hidden_size", [32, 144]) +@pytest.mark.parametrize("data_type", [torch.float32, torch.float16]) +@pytest.mark.parametrize("router", [Top1Router, Top2Router]) +@rerun_if_address_is_in_use() +def test_moe_kernel(rs, hidden_size, data_type, router): + spawn(run_routing, 4, rs=rs, hidden_size=hidden_size, data_type=data_type, router=router) + + +if __name__ == '__main__': + test_moe_kernel(2, 256, torch.float16, Top2Router) diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..8a0283ba71fc9c6798c5800e4cc529acee47de99 --- /dev/null +++ b/tests/test_moe/test_moe_checkpoint.py @@ -0,0 +1,50 @@ +import os + +import pytest +import torch +import torch.distributed as dist + +import colossalai +from colossalai.context import MOE_CONTEXT +from colossalai.nn.layer.moe import load_moe_model, save_moe_model +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device +from colossalai.zero import ColoInitContext +from tests.test_moe.test_moe_zero_init import MoeModel +from tests.test_zero.test_legacy.common import CONFIG + + +def exam_moe_checkpoint(): + with ColoInitContext(device=get_current_device()): + model = MoeModel(checkpoint=True) + save_moe_model(model, 'temp_path.pth') + + with ColoInitContext(device=get_current_device()): + other_model = MoeModel(checkpoint=True) + load_moe_model(other_model, 'temp_path.pth') + + state_0 = model.state_dict() + state_1 = other_model.state_dict() + for k, v in state_0.items(): + u = state_1.get(k) + assert torch.equal(u.data, v.data) + + if dist.get_rank() == 0: + os.remove('temp_path.pth') + + +def _run_dist(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + MOE_CONTEXT.setup(seed=42) + exam_moe_checkpoint() + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [2, 4]) +@rerun_if_address_is_in_use() +def test_moe_checkpoint(world_size): + spawn(_run_dist) + + +if __name__ == '__main__': + test_moe_checkpoint(world_size=4) diff --git a/tests/test_moe/test_moe_colo_init.py b/tests/test_moe/test_moe_colo_init.py new file mode 100644 index 0000000000000000000000000000000000000000..555338fcf9fcde69bec2c19efdd26d97977ada66 --- /dev/null +++ b/tests/test_moe/test_moe_colo_init.py @@ -0,0 +1,56 @@ +import pytest +import torch +import torch.distributed as dist + +import colossalai +from colossalai.context import MOE_CONTEXT +from colossalai.tensor import ColoParameter +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device +from colossalai.zero import ColoInitContext +from tests.test_moe.test_moe_zero_init import MoeModel +from tests.test_tensor.common_utils import debug_print +from tests.test_zero.test_legacy.common import CONFIG + + +@parameterize("init_device_type", ['cpu', 'cuda']) +def exam_moe_colo_init(init_device_type): + world_size = dist.get_world_size() + + if init_device_type == 'cuda': + init_device = get_current_device() + elif init_device_type == 'cpu': + init_device = torch.device("cpu") + else: + raise NotImplementedError("Unknown device found.") + + with ColoInitContext(device=init_device): + model = MoeModel(checkpoint=True) + + for name, param in model.named_parameters(): + assert isinstance(param, ColoParameter), "parameter `{}` has an init problem".format(name) + + if hasattr(param, "moe_info"): + param.set_process_group(param.moe_info.pg) + + if hasattr(param, "moe_info"): + assert param.process_group.dp_world_size() == param.moe_info.dp_size + else: + assert param.process_group.dp_world_size() == world_size + + +def _run_dist(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + MOE_CONTEXT.setup(seed=42) + exam_moe_colo_init() + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [4]) +@rerun_if_address_is_in_use() +def test_moe_colo_init(world_size): + spawn(_run_dist, world_size) + + +if __name__ == '__main__': + test_moe_colo_init(world_size=4) diff --git a/tests/test_moe/test_moe_group.py b/tests/test_moe/test_moe_group.py new file mode 100644 index 0000000000000000000000000000000000000000..6dc3f5f18b6df04764800814c8d02e70205347af --- /dev/null +++ b/tests/test_moe/test_moe_group.py @@ -0,0 +1,68 @@ +import pytest +import torch.distributed as dist +import torch.nn as nn + +import colossalai +from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.nn.layer.moe import Experts +from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device +from colossalai.utils.moe import sync_moe_model_param + +D_MODEL = 4 +D_FF = 8 +CONFIG = dict() + + +def run_test(rank, world_size, port): + world_size = 4 + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + expert_module = nn.Linear + expert_factor = dict(in_features=D_MODEL, out_features=D_FF, device=get_current_device()) + + MOE_CONTEXT.setup(42) # MOE environment initialization + exp0 = Experts(expert_module, 1, **expert_factor) + exp1 = Experts(expert_module, 2, **expert_factor) + exp2 = Experts(expert_module, 4, **expert_factor) + exp3 = Experts(expert_module, 8, **expert_factor) + + assert exp0.num_local_experts == 1 + assert exp1.num_local_experts == 1 + assert exp2.num_local_experts == 1 + assert exp3.num_local_experts == 2 + # experts deployment passed + + parallel_info_dict = MOE_CONTEXT.parallel_info_dict + rank = dist.get_rank() + + assert len(parallel_info_dict) == 3 + assert dist.get_rank(parallel_info_dict[4].ep_group) == rank + assert dist.get_rank(parallel_info_dict[2].ep_group) == rank % 2 + assert dist.get_rank(parallel_info_dict[1].ep_group) == 0 + + assert dist.get_rank(parallel_info_dict[4].dp_group) == 0 + assert dist.get_rank(parallel_info_dict[2].dp_group) == rank // 2 + assert dist.get_rank(parallel_info_dict[1].dp_group) == rank + # group creation passed + + model = nn.ModuleList([exp0, exp1, exp2, exp3]) + model = model.to(get_current_device()) + sync_moe_model_param(model) + + assert_equal_in_group(exp0.experts[0].weight.data, parallel_info_dict[1].dp_group) + assert_equal_in_group(exp0.experts[0].bias.data, parallel_info_dict[1].dp_group) + # MOE experts layout success when ep_size = 1 + + assert_equal_in_group(exp1.experts[0].weight.data, parallel_info_dict[2].dp_group) + assert_equal_in_group(exp1.experts[0].bias.data, parallel_info_dict[2].dp_group) + # MOE experts layout success when ep_size = 2 + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_moe_initialization(): + spawn(run_test, 4) + + +if __name__ == '__main__': + test_moe_initialization() diff --git a/tests/test_moe/test_moe_zero_init.py b/tests/test_moe/test_moe_zero_init.py new file mode 100644 index 0000000000000000000000000000000000000000..79722f9f40560bfa78ee2d871da94fdab2a185a3 --- /dev/null +++ b/tests/test_moe/test_moe_zero_init.py @@ -0,0 +1,108 @@ +import pytest +import torch +import torch.nn as nn + +import colossalai +from colossalai.context import MOE_CONTEXT +from colossalai.logging import get_dist_logger +from colossalai.nn import CheckpointModule +from colossalai.nn.layer import MoeModule +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device +from colossalai.zero.legacy.init_ctx import ZeroInitContext +from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy +from tests.test_zero.test_legacy.common import CONFIG + + +class MoeModel(nn.Module): + + def __init__(self, checkpoint: bool = False): + + class TestSubModule(CheckpointModule): + + def __init__(self): + super().__init__(checkpoint) + expert_cls = nn.Linear + expert_args_dict = dict(in_features=16, out_features=16) + self.moe = MoeModule(dim_model=16, + num_experts=8, + use_residual=True, + expert_cls=expert_cls, + **expert_args_dict) + self.proj = nn.Linear(16, 4) + + def _forward(self, x): + x, y = self.moe(x) + x = self.proj(x) + return x, y + + super().__init__() + self.test_embed = nn.Linear(4, 16) + self.test_transform = TestSubModule() + + def forward(self, x): + MOE_CONTEXT.reset_loss() + + x = self.test_embed(x) + x, y = self.test_transform(x) + + MOE_CONTEXT.add_loss(y) + return x + + +@parameterize("init_device_type", ['cpu', 'cuda']) +@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) +def run_moe_zero_init(init_device_type, shard_strategy_class): + logger = get_dist_logger("test_moe_zero_init") + + if init_device_type == 'cuda': + init_device = get_current_device() + elif init_device_type == 'cpu': + init_device = torch.device("cpu") + else: + raise NotImplementedError("Unknown device found.") + + model_numel_tensor = torch.zeros(1, dtype=torch.int) + with ZeroInitContext(target_device=init_device, + shard_strategy=shard_strategy_class(), + shard_param=True, + model_numel_tensor=model_numel_tensor): + model = MoeModel(checkpoint=True) + + for name, param in model.named_parameters(): + assert hasattr(param, 'colo_attr') + + # the parameters in moe experts and its gate should not be sharded + if ('experts' in name) or ('gate' in name) or ('residual_combine' in name): + assert not param.colo_attr.sharded_data_tensor.is_sharded, "`{}` parameter has problem".format(name) + else: + assert param.colo_attr.sharded_data_tensor.is_sharded + + # the parameters in moe experts is not replicated + if 'experts' in name: + assert not param.colo_attr.is_replicated + else: + assert param.colo_attr.is_replicated + + if param.colo_attr.param_is_sharded: + assert param.colo_attr.data_payload.device.type == init_device.type, \ + f'{param.colo_attr.data_payload.device.type} vs. {init_device.type}' + else: + assert param.colo_attr.data_payload.device.type == 'cuda' + + +def _run_dist(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + MOE_CONTEXT.setup(seed=42) + run_moe_zero_init() + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [2, 4]) +@rerun_if_address_is_in_use() +def test_moe_zero_init(world_size): + spawn(_run_dist, world_size) + + +if __name__ == '__main__': + test_moe_zero_init(world_size=2) diff --git a/tests/test_moe/test_moe_zero_model.py b/tests/test_moe/test_moe_zero_model.py new file mode 100644 index 0000000000000000000000000000000000000000..ec37967f18c5a0b2fc735e38de062c4a1967e4f9 --- /dev/null +++ b/tests/test_moe/test_moe_zero_model.py @@ -0,0 +1,70 @@ +import pytest +import torch + +import colossalai +from colossalai.context import MOE_CONTEXT +from colossalai.engine.gradient_handler import MoeGradientHandler +from colossalai.nn import MoeLoss +from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use, spawn +from colossalai.zero.legacy.init_ctx import ZeroInitContext +from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy +from colossalai.zero.legacy.sharded_model import ShardedModelV2 +from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp16 +from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_moe.test_moe_zero_init import MoeModel +from tests.test_zero.test_legacy.common import CONFIG, check_grads_padding, run_fwd_bwd + + +@parameterize("enable_autocast", [False]) +@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) +def run_model_test(enable_autocast, shard_strategy_class): + shard_strategy = shard_strategy_class() + + get_components_func = non_distributed_component_funcs.get_callable('hanging_param_model') + _, train_dataloader, _, optimizer_class, _ = get_components_func() + criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss) + + with ZeroInitContext(target_device=torch.device('cuda', torch.cuda.current_device()), + shard_strategy=shard_strategy, + shard_param=True): + zero_model = MoeModel(checkpoint=True) + zero_model = ShardedModelV2(zero_model, shard_strategy) + + # check whether parameters are identical in ddp + for name, p in zero_model.named_parameters(): + if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated: + assert_equal_in_group(p.colo_attr.data_payload) + + model = MoeModel(checkpoint=True).half() + col_model_deepcopy(zero_model, model) + model = model.cuda() + grad_handler = MoeGradientHandler(model) + + for i, (data, label) in enumerate(train_dataloader): + if i > 5: + break + + data, label = cast_tensor_to_fp16(data).cuda(), label.cuda() + run_fwd_bwd(model, data, label, criterion, enable_autocast) + run_fwd_bwd(zero_model, data, label, criterion, enable_autocast) + grad_handler.handle_gradient() + + check_grads_padding(model, zero_model, loose=True) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + MOE_CONTEXT.setup(seed=42) + run_model_test() + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [2]) +@rerun_if_address_is_in_use() +def test_moe_zero_model(world_size): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_moe_zero_model(world_size=2) diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py new file mode 100644 index 0000000000000000000000000000000000000000..efc6e9ddae27311ab5e16771b366b2baf87d2e00 --- /dev/null +++ b/tests/test_moe/test_moe_zero_optim.py @@ -0,0 +1,120 @@ +import pytest +import torch + +import colossalai +from colossalai.amp import convert_to_apex_amp +from colossalai.context import MOE_CONTEXT +from colossalai.engine.gradient_handler import MoeGradientHandler +from colossalai.nn import MoeLoss +from colossalai.nn.optimizer import CPUAdam +from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device +from colossalai.zero.legacy.init_ctx import ZeroInitContext +from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy +from colossalai.zero.legacy.sharded_model import ShardedModelV2 +from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy +from colossalai.zero.legacy.sharded_optim import ShardedOptimizerV2 +from colossalai.zero.low_level._utils import has_inf_or_nan +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_moe.test_moe_zero_init import MoeModel +from tests.test_zero.test_legacy.common import CONFIG, check_sharded_model_params + + +def _run_step(model, optimizer, data, label, criterion, grad_handler): + model.train() + optimizer.zero_grad() + + if criterion: + y = model(data) + loss = criterion(y, label) + else: + loss = model(data, label) + + loss = loss.float() + if isinstance(model, ShardedModelV2): + optimizer.backward(loss) + else: + loss.backward() + + if grad_handler is not None: + grad_handler.handle_gradient() + + optimizer.step() + + +@parameterize("cpu_offload", [True]) +@parameterize("use_cpuadam", [True]) # We do not use Hybrid Adam right now, since it has a little bug +@parameterize("reuse_fp16_shard", [True, False]) +@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) +def _run_test_sharded_optim_v2(cpu_offload, + shard_strategy_class, + use_cpuadam, + reuse_fp16_shard, + gpu_margin_mem_ratio=0.0): + shard_strategy = shard_strategy_class() + if use_cpuadam and cpu_offload is False: + return + MOE_CONTEXT.reset_loss() + get_components_func = non_distributed_component_funcs.get_callable('hanging_param_model') + _, train_dataloader, _, optimizer_class, _ = get_components_func() + criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss) + + with ZeroInitContext(target_device=torch.device('cpu') if cpu_offload else get_current_device(), + shard_strategy=shard_strategy, + shard_param=True): + zero_model = MoeModel(checkpoint=True) + + zero_model = ShardedModelV2(zero_model, + shard_strategy, + tensor_placement_policy='cpu' if cpu_offload else 'cuda', + reuse_fp16_shard=reuse_fp16_shard) + + # check whether parameters are identical in ddp + for name, p in zero_model.named_parameters(): + if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated: + assert_equal_in_group(p.colo_attr.data_payload.to(get_current_device())) + + model = MoeModel(checkpoint=True).half() + col_model_deepcopy(zero_model, model) + model = model.cuda().float() + + if use_cpuadam: + optimizer_class = CPUAdam + optim = optimizer_class(model.parameters(), lr=1e-3) + sharded_optim = optimizer_class(zero_model.parameters(), lr=1e-3) + sharded_optim = ShardedOptimizerV2(zero_model, + sharded_optim, + initial_scale=2**5, + gpu_margin_mem_ratio=gpu_margin_mem_ratio) + + amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False) + apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config) + apex_grad_handler = MoeGradientHandler(model) + + for i, (data, label) in enumerate(train_dataloader): + if i > 5: + break + data, label = data.cuda(), label.cuda() + _run_step(apex_model, apex_optimizer, data, label, criterion, apex_grad_handler) + _run_step(zero_model, sharded_optim, data, label, criterion, None) + check_sharded_model_params(model, zero_model, loose=True, reuse_fp16_shard=use_cpuadam) + for param in model.parameters(): + assert not has_inf_or_nan(param) + + +def _run_dist(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + MOE_CONTEXT.setup(seed=42) + _run_test_sharded_optim_v2() + + +# use_cpuadam = True can be used with cpu_offload = False +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [2]) +@rerun_if_address_is_in_use() +def test_moe_zero_optim(world_size): + spawn(_run_dist, world_size) + + +if __name__ == '__main__': + test_moe_zero_optim(world_size=4) diff --git a/tests/test_ops/test_addmm_tp.py b/tests/test_ops/test_addmm_tp.py new file mode 100644 index 0000000000000000000000000000000000000000..ecd3721b902e252de6c765a6f8b1c96215d984a9 --- /dev/null +++ b/tests/test_ops/test_addmm_tp.py @@ -0,0 +1,73 @@ +import pytest +import torch +import torch.nn as nn + +import colossalai +from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup +from colossalai.testing import rerun_if_address_is_in_use, spawn +from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, tensor_equal, tensor_shard_equal + + +class Conv1D(nn.Module): + """ + 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2). + Basically works like a linear layer but the weights are transposed. + Args: + nf (`int`): The number of output features. + nx (`int`): The number of input features. + """ + + def __init__(self, nf, nx): + super().__init__() + self.nf = nf + w = torch.empty(nx, nf) + nn.init.normal_(w, std=0.02) + self.weight = nn.Parameter(w) + self.bias = nn.Parameter(torch.ones(nf)) + + def forward(self, x): + size_out = x.size()[:-1] + (self.nf,) + x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) + x = x.view(size_out) + return x + + +def run_with_spec(spec_init_func, split_bias): + model = Conv1D(4, 16).cuda() + world_size = torch.distributed.get_world_size() + pg = ProcessGroup(tp_degree=world_size) + + weight = ColoTensor(torch.nn.Parameter(model.weight.detach()), ColoTensorSpec(pg)) + bias = ColoTensor(torch.nn.Parameter(model.bias.detach()), ColoTensorSpec(pg)) + + spec_init_func(weight, pg) + if split_bias: + spec_init_func(bias, pg) + + x = torch.rand(2, 16).cuda() + out = model(x) + colo_out = torch.addmm(bias, x, weight) + colo_out = colo_out.to_replicate() + assert tensor_equal(out, colo_out) + grad = torch.rand_like(out) + out.backward(grad) + colo_out.backward(grad) + tensor_shard_equal(model.weight.grad, weight.grad, pg.tp_local_rank(), pg.tp_world_size()) + tensor_shard_equal(model.bias.grad, bias.grad, pg.tp_local_rank(), pg.tp_world_size()) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_with_spec(spec_init_func=split_param_row_tp1d, split_bias=False) + run_with_spec(spec_init_func=split_param_col_tp1d, split_bias=True) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_addmm_1d(world_size): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_addmm_1d(4) diff --git a/tests/test_ops/test_embedding_bag_tp.py b/tests/test_ops/test_embedding_bag_tp.py new file mode 100644 index 0000000000000000000000000000000000000000..d3d3dcf7e2c9a2737c26e97ad4e7ae487c201917 --- /dev/null +++ b/tests/test_ops/test_embedding_bag_tp.py @@ -0,0 +1,43 @@ +import pytest +import torch +from torch.nn import functional as F + +import colossalai +from colossalai.tensor import ColoParameter, ColoTensorSpec, ProcessGroup +from colossalai.testing import rerun_if_address_is_in_use, spawn +from tests.test_tensor.common_utils import split_param_col_tp1d, tensor_equal, tensor_shard_equal + + +def run_with_spec(spec_init_func): + pg = ProcessGroup(tp_degree=torch.distributed.get_world_size()) + model = torch.nn.EmbeddingBag(10, 4).cuda() + weight = ColoParameter(model.weight.clone(), True, ColoTensorSpec(pg)) + + spec_init_func(weight, pg) + + inputs = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]).cuda() + offsets = torch.tensor([0, 4]).cuda() + out = model(inputs, offsets=offsets) + colo_out = F.embedding_bag(inputs, weight, offsets=offsets) + assert tensor_equal(out, colo_out) + grad = torch.rand_like(out) + out.backward(grad) + colo_out.backward(grad) + assert tensor_shard_equal(model.weight.grad, weight.grad, pg.tp_local_rank(), pg.tp_world_size()) + + +def run_dist(rank, world_size, port): + config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_with_spec(split_param_col_tp1d) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_embedding_bag_1d(world_size): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_embedding_bag_1d(4) diff --git a/tests/test_ops/test_embedding_tp.py b/tests/test_ops/test_embedding_tp.py new file mode 100644 index 0000000000000000000000000000000000000000..c0b376e2c92a298bf2e8c42257ba8b39a14e9a35 --- /dev/null +++ b/tests/test_ops/test_embedding_tp.py @@ -0,0 +1,44 @@ +import pytest +import torch +from torch.nn import functional as F + +import colossalai +from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup +from colossalai.testing import rerun_if_address_is_in_use, spawn +from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, tensor_equal, tensor_shard_equal + + +def run_with_spec(spec_init_func, pg: ProcessGroup): + model = torch.nn.Embedding(12, 32).cuda() + weight = ColoTensor(torch.nn.Parameter(model.weight.detach()), ColoTensorSpec(pg)) + + spec_init_func(weight, pg) + + x = torch.tensor((0, 3, 6, 9)).cuda() + out = model(x) + colo_out = F.embedding(x, weight) + assert tensor_equal(out, colo_out) + grad = torch.rand_like(out) + out.backward(grad) + colo_out.backward(grad) + # compare grad inside a TP group + assert tensor_shard_equal(model.weight.grad, weight.grad, pg.tp_local_rank(), pg.tp_world_size()) + + +def run_dist(rank, world_size, port): + # config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + pg = ProcessGroup(tp_degree=world_size) + run_with_spec(split_param_row_tp1d, pg) + run_with_spec(split_param_col_tp1d, pg) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_embedding_1d(world_size): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_embedding_1d(4) diff --git a/tests/test_ops/test_linear_tp.py b/tests/test_ops/test_linear_tp.py new file mode 100644 index 0000000000000000000000000000000000000000..c88adfdd9a7757cd7eabc8f27d40adae6374c61c --- /dev/null +++ b/tests/test_ops/test_linear_tp.py @@ -0,0 +1,48 @@ +import pytest +import torch +import torch.nn.functional as F + +import colossalai +from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup +from colossalai.testing import rerun_if_address_is_in_use, spawn +from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, tensor_equal, tensor_shard_equal + + +def run_with_spec(spec_init_func, split_bias): + pg = ProcessGroup(tp_degree=torch.distributed.get_world_size()) + model = torch.nn.Linear(4, 8).cuda() + weight = ColoTensor(torch.nn.Parameter(model.weight.detach()), ColoTensorSpec(pg)) + bias = ColoTensor(torch.nn.Parameter(model.bias.detach()), ColoTensorSpec(pg)) + + spec_init_func(weight, pg) + if split_bias: + spec_init_func(bias, pg) + + x = torch.rand(2, 4).cuda() + out = model(x) + colo_out = F.linear(x, weight, bias) + colo_out = colo_out.to_replicate() + assert tensor_equal(out, colo_out) + grad = torch.rand_like(out) + out.backward(grad) + colo_out.backward(grad) + assert tensor_shard_equal(model.weight.grad, weight.grad, pg.tp_local_rank(), pg.tp_world_size()) + assert tensor_shard_equal(model.bias.grad, bias.grad, pg.tp_local_rank(), pg.tp_world_size()) + + +def run_dist(rank, world_size, port): + config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_with_spec(spec_init_func=split_param_col_tp1d, split_bias=False) + run_with_spec(spec_init_func=split_param_row_tp1d, split_bias=True) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_linear_1d(world_size): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_linear_1d(4) diff --git a/tests/test_ops/test_loss_func.py b/tests/test_ops/test_loss_func.py new file mode 100644 index 0000000000000000000000000000000000000000..fc55c7f7725412ec75411e60b51791c654848ca3 --- /dev/null +++ b/tests/test_ops/test_loss_func.py @@ -0,0 +1,48 @@ +import pytest +import torch +import torch.nn.functional as F + +import colossalai +from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device + + +def check_cross_entropy(): + input_t = torch.randn(4, 4, device=get_current_device(), requires_grad=True) + input_ct = torch.randn(4, 4, device=get_current_device(), requires_grad=True) + with torch.no_grad(): + input_ct.copy_(input_t) + + target = torch.randint(4, (4,), dtype=torch.int64, device=get_current_device()) + + world_size = torch.distributed.get_world_size() + pg = ProcessGroup(tp_degree=world_size) + input_t_colo = ColoTensor.from_torch_tensor(tensor=input_ct, spec=ColoTensorSpec(pg)) + input_shard = input_t_colo.redistribute(ShardSpec([-1], [pg.tp_world_size()])) + input_shard.set_tensor_spec(dist_spec=None, compute_spec=ComputeSpec(ComputePattern.TP1D)) + + output = F.cross_entropy(input_t, target) + output_colo = F.cross_entropy(input_shard, target) + assert torch.allclose(output_colo, output) + + output.backward() + output_colo.backward() + + assert torch.allclose(input_t.grad, input_ct.grad) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + check_cross_entropy() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2]) +@rerun_if_address_is_in_use() +def test_loss_func(world_size): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_loss_func(1) diff --git a/tests/test_ops/test_op.py b/tests/test_ops/test_op.py new file mode 100644 index 0000000000000000000000000000000000000000..4176d3b64d90e4c560ba48a6de42a9c9e26b73d1 --- /dev/null +++ b/tests/test_ops/test_op.py @@ -0,0 +1,87 @@ +import pytest +import torch +import torch.nn.functional as F +from torch.nn import Parameter + +import colossalai +from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup, ShardSpec +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device + + +def _run_layer_norm(): + ln_op = torch.nn.LayerNorm(2, 3, device=get_current_device()) + + input_t = torch.randn(3, 2, device=get_current_device()) + + pg = ProcessGroup(tp_degree=torch.distributed.get_world_size()) + input_t_colo = ColoTensor.from_torch_tensor(input_t.clone().detach(), ColoTensorSpec(pg)) + + # prepare colossalai LN + weight = ColoTensor(Parameter(ln_op.weight.detach()), ColoTensorSpec(pg)) + bias = ColoTensor(Parameter(ln_op.bias.detach()), ColoTensorSpec(pg)) + + output = ln_op(input_t) + output_colo = F.layer_norm(input_t_colo, ln_op.normalized_shape, weight, bias, ln_op.eps) + + assert torch.allclose(output_colo, output) + + torch.mean(output).backward() + torch.mean(output_colo).backward() + + assert torch.allclose(ln_op.weight.grad, weight.grad) + + +def check_spec_eq(tensor, other): + assert isinstance(tensor, ColoTensor) and isinstance(other, ColoTensor) + for k in dir(tensor.dist_spec): + if not k.startswith('__'): + assert hasattr(other.dist_spec, k), f"{k}" + assert getattr(tensor.dist_spec, k) == getattr(other.dist_spec, k) + + +def check_element_wise_ops(): + world_size = torch.distributed.get_world_size() + pg = ProcessGroup(tp_degree=world_size) + t = torch.rand(2, 2) + x = ColoTensor(t, spec=ColoTensorSpec(pg, ShardSpec([0], [pg.tp_world_size()]))) + + check_spec_eq(x, x.cuda()) + assert torch.equal(x.cuda(), t.cuda()) + check_spec_eq(x, torch.abs(x)) + assert torch.equal(torch.abs(x), torch.abs(t)) + check_spec_eq(x, F.sigmoid(x)) + assert torch.equal(F.sigmoid(x), F.sigmoid(t)) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + check_element_wise_ops() + _run_layer_norm() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [2]) +@rerun_if_address_is_in_use() +def test_element_wise_ops(world_size): + spawn(run_dist, world_size) + + +def run_dist2(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + _run_layer_norm() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1]) +@rerun_if_address_is_in_use() +def test_ln(world_size): + spawn(run_dist2, world_size) + + +def check_all(): + test_element_wise_ops(2) + + +if __name__ == '__main__': + check_all() diff --git a/tests/test_ops/test_view.py b/tests/test_ops/test_view.py new file mode 100644 index 0000000000000000000000000000000000000000..a9f2033201c7ac6a1797b92dd6e834d97a90b5d1 --- /dev/null +++ b/tests/test_ops/test_view.py @@ -0,0 +1,97 @@ +import pytest +import torch +import torch.distributed as dist + +import colossalai +from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup, ShardSpec +from colossalai.tensor.distspec import DistPlacementPattern +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device +from tests.test_tensor.common_utils import debug_print, split_param_col_tp1d, split_param_row_tp1d + + +def exam_view_core(pg): + # the case of replicated ColoTensors + x = torch.randn(4, 4).cuda() + x_colo = ColoTensor(x, ColoTensorSpec(pg)) + + y = x.view(2, -1, 2) + y_colo = x_colo.view(2, -1, 2) + + assert torch.all(y == y_colo) + assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE + # the perfect case of col-sliced ColoTensors + split_param_col_tp1d(x_colo, pg) + + z = x.view(torch.Size((2, 1, 2, -1))) + z_colo = x_colo.view(torch.Size((2, 1, 2, -1))) + if dist.get_rank() == 0: + z = z[:, :, :, 0:2] + else: + z = z[:, :, :, 2:] + assert torch.all(z == z_colo) + assert z_colo.dist_spec == x_colo.dist_spec + # the perfect case of row-sliced ColoTensors + split_param_row_tp1d(x_colo, pg) + + z = x.view(torch.Size((-1, 2, 2))) + z_colo = x_colo.view(torch.Size((-1, 2, 2))) + if dist.get_rank() == 0: + z = z[0:2, :, :] + else: + z = z[2:, :, :] + assert torch.all(z == z_colo) + assert z_colo.dist_spec == x_colo.dist_spec + # the normal case of row-sliced ColoTensors + z = x.view(-1, 2, 2, 2) + z_colo = x_colo.view(-1, 2, 2, 2) + assert torch.all(z == z_colo) + assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE + + +def exam_view_autograd(pg): + x = torch.randn(8, 2, device=get_current_device(), requires_grad=True) + y = torch.randn(8, 2, device=get_current_device(), requires_grad=True) + with torch.no_grad(): + y.copy_(x) + y = ColoTensor(y, ColoTensorSpec(pg)) + y_slice = y.redistribute(ShardSpec([-1], [pg.tp_world_size()])) + + xx = x.view(2, 2, -1) + yy_slice = y_slice.view(2, 2, -1) + yy = yy_slice.to_replicate() + grad = torch.randn(2, 2, 4, device=get_current_device()) + + xx.backward(grad) + yy.backward(grad) + assert torch.all(x.grad == y.grad) + + +def exam_view_errors(pg): + x = torch.randn(8, 2, device=get_current_device()) + x = ColoTensor(x, ColoTensorSpec(pg)) + split_param_row_tp1d(x, pg) + + x.view('a', 'b', 'c') + x.view(8, -1) + x.view([-2, -2, -2]) + x.view((-1, -1, -1)) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + pg = ProcessGroup(tp_degree=torch.distributed.get_world_size()) + exam_view_core(pg) + exam_view_autograd(pg) + # exam_view_errors(pg) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [2]) +@rerun_if_address_is_in_use() +def test_view(world_size): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_view(2) diff --git a/tests/test_optimizer/test_cpu_adam.py b/tests/test_optimizer/test_cpu_adam.py new file mode 100644 index 0000000000000000000000000000000000000000..8b3ecf8517f7cd13cc67555bb81cf388efdd0654 --- /dev/null +++ b/tests/test_optimizer/test_cpu_adam.py @@ -0,0 +1,121 @@ +import math + +import torch + +from colossalai.testing import clear_cache_before_run, parameterize + + +def torch_adam_update( + step, + lr, + beta1, + beta2, + eps, + weight_decay, + param, + grad, + exp_avg, + exp_avg_sq, + use_adamw, +): + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + if weight_decay != 0: + if use_adamw: + # Perform stepweight decay + param.mul_(1 - lr * weight_decay) + else: + grad = grad.add(param, alpha=weight_decay) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) + + step_size = lr / bias_correction1 + + param.addcdiv_(exp_avg, denom, value=-step_size) + + +def assertLess(data_diff, threshold, msg): + assert data_diff < threshold, msg + + +def assertTrue(condition, msg): + assert condition, msg + + +@clear_cache_before_run() +@parameterize('adamw', [True, False]) +@parameterize('step', [1, 2]) +@parameterize('p_dtype', [torch.float, torch.half]) +@parameterize('g_dtype', [torch.float, torch.half]) +def test_cpu_adam(adamw, step, p_dtype, g_dtype): + lr = 1e-3 + beta1, beta2 = 0.9, 0.999 + eps = 1e-8 + weight_decay = 0 + + for i in range(3): + p_data = torch.rand(64, dtype=p_dtype) + p_data_copy = p_data.clone().float() + p_grad = torch.rand(64, dtype=g_dtype) + p_grad_copy = p_grad.clone().float() + exp_avg = torch.rand(p_data.shape) + exp_avg_copy = exp_avg.clone() + exp_avg_sq = torch.rand(p_data.shape) + exp_avg_sq_copy = exp_avg_sq.clone() + + from colossalai.kernel.op_builder import CPUAdamBuilder + cpu_optim = CPUAdamBuilder().load() + + cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, adamw) + + cpu_adam_op.step( + step, + lr, + beta1, + beta2, + eps, + weight_decay, + True, + p_data.view(-1), # fp32 data + p_grad.view(-1), # fp32 grad + exp_avg.view(-1), + exp_avg_sq.view(-1), + -1, + ) + + torch_adam_update( + step, + lr, + beta1, + beta2, + eps, + weight_decay, + p_data_copy, # fp32 data + p_grad_copy, # fp32 grad + exp_avg_copy, + exp_avg_sq_copy, + adamw, + ) + var = p_data_copy - p_data + data_diff = torch.max(torch.abs(var)) + threshold = 1e-3 + assertLess( + data_diff, + threshold, + f"p_data diff {data_diff}. failed check, step {step}, lr {lr}, eps " + f"{eps} beta1 {beta1} beta2 {beta2} weight_decay {weight_decay} p_dtype {p_dtype}, g_dtype {g_dtype}", + ) + max_grad_diff = torch.max(torch.abs(p_grad_copy - p_grad)) + assertTrue(max_grad_diff < threshold, f"diff {max_grad_diff}") + max_exp_avg_diff = torch.max(torch.abs(exp_avg_copy - exp_avg)) + assertTrue(max_exp_avg_diff < threshold, f"max_exp_avg_diff {max_exp_avg_diff}") + max_exp_avg_sq_diff = torch.max(torch.abs(exp_avg_sq_copy - exp_avg_sq)) + assertTrue(max_exp_avg_sq_diff < threshold, f"max_exp_avg_sq_diff {max_exp_avg_sq_diff}") + + +if __name__ == '__main__': + test_cpu_adam() diff --git a/tests/test_optimizer/test_fused_adam.py b/tests/test_optimizer/test_fused_adam.py new file mode 100644 index 0000000000000000000000000000000000000000..114d5293dad96063d9cdd7103be01a73fe60524a --- /dev/null +++ b/tests/test_optimizer/test_fused_adam.py @@ -0,0 +1,64 @@ +import torch +import torch.nn as nn +from torch.optim import AdamW +from torch.optim.adam import Adam + +from colossalai.nn.optimizer.fused_adam import FusedAdam +from colossalai.testing import clear_cache_before_run, parameterize + + +class FC(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.fc = nn.Sequential(nn.Linear(64, 64)) + + def forward(self, x): + return self.fc(x) + + +@clear_cache_before_run() +@parameterize('adamw', [False, True]) +@parameterize('p_dtype', [torch.float, torch.half]) +@parameterize('g_dtype', [torch.float, torch.half]) +def test_adam(adamw, p_dtype, g_dtype): + model = FC().cuda().to(p_dtype) + state = model.state_dict() + model_copy = FC().cuda().to(p_dtype) + model_copy.load_state_dict(state.copy()) + + if adamw: + optim = FusedAdam(model.parameters(), lr=1e-3, adamw_mode=True) + torch_optim = AdamW(model_copy.parameters(), lr=1e-3) + else: + optim = FusedAdam(model.parameters(), lr=1e-3) + torch_optim = Adam(model_copy.parameters(), lr=1e-3) + + data = torch.rand(1024, 64).cuda().to(p_dtype) + data_copy = data.clone() + label = torch.rand(1024, 64).cuda().to(p_dtype) + + for d, l in zip(data, label): + y = model(d) + loss = ((l - y)**2).sum() + optim.zero_grad() + loss.backward() + if p_dtype != g_dtype: + for i in range(len(optim.param_groups[0]['params'])): + optim.param_groups[0]['params'][i].grad.data = optim.param_groups[0]['params'][i].grad.data.to(g_dtype) + optim.step() + + for d, l in zip(data_copy, label): + y = model_copy(d) + loss = ((l - y)**2).sum() + torch_optim.zero_grad() + loss.backward() + torch_optim.step() + + assert len(optim.param_groups[0]['params']) == len(torch_optim.param_groups[0]['params']) + + for i in range(len(optim.param_groups[0]['params'])): + if torch.isnan(optim.param_groups[0]['params'][i]).any() \ + or torch.isnan(torch_optim.param_groups[0]['params'][i]).any(): + continue + assert torch.allclose(optim.param_groups[0]['params'][i], torch_optim.param_groups[0]['params'][i], 2e-3, 2e-3) diff --git a/tests/test_optimizer/test_fused_adam_kernel.py b/tests/test_optimizer/test_fused_adam_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..4afa13349c1be4a1424855886d95e17a14047369 --- /dev/null +++ b/tests/test_optimizer/test_fused_adam_kernel.py @@ -0,0 +1,95 @@ +import math + +import torch +import torch.nn as nn +from numpy import dtype + +from colossalai.testing import clear_cache_before_run, parameterize +from colossalai.utils import multi_tensor_applier + + +def torch_adam_update( + step, + lr, + beta1, + beta2, + eps, + weight_decay, + param, + grad, + exp_avg, + exp_avg_sq, + use_adamw, +): + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + if weight_decay != 0: + if use_adamw: + # Perform stepweight decay + param.mul_(1 - lr * weight_decay) + else: + grad = grad.add(param, alpha=weight_decay) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) + + step_size = lr / bias_correction1 + + param.addcdiv_(exp_avg, denom, value=-step_size) + + +@clear_cache_before_run() +@parameterize('adamw', [False, True]) +@parameterize('step', [1, 2]) +@parameterize('p_dtype', [torch.float, torch.half]) +@parameterize('g_dtype', [torch.float, torch.half]) +def test_adam(adamw, step, p_dtype, g_dtype): + from colossalai.kernel.op_builder import FusedOptimBuilder + fused_optim = FusedOptimBuilder().load() + fused_adam = fused_optim.multi_tensor_adam + + dummy_overflow_buf = torch.cuda.IntTensor([0]) + + count = 0 + + for i in range(3): + p = torch.rand(64, dtype=p_dtype).cuda() + p_copy = p.clone().float() + g = torch.rand(p.shape, dtype=g_dtype).cuda() + g_copy = g.clone().float() + m = torch.rand(p.shape).cuda() + m_copy = m.clone() + v = torch.rand(p.shape).cuda() + v_copy = v.clone() + + lr = 1e-3 + beta1, beta2 = 0.9, 0.999 + eps = 1e-8 + weight_decay = 0 + + multi_tensor_applier(fused_adam, dummy_overflow_buf, [[g], [p], [m], [v]], lr, beta1, beta2, eps, step, adamw, + True, weight_decay, -1) + + torch_adam_update( + step, + lr, + beta1, + beta2, + eps, + weight_decay, + p_copy, # fp32 data + g_copy, # fp32 grad + m_copy, + v_copy, + adamw, + ) + + if torch.isnan(p).any() or torch.isnan(p_copy).any(): + count += 1 + continue + assert count < 200, "too many nans" + assert torch.allclose(p.to(torch.float), p_copy.to(torch.float), 1e-5, + 1e-5), f"failed check, adamw {adamw}, p_dtype {p_dtype}, g_dtype {g_dtype}" diff --git a/tests/test_optimizer/test_hybrid_adam.py b/tests/test_optimizer/test_hybrid_adam.py new file mode 100644 index 0000000000000000000000000000000000000000..d075149dfcb1e6a7c3d6357f1987a68750004810 --- /dev/null +++ b/tests/test_optimizer/test_hybrid_adam.py @@ -0,0 +1,42 @@ +import torch +import torch.nn as nn +from torch.optim import AdamW +from torch.optim.adam import Adam + +from colossalai.nn.optimizer.hybrid_adam import HybridAdam +from colossalai.testing import clear_cache_before_run, parameterize + +RE = 3 + + +@clear_cache_before_run() +@parameterize('adamw', [False, True]) +@parameterize('device', ['cpu', 'cuda:0']) +@parameterize('p_dtype', [torch.float]) +@parameterize('g_dtype', [torch.float, torch.half]) +def test_adam(adamw, device, p_dtype, g_dtype): + rng_state = torch.get_rng_state() + p = nn.Parameter(torch.rand(64).to(device, p_dtype)) + torch.set_rng_state(rng_state) + p_copy = nn.Parameter(torch.rand(64).to(device).float()) + + if adamw: + optim = HybridAdam([p], lr=1e-3, adamw_mode=True) + torch_optim = AdamW([p_copy], lr=1e-3) + else: + optim = HybridAdam([p], lr=1e-3) + torch_optim = Adam([p_copy], lr=1e-3) + + print(f"adaw mode {adamw}, device {device}, p_dtype {p_dtype}, g_dtype {g_dtype}") + for i in range(RE): + p.grad = torch.rand(64).to(device, p_dtype) + p_copy.grad = p.grad.clone().float() + p.grad.data = p.grad.data.to(g_dtype) + + optim.step() + torch_optim.step() + + if torch.isnan(p.data).any() or torch.isnan(p_copy.data).any(): + continue + assert torch.allclose(p.data, p_copy.data, 1e-4, 1e-2), \ + f"adaw mode {adamw}, device {device}, p_dtype {p_dtype}, g_dtype {g_dtype}" diff --git a/tests/test_optimizer/test_nvme.py b/tests/test_optimizer/test_nvme.py new file mode 100644 index 0000000000000000000000000000000000000000..5d794ac2dd1a92123fd8fde23c40fcc529570aa7 --- /dev/null +++ b/tests/test_optimizer/test_nvme.py @@ -0,0 +1,49 @@ +import pytest +import torch + +from colossalai.nn.optimizer import CPUAdam, HybridAdam +from colossalai.testing import clear_cache_before_run, parameterize +from tests.components_to_test.registry import non_distributed_component_funcs + + +def move_some_params_to_cuda(model, torch_model): + model.embed.weight.data = model.embed.weight.cuda() + torch_model.embed.weight.data = model.embed.weight.cuda() + model.ln1.weight.data = model.ln1.weight.cuda() + torch_model.ln1.weight.data = model.ln1.weight.cuda() + + +def check_params_equal(model, torch_model): + for p, torch_p in zip(model.parameters(), torch_model.parameters()): + assert torch.allclose(p, torch_p, atol=1e-3), f'diff: {torch.abs(p - torch_p)}' + + +@clear_cache_before_run() +@parameterize('nvme_offload_fraction', [0.0, 0.5, 1.0]) +@parameterize('nvme_offload_dir', ['./offload', None]) +@parameterize('adam_cls', [CPUAdam, HybridAdam]) +def test_nvme_adam(nvme_offload_fraction, nvme_offload_dir, adam_cls): + get_components_func = non_distributed_component_funcs.get_callable('simple_net') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + model = model_builder() + torch_model = model_builder() + move_some_params_to_cuda(model, torch_model) + optimizer = adam_cls(model.parameters(), + lr=0.1, + nvme_offload_fraction=nvme_offload_fraction, + nvme_offload_dir=nvme_offload_dir) + torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=0.1) + with torch.no_grad(): + for p, torch_p in zip(model.parameters(), torch_model.parameters()): + torch_p.copy_(p) + p.grad = torch.rand_like(p) + torch_p.grad = p.grad + + for _ in range(3): + optimizer.step() + torch_optimizer.step() + check_params_equal(model, torch_model) + + +if __name__ == '__main__': + test_nvme_adam(0.5, './offload', CPUAdam) diff --git a/tests/test_pipeline/rpc_test_utils.py b/tests/test_pipeline/rpc_test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..dab474a4ee21156d4ebdcd6e19570878776b1efe --- /dev/null +++ b/tests/test_pipeline/rpc_test_utils.py @@ -0,0 +1,150 @@ +import argparse +import os +import warnings + +import torch +import torch.distributed as dist +import torch.distributed.rpc as rpc +import torch.multiprocessing as mp +from torch import nn +from torch._C._distributed_rpc import _is_current_rpc_agent_set +from torch.optim import SGD, Adam, Optimizer, RMSprop + +from colossalai import launch +from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.pipeline_process_group import ppg + +rpc_is_initialized = _is_current_rpc_agent_set + + +def color_debug(text, prefix=' ', color='blue'): + color = color.upper() + print(getattr(Back, color), prefix, Style.RESET_ALL, text) + + +class MLP(nn.Module): + + def __init__(self, dim: int, layers: int): + super().__init__() + self.layers = torch.nn.ModuleList() + + for _ in range(layers): + self.layers.append(nn.Linear(dim, dim, bias=False)) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x.sum() + + +class DAG_MLP(nn.Module): + + def __init__(self, dim: int, layers: int): + super().__init__() + self.layers = torch.nn.ModuleList() + self.dag_layer = nn.Linear(dim, dim, bias=False) + + for _ in range(layers): + self.layers.append(nn.Linear(dim, dim, bias=False)) + + def forward(self, x, y): + for layer in self.layers: + x = layer(x) + y = self.dag_layer(y) + return x.sum(), y.sum() + + +class RpcTestModel(nn.Module): + + def __init__(self, stage_id, actual_stage_num, feat_num, h) -> None: + super().__init__() + self.rank = stage_id + self.is_last_rank = stage_id == actual_stage_num - 1 + self.linear_name = f'linear_{stage_id}' + + if stage_id == 0: + linear = nn.Linear(feat_num, h) + elif stage_id == actual_stage_num - 1: + linear = nn.Linear(h, 1) + else: + linear = nn.Linear(h, h) + + setattr(self, self.linear_name, linear) + + def forward(self, x) -> torch.Tensor: + linear: nn.Module = getattr(self, self.linear_name) + out: torch.Tensor = linear(x) + + if self.is_last_rank: + out = out.sum() + return out + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--epoch', type=int, default=1) + parser.add_argument('--world_size', type=int, default=2) + parser.add_argument('--batch_size', type=int, default=16) + parser.add_argument('--dp_degree', type=int, default=1) + parser.add_argument('--tp_degree', type=int, default=1) + parser.add_argument('--num_microbatches', type=int, default=2) + parser.add_argument('--chunk', type=int, default=1) + parser.add_argument('--use_checkpoint', action='store_true') + parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'RMSprop'], default='SGD') + parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda') + parser.add_argument('--master_addr', type=str, default='localhost') + parser.add_argument('--master_port', type=str, default='29020') + parser.add_argument('--num_worker_threads', type=str, default=128) + return parser.parse_args() + + +def pg_parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--world_size', type=int, default=4) + parser.add_argument('--dp_degree', type=int, default=2) + parser.add_argument('--tp_degree', type=int, default=1) + parser.add_argument('--chunk', type=int, default=1) + parser.add_argument('--num_worker_threads', type=str, default=128) + parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda') + parser.add_argument('--master_addr', type=str, default='localhost') + parser.add_argument('--master_port', type=str, default='29020') + return parser.parse_args() + + +def run_worker(rank, args, master_func): + os.environ['MASTER_ADDR'] = args.master_addr + os.environ['MASTER_PORT'] = args.master_port + + device = args.device + world_size = args.world_size + dp_degree = args.dp_degree + tp_degree = args.tp_degree + num_worker_threads = args.num_worker_threads + host = args.master_addr + port = args.master_port + backend = 'nccl' if device == 'cuda' else 'gloo' + + disable_existing_loggers() + + launch(dict(), rank, world_size, host, int(port), backend, verbose=False) + ppg.set_global_info(rank=rank, + world_size=world_size, + dp_degree=dp_degree, + tp_degree=tp_degree, + num_worker_threads=num_worker_threads, + device=device) + + # in rpc mode, only rank 0 is needed to be coded + if rank == 0: + master_func(args) + # barrier here + if rpc_is_initialized(): + rpc.shutdown() + else: + warnings.warn("RPC has not been initialized") + + +def rpc_run(args, master_func): + world_size = args.world_size + assert args.num_microbatches >= args.world_size, "num_microbatches cannot be fewer than world_size!" + mp.spawn(run_worker, args=(args, master_func), nprocs=world_size) diff --git a/tests/test_pipeline/test_cuda_rpc_chimera.py b/tests/test_pipeline/test_cuda_rpc_chimera.py new file mode 100644 index 0000000000000000000000000000000000000000..45ad8f828e61506649295b88ecd7e2fbfd5dcd3c --- /dev/null +++ b/tests/test_pipeline/test_cuda_rpc_chimera.py @@ -0,0 +1,80 @@ +import torch +from torch import nn +import torch.autograd as autograd + +from colossalai.pipeline.rpc import ChimeraPipelineEngine +from colossalai.testing import assert_close +from rpc_test_utils import rpc_run, parse_args, RpcTestModel + +# global variable for model created +feat_num = 100 +h = 100 + + +def partition(pp_rank: int, chunk: int, stage_num: int): + torch.manual_seed(1024) + partition = RpcTestModel(pp_rank, stage_num, feat_num, h) + return partition + + +def run_master(args): + torch.manual_seed(100) + + epoch = args.epoch + device = args.device + stage_num = args.world_size + chunk = 1 + num_microbatches = args.num_microbatches + use_checkpoint = False + + sample_num = 1024 + batch_size = 1024 + + assert sample_num % batch_size == 0 + + engine = ChimeraPipelineEngine(partition_fn=partition, + stage_num=stage_num, + num_microbatches=num_microbatches, + device=device, + checkpoint=use_checkpoint) + engine.initialize_optimizer(torch.optim.Adam, lr=1e-3) + + input_sample = torch.randn((sample_num, feat_num), device=device) + + forward_result = engine.forward_backward(input_sample) + + cuda_rpc_result = [] + single_result = [] + actual_stage_num = engine._get_actual_stage_num() + + # compute forward result and backward grad of parameters in cuda rpc + cuda_rpc_result.append(sum(forward_result[0])) + grad = engine.remote_grad() + for stage_id in range(actual_stage_num): + for p in grad[stage_id]: + cuda_rpc_result.append(p) + + # compute forward result and backward grad of parameters just in rank_0 + test_model = nn.Sequential( + *[partition(pp_rank, chunk, actual_stage_num) for pp_rank in range(actual_stage_num)]).to(device) + # input_sample = input_sample[len(input_sample) // 2:] + input_sample = input_sample.requires_grad_() + out_val = test_model(input_sample).sum() + autograd.backward(out_val) + single_result.append(out_val) + for p in test_model.parameters(): + single_result.append(p.grad) + + # print("my") + # print(cuda_rpc_result[1]) + # print("answer:") + # print(single_result[1]) + + # assert len(cuda_rpc_result) == len(single_result) + # for r_c, r_s in zip(cuda_rpc_result, single_result): + # assert_close(r_c, r_s, 0.001, 0.001) + + +if __name__ == "__main__": + args = parse_args() + rpc_run(args, run_master) diff --git a/tests/test_pipeline/test_cuda_rpc_optimizer.py b/tests/test_pipeline/test_cuda_rpc_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..842566730caf2454bdeb6e27b864e58beb3ba7a1 --- /dev/null +++ b/tests/test_pipeline/test_cuda_rpc_optimizer.py @@ -0,0 +1,81 @@ +import torch +from torch import nn +from torch import autograd +from torch.optim import SGD, Adam, RMSprop, Optimizer + +from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine +from colossalai.testing import assert_close +from rpc_test_utils import rpc_run, parse_args, RpcTestModel + +# global variable for model created +feat_num = 100 +h = 100 + + +def partition(pp_rank: int, chunk: int, stage_num: int): + torch.manual_seed(1024) + partition = RpcTestModel(pp_rank, stage_num, feat_num, h) + return partition + + +def run_master(args): + torch.manual_seed(100) + + device = args.device + stage_num = args.world_size + chunk = args.chunk + actual_stage_num = stage_num * chunk + use_checkpoint = args.use_checkpoint + num_microbatches = args.num_microbatches + optimizer_class = globals()[args.optimizer] + + lr = 1e-3 + sample_num = 1024 + batch_size = 1024 + + assert sample_num % batch_size == 0 + + input_sample = torch.randn((sample_num, feat_num), device=device) + + engine = OneFOneBPipelineEngine(partition_fn=partition, + stage_num=stage_num, + num_microbatches=num_microbatches, + device=device, + chunk=chunk, + checkpoint=use_checkpoint) + + engine.initialize_optimizer(optimizer_class, lr=lr) + + _ = engine.forward_backward(input_sample) + + cuda_rpc_result = [] + single_result = [] + actual_stage_num = engine._get_actual_stage_num() + + # compute parameters after updating in cuda rpc + parameters = engine.remote_parameters() + for stage_id in range(actual_stage_num): + for p in parameters[stage_id]: + cuda_rpc_result.append(p) + + # compute forward result and backward grad of parameters just in rank_0 + test_model = nn.Sequential( + *[partition(pp_rank, chunk, actual_stage_num) for pp_rank in range(actual_stage_num)]).to(device) + optimizer: Optimizer = optimizer_class(test_model.parameters(), lr=lr) + input_sample = input_sample.requires_grad_() + out_val = test_model(input_sample).sum() + autograd.backward(out_val) + optimizer.step() + optimizer.zero_grad() + + for p in test_model.parameters(): + single_result.append(p) + + assert len(cuda_rpc_result) == len(single_result) + for r_c, r_s in zip(cuda_rpc_result, single_result): + assert_close(r_c, r_s, 0.001, 0.001) + + +if __name__ == "__main__": + args = parse_args() + rpc_run(args, run_master) diff --git a/tests/test_pipeline/test_cuda_rpc_performance.py b/tests/test_pipeline/test_cuda_rpc_performance.py new file mode 100644 index 0000000000000000000000000000000000000000..6a0509555862ea0b1be5ba10c961075af854720a --- /dev/null +++ b/tests/test_pipeline/test_cuda_rpc_performance.py @@ -0,0 +1,90 @@ +import os +from typing import Callable, List, Optional, Type, Union +import time + +import pytest +import torch +import torch.nn as nn +from titans.dataloader.cifar10 import build_cifar +from torchvision.models import resnet50 +from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1 +from tqdm import tqdm + +from rpc_test_utils import rpc_run, parse_args +import colossalai +import colossalai.nn as col_nn +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.trainer import Trainer, hooks +from colossalai.utils import MultiTimer, get_dataloader +from colossalai.context import ParallelMode +from colossalai.pipeline.pipelinable import PipelinableContext, PipelinableModel +from colossalai.pipeline.rpc import OneFOneBPipelineEngine, ChimeraPipelineEngine +from colossalai.pipeline.pipeline_process_group import ppg + + +def flatten(x): + return torch.flatten(x, 1) + + +def partition(pp_rank: int, chunk: int, stage_num: int): + pipelinable = PipelinableContext() + + # build model partitions + with pipelinable: + # input : [B, 3, 32, 32] + _ = resnet50() + + pipelinable.policy = "customized" + + exec_seq = [ + 'conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool', (flatten, "behind"), 'fc' + ] + pipelinable.to_layer_list(exec_seq) + partition = pipelinable.partition(chunk, stage_num, pp_rank) + return partition + + +def run_master(args): + batch_size = args.batch_size + chunk = args.chunk + device = args.device + world_size = args.world_size + stage_num = world_size + num_microbatches = args.num_microbatches + + # build dataloader + root = os.environ.get('DATA', './data') + train_dataloader, test_dataloader = build_cifar(batch_size, root, padding=4, crop=32, resize=32) + criterion = nn.CrossEntropyLoss() + + pp_engine = OneFOneBPipelineEngine(partition_fn=partition, + stage_num=stage_num, + num_microbatches=num_microbatches, + device=device, + chunk=chunk, + criterion=criterion, + checkpoint=False) + + pp_engine.initialize_optimizer(torch.optim.Adam, lr=1e-3) + s = time.time() + + for bx, by in tqdm(train_dataloader): + pp_engine.forward_backward(bx, labels=by, forward_only=False) + + cost_time = time.time() - s + + print("total cost time :", cost_time) + print("cost time per batch:", cost_time / len(train_dataloader)) + + +@pytest.mark.skip("Test for performance, no need for CI") +def main(): + args = parse_args() + # this is due to limitation of partition function + args.world_size = 2 + args.chunk = 1 + rpc_run(args, run_master) + + +if __name__ == '__main__': + main() diff --git a/tests/test_pipeline/test_cuda_rpc_pipeline.py b/tests/test_pipeline/test_cuda_rpc_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..8d03e79813e89976153def8ecc7873af2c913701 --- /dev/null +++ b/tests/test_pipeline/test_cuda_rpc_pipeline.py @@ -0,0 +1,48 @@ +import torch +from torch import nn + +from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine +from rpc_test_utils import rpc_run, parse_args, RpcTestModel + +# global variable for model created +feat_num = 100 +h = 100 + + +def partition(pp_rank: int, chunk: int, stage_num: int): + torch.manual_seed(1024) + partition = RpcTestModel(pp_rank, stage_num, feat_num, h) + return partition + + +def run_master(args): + torch.manual_seed(100) + + epoch = args.epoch + device = args.device + stage_num = args.world_size + chunk = args.chunk + num_microbatches = args.num_microbatches + use_checkpoint = args.use_checkpoint + + sample_num = 1024 + batch_size = 1024 + + assert sample_num % batch_size == 0 + + input_sample = torch.randn((sample_num, feat_num), device=device) + + engine = OneFOneBPipelineEngine(partition_fn=partition, + stage_num=stage_num, + num_microbatches=num_microbatches, + device=device, + chunk=chunk, + checkpoint=use_checkpoint) + + for _ in range(epoch): + _ = engine.forward_backward(input_sample, forward_only=False) + + +if __name__ == "__main__": + args = parse_args() + rpc_run(args, run_master) diff --git a/tests/test_pipeline/test_cuda_rpc_value_correctness.py b/tests/test_pipeline/test_cuda_rpc_value_correctness.py new file mode 100644 index 0000000000000000000000000000000000000000..e6713478baecae9115ac3142d5c601d850bec070 --- /dev/null +++ b/tests/test_pipeline/test_cuda_rpc_value_correctness.py @@ -0,0 +1,73 @@ +import torch +from torch import nn +from torch import autograd + +from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine +from colossalai.testing import assert_close +from rpc_test_utils import rpc_run, parse_args, RpcTestModel + +feat_num = 100 +h = 100 + + +def partition(pp_rank: int, chunk: int, stage_num: int): + torch.manual_seed(1024) + partition = RpcTestModel(pp_rank, stage_num, feat_num, h) + return partition + + +def run_master(args): + torch.manual_seed(100) + + device = args.device + stage_num = args.world_size + chunk = args.chunk + actual_stage_num = stage_num * chunk + use_checkpoint = args.use_checkpoint + num_microbatches = args.num_microbatches + + sample_num = 1024 + batch_size = 1024 + + assert sample_num % batch_size == 0 + + input_sample = torch.randn((sample_num, feat_num), device=device) + + engine = OneFOneBPipelineEngine(partition_fn=partition, + stage_num=stage_num, + num_microbatches=num_microbatches, + device=device, + chunk=chunk, + checkpoint=use_checkpoint) + + forward_result = engine.forward_backward(input_sample) + + cuda_rpc_result = [] + single_result = [] + actual_stage_num = engine._get_actual_stage_num() + + # compute forward result and backward grad of parameters in cuda rpc + cuda_rpc_result.append(sum(forward_result[0])) + grad = engine.remote_grad() + for stage_id in range(actual_stage_num): + for p in grad[stage_id]: + cuda_rpc_result.append(p) + + # compute forward result and backward grad of parameters just in rank_0 + test_model = nn.Sequential( + *[partition(pp_rank, chunk, actual_stage_num) for pp_rank in range(actual_stage_num)]).to(device) + input_sample = input_sample.requires_grad_() + out_val = test_model(input_sample).sum() + autograd.backward(out_val) + single_result.append(out_val) + for p in test_model.parameters(): + single_result.append(p.grad) + + assert len(cuda_rpc_result) == len(single_result) + for r_c, r_s in zip(cuda_rpc_result, single_result): + assert_close(r_c, r_s, 0.001, 0.001) + + +if __name__ == "__main__": + args = parse_args() + rpc_run(args, run_master) diff --git a/tests/test_pipeline/test_middleware_1f1b.py b/tests/test_pipeline/test_middleware_1f1b.py new file mode 100644 index 0000000000000000000000000000000000000000..5b3aad70327598341d7196ae97b0d2e791f57e39 --- /dev/null +++ b/tests/test_pipeline/test_middleware_1f1b.py @@ -0,0 +1,145 @@ +import os +from functools import partial + +import pytest +import torch +import torch.distributed.rpc as rpc +from rpc_test_utils import DAG_MLP, MLP +from torch._C._distributed_rpc import _is_current_rpc_agent_set + +from colossalai import launch +from colossalai.fx import ColoTracer +from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass +from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.middleware.adaptor import get_fx_topology +from colossalai.pipeline.pipeline_process_group import ppg +from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + +# global variable for model created +batch_size = 16 +dim = 10 +rpc_is_initialized = _is_current_rpc_agent_set + + +def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs): + model.eval() + tracer = ColoTracer() + meta_args = {k: v.to('meta') for k, v in data_kwargs.items()} + graph = tracer.trace(root=model, meta_args=meta_args) + gm = torch.fx.GraphModule(model, graph, model.__class__.__name__) + annotated_model = balanced_split_pass(gm, stage_num) + top_module, split_submodules = split_with_split_nodes_pass(annotated_model, merge_output=True) + topo = get_fx_topology(top_module) + for submodule in split_submodules: + if isinstance(submodule, torch.fx.GraphModule): + setattr(submodule, '_topo', topo) + return split_submodules[pp_rank + 1] + + +def partition(model, data_kwargs: dict, pp_rank: int, chunk: int, stage_num: int): + torch.manual_seed(1024) + partition = create_partition_module(pp_rank, stage_num, model, data_kwargs) + return partition + + +def run_master(model_cls, world_size, forward_only): + torch.manual_seed(100) + + epoch = 3 + device = 'cuda' + stage_num = world_size + chunk = 1 + num_microbatches = 8 + use_checkpoint = 'store_true' + + if model_cls == MLP: + + def data_gen(): + x = torch.zeros((batch_size, dim)) + kwargs = dict(x=x) + return kwargs + + model = model_cls(dim, stage_num * 3) + if forward_only: + labels = None + else: + labels = 1 + elif model_cls == DAG_MLP: + + def data_gen(): + x = torch.zeros((batch_size, dim)) + y = torch.zeros((batch_size, dim)) + kwargs = dict(x=x, y=y) + return kwargs + + model = model_cls(dim, stage_num * 3) + if forward_only: + labels = None + else: + labels = 1 + else: + pass + + data_kwargs = data_gen() + + engine = OneFOneBPipelineEngine( + partition_fn=partial(partition, model, data_kwargs), + stage_num=stage_num, + num_microbatches=num_microbatches, + device=device, + chunk=chunk, + checkpoint=use_checkpoint, + ) + if not forward_only: + engine.initialize_optimizer(getattr(torch.optim, 'SGD'), lr=1e-3) + + for _ in range(epoch): + input_x = torch.randn((batch_size, dim), device=device) + input_y = torch.randn((batch_size, dim), device=device) + logits = engine.forward_backward({'x': input_x, 'y': input_y}, labels=labels, forward_only=forward_only) + + +def run_worker(rank, world_size, port, model_cls, forward_only, master_func): + master_addr = 'localhost' + master_port = 29020 + os.environ['MASTER_ADDR'] = master_addr + os.environ['MASTER_PORT'] = str(master_port) + + disable_existing_loggers() + + launch(dict(), rank, world_size, master_addr, master_port, 'nccl', verbose=False) + ppg.set_global_info(rank=rank, + world_size=world_size, + dp_degree=1, + tp_degree=1, + num_worker_threads=128, + device='cuda') + + # in rpc mode, only rank 0 is needed to be coded + if rank == 0: + master_func(model_cls, world_size, forward_only) + # barrier here + if rpc_is_initialized(): + rpc.shutdown() + + +@pytest.mark.skip("skip due to CI torch version 1.11") +@parameterize('model_cls', [MLP, DAG_MLP]) +@parameterize('forward_only', [True, False]) +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_pp_middleware_fwd(model_cls, forward_only): + world_size = 4 + master_func = run_master + spawn( + run_worker, + world_size, + model_cls=model_cls, + forward_only=forward_only, + master_func=master_func, + ) + + +if __name__ == "__main__": + test_pp_middleware_fwd() diff --git a/tests/test_pipeline/test_pipelinable.py b/tests/test_pipeline/test_pipelinable.py new file mode 100644 index 0000000000000000000000000000000000000000..627cb5ac6f51968ae29c2743d5ec5a730af70814 --- /dev/null +++ b/tests/test_pipeline/test_pipelinable.py @@ -0,0 +1,57 @@ +import torch + +from colossalai.pipeline.pipelinable import PipelinableContext +from colossalai.testing import rerun_if_address_is_in_use, rerun_on_exception, spawn + +NUM_CHUNKS = 1 +PIPELINE_SIZE = 2 + + +class MLP(torch.nn.Module): + + def __init__(self, dim: int = 256): + super().__init__() + intermediate_dim = dim * 4 + self.dense_1 = torch.nn.Linear(dim, intermediate_dim) + self.activation = torch.nn.GELU() + self.dense_2 = torch.nn.Linear(intermediate_dim, dim) + self.dropout = torch.nn.Dropout(0.1) + + def forward(self, x): + x = self.dense_1(x) + x = self.activation(x) + x = self.dense_2(x) + x = self.dropout(x) + return x + + +def run_pipelinable(rank, world_size, port): + pipelinable = PipelinableContext() + with pipelinable: + model = MLP() + + assert pipelinable.policy == "balanced" + pipelinable.policy = "uniform" + assert pipelinable.policy == "uniform" + pipelinable.to_layer_list() + + assert pipelinable.layers_count == len(list(model.children())) + + pipeline_model_part_0 = pipelinable.partition(NUM_CHUNKS, PIPELINE_SIZE, 0) + assert isinstance(pipeline_model_part_0, torch.nn.Module) + pipeline_model_part_1 = pipelinable.partition(NUM_CHUNKS, PIPELINE_SIZE, 1) + assert isinstance(pipeline_model_part_1, torch.nn.Module) + + layers_count_in_part_0 = len(list(pipeline_model_part_0._module_list)) + layers_count_in_part_1 = len(list(pipeline_model_part_1._module_list)) + + assert layers_count_in_part_0 + layers_count_in_part_1 == pipelinable.layers_count + + +@rerun_if_address_is_in_use() +def test_pipelinable(): + spawn(run_pipelinable, 1) + + +if __name__ == '__main__': + test_pipelinable() diff --git a/tests/test_pipeline/test_pipeline_process_group.py b/tests/test_pipeline/test_pipeline_process_group.py new file mode 100644 index 0000000000000000000000000000000000000000..2a00e3ac55b195b8b3cde6ec9c17b949dc7511af --- /dev/null +++ b/tests/test_pipeline/test_pipeline_process_group.py @@ -0,0 +1,42 @@ +import os + +import torch.distributed.rpc as rpc +from rpc_test_utils import pg_parse_args, rpc_is_initialized + +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.pipeline_process_group import ppg +from colossalai.testing import spawn + + +def run_worker(rank, args): + os.environ['MASTER_ADDR'] = args.master_addr + os.environ['MASTER_PORT'] = args.master_port + + device = args.device + world_size = args.world_size + dp_degree = args.dp_degree + tp_degree = args.tp_degree + num_worker_threads = args.num_worker_threads + host = args.master_addr + port = args.master_port + backend = 'nccl' if device == 'cuda' else 'gloo' + + disable_existing_loggers() + launch(dict(), rank, world_size, host, int(port), backend, verbose=False) + + ppg.set_global_info(rank=rank, + world_size=world_size, + dp_degree=dp_degree, + tp_degree=tp_degree, + num_worker_threads=num_worker_threads, + device=device) + + if rpc_is_initialized(): + rpc.shutdown() + + +if __name__ == "__main__": + args = pg_parse_args() + world_size = args.world_size + spawn(run_worker, world_size, args=args) diff --git a/tests/test_tensor/common_utils/__init__.py b/tests/test_tensor/common_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5387db70445ffec90f2ffaafbc7bf41368ce6dc3 --- /dev/null +++ b/tests/test_tensor/common_utils/__init__.py @@ -0,0 +1 @@ +from ._utils import * diff --git a/tests/test_tensor/common_utils/_utils.py b/tests/test_tensor/common_utils/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b405f8cd2108368b336f432a9ddc41b061c9f0c5 --- /dev/null +++ b/tests/test_tensor/common_utils/_utils.py @@ -0,0 +1,91 @@ +import os +import random + +import numpy as np +import torch +import torch.distributed as dist +from torch.testing import assert_close + +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.tensor import ComputePattern, ComputeSpec, ShardSpec + + +def set_seed(seed): + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def check_equal(A, B): + assert torch.allclose(A, B, rtol=1e-3, atol=1e-1) == True + + +def replace_parameter_add_grad(layer, weight=None, bias=None): + if weight is not None: + delattr(layer, 'weight') + setattr(layer, 'weight', weight) + layer.weight.requires_grad = True + if bias is not None: + delattr(layer, 'bias') + setattr(layer, 'bias', bias) + layer.bias.requires_grad = True + + +def broadcast_tensor_chunk(tensor, chunk_size=1, local_rank=0): + dist.broadcast(tensor, src=0) + tensor_chunk = torch.chunk(tensor, chunk_size, dim=-1)[local_rank] + return tensor_chunk.clone() + + +def tensor_equal(t_a: torch.Tensor, t_b: torch.Tensor, rtol: float = 1e-3, atol: float = 1e-1): + assert_close(t_a, t_b, rtol=rtol, atol=atol) + return True + + +def tensor_shard_equal(tensor: torch.Tensor, + shard: torch.Tensor, + rank: int, + world_size: int, + rtol: float = 1e-3, + atol: float = 1e-1): + assert tensor.ndim == shard.ndim + if tensor.shape == shard.shape: + return tensor_equal(tensor, shard, rtol, atol) + else: + dims_not_eq = torch.nonzero(torch.tensor(tensor.shape) != torch.tensor(shard.shape)) + if dims_not_eq.numel() == 1: + # 1D shard + dim = dims_not_eq.item() + if world_size is None: + world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) + if rank is None: + rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + return tensor_equal(tensor.chunk(world_size, dim)[rank], shard, rtol, atol) + else: + raise NotImplementedError + + +def split_param_single_dim_tp1d(dim, param, pg): + spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + if param.process_group.tp_world_size() == 1: + param.set_process_group(pg) + param.set_tensor_spec(*spec) + + +def split_param_row_tp1d(param, pg): + split_param_single_dim_tp1d(0, param, pg) + + +def split_param_col_tp1d(param, pg): + split_param_single_dim_tp1d(-1, param, pg) + + +def debug_print(ranks, *args): + if dist.get_rank() in ranks: + print(*args) + dist.barrier() diff --git a/tests/test_tensor/core/test_dist_spec_mgr.py b/tests/test_tensor/core/test_dist_spec_mgr.py new file mode 100644 index 0000000000000000000000000000000000000000..89476a35b63a0575c1f2cb30c0c7a74e0f4c0596 --- /dev/null +++ b/tests/test_tensor/core/test_dist_spec_mgr.py @@ -0,0 +1,64 @@ +import math + +import pytest +import torch +import torch.distributed as dist + +import colossalai +from colossalai.tensor import DistSpecManager, ProcessGroup, ReplicaSpec, ShardSpec +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +def run(): + group = ProcessGroup(tp_degree=dist.get_world_size()) + rank = dist.get_rank() + size = dist.get_world_size() + depth = int(math.sqrt(size)) + assert depth == math.sqrt(size) + x = torch.rand(8, 8).cuda() + old_dist_spec = ReplicaSpec() + row_spec = ShardSpec([0], [size]) + col_spec = ShardSpec([-1], [size]) + mat_spec = ShardSpec([0, 1], [depth, depth]) + row_shard = DistSpecManager._shard_as(x, old_dist_spec, row_spec, group) + assert torch.equal(x.chunk(size, 0)[rank], row_shard) + assert torch.equal(x, DistSpecManager._gather(row_shard, row_spec, group)) + col_shard = DistSpecManager._all_to_all(row_shard, row_spec, col_spec, group) + assert torch.equal(x.chunk(size, -1)[rank], col_shard) + assert torch.equal(x, DistSpecManager._gather(col_shard, col_spec, group)) + mat_shard = DistSpecManager._shard_as(x, old_dist_spec, mat_spec, group) + assert torch.equal(x.chunk(depth, 0)[rank // depth].chunk(depth, 1)[rank % depth], mat_shard) + assert torch.equal(x, DistSpecManager._gather(mat_shard, mat_spec, group)) + + +def check_mem(): + pg = ProcessGroup(tp_degree=dist.get_world_size()) + size = dist.get_world_size() + assert torch.cuda.memory_allocated() == 0 + x = torch.rand(32, 32).cuda() + orig_mem = x.numel() * x.element_size() + assert torch.cuda.memory_allocated() == orig_mem + old_dist_spec = ReplicaSpec() + row_spec = ShardSpec([0], [size]) + x.data = DistSpecManager._shard_as(x, old_dist_spec, row_spec, pg) + assert x.size(0) == 32 // size and x.size(1) == 32 + assert torch.cuda.memory_allocated() == orig_mem // size + x.data = DistSpecManager._gather(x, row_spec, pg) + assert torch.cuda.memory_allocated() == orig_mem + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + check_mem() + run() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_dist_spec_mgr(world_size): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_dist_spec_mgr(4) diff --git a/tests/test_tensor/core/test_tensor.py b/tests/test_tensor/core/test_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..64d198b350a82fd3d5722cdeed81fc9d283e8867 --- /dev/null +++ b/tests/test_tensor/core/test_tensor.py @@ -0,0 +1,153 @@ +import pytest +import torch +from numpy import allclose + +import colossalai +from colossalai.core import global_context as gpc +from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup, ReplicaSpec, ShardSpec, distspec +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +def _run_tensor_indexing(): + pg = ProcessGroup() + torch_t = torch.randn(2, 3) + colo_t = ColoTensor(torch_t, ColoTensorSpec(pg)) + assert allclose(torch_t[:, 1], colo_t[:, 1]) + + +def _run_wrapped_tensor_func(): + pg = ProcessGroup() + t_ref = torch.randn(4, 5) + t = ColoTensor.from_torch_tensor(t_ref.clone(), ColoTensorSpec(pg)) + + # non-func attr + assert t.is_cuda == t_ref.is_cuda + + # return 1 torch.Tensor + t_abs = t.abs() + assert isinstance(t_abs, ColoTensor) and torch.equal(t_abs, t_ref.abs()) + + # return 1 non-torch.Tensor + assert t.dim() == t_ref.dim() + + # return >1 torch.Tensor + assert isinstance(t, ColoTensor) + t_split1, t_split2 = t.split(2) + assert isinstance(t_split1, ColoTensor) and isinstance(t_split2, ColoTensor), f"{type(t_split1)} {type(t_split2)}" + + +def _run_operand(world_size): + pg = ProcessGroup() + t_ref = torch.randn(4, 5) + t = ColoTensor.from_torch_tensor(t_ref.clone(), ColoTensorSpec(pg)) + + t_ref_res = t_ref + t_ref + t_res = t + t + + assert isinstance(t_res, ColoTensor) + assert torch.allclose(t_ref_res, t_res) + + pg = ProcessGroup(tp_degree=world_size) + t = ColoTensor.from_torch_tensor(t_ref.clone(), ColoTensorSpec(pg)) + t.set_dist_spec(ShardSpec([0], [world_size])) + t_new = torch.zeros_like(t) + assert isinstance(t_new, ColoTensor) + assert t_new.is_sharded() + + +#### Test Distributed init a Colotensor + + +def _run_view(world_size): + t_ref = torch.randn(4, 5) + rank = gpc.get_global_rank() + pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size) + t = ColoTensor.from_torch_tensor( + t_ref, ColoTensorSpec(pg, dist_attr=ShardSpec(dims=[0], num_partitions=[pg.tp_world_size()]))) + + assert t.size_global()[0] == 4 * world_size + assert t.size_global(1) == 5 + assert t.size_global() == torch.Size([4 * world_size, 5]) + + t = t.view(4 * 5 * world_size) + assert t.shape == torch.Size([4 * 5 * world_size]) + + +def _run_tensor_shard_init(world_size): + t_ref = torch.randn(4, 5) + pg = ProcessGroup(tp_degree=world_size) + shard_attr = ShardSpec(dims=[0], num_partitions=[pg.tp_world_size()]) + tensor_spec = ColoTensorSpec(pg, dist_attr=shard_attr) + t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec) + t.set_dist_spec(ReplicaSpec()) + + assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape} vs ({4 * world_size, 5})" + + +def _run_tensor_replicated_init(world_size): + t_ref = torch.randn(4 * world_size, 5) + pg = ProcessGroup() + spec = ColoTensorSpec(pg) + t = ColoTensor.from_torch_tensor(t_ref.clone(), spec) + + assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape}" + + +def _run_process_group(world_size): + pg1 = ProcessGroup() + pg2 = ProcessGroup() + assert pg1 == pg2 + + +def _run_redistributed(world_size): + if world_size != 4: + return + pg1 = ProcessGroup(tp_degree=2, dp_degree=2) + pg2 = ProcessGroup(tp_degree=4, dp_degree=1) + + spec1 = ColoTensorSpec(pg1) + t1 = ColoTensor.from_torch_tensor(torch.randn(2, 3, 4), spec1) + t1 = t1.redistribute(ShardSpec([0], [pg1.tp_world_size()])) + assert t1.is_sharded() + t1 = t1.redistribute(ShardSpec([-1], [pg2.tp_world_size()]), pg2) + assert t1.is_sharded() + pg3 = ProcessGroup(tp_degree=1, dp_degree=4) + t1 = t1.redistribute(ReplicaSpec(), pg3) + assert t1.is_replicate() + + +def _run_set_tensor_spec(world_size): + if world_size != 4: + return + pg = ProcessGroup(tp_degree=2, dp_degree=2) + spec1 = ColoTensorSpec(pg) + t1 = ColoTensor.from_torch_tensor(torch.randn(2, 3, 4), spec1) + + dist_spec2 = ShardSpec([-1], [pg.tp_world_size()]) + assert t1.is_replicate() + t1.set_dist_spec(dist_spec2) + assert t1.is_shard_1dcol() + + +def run_dist_tests(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + _run_tensor_shard_init(world_size) + _run_tensor_replicated_init(world_size) + _run_view(world_size) + _run_process_group(world_size) + _run_tensor_indexing() + _run_operand(world_size) + _run_wrapped_tensor_func() + _run_redistributed(world_size) + _run_set_tensor_spec(world_size) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2]) +@rerun_if_address_is_in_use() +def test_dist_cases(world_size): + spawn(run_dist_tests, world_size) + + +if __name__ == '__main__': + test_dist_cases(4) diff --git a/tests/test_tensor/model/test_gpt2.py b/tests/test_tensor/model/test_gpt2.py new file mode 100644 index 0000000000000000000000000000000000000000..337bfa840d5da1866a3f822803ace857bdefc4b9 --- /dev/null +++ b/tests/test_tensor/model/test_gpt2.py @@ -0,0 +1,148 @@ +import pytest +import torch +from torch.nn.parallel import DistributedDataParallel as DDP + +import colossalai +from colossalai.nn.parallel.data_parallel import ColoDDP +from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils.cuda import get_current_device +from colossalai.zero import ColoInitContext +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_tensor.common_utils import ( + debug_print, + set_seed, + split_param_col_tp1d, + split_param_row_tp1d, + tensor_equal, + tensor_shard_equal, +) + + +def init_1d_row_spec(model, pg: ProcessGroup): + tensor_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + for n, p in model.named_parameters(): + p.set_process_group(pg) + if 'weight' in n and 'ln' not in n: + p.set_tensor_spec(*tensor_spec) + + +def init_1d_col_spec(model, pg: ProcessGroup): + spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + + for n, p in model.named_parameters(): + p.set_process_group(pg) + if 'ln' not in n and ('weight' in n or 'bias' in n): + p.set_tensor_spec(*spec) + + +def init_megatron_spec(model, pg: ProcessGroup): + for mn, module in model.named_modules(): + # debug_print([0], mn) + for pn, param in module.named_parameters(recurse=False): + # debug_print([0], '\t', pn, param.compute_spec, param.shape) + param.set_process_group(pg) + + if 'mlp.c_fc' in mn: + if 'weight' in pn or 'bias' in pn: + split_param_col_tp1d(param, pg) + param.compute_spec.set_output_replicate(False) + else: + raise RuntimeError + elif 'mlp.c_proj' in mn: + if 'weight' in pn: + split_param_row_tp1d(param, pg) + else: + assert 'bias' in pn + elif 'wte' in mn or 'wpe' in mn: + assert 'weight' in pn + split_param_col_tp1d(param, pg) + elif 'c_attn' in mn or 'c_proj' in mn: + split_param_col_tp1d(param, pg) + # debug_print([0], '\t', param.compute_spec, param.shape) + + +def check_param_equal(model, torch_model, pg: ProcessGroup): + for p, torch_p in zip(model.parameters(), torch_model.parameters()): + assert pg.tp_local_rank() is not None, f"{pg.rank()} {pg.tp_world_size()} {pg._tp_degree} {pg.tp_local_rank()}1" + assert pg.tp_world_size() is not None + assert tensor_shard_equal(torch_p, p, pg.tp_local_rank(), pg.tp_world_size()) + + +def check_grad_equal(model, torch_model, pg: ProcessGroup): + for p, torch_p in zip(model.parameters(), torch_model.parameters()): + assert tensor_shard_equal(torch_p.grad, p.grad, pg.tp_local_rank(), pg.tp_world_size()) + + +def run_gpt(init_spec_func, use_ddp): + world_size = torch.distributed.get_world_size() + + # build a PG with TP and DP hybrid + pg = ProcessGroup(dp_degree=(2 if (use_ddp and world_size >= 2) else 1)) + + # set seed make processes of the same tp group use the same seed + # set_seed(pg.tp_local_rank()) + + get_components_func = non_distributed_component_funcs.get_callable('gpt2') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + # make sure torch_model and model has the same parameter values + with ColoInitContext(device=get_current_device()): + model = model_builder() + model = model.cuda() + torch_model = model_builder().cuda() + + if use_ddp: + torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group()) + model = ColoDDP(model, process_group=pg) + + for torch_p, p in zip(torch_model.parameters(), model.parameters()): + torch_p.data.copy_(p) + + init_spec_func(model, pg) + + check_param_equal(model, torch_model, pg) + + # close the dropout in eval mode + model.eval() + torch_model.eval() + set_seed(pg.dp_local_rank()) + torch.distributed.barrier() + for i, (input_ids, label) in enumerate(train_dataloader): + colo_input = ColoTensor.from_torch_tensor(input_ids, ColoTensorSpec(pg)) + logits = model(colo_input) + torch_logits = torch_model(input_ids) + assert tensor_equal(torch_logits, logits), f"{torch_logits - logits}" + loss = criterion(logits, input_ids) + torch_loss = criterion(torch_logits, input_ids) + if use_ddp: + model.backward(loss) + else: + loss.backward() + torch_loss.backward() + check_grad_equal(model, torch_model, pg) + if i > 0: + break + set_seed(313) + + +def run_dist(rank, world_size, port, use_ddp): + if use_ddp and world_size == 1: + return + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + # Comments below tests for speed concern + # run_gpt(init_1d_row_spec, use_ddp) + # run_gpt(init_1d_col_spec, use_ddp) + run_gpt(init_megatron_spec, use_ddp) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@pytest.mark.parametrize('use_ddp', [False, True]) +@rerun_if_address_is_in_use() +def test_gpt(world_size, use_ddp): + spawn(run_dist, world_size, use_ddp=use_ddp) + + +if __name__ == '__main__': + test_gpt(4, use_ddp=False) diff --git a/tests/test_tensor/model/test_model.py b/tests/test_tensor/model/test_model.py new file mode 100644 index 0000000000000000000000000000000000000000..79d70e53c5cb133e31c1741799b2eb3b0dc424b8 --- /dev/null +++ b/tests/test_tensor/model/test_model.py @@ -0,0 +1,334 @@ +import pytest +import torch + +import colossalai +from colossalai.nn.optimizer import ColossalaiOptimizer +from colossalai.tensor import ColoTensor, ProcessGroup +from colossalai.tensor.colo_parameter import ColoParameter +from colossalai.testing import free_port, rerun_if_address_is_in_use, spawn +from colossalai.utils.cuda import get_current_device +from colossalai.zero import ColoInitContext +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_tensor.common_utils import ( + check_equal, + set_seed, + split_param_col_tp1d, + split_param_row_tp1d, + tensor_shard_equal, +) + + +def run_1d_hybrid_tp(model_name): + # A simple net with two stacked nn.Linear + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + + set_seed(1) + with ColoInitContext(device=get_current_device()): + model = model_builder(checkpoint=True) + + if rank == 0: + model_torch = model_builder(checkpoint=True) + model_torch = model_torch.cuda() + + optimizer_torch = ColossalaiOptimizer(torch.optim.SGD(model_torch.parameters(), lr=0.1)) + + # Make two models have the same init params + for p1, p2 in zip(model.parameters(), model_torch.parameters()): + p2.data.copy_(p1.data) + else: + model_torch = None + optimizer_torch = None + + pg = ProcessGroup(tp_degree=world_size) + if 'bert' == model_name: + for name, p in model.named_parameters(): + if not isinstance(p, ColoTensor): + continue + + # num_class = type_vocab_size = 2 | (8, 2) + if 'classifier' in name and 'weight' in name: + split_param_col_tp1d(p, pg) + # num_class = vocab_size = 30524 | (30524, 8) + elif 'word_embeddings' in name and 'weight' in name: + split_param_row_tp1d(p, pg) + # num_class = seq_len = 512 | (512, 8) + elif 'position_embeddings' in name and 'weight' in name: + split_param_row_tp1d(p, pg) + # num_class = type_vocab_size = 2 | (2, 8) + elif 'token_type_embeddings' in name and 'weight' in name: + split_param_col_tp1d(p, pg) + + elif "simple_net" == model_name: + # A naive way to set spec for all weights in Linear + for name, p in model.named_parameters(): + if not isinstance(p, ColoTensor): + continue + if 'embed' in name and 'weight' in name: + split_param_col_tp1d(p, pg) + if 'proj1' in name and ('weight' in name or 'bias' in name): + split_param_row_tp1d(p, pg) + if 'proj2' in name and 'weight' in name: + split_param_col_tp1d(p, pg) + if 'classifier' in name and ('weight' in name or 'bias' in name): + split_param_row_tp1d(p, pg) + + model = model.cuda() + model.eval() + if rank == 0: + model_torch.eval() + + colo_optimizer = ColossalaiOptimizer(torch.optim.SGD(model.parameters(), lr=0.1)) + + for i, (data, label) in enumerate(train_dataloader): + + # Zero grad + colo_optimizer.zero_grad() + if rank == 0: + optimizer_torch.zero_grad() + torch.distributed.barrier() + + data = data.to(get_current_device()) + label = label.to(get_current_device()) + + torch.distributed.broadcast(data, 0, group=pg.tp_process_group()) + torch.distributed.broadcast(label, 0, group=pg.tp_process_group()) + + # Bcast rank0 data to all processes + if criterion: + output = model(data) + loss = criterion(output, label) + else: + output = model(data, label) + loss = output + + # Test output + if rank == 0: + if criterion: + output_torch = model_torch(data) + loss_torch = criterion(output_torch, label) + else: + output_torch = model_torch(data, label) + loss_torch = output_torch + assert torch.allclose(loss, loss_torch, rtol=1e-2), f"model_name {model_name} failed" + torch.distributed.barrier() + + loss.backward() + colo_optimizer.step() + + if rank == 0: + loss_torch.backward() + optimizer_torch.step() + + with torch.no_grad(): + # check param + for p, torch_p in zip(model.parameters(), model_torch.parameters()): + assert tensor_shard_equal(torch_p, p, pg.tp_local_rank(), pg.tp_world_size()) + torch.distributed.barrier() + if i > 5: + break + + +# Test the overrided parameters() and named_parameters() member functions +def test_model_parameters(): + colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl') + + # build a module with 2 Linear, 4 parameters in total. + class Net(torch.nn.Module): + + def __init__(self): + super().__init__() + self.fcs = torch.nn.Sequential(torch.nn.Linear(2, 3), torch.nn.Linear(3, 2)) + self.extra_param = torch.nn.Parameter(torch.randn(2)) + + with ColoInitContext(device=get_current_device()): + model = Net() + + param_cnt = 0 + for name, p in model.named_parameters(): + param_cnt += 1 + assert param_cnt == 5 + + for name, colo_p in model.named_parameters(): + assert colo_p.is_model_data() + + param_cnt = 0 + for name, p in model.named_parameters(recurse=False): + param_cnt += 1 + assert param_cnt == 1 + + param_cnt = 0 + for p in model.fcs[0].parameters(recurse=False): + param_cnt += 1 + assert param_cnt == 2 + + +def test_colo_optimizer(): + get_components_func = non_distributed_component_funcs.get_callable('simple_net') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + set_seed(1) + with ColoInitContext(device=get_current_device()): + model = model_builder(checkpoint=True) + + colo_optimizer = ColossalaiOptimizer(torch.optim.SGD(model.parameters(), lr=0.1)) + for i, (data, label) in enumerate(train_dataloader): + colo_optimizer.zero_grad() + data = data.to(get_current_device()) + label = label.to(get_current_device()) + + # Bcast rank0 data to all processes + if criterion: + output = model(data) + loss = criterion(output, label) + else: + output = model(data, label) + loss = output + + loss.backward() + colo_optimizer.step() + + if i > 5: + break + + +def run_1d_row_tp(model_name: str): + # A simple net with two stacked nn.Linear + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + rank = torch.distributed.get_rank() + + set_seed(1) + with ColoInitContext(device=get_current_device()): + model = model_builder(checkpoint=True) + + world_size = torch.distributed.get_world_size() + pg = ProcessGroup(tp_degree=world_size) + + set_seed(1) + if rank == 0: + model_torch = model_builder(checkpoint=True) + model_torch = model_torch.cuda() + + # A naive way to set spec for all weights in Linear + for mo_name, module in model.named_modules(): + # print(mo_name) + for pa_name, param in module.named_parameters(recurse=False): + # print('\t', pa_name, param.shape) + if not isinstance(param, ColoTensor): + continue + if 'weight' in pa_name: + if 'embed' in mo_name and 'token' not in mo_name and 'LayerNorm' not in mo_name: + split_param_row_tp1d(param, pg) + elif 'LayerNorm' not in mo_name and 'ln' not in mo_name: + split_param_col_tp1d(param, pg) + + model = model.cuda() + + for i, (data, label) in enumerate(train_dataloader): + data = data.to(get_current_device()) + label = label.to(get_current_device()) + + torch.distributed.broadcast(data, 0, group=pg.tp_process_group()) + torch.distributed.broadcast(label, 0, group=pg.tp_process_group()) + + # Bcast rank0 data to all processes + if criterion: + output = model(data) + loss = criterion(output, label) + else: + output = model(data, label) + loss = output + + # For reference + if rank == 0: + if criterion: + output_torch = model_torch(data) + loss_torch = criterion(output_torch, label) + else: + output_torch = model_torch(data, label) + loss_torch = output_torch + assert torch.allclose(loss, loss_torch, rtol=1e-2) + torch.distributed.barrier() + + loss.backward() + + if rank == 0: + loss_torch.backward() + torch.distributed.barrier() + + if i > 5: + break + + +def _run_pretrain_load(): + from transformers import BertForMaskedLM + set_seed(1) + model_pretrained = BertForMaskedLM.from_pretrained('bert-base-uncased') + with ColoInitContext(device=get_current_device()): + model = BertForMaskedLM.from_pretrained('bert-base-uncased') + + model_pretrained = model_pretrained.cuda() + model = model.cuda() + + dict_pretrained = {} + dict_col = {} + c_ref = 0 + for name, param in model_pretrained.named_parameters(): + dict_pretrained[name] = param + c_ref += 1 + c1 = 0 + c2 = 0 + for name, param in model.named_parameters(): + if isinstance(param, ColoParameter): + c1 += 1 + else: + c2 += 1 + dict_col[name] = param + assert c_ref == c1 + assert c2 == 0 + if model_pretrained.cls.predictions.decoder.bias is model_pretrained.cls.predictions.bias: + assert model.cls.predictions.decoder.bias is model.cls.predictions.bias + + for name, param in dict_pretrained.items(): + check_equal(param, dict_col[name]) + + +def run_model_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + # Comment below test for speed consideration + # for name in ['bert', 'simple_net']: + # run_1d_row_tp(name) + for name in ['bert', 'simple_net']: + run_1d_hybrid_tp(name) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_model(world_size): + spawn(run_model_dist, world_size) + + +def run_pretrain_load_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + _run_pretrain_load() + + +# The test case has to download huggingface pretrained models from the internet +# So we manually trigger the test. +@pytest.mark.skip +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_pretrain_load(world_size): + spawn(run_pretrain_load_dist, world_size) + + +if __name__ == '__main__': + # test_model_parameters() + # test_colo_optgimizer() + test_model(4) + # test_pretrain_load(4) diff --git a/tests/test_tensor/model/test_module_spec.py b/tests/test_tensor/model/test_module_spec.py new file mode 100644 index 0000000000000000000000000000000000000000..b50851e5eaf2f04cd5c7480a0fa2dc4ff02c4876 --- /dev/null +++ b/tests/test_tensor/model/test_module_spec.py @@ -0,0 +1,227 @@ +from copy import deepcopy + +import pytest +import torch + +import colossalai +from colossalai.nn.parallel.layers import check_colo_module, init_colo_module +from colossalai.tensor import ( + ColoTensor, + ColoTensorSpec, + ComputePattern, + ComputeSpec, + ProcessGroup, + ReplicaSpec, + ShardSpec, + distspec, +) +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils.cuda import get_current_device +from colossalai.zero import ColoInitContext +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_tensor.common_utils import set_seed, tensor_equal, tensor_shard_equal + + +def run_model_with_spec(mode, model_name): + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + world_size = torch.distributed.get_world_size() + pg = ProcessGroup(tp_degree=world_size) + rank = pg.rank() + + set_seed(1) + with ColoInitContext(device=get_current_device()): + model = model_builder(checkpoint=False) + + if rank == 0: + model_seq = model_builder(checkpoint=False) + model_seq = model_seq.cuda() + + # Make two models have the same init params + for p1, p2 in zip(model.parameters(), model_seq.parameters()): + p2.data.copy_(p1.data) + + compute_spec = ComputeSpec(ComputePattern.TP1D) + # Not all layers in Bert can be mod by 4. + # e.g. row shard for all layers is invalid because the first dim of some layer is the classification type size 2. + if 'bert' == model_name: + if 'col' == mode: + init_colo_module(model.bert.embeddings, compute_spec, pg=pg, recursive=True, mode=mode) + init_colo_module(model.bert.encoder, compute_spec, pg=pg, recursive=True, mode=mode) + init_colo_module(model.classifier, compute_spec, pg=pg, recursive=True, mode='row') + elif 'row' == mode: + init_colo_module(model.bert.embeddings, compute_spec, pg=pg, recursive=True, mode='col') + init_colo_module(model.bert.encoder, compute_spec, pg=pg, recursive=True, mode=mode) + init_colo_module(model.classifier, compute_spec, pg=pg, recursive=True, mode=mode) + elif 'simple_net' == model_name: + init_colo_module(model, compute_spec, pg=pg, recursive=True, mode=mode) + + model = model.cuda() + for i, (data, label) in enumerate(train_dataloader): + data = data.to(get_current_device()) + label = label.to(get_current_device()) + + torch.distributed.broadcast(data, 0, group=pg.tp_process_group()) + torch.distributed.broadcast(label, 0, group=pg.tp_process_group()) + + if criterion: + output = model(data) + loss = criterion(output, label) + else: + output = model(data, label) + loss = output + + # For reference + if rank == 0: + if criterion: + output_seq = model_seq(data) + loss_seq = criterion(output_seq, label) + else: + output_seq = model_seq(data, label) + loss_seq = output_seq + + if rank == 0: + with torch.no_grad(): + assert torch.allclose(loss, loss_seq, rtol=1e-2) + + loss.backward() + + if rank == 0: + loss_seq.backward() + + with torch.no_grad(): + # check param + for p1, p2 in zip(model.parameters(), model_seq.parameters()): + if p1.size() == p2.size(): + assert torch.allclose(p1, p2) + else: + if p1.size(-1) < p2.size(-1): # col + world_size = p2.size(-1) // p1.size(-1) + split_p2 = torch.chunk(p2, world_size, dim=-1)[0] + + elif p1.size(0) < p2.size(0): # row + world_size = p2.size(0) // p1.size(0) + split_p2 = torch.chunk(p2, world_size, dim=0)[0] + + assert torch.allclose(p1, split_p2) + + if i > 3: + break + + +def run_linear_with_spec(mode): + with ColoInitContext(device=get_current_device()): + model = torch.nn.Linear(4, 8) + + model_handy = deepcopy(model) + world_size = torch.distributed.get_world_size() + pg = ProcessGroup(tp_degree=world_size) + compute_spec = ComputeSpec(ComputePattern.TP1D) + init_colo_module(model, compute_spec, pg=pg, recursive=True, mode=mode) + + x = torch.rand(2, 4).cuda() + colo_x = ColoTensor.from_torch_tensor(x, ColoTensorSpec(pg)) + + out = model(x) + colo_out = model_handy(colo_x) + assert tensor_equal(out, colo_out) + + grad = torch.rand_like(out) + out.backward(grad) + colo_out.backward(grad) + + assert tensor_shard_equal(model_handy.weight.grad, model.weight.grad, pg.tp_local_rank(), pg.tp_world_size()) + assert tensor_shard_equal(model_handy.bias.grad, model.bias.grad, pg.tp_local_rank(), pg.tp_world_size()) + + +def run_check_shared_param(): + from transformers import BertConfig, BertForMaskedLM + hidden_dim = 8 + num_head = 4 + sequence_length = 12 + num_layer = 2 + vocab_size = 24 + + world_size = torch.distributed.get_world_size() + pg = ProcessGroup(tp_degree=world_size) + rank = pg.rank() + + config = BertConfig(vocab_size=vocab_size, + hidden_size=hidden_dim, + intermediate_size=hidden_dim * 4, + num_attention_heads=num_head, + max_position_embeddings=sequence_length, + num_hidden_layers=num_layer, + hidden_dropout_prob=0., + attention_probs_dropout_prob=0.) + with ColoInitContext(device=get_current_device()): + model = BertForMaskedLM(config) + + model = model.cuda() + compute_spec = ComputeSpec(ComputePattern.TP1D) + # model.cls.predictions.decoder and model.cls.predictions share the bias, so they should have the same spec + assert len(model.cls.predictions.decoder.bias.shared_param_modules) == 2 + # They are all Linear, so both row is allowed. This should pass check. + init_colo_module(model, compute_spec, pg=pg, recursive=True, mode='row') + # This should be detected by check because you can not set weight as row while set bias as col. + col_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + + # TODO(jiaruifang) optimize this line + if not model.cls.predictions.bias.has_initialized: + model.cls.predictions.bias.pg = pg + model.cls.predictions.bias.dist_spec = ReplicaSpec() + model.cls.predictions.bias.has_initialized = True + model.cls.predictions.bias.set_tensor_spec(*col_spec) + try: + check_colo_module(model.cls.predictions.decoder, pg=pg, recursive=False) + except Exception as e: + assert 'incorrectly sharded' in str(e) + + +def run_dist(rank, world_size, port): + config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_linear_with_spec('col') + run_linear_with_spec('row') + + +def run_dist_model(rank, world_size, port): + config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + for model_name in ['simple_net', 'bert']: + run_model_with_spec('col', model_name) + run_model_with_spec('row', model_name) + + +def run_dist_check(rank, world_size, port): + config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_check_shared_param() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@pytest.mark.skip("for higher testing speed") +@rerun_if_address_is_in_use() +def test_module_linear_1d(world_size): + spawn(run_dist, world_size) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@pytest.mark.skip("for higher testing speed") +@rerun_if_address_is_in_use() +def test_module_model(world_size): + spawn(run_dist_model, world_size) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2]) +@pytest.mark.skip("for higher testing speed") +@rerun_if_address_is_in_use() +def test_module_check(world_size): + spawn(run_dist_check, world_size) + + +if __name__ == '__main__': + test_module_linear_1d(4) diff --git a/tests/test_tensor/test_colo_checkpoint_tools.py b/tests/test_tensor/test_colo_checkpoint_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..a53a3f37a664134e9b0f379ff3ad7f8b5abd5b16 --- /dev/null +++ b/tests/test_tensor/test_colo_checkpoint_tools.py @@ -0,0 +1,41 @@ +import pytest +import torch +import torch.distributed as dist + +import colossalai +from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils.checkpoint.utils import gather_tensor, scatter_tensor +from tests.test_tensor.common_utils import tensor_shard_equal + + +def run_dist(rank, world_size, port, dp_degree, tp_degree): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + pg = ProcessGroup(dp_degree=dp_degree, tp_degree=tp_degree) + x = torch.randn(4, 4) + param = ColoTensor(torch.nn.Parameter(x), spec=ColoTensorSpec(pg)) + spec = ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D) + param.set_tensor_spec(*spec) + + gather_tensor(param) + if dist.get_rank() == 0: + assert torch.all(x == param) + else: + assert tensor_shard_equal(x, param.data, pg.tp_local_rank(), pg.tp_world_size()) + dist.barrier() + + scatter_tensor(param, spec[0]) + assert tensor_shard_equal(x, param.data, pg.tp_local_rank(), pg.tp_world_size()) + assert param.requires_grad is True + dist.barrier() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [4]) +@rerun_if_address_is_in_use() +def test_checkpoint(world_size): + spawn(run_dist, world_size, dp_degree=2, tp_degree=world_size // 2) + + +if __name__ == '__main__': + test_checkpoint(world_size=4) diff --git a/tests/test_tensor/test_comm_spec_apply.py b/tests/test_tensor/test_comm_spec_apply.py new file mode 100644 index 0000000000000000000000000000000000000000..2c68633aabc819706e35f69dc4ebfbb521dc1972 --- /dev/null +++ b/tests/test_tensor/test_comm_spec_apply.py @@ -0,0 +1,219 @@ +import pytest +import torch + +from colossalai.core import global_context as gpc +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec +from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +def check_all_gather(device_mesh, rank): + # tensor to comm + if rank in (0, 2): + sharded_tensor_to_comm = torch.ones(2, 2).cuda() + else: + sharded_tensor_to_comm = torch.zeros(2, 2).cuda() + + # tensor to check + tensor_to_check = torch.cat((torch.ones(2, 2), torch.zeros(2, 2)), 1).cuda() + + # test all gather + dim_partition_dict = {1: [1]} + + # DistSpec: + # shard_sequence: R,S1 + # device_mesh_shape: (2, 2) + sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + # CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1) + comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + sharding_spec, + gather_dim=1, + logical_process_axis=1) + sharded_tensor_to_comm = sharded_tensor_to_comm = comm_spec.covert_spec_to_action(sharded_tensor_to_comm) + + assert sharded_tensor_to_comm.equal(tensor_to_check) + + +def check_shard(device_mesh, rank): + # tensor to comm + sharded_tensor_to_comm_0 = torch.zeros(2, 2).cuda() + sharded_tensor_to_comm_1 = torch.ones(2, 2).cuda() + # tensor([[0., 0., 1., 1.], + # [0., 0., 1., 1.]]) + tensor_to_shard = torch.cat((sharded_tensor_to_comm_0, sharded_tensor_to_comm_1), 1) + + # test shard + dim_partition_dict = {} + + # DistSpec: + # shard_sequence: R,R + # device_mesh_shape: (2, 2) + sharding_spec = ShardingSpec(device_mesh, tensor_to_shard.shape, dim_partition_dict=dim_partition_dict) + + # CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1) + comm_spec = CommSpec(CollectiveCommPattern.SPLIT_FWD_GATHER_BWD, sharding_spec, shard_dim=1, logical_process_axis=1) + tensor_to_shard = comm_spec.covert_spec_to_action(tensor_to_shard) + + if rank in (0, 2): + assert tensor_to_shard.equal(sharded_tensor_to_comm_0) + if rank in (1, 3): + assert tensor_to_shard.equal(sharded_tensor_to_comm_1) + + +def check_all_to_all(device_mesh, rank): + # tensor to comm + if rank in (0, 1): + sharded_tensor_0 = torch.zeros(2, 1) + sharded_tensor_1 = torch.ones(2, 1) + # tensor([[0., 1.], + # [0., 1.]]) + tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda() + if rank in (2, 3): + sharded_tensor_0 = torch.ones(2, 1) * 2 + sharded_tensor_1 = torch.ones(2, 1) * 3 + # tensor([[2., 3.], + # [2., 3.]]) + tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda() + + if rank in (0, 1): + # tensor([[0.], + # [0.], + # [2.], + # [2.]]) + tensor_to_check = torch.tensor([[0], [0], [2], [2]], dtype=tensor_to_comm.dtype).cuda() + if rank in (2, 3): + # tensor([[1.], + # [1.], + # [3.], + # [3.]]) + tensor_to_check = torch.tensor([[1], [1], [3], [3]], dtype=tensor_to_comm.dtype).cuda() + + # test shard + dim_partition_dict = {0: [0]} + + # DistSpec: + # shard_sequence: S0,R + # device_mesh_shape: (2, 2) + sharding_spec = ShardingSpec(device_mesh, torch.Size((4, 2)), dim_partition_dict=dim_partition_dict) + + # CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1) + comm_spec = CommSpec(CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD, + sharding_spec, + gather_dim=0, + shard_dim=1, + logical_process_axis=0) + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_all_reduce_fwd(device_mesh, rank): + # tensor to comm + tensor_to_comm = torch.ones(2, 2).cuda() * rank + + # reduce through logical process axis 0 + # tensor to check + if rank in (0, 2): + # tensor([[2., 2.], + # [2., 2.]]) + tensor_to_check = torch.tensor([[2, 2], [2, 2]], dtype=tensor_to_comm.dtype).cuda() + if rank in (1, 3): + # tensor([[4., 4.], + # [4., 4.]]) + tensor_to_check = torch.tensor([[4, 4], [4, 4]], dtype=tensor_to_comm.dtype).cuda() + + dim_partition_dict = {} + # DistSpec: + # shard_sequence: R,R + # device_mesh_shape: (2, 2) + sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict) + + comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, sharding_spec, logical_process_axis=0) + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_all_reduce_bwd(device_mesh, rank): + # tensor to comm + tensor_to_comm = torch.ones(2, 2).cuda() * rank + + tensor_to_check = torch.ones(2, 2).cuda() * rank + + dim_partition_dict = {} + # DistSpec: + # shard_sequence: R,R + # device_mesh_shape: (2, 2) + sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict) + + comm_spec = CommSpec(CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, sharding_spec, logical_process_axis=0) + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_all_reduce_in_flatten_device_mesh(device_mesh, rank): + # tensor to comm + tensor_to_comm = torch.ones(2, 2).cuda() * rank + + # reduce through logical process axis 0 at flatten device mesh + # tensor to check + # tensor([[6., 6.], + # [6., 6.]]) + tensor_to_check = torch.tensor([[6, 6], [6, 6]], dtype=tensor_to_comm.dtype).cuda() + + dim_partition_dict = {} + # DistSpec: + # shard_sequence: R,R + # device_mesh_shape: (2, 2) + sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict) + + # CommSpec:(comm_pattern:all_reduce, logical_process_axis:[0, 1]) + comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, sharding_spec, logical_process_axis=[0, 1]) + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_comm(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + physical_mesh_id = torch.arange(0, 4) + assert rank == gpc.get_global_rank() + + mesh_shape = (2, 2) + # [[0, 1, + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + # test all gather + check_all_gather(device_mesh, rank) + + # test shard + check_shard(device_mesh, rank) + + # test all to all + check_all_to_all(device_mesh, rank) + + # test all reduce + check_all_reduce_fwd(device_mesh, rank) + check_all_reduce_bwd(device_mesh, rank) + + # test all reduce in 1D flatten device mesh + check_all_reduce_in_flatten_device_mesh(device_mesh, rank) + gpc.destroy() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_comm_spec(): + world_size = 4 + spawn(check_comm, world_size) + + +if __name__ == '__main__': + test_comm_spec() diff --git a/tests/test_tensor/test_context.py b/tests/test_tensor/test_context.py new file mode 100644 index 0000000000000000000000000000000000000000..45def034ba8e6f592dfe294d16561558d2139c3b --- /dev/null +++ b/tests/test_tensor/test_context.py @@ -0,0 +1,64 @@ +import pytest +import torch + +import colossalai +from colossalai.tensor import ( + ColoParameter, + ColoTensorSpec, + ComputePattern, + ComputeSpec, + ProcessGroup, + ReplicaSpec, + ShardSpec, +) +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils.cuda import get_current_device +from colossalai.zero import ColoInitContext +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_tensor.common_utils import set_seed + + +def run_colo_init_context(rank: int, world_size: int, port: int): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + # make sure seed of each process is the same, so the params are consistent among processes and the params are exactly replicated. + set_seed(42) + get_components_func = non_distributed_component_funcs.get_callable('gpt2') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + # keep parameters replicated during init + with ColoInitContext(device=get_current_device()): + model1 = model_builder() + + # shard the parameters during init + set_seed(42) + shard_spec = ReplicaSpec() + + # If using ShardSpec, the assertations will failed. + # But it is not a bug, the initialized values are not consist with the original one. + # shard_spec = ShardSpec(dims=[0], num_partitions=[world_size]) + default_pg = ProcessGroup(tp_degree=world_size) + with ColoInitContext(device=get_current_device(), default_pg=default_pg, default_dist_spec=shard_spec): + model2 = model_builder() + + # reshard both models + new_shard = ShardSpec(dims=[-1], num_partitions=[world_size]) + for p1, p2 in zip(model1.parameters(), model2.parameters()): + p1: ColoParameter = p1 + p1.set_process_group(ProcessGroup(tp_degree=world_size)) + p1.set_dist_spec(new_shard) + p2.set_dist_spec(new_shard) + + for p1, p2 in zip(model1.parameters(), model2.parameters()): + assert (torch.allclose(p1, p2)) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_colo_init_context(world_size): + spawn(run_colo_init_context, world_size) + + +if __name__ == '__main__': + test_colo_init_context(2) diff --git a/tests/test_tensor/test_dtensor/test_comm_spec.py b/tests/test_tensor/test_dtensor/test_comm_spec.py new file mode 100644 index 0000000000000000000000000000000000000000..d1f5b9299397f0b63e6c66e072878bdcf0059afd --- /dev/null +++ b/tests/test_tensor/test_dtensor/test_comm_spec.py @@ -0,0 +1,185 @@ +import pytest +import torch +import torch.distributed as dist +from torch.distributed import ReduceOp + +from colossalai.core import global_context as gpc +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern, CommSpec +from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +def check_all_gather(process_groups_dict, rank): + # tensor to comm + if rank in (0, 2): + sharded_tensor_to_comm = torch.ones(2, 2).cuda() + else: + sharded_tensor_to_comm = torch.zeros(2, 2).cuda() + + # tensor to check + tensor_to_check = torch.cat((torch.ones(2, 2), torch.zeros(2, 2)), 1).cuda() + + # CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1) + comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + process_groups_dict, + gather_dim=1, + logical_process_axis=1) + sharded_tensor_to_comm = sharded_tensor_to_comm = comm_spec.covert_spec_to_action(sharded_tensor_to_comm) + + assert sharded_tensor_to_comm.equal(tensor_to_check) + + +def check_shard(process_groups_dict, rank): + # tensor to comm + sharded_tensor_to_comm_0 = torch.zeros(2, 2).cuda() + sharded_tensor_to_comm_1 = torch.ones(2, 2).cuda() + # tensor([[0., 0., 1., 1.], + # [0., 0., 1., 1.]]) + tensor_to_shard = torch.cat((sharded_tensor_to_comm_0, sharded_tensor_to_comm_1), 1) + + # CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1) + comm_spec = CommSpec(CollectiveCommPattern.SPLIT_FWD_GATHER_BWD, + process_groups_dict, + shard_dim=1, + logical_process_axis=1) + tensor_to_shard = comm_spec.covert_spec_to_action(tensor_to_shard) + + if rank in (0, 2): + assert tensor_to_shard.equal(sharded_tensor_to_comm_0) + if rank in (1, 3): + assert tensor_to_shard.equal(sharded_tensor_to_comm_1) + + +def check_all_to_all(process_groups_dict, rank): + # tensor to comm + if rank in (0, 1): + sharded_tensor_0 = torch.zeros(2, 1) + sharded_tensor_1 = torch.ones(2, 1) + # tensor([[0., 1.], + # [0., 1.]]) + tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda() + if rank in (2, 3): + sharded_tensor_0 = torch.ones(2, 1) * 2 + sharded_tensor_1 = torch.ones(2, 1) * 3 + # tensor([[2., 3.], + # [2., 3.]]) + tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda() + + if rank in (0, 1): + # tensor([[0.], + # [0.], + # [2.], + # [2.]]) + tensor_to_check = torch.tensor([[0], [0], [2], [2]], dtype=tensor_to_comm.dtype).cuda() + if rank in (2, 3): + # tensor([[1.], + # [1.], + # [3.], + # [3.]]) + tensor_to_check = torch.tensor([[1], [1], [3], [3]], dtype=tensor_to_comm.dtype).cuda() + + # CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1) + comm_spec = CommSpec(CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD, + process_groups_dict, + gather_dim=0, + shard_dim=1, + logical_process_axis=0) + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_all_reduce_fwd(process_groups_dict, rank): + # tensor to comm + tensor_to_comm = torch.ones(2, 2).cuda() * rank + + # reduce through logical process axis 0 + # tensor to check + if rank in (0, 2): + # tensor([[2., 2.], + # [2., 2.]]) + tensor_to_check = torch.tensor([[2, 2], [2, 2]], dtype=tensor_to_comm.dtype).cuda() + if rank in (1, 3): + # tensor([[4., 4.], + # [4., 4.]]) + tensor_to_check = torch.tensor([[4, 4], [4, 4]], dtype=tensor_to_comm.dtype).cuda() + + comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, process_groups_dict, logical_process_axis=0) + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_all_reduce_bwd(process_groups_dict, rank): + # tensor to comm + tensor_to_comm = torch.ones(2, 2).cuda() * rank + + tensor_to_check = torch.ones(2, 2).cuda() * rank + + comm_spec = CommSpec(CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, process_groups_dict, logical_process_axis=0) + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_all_reduce_in_flatten_device_mesh(process_groups_dict, rank): + # tensor to comm + tensor_to_comm = torch.ones(2, 2).cuda() * rank + + # reduce through logical process axis 0 at flatten device mesh + # tensor to check + # tensor([[6., 6.], + # [6., 6.]]) + tensor_to_check = torch.tensor([[6, 6], [6, 6]], dtype=tensor_to_comm.dtype).cuda() + + # CommSpec:(comm_pattern:all_reduce, logical_process_axis:[0, 1]) + comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, process_groups_dict, logical_process_axis=0) + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_comm(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + physical_mesh_id = torch.arange(0, 4) + assert rank == gpc.get_global_rank() + + mesh_shape = (2, 2) + # [[0, 1, + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + process_groups_dict = device_mesh.process_groups_dict + + # test all gather + check_all_gather(process_groups_dict, rank) + + # test shard + check_shard(process_groups_dict, rank) + + # test all to all + check_all_to_all(process_groups_dict, rank) + + # test all reduce + check_all_reduce_fwd(process_groups_dict, rank) + check_all_reduce_bwd(process_groups_dict, rank) + + flatten_process_groups_dict = device_mesh.flatten_device_mesh.process_groups_dict + # test all reduce in 1D flatten device mesh + check_all_reduce_in_flatten_device_mesh(flatten_process_groups_dict, rank) + gpc.destroy() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_comm_spec(): + world_size = 4 + spawn(check_comm, world_size) + + +if __name__ == '__main__': + test_comm_spec() diff --git a/tests/test_tensor/test_dtensor/test_dtensor.py b/tests/test_tensor/test_dtensor/test_dtensor.py new file mode 100644 index 0000000000000000000000000000000000000000..3ca369acbf87c6cc9785c9304fa3ed06bdae9661 --- /dev/null +++ b/tests/test_tensor/test_dtensor/test_dtensor.py @@ -0,0 +1,99 @@ +import torch + +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.tensor.d_tensor.d_tensor import DTensor, distribute_tensor +from colossalai.tensor.d_tensor.layout import Layout +from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +class TestModel(torch.nn.Module): + + def __init__(self, in_features, out_features): + super().__init__() + self.linear_1 = torch.nn.Linear(in_features, out_features) + self.linear_2 = torch.nn.Linear(out_features, in_features) + + def forward(self, x): + x = self.linear_1(x) + x = self.linear_2(x) + return x + + +def check_dtensor(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + test_model = TestModel(8, 8).to('cuda') + original_tensor = torch.rand(4, 8).to('cuda') + compare_output = test_model(original_tensor) + + device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True) + target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]}) + layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=target_sharding_spec, + entire_shape=original_tensor.shape) + d_tensor = DTensor(original_tensor, layout) + + assert d_tensor.entire_shape == original_tensor.shape + assert d_tensor.data_type == original_tensor.dtype + + if rank in (0, 1): + assert d_tensor.to_local().equal(original_tensor.narrow(0, 0, 2)) + elif rank in (2, 3): + assert d_tensor.to_local().equal(original_tensor.narrow(0, 2, 2)) + else: + raise ValueError(f'rank {rank} is not in the device mesh') + assert d_tensor.to_global().equal(original_tensor) + output = test_model(d_tensor) + + if rank in (0, 1): + assert output.equal(compare_output.narrow(0, 0, 2)) + elif rank in (2, 3): + assert output.equal(compare_output.narrow(0, 2, 2)) + else: + raise ValueError(f'rank {rank} is not in the device mesh') + + new_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0, 1]}) + new_layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=new_sharding_spec, + entire_shape=original_tensor.shape) + + d_tensor.layout_convert(new_layout) + + if rank == 0: + assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 0, 1)) + elif rank == 1: + assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 1, 1)) + elif rank == 2: + assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 2, 1)) + elif rank == 3: + assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 3, 1)) + else: + raise ValueError(f'rank {rank} is not in the device mesh') + + dtensor_from_local = distribute_tensor(original_tensor, new_layout) + + if rank == 0: + assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 0, 1)) + elif rank == 1: + assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 1, 1)) + elif rank == 2: + assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 2, 1)) + elif rank == 3: + assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 3, 1)) + else: + raise ValueError(f'rank {rank} is not in the device mesh') + + +@rerun_if_address_is_in_use() +def test_dtensor(): + world_size = 4 + spawn(check_dtensor, world_size) + + +if __name__ == '__main__': + test_dtensor() diff --git a/tests/test_tensor/test_dtensor/test_dtensor_sharding_spec.py b/tests/test_tensor/test_dtensor/test_dtensor_sharding_spec.py new file mode 100644 index 0000000000000000000000000000000000000000..7fd1c3d90fc4c1c4810b8b177c92ec1d514e0d29 --- /dev/null +++ b/tests/test_tensor/test_dtensor/test_dtensor_sharding_spec.py @@ -0,0 +1,34 @@ +import operator +from functools import reduce + +from colossalai.tensor.d_tensor.sharding_spec import ALLGATHER_COST, SHARD_COST, STEP_PENALTY, ShardingSpec + + +def test_dtensor_sharding_spec(): + dims = 4 + dim_partition_dict_0 = {0: [0, 1]} + # DistSpec: + # shard_sequence: S01,R,R,R + sharding_spec_0 = ShardingSpec(dims, dim_partition_dict=dim_partition_dict_0) + assert str(sharding_spec_0.sharding_sequence) == "[S01, R, R, R]" + + dim_partition_dict_1 = {1: [0, 1]} + # DistSpec: + # shard_sequence: R,S01,R,R + sharding_spec_1 = ShardingSpec(dims, dim_partition_dict=dim_partition_dict_1) + assert str(sharding_spec_1.sharding_sequence) == "[R, S01, R, R]" + + dim_spec_list_0 = [dim_spec for dim_spec in sharding_spec_0.sharding_sequence] + dim_spec_list_1 = [dim_spec for dim_spec in sharding_spec_1.sharding_sequence] + + assert dim_spec_list_0[0].dim_diff(dim_spec_list_1[0]) == ALLGATHER_COST + STEP_PENALTY + ALLGATHER_COST + assert dim_spec_list_0[1].dim_diff(dim_spec_list_1[1]) == SHARD_COST + STEP_PENALTY + SHARD_COST + assert dim_spec_list_0[2].dim_diff(dim_spec_list_1[2]) == 0 + assert dim_spec_list_0[3].dim_diff(dim_spec_list_1[3]) == 0 + + assert sharding_spec_0.spec_diff(sharding_spec_1) == \ + reduce(operator.add, [dim_spec_list_0[i].dim_diff(dim_spec_list_1[i]) for i in range(dims)], 0) + + +if __name__ == '__main__': + test_dtensor_sharding_spec() diff --git a/tests/test_tensor/test_dtensor/test_layout_converter.py b/tests/test_tensor/test_dtensor/test_layout_converter.py new file mode 100644 index 0000000000000000000000000000000000000000..5f56decb5e5dd03a132c7b1ff2df9a64d63d7da9 --- /dev/null +++ b/tests/test_tensor/test_dtensor/test_layout_converter.py @@ -0,0 +1,198 @@ +import math + +import pytest +import torch + +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern +from colossalai.tensor.d_tensor.layout import Layout +from colossalai.tensor.d_tensor.layout_converter import LayoutConverter +from colossalai.tensor.d_tensor.sharding_spec import DimSpec, ShardingSpec +from colossalai.testing import rerun_if_address_is_in_use, spawn + +entire_shape = torch.Size((64, 32, 16)) +layout_converter = LayoutConverter() +physical_mesh_id = torch.arange(0, 4).reshape(2, 2) +mesh_shape = (2, 2) + + +def check_one_step_transform(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + # [[0, 1], + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + dim_partition_dict = {0: [0], 1: [1]} + # DistSpec: + # shard_sequence: S0,S1,R + # device_mesh_shape: (2, 2) + sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) + layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=sharding_spec, + entire_shape=entire_shape) + + rst_dict = layout_converter.all_gather_transform_layouts(layout) + + assert '[R, S1, R]' in [ + str(all_gather_layout.sharding_spec.sharding_sequence) for all_gather_layout in rst_dict.keys() + ] + assert '[S0, R, R]' in [ + str(all_gather_layout.sharding_spec.sharding_sequence) for all_gather_layout in rst_dict.keys() + ] + + dim_partition_dict_all2all = {0: [0], 1: [1]} + # DistSpec: + # shard_sequence: S0,S1,R + # device_mesh_shape: (4, 4) + sharding_spec_all2all = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict_all2all) + layout_all2all = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=sharding_spec_all2all, + entire_shape=entire_shape) + + rst_dict_all2all = layout_converter.all_to_all_transform_layout(layout_all2all) + + assert '[S01, R, R]' in [ + str(all2all_layout.sharding_spec.sharding_sequence) for all2all_layout in rst_dict_all2all.keys() + ] + assert '[R, S1, S0]' in [ + str(all2all_layout.sharding_spec.sharding_sequence) for all2all_layout in rst_dict_all2all.keys() + ] + assert '[S0, R, S1]' in [ + str(all2all_layout.sharding_spec.sharding_sequence) for all2all_layout in rst_dict_all2all.keys() + ] + + dim_partition_shard = {0: [0]} + # DistSpec: + # shard_sequence: S0,R,R + # device_mesh_shape: (4, 4) + sharding_spec_shard = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_shard) + shard_layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=sharding_spec_shard, + entire_shape=entire_shape) + + rst_dict_shard = layout_converter.shard_transform_layout(shard_layout) + + assert '[S01, R, R]' in [ + str(shard_layout.sharding_spec.sharding_sequence) for shard_layout in rst_dict_shard.keys() + ] + assert '[S0, S1, R]' in [ + str(shard_layout.sharding_spec.sharding_sequence) for shard_layout in rst_dict_shard.keys() + ] + assert '[S0, R, S1]' in [ + str(shard_layout.sharding_spec.sharding_sequence) for shard_layout in rst_dict_shard.keys() + ] + + +def check_layout_converting(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + dim_partition_source = {1: [0, 1]} + dim_partition_target = {0: [0, 1]} + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # DistSpec: + # shard_sequence: R,S01,R + # device_mesh_shape: (4, 4) + sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) + source_layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=sharding_spec_source, + entire_shape=entire_shape) + + # DistSpec: + # shard_sequence: S01,R,R + # device_mesh_shape: (4, 4) + sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) + target_layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=sharding_spec_target, + entire_shape=entire_shape) + + transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout) + + # check transform path + transform_path_str = '->'.join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path]) + assert transform_path_str == '[R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R]' + + # check comm action sequence + # all-gather(S01) -> S0 + assert comm_action_sequence[0].comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD + assert comm_action_sequence[0].gather_dim == 1 + assert comm_action_sequence[0].logical_process_axis == 1 + + # all-to-all(R, S0) -> [S0, R] + assert comm_action_sequence[1].comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD + assert comm_action_sequence[1].gather_dim == 1 + assert comm_action_sequence[1].shard_dim == 0 + assert comm_action_sequence[1].logical_process_axis == 0 + + # shard(S0) -> [S01] + assert comm_action_sequence[2].comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD + assert comm_action_sequence[2].shard_dim == 0 + assert comm_action_sequence[2].logical_process_axis == 1 + + # checkout chached_spec_pairs_transform_path + assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][0] == transform_path + assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][1] == comm_action_sequence + + comm_cost = layout_converter.get_total_comm_cost(source_layout, target_layout) + + assert comm_cost['forward'] == comm_cost['backward'] + assert math.floor(comm_cost['total']) == math.floor(comm_cost['forward'] + comm_cost['backward']) + + +def check_layout_converting_apply(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + dim_partition_source = {1: [0, 1]} + dim_partition_target = {0: [0, 1]} + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # DistSpec: + # shard_sequence: R,S01,R + # device_mesh_shape: (4, 4) + sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) + source_layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=sharding_spec_source, + entire_shape=entire_shape) + + # DistSpec: + # shard_sequence: S01,R,R + # device_mesh_shape: (4, 4) + sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) + target_layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=sharding_spec_target, + entire_shape=entire_shape) + + original_tensor = torch.rand(entire_shape).cuda() + + # tensor_to_apply: [R, S01, R] + tensor_to_apply = original_tensor.narrow(1, rank * 8, 8) + + # tensor_to_check: [S01, R, R] + tensor_to_check = original_tensor.narrow(0, rank * 16, 16) + + converted_tensor = layout_converter.apply(tensor_to_apply, source_layout, target_layout) + assert converted_tensor.equal(tensor_to_check) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_layout_converter(): + world_size = 4 + spawn(check_one_step_transform, world_size) + spawn(check_layout_converting, world_size) + spawn(check_layout_converting_apply, world_size) + + +if __name__ == '__main__': + test_layout_converter() diff --git a/tests/test_tensor/test_mix_gather.py b/tests/test_tensor/test_mix_gather.py new file mode 100644 index 0000000000000000000000000000000000000000..9122808eb5a3ef841ab7af7344822abf90e7447f --- /dev/null +++ b/tests/test_tensor/test_mix_gather.py @@ -0,0 +1,330 @@ +import pytest +import torch + +from colossalai.core import global_context as gpc +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec +from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.tensor.utils import mix_gather_simulator +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +def check_mix_gather_S0S1(device_mesh, rank): + tensor_to_check = torch.arange(64).reshape((8, 8)).cuda() + (f, b) = (0, 1) + f_target_pair = (f, [0]) + b_target_pair = (b, [1]) + gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair) + tensor_slice = [4, 2] # (4, 2) + rank_slice = 4 + f_start = (rank // rank_slice) * tensor_slice[0] + b_start = (rank % rank_slice) * tensor_slice[1] + tensor_to_comm = tensor_to_check[f_start:f_start + tensor_slice[0], + b_start:b_start + tensor_slice[1]].contiguous().cuda() + + dim_partition_dict = {0: [0], 1: [1]} + + # DistSpec: + # shard_sequence: S0,S1 + # device_mesh_shape: (2, 4) + source_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + comm_spec = CommSpec(CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD, + sharding_spec=source_spec, + gather_dim=gather_dim, + logical_process_axis=logical_process_axes, + forward_only=True, + mix_gather=True) + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_two_all_gather_S0S1(device_mesh, rank): + tensor_width = 8 + tensor_to_check = torch.arange(int(tensor_width * tensor_width)).reshape((tensor_width, tensor_width)).cuda() + + dim_partition_dict = {0: [0], 1: [1]} + + tensor_slice = [tensor_width // 2, tensor_width // 4] # (4, 2) + rank_slice = 4 + f_start = (rank // rank_slice) * tensor_slice[0] + b_start = (rank % rank_slice) * tensor_slice[1] + tensor_to_comm = tensor_to_check[f_start:f_start + tensor_slice[0], + b_start:b_start + tensor_slice[1]].contiguous().cuda() + + # DistSpec: + # shard_sequence: S0,S1 + # device_mesh_shape: (2, 4) + sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + # CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:0) + comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + sharding_spec, + gather_dim=0, + logical_process_axis=0) + + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + dim_partition_dict = {1: [1]} + # DistSpec: + # shard_sequence: R,S1 + # device_mesh_shape: (2, 4) + sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + # CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1) + comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + sharding_spec, + gather_dim=1, + logical_process_axis=1) + + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_mix_gather_S1S0(device_mesh, rank): + tensor_to_check = torch.arange(64).reshape((8, 8)).cuda() + (f, b) = (0, 1) + f_target_pair = (f, [1]) + b_target_pair = (b, [0]) + gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair) + tensor_slice = [2, 4] + rank_slice = 4 + f_start = (rank % rank_slice) * tensor_slice[0] + b_start = (rank // rank_slice) * tensor_slice[1] + tensor_to_comm = tensor_to_check[f_start:f_start + tensor_slice[0], + b_start:b_start + tensor_slice[1]].contiguous().cuda() + + dim_partition_dict = {0: [1], 1: [0]} + + # DistSpec: + # shard_sequence: S1,S0 + # device_mesh_shape: (2, 4) + source_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + comm_spec = CommSpec(CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD, + sharding_spec=source_spec, + gather_dim=gather_dim, + logical_process_axis=logical_process_axes, + forward_only=True, + mix_gather=True) + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_two_all_gather_S1S0(device_mesh, rank): + tensor_width = 8 + tensor_to_check = torch.arange(int(tensor_width * tensor_width)).reshape((tensor_width, tensor_width)).cuda() + + tensor_slice = [tensor_width // 4, tensor_width // 2] # (4, 2) + rank_slice = 4 + f_start = (rank % rank_slice) * tensor_slice[0] + b_start = (rank // rank_slice) * tensor_slice[1] + tensor_to_comm = tensor_to_check[f_start:f_start + tensor_slice[0], + b_start:b_start + tensor_slice[1]].contiguous().cuda() + + dim_partition_dict = {0: [1], 1: [0]} + + # DistSpec: + # shard_sequence: S1,S0 + # device_mesh_shape: (2, 4) + sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + # CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:1) + comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + sharding_spec, + gather_dim=0, + logical_process_axis=1) + + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + dim_partition_dict = {1: [0]} + # DistSpec: + # shard_sequence: R,S0 + # device_mesh_shape: (2, 4) + sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + # CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:0) + comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + sharding_spec, + gather_dim=1, + logical_process_axis=0) + + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_mix_gather_S01R(device_mesh, rank): + tensor_to_check = torch.arange(64).reshape((8, 8)).cuda() + (f, b) = (0, 1) + f_target_pair = (f, [0, 1]) + b_target_pair = (b, []) + gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair) + tensor_to_comm = tensor_to_check[rank:rank + 1, :].contiguous().cuda() + + dim_partition_dict = {0: [0, 1]} + # DistSpec: + # shard_sequence: S01,R + # device_mesh_shape: (2, 4) + source_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + comm_spec = CommSpec(CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD, + sharding_spec=source_spec, + gather_dim=gather_dim, + logical_process_axis=logical_process_axes, + forward_only=True, + mix_gather=True) + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_two_all_gather_S01R(device_mesh, rank): + tensor_width = 8 + tensor_to_check = torch.arange(int(tensor_width * tensor_width)).reshape((tensor_width, tensor_width)).cuda() + + rank_stride = tensor_width // 8 + tensor_to_comm = tensor_to_check[rank:rank + rank_stride, :].contiguous().cuda() + + dim_partition_dict = {0: [0, 1]} + + # DistSpec: + # shard_sequence: S01, R + # device_mesh_shape: (2, 4) + sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + # CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:0) + comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + sharding_spec, + gather_dim=0, + logical_process_axis=1) + + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + dim_partition_dict = {0: [0]} + + # DistSpec: + # shard_sequence: S1, R + # device_mesh_shape: (2, 4) + sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + # CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:1) + comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + sharding_spec, + gather_dim=0, + logical_process_axis=0) + + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_mix_gather_RS01(device_mesh, rank): + tensor_to_check = torch.arange(64).reshape((8, 8)).cuda() + + (f, b) = (0, 1) + f_target_pair = (f, []) + b_target_pair = (b, [0, 1]) + gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair) + tensor_to_comm = tensor_to_check[:, rank:rank + 1].contiguous().cuda() + + dim_partition_dict = {1: [0, 1]} + # DistSpec: + # shard_sequence: R, S01 + # device_mesh_shape: (2, 4) + source_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + comm_spec = CommSpec(CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD, + sharding_spec=source_spec, + gather_dim=gather_dim, + logical_process_axis=logical_process_axes, + forward_only=True, + mix_gather=True) + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_two_all_gather_RS01(device_mesh, rank): + tensor_width = 8 + tensor_to_check = torch.arange(int(tensor_width * tensor_width)).reshape((tensor_width, tensor_width)).cuda() + + rank_stride = tensor_width // 8 + tensor_to_comm = tensor_to_check[:, rank:rank + rank_stride].contiguous().cuda() + + dim_partition_dict = {1: [0, 1]} + + # DistSpec: + # shard_sequence: R, S01 + # device_mesh_shape: (2, 4) + sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + # CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:0) + comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + sharding_spec, + gather_dim=1, + logical_process_axis=1) + + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + dim_partition_dict = {1: [0]} + + # DistSpec: + # shard_sequence: R, S1 + # device_mesh_shape: (2, 4) + sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + # CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1) + comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + sharding_spec, + gather_dim=1, + logical_process_axis=0) + + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_comm(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + physical_mesh_id = torch.arange(0, 8) + assert rank == gpc.get_global_rank() + + mesh_shape = (2, 4) + # [[0, 1, 2, 3], + # [4, 5, 6, 7]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True, need_flatten=True) + + check_mix_gather_S0S1(device_mesh, rank) + + check_two_all_gather_S0S1(device_mesh, rank) + + check_mix_gather_S1S0(device_mesh, rank) + + check_two_all_gather_S1S0(device_mesh, rank) + + check_mix_gather_S01R(device_mesh, rank) + + check_two_all_gather_S01R(device_mesh, rank) + + check_mix_gather_RS01(device_mesh, rank) + + check_two_all_gather_RS01(device_mesh, rank) + + +@pytest.mark.skip(reason="Skip because the check functions assume 8 GPUS but CI only have 4 GPUs") +@rerun_if_address_is_in_use() +def test_mix_gather(): + world_size = 8 + spawn(check_comm, world_size) + + +if __name__ == '__main__': + test_mix_gather() diff --git a/tests/test_tensor/test_parameter.py b/tests/test_tensor/test_parameter.py new file mode 100644 index 0000000000000000000000000000000000000000..9c3f05da1ffa8c1fc62851cadf7b572ef0a6d63d --- /dev/null +++ b/tests/test_tensor/test_parameter.py @@ -0,0 +1,34 @@ +import pytest +import torch +from common_utils import tensor_equal + +import colossalai +from colossalai.tensor import ColoParameter, ColoTensor, ColoTensorSpec, ProcessGroup +from colossalai.testing import free_port + + +@pytest.mark.skip +def test_multiinheritance(): + colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl') + colo_param = ColoParameter(None, requires_grad=True) + assert colo_param.dist_spec.placement.value == 'r' + assert isinstance(colo_param, ColoTensor) + assert isinstance(colo_param, torch.nn.Parameter) + + # __deepcopy__ overload + import copy + colo_param2 = copy.deepcopy(colo_param) + assert isinstance(colo_param2, ColoParameter) + assert tensor_equal(colo_param.data, colo_param2.data) + assert colo_param.requires_grad == colo_param2.requires_grad + + # __repr__ overload + assert 'ColoParameter' in str(colo_param) + + # __torch_function__ + clone_param = torch.clone(colo_param) + assert isinstance(clone_param, ColoTensor) + + +if __name__ == '__main__': + test_multiinheritance() diff --git a/tests/test_tensor/test_shape_consistency.py b/tests/test_tensor/test_shape_consistency.py new file mode 100644 index 0000000000000000000000000000000000000000..6fe9ee292cd014ae37d41ec5df4f7fcd6ccb2d99 --- /dev/null +++ b/tests/test_tensor/test_shape_consistency.py @@ -0,0 +1,144 @@ +from colossalai.tensor.shape_consistency import ShapeConsistencyManager, CollectiveCommPattern +import torch +from colossalai.tensor.sharding_spec import _DimSpec, ShardingSpec +from colossalai.device.device_mesh import DeviceMesh + +physical_mesh_id = torch.arange(0, 16).reshape(2, 8) +mesh_shape = (4, 4) +# [[0, 1, 2, 3], +# [4, 5, 6, 7], +# [8, 9, 10,11], +# [12,13,14,15]] +device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) +entire_shape = torch.Size((64, 32, 16)) +shape_consistency_manager = ShapeConsistencyManager() + + +def test_one_step_transform(): + + dim_partition_dict = {0: [0], 1: [1]} + # DistSpec: + # shard_sequence: S0,S1,R + # device_mesh_shape: (4, 4) + sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict) + + # {DistSpec: + # shard_sequence: R,S1,R + # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:0), 0), DistSpec: + # shard_sequence: S0,R,R + # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1), 0)} + rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, { + "forward": 0, + "backward": 0, + "total": 0 + }) + + assert '[R, S1, R]' in [ + str(all_gather_sharding_spec.sharding_sequence) for all_gather_sharding_spec in rst_dict.keys() + ] + assert '[S0, R, R]' in [ + str(all_gather_sharding_spec.sharding_sequence) for all_gather_sharding_spec in rst_dict.keys() + ] + + dim_partition_dict_all2all = {0: [0], 1: [1]} + # DistSpec: + # shard_sequence: S0,S1,R + # device_mesh_shape: (4, 4) + sharding_spec_all2all = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_all2all) + # {DistSpec: + # shard_sequence: S01,R,R + # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:1, shard_dim:0, logical_process_axis: 1), 0), DistSpec: + # shard_sequence: R,S1,S0 + # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:0, shard_dim:2, logical_process_axis: 0), 0), DistSpec: + # shard_sequence: S0,R,S1 + # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:1, shard_dim:2, logical_process_axis: 1), 0)} + rst_dict_all2all = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec_all2all, { + "forward": 0, + "backward": 0, + "total": 0 + }) + + assert '[S01, R, R]' in [ + str(all2all_sharding_spec.sharding_sequence) for all2all_sharding_spec in rst_dict_all2all.keys() + ] + assert '[R, S1, S0]' in [ + str(all2all_sharding_spec.sharding_sequence) for all2all_sharding_spec in rst_dict_all2all.keys() + ] + assert '[S0, R, S1]' in [ + str(all2all_sharding_spec.sharding_sequence) for all2all_sharding_spec in rst_dict_all2all.keys() + ] + + dim_partition_shard = {0: [0]} + # DistSpec: + # shard_sequence: S0,R,R + # device_mesh_shape: (4, 4) + sharding_spec_shard = ShardingSpec(device_mesh, entire_shape, dim_partition_shard) + # {DistSpec: + # shard_sequence: S01,R,R + # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:0, logical_process_axis:1), 0), DistSpec: + # shard_sequence: S0,S1,R + # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1), 0), DistSpec: + # shard_sequence: S0,R,S1 + # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:2, logical_process_axis:1), 0)} + rst_dict_shard = shape_consistency_manager.get_all_shard_spec(sharding_spec_shard, { + "forward": 0, + "backward": 0, + "total": 0 + }) + + assert '[S01, R, R]' in [ + str(shard_sharding_spec.sharding_sequence) for shard_sharding_spec in rst_dict_shard.keys() + ] + assert '[S0, S1, R]' in [ + str(shard_sharding_spec.sharding_sequence) for shard_sharding_spec in rst_dict_shard.keys() + ] + assert '[S0, R, S1]' in [ + str(shard_sharding_spec.sharding_sequence) for shard_sharding_spec in rst_dict_shard.keys() + ] + + +def test_shape_consistency(): + dim_partition_source = {1: [0, 1]} + dim_partition_target = {0: [0, 1]} + + # DistSpec: + # shard_sequence: R,S01,R + # device_mesh_shape: (4, 4) + sharding_spec_source = ShardingSpec(device_mesh, entire_shape, dim_partition_source) + + # DistSpec: + # shard_sequence: S01,R,R + # device_mesh_shape: (4, 4) + sharding_spec_target = ShardingSpec(device_mesh, entire_shape, dim_partition_target) + + transform_path, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency( + sharding_spec_source, sharding_spec_target) + + transform_path_str = '->'.join([str(sharding_spec.sharding_sequence) for sharding_spec in transform_path]) + assert transform_path_str == '[R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R]' + + # all-gather(S01) -> S0 + assert comm_action_sequence[0].comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD + assert comm_action_sequence[0].gather_dim == 1 + assert comm_action_sequence[0].logical_process_axis == 1 + + # all-to-all(R, S0) -> [S0, R] + assert comm_action_sequence[1].comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD + assert comm_action_sequence[1].gather_dim == 1 + assert comm_action_sequence[1].shard_dim == 0 + assert comm_action_sequence[1].logical_process_axis == 0 + + # shard(S0) -> [S01] + assert comm_action_sequence[2].comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD + assert comm_action_sequence[2].shard_dim == 0 + assert comm_action_sequence[2].logical_process_axis == 1 + + assert shape_consistency_manager.cached_spec_pairs_transform_path[('[R, S01, R]', + '[S01, R, R]')][0] == transform_path + assert shape_consistency_manager.cached_spec_pairs_transform_path[('[R, S01, R]', + '[S01, R, R]')][1] == comm_action_sequence + + +if __name__ == '__main__': + test_one_step_transform() + test_shape_consistency() diff --git a/tests/test_tensor/test_shape_consistency_apply.py b/tests/test_tensor/test_shape_consistency_apply.py new file mode 100644 index 0000000000000000000000000000000000000000..b57952df401fdcbe99299077e7c5618c137045e6 --- /dev/null +++ b/tests/test_tensor/test_shape_consistency_apply.py @@ -0,0 +1,76 @@ +import pytest +import torch + +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.tensor.shape_consistency import CollectiveCommPattern, ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +def check_apply(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1, + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + entire_shape = torch.Size((4, 2)) + shape_consistency_manager = ShapeConsistencyManager() + dim_partition_source = {0: [0]} + dim_partition_target = {1: [0]} + + # DistSpec: + # shard_sequence: S0,R + # device_mesh_shape: (2, 2) + sharding_spec_source = ShardingSpec(device_mesh, entire_shape, dim_partition_source) + + # DistSpec: + # shard_sequence: R,S0 + # device_mesh_shape: (2, 2) + sharding_spec_target = ShardingSpec(device_mesh, entire_shape, dim_partition_target) + + if rank in (0, 1): + sharded_tensor_0 = torch.zeros(2, 1) + sharded_tensor_1 = torch.ones(2, 1) + # tensor([[0., 1.], + # [0., 1.]]) + tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda() + if rank in (2, 3): + sharded_tensor_0 = torch.ones(2, 1) * 2 + sharded_tensor_1 = torch.ones(2, 1) * 3 + # tensor([[2., 3.], + # [2., 3.]]) + tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda() + + if rank in (0, 1): + # tensor([[0.], + # [0.], + # [2.], + # [2.]]) + tensor_to_check = torch.tensor([[0], [0], [2], [2]], dtype=tensor_to_comm.dtype).cuda() + if rank in (2, 3): + # tensor([[1.], + # [1.], + # [3.], + # [3.]]) + tensor_to_check = torch.tensor([[1], [1], [3], [3]], dtype=tensor_to_comm.dtype).cuda() + + tensor_to_comm.sharding_spec = sharding_spec_source + tensor_to_comm = shape_consistency_manager.apply(tensor_to_comm, sharding_spec_target) + assert tensor_to_comm.equal(tensor_to_check) + assert str(tensor_to_comm.sharding_spec.sharding_sequence) == str(sharding_spec_target.sharding_sequence) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_apply(): + world_size = 4 + spawn(check_apply, world_size) + + +if __name__ == '__main__': + test_apply() diff --git a/tests/test_tensor/test_sharded_linear.py b/tests/test_tensor/test_sharded_linear.py new file mode 100644 index 0000000000000000000000000000000000000000..d66d4fec14d11142a851535114e277c32c97c3c4 --- /dev/null +++ b/tests/test_tensor/test_sharded_linear.py @@ -0,0 +1,232 @@ +import pytest +import torch +import torch.nn.functional as F + +import colossalai +from colossalai.device.device_mesh import DeviceMesh +from colossalai.nn._ops._utils import gather_forward_split_backward +from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup +from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + # create mlp vars + x = ColoTensor.from_torch_tensor(torch.rand(4, 4, 8, requires_grad=True)).cuda() + w = ColoParameter.from_torch_tensor(torch.rand(16, 8, requires_grad=True)).cuda() + b = ColoParameter.from_torch_tensor(torch.rand(16, requires_grad=True)).cuda() + + # run normal forward + out = F.linear(x, w, b) + + # create mesh meta + # the mesh is in the following topo + # [[0, 1], + # [2, 3]] + physical_mesh_id = torch.arange(0, 4).reshape(2, 2) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + row_id = rank // 2 + column_id = rank % 2 + + # create pg + row_process_group = None + col_process_group = None + row_to_ranks = {0: [0, 1], 1: [2, 3]} + col_to_ranks = {0: [0, 2], 1: [1, 3]} + + for idx in range(2): + # row ranks + row_ranks = row_to_ranks[idx] + row_pg = ProcessGroup(ranks=row_ranks, tp_degree=2) + + # col ranks + col_ranks = col_to_ranks[idx] + col_pg = ProcessGroup(ranks=col_ranks, tp_degree=2) + + if rank in row_ranks: + row_process_group = row_pg + + if rank in col_ranks: + col_process_group = col_pg + + ######################## + # RRR x RS0 -> RRS0 # + ######################## + # w will be transposed in F.linear + x_replica = x.detach().clone() + w_shard = torch.chunk(w.detach().clone(), chunks=2, dim=0)[row_id] + b_shard = torch.chunk(b.detach().clone(), chunks=2, dim=0)[row_id] + + # adding sharding spec + x_replica.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={}) + w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={0: [0]}) + b_shard.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={0: [0]}) + + # check sharding spec + assert str(x_replica.sharding_spec.sharding_sequence) == "[R, R, R]" + assert str(w_shard.sharding_spec.sharding_sequence) == "[S0, R]" + assert str(b_shard.sharding_spec.sharding_sequence) == "[S0]" + + w_shard.pg_axis0 = col_process_group + w_shard.pg_axis1 = row_process_group + + out_shard = F.linear(x_replica, w_shard, b_shard) + assert str(out_shard.sharding_spec.sharding_sequence) == "[R, R, S0]" + + # each row only has a mini-batch + expected_out_shard = torch.chunk(out, chunks=2, dim=2)[row_id] + assert torch.allclose(out_shard, expected_out_shard) + + ######################## + # S0RR x RS1 -> S0RS1 # + ######################## + # w will be transposed in F.linear + x_shard = torch.chunk(x.detach().clone(), chunks=2, dim=0)[row_id] + w_shard = torch.chunk(w.detach().clone(), chunks=2, dim=0)[column_id] + b_shard = torch.chunk(b.detach().clone(), chunks=2, dim=0)[column_id] + + # adding sharding spec + x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={0: [0]}) + w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={0: [1]}) + b_shard.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={0: [1]}) + + # check sharding spec + assert str(x_shard.sharding_spec.sharding_sequence) == "[S0, R, R]" + assert str(w_shard.sharding_spec.sharding_sequence) == "[S1, R]" + assert str(b_shard.sharding_spec.sharding_sequence) == "[S1]" + + w_shard.pg_axis0 = col_process_group + w_shard.pg_axis1 = row_process_group + + out_shard = F.linear(x_shard, w_shard, b_shard) + + # each row only has a mini-batch + expected_out_shard = torch.chunk(out, chunks=2, dim=0)[row_id] + expected_out_shard = torch.chunk(expected_out_shard, chunks=2, dim=2)[column_id] + assert torch.allclose(out_shard, expected_out_shard) + + ######################## + # S0RS1 x S1R -> S0RR # + ######################## + # w will be transposed in F.linear + x_shard = torch.chunk(x.clone(), chunks=2, dim=0)[row_id] + x_shard = torch.chunk(x_shard, chunks=2, dim=2)[column_id] + w_shard = torch.chunk(w.clone(), chunks=2, dim=1)[column_id] + b_replica = b.clone() + + # adding sharding spec + x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={0: [0], 2: [1]}) + w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={1: [1]}) + b_replica.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={}) + + # check sharding spec + assert str(x_shard.sharding_spec.sharding_sequence) == "[S0, R, S1]" + assert str(w_shard.sharding_spec.sharding_sequence) == "[R, S1]" + assert str(b_replica.sharding_spec.sharding_sequence) == "[R]" + + w_shard.pg_axis0 = col_process_group + w_shard.pg_axis1 = row_process_group + + out_shard = F.linear(x_shard, w_shard, b_replica) + + # each row only has a mini-batch + expected_out_shard = torch.chunk(out, chunks=2, dim=0)[row_id] + assert torch.allclose(out_shard, expected_out_shard) + + ######################## + # RRS0 x S0R -> RRR # + ######################## + # w will be transposed in F.linear + x_shard = torch.chunk(x.clone(), chunks=2, dim=2)[row_id] + w_shard = torch.chunk(w.clone(), chunks=2, dim=1)[row_id] + b_replica = b.clone() + + # adding sharding spec + x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={2: [0]}) + w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={1: [0]}) + b_replica.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={}) + + # check sharding spec + assert str(x_shard.sharding_spec.sharding_sequence) == "[R, R, S0]" + assert str(w_shard.sharding_spec.sharding_sequence) == "[R, S0]" + assert str(b_replica.sharding_spec.sharding_sequence) == "[R]" + + w_shard.pg_axis0 = col_process_group + w_shard.pg_axis1 = row_process_group + + out_shard = F.linear(x_shard, w_shard, b_replica) + + # each row only has a mini-batch + expected_out_shard = out + assert torch.allclose(out_shard, expected_out_shard) + + ######################## + # RS0S1 x S1R -> RS0R # + ######################## + # w will be transposed in F.linear + x_shard = torch.chunk(x.clone(), chunks=2, dim=1)[row_id] + x_shard = torch.chunk(x_shard, chunks=2, dim=2)[column_id] + w_shard = torch.chunk(w.clone(), chunks=2, dim=1)[column_id] + b_replica = b.clone() + + # adding sharding spec + x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={1: [0], 2: [1]}) + w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={1: [1]}) + b_replica.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={}) + + # check sharding spec + assert str(x_shard.sharding_spec.sharding_sequence) == "[R, S0, S1]" + assert str(w_shard.sharding_spec.sharding_sequence) == "[R, S1]" + assert str(b_replica.sharding_spec.sharding_sequence) == "[R]" + + w_shard.pg_axis0 = col_process_group + w_shard.pg_axis1 = row_process_group + + out_shard = F.linear(x_shard, w_shard, b_replica) + + # each row only has a mini-batch + expected_out_shard = torch.chunk(out, chunks=2, dim=1)[row_id] + assert torch.allclose(out_shard, expected_out_shard) + + ######################## + # RRS0 x S0S1 -> RRS1 # + ######################## + # w will be transposed in F.linear + x_shard = torch.chunk(x.clone(), chunks=2, dim=2)[row_id] + w_shard = torch.chunk(w.clone(), chunks=2, dim=1)[row_id] + w_shard = torch.chunk(w_shard, chunks=2, dim=0)[column_id] + b_shard = torch.chunk(b.clone(), chunks=2, dim=0)[column_id] + + # adding sharding spec + x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={2: [0]}) + w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={0: [1], 1: [0]}) + b_shard.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={0: [1]}) + + # check sharding spec + assert str(x_shard.sharding_spec.sharding_sequence) == "[R, R, S0]" + assert str(w_shard.sharding_spec.sharding_sequence) == "[S1, S0]" + assert str(b_shard.sharding_spec.sharding_sequence) == "[S1]" + + w_shard.pg_axis0 = col_process_group + w_shard.pg_axis1 = row_process_group + + out_shard = F.linear(x_shard, w_shard, b_shard) + + # each row only has a mini-batch + expected_out_shard = torch.chunk(out, chunks=2, dim=2)[column_id] + assert torch.allclose(out_shard, expected_out_shard) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [4]) +@rerun_if_address_is_in_use() +def test_sharded_mlp(world_size): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_sharded_mlp(4) diff --git a/tests/test_tensor/test_sharding_spec.py b/tests/test_tensor/test_sharding_spec.py new file mode 100644 index 0000000000000000000000000000000000000000..909c84ef0f0ebeadcaab09c8fa23dd0a90199843 --- /dev/null +++ b/tests/test_tensor/test_sharding_spec.py @@ -0,0 +1,25 @@ +import torch + +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec + + +def test_sharding_spec(): + physical_mesh_id = torch.arange(0, 16).reshape(2, 8) + mesh_shape = (4, 4) + # [[0, 1, 2, 3], + # [4, 5, 6, 7], + # [8, 9, 10,11], + # [12,13,14,15]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + entire_shape = torch.Size((16, 8, 6)) + dim_partition_dict = {0: [0, 1]} + # DistSpec: + # shard_sequence: S01,R,R + # device_mesh_shape: (4, 4) + sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict) + assert str(sharding_spec.sharding_sequence) == "[S01, R, R]" + + +if __name__ == '__main__': + test_sharding_spec() diff --git a/tests/test_tensor/test_tp_with_zero.py b/tests/test_tensor/test_tp_with_zero.py new file mode 100644 index 0000000000000000000000000000000000000000..c636d9442902ebd2a85931fb3eaa8b3cc71231f5 --- /dev/null +++ b/tests/test_tensor/test_tp_with_zero.py @@ -0,0 +1,143 @@ +import pytest +import torch +from torch.nn.parallel import DistributedDataParallel as DDP + +import colossalai +from colossalai.amp import convert_to_apex_amp +from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils.cuda import get_current_device +from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, GeminiDDP, ZeroDDP +from colossalai.zero.gemini import search_chunk_configuration +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_tensor.common_utils import set_seed, tensor_shard_equal +from tests.test_tensor.model.test_gpt2 import init_megatron_spec + + +def check_param(model: ZeroDDP, torch_model: torch.nn.Module, pg: ProcessGroup): + zero_dict = model.state_dict(only_rank_0=False) + torch_dict = torch_model.state_dict() + + for key, value in torch_dict.items(): + # key is 'module.model.PARAMETER', so we truncate it + key = key[7:] + assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) + temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) + # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value))) + assert tensor_shard_equal(value, temp_zero_value, pg.tp_local_rank(), pg.tp_world_size()), \ + "parameter '{}' has problem.".format(key) + + +def run_fwd_bwd(model, criterion, optimizer, input_ids): + optimizer.zero_grad() + logits = model(input_ids) + logits = logits.float() + loss = criterion(logits, input_ids) + optimizer.backward(loss) + return logits + + +def init_1d_row_spec(model, pg: ProcessGroup): + spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + for n, p in model.named_parameters(): + p.set_process_group(pg) + if 'weight' in n and 'ln' not in n: + p.set_tensor_spec(*spec) + + +def init_1d_col_spec(model, pg: ProcessGroup): + spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + for n, p in model.named_parameters(): + p.set_process_group(pg) + if 'ln' not in n and ('weight' in n or 'bias' in n): + p.set_tensor_spec(*spec) + + +@parameterize('placement_policy', ['cuda', 'cpu']) +def run_gpt(placement_policy, tp_init_spec_func=None): + set_seed(42) + get_components_func = non_distributed_component_funcs.get_callable('gpt2') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + with ColoInitContext(device=get_current_device()): + model = model_builder() + model = model.cuda() + torch_model = model_builder().cuda() + + for torch_p, p in zip(torch_model.parameters(), model.parameters()): + torch_p.data.copy_(p.data) + + world_size = torch.distributed.get_world_size() + + # world size, dp = 2, tp =2, construct a hybrid parallelism. + if world_size == 4: + pg = ProcessGroup(tp_degree=2) + else: + pg = ProcessGroup(tp_degree=world_size) + + if tp_init_spec_func: + tp_init_spec_func(model, pg) + + dp_world_size = pg.dp_world_size() + config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict[dp_world_size]['chunk_size'] = 5000 + config_dict[dp_world_size]['keep_gathered'] = False + if placement_policy != 'cuda': + init_device = torch.device('cpu') + else: + init_device = None + + model = GeminiDDP(model, init_device, placement_policy, True, False) + # The same as the following 3 lines + # chunk_manager = ChunkManager(config_dict, init_device=init_device) + # gemini_manager = GeminiManager(placement_policy, chunk_manager) + # model = ZeroDDP(model, gemini_manager, pin_memory=True) + + zero_optim = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=1) + # The same as the following 2 lines + # optimizer = HybridAdam(model.parameters(), lr=1e-3) + # zero_optim = ZeroOptimizer(optimizer, model, initial_scale=1) + + amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1) + torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) + torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) + torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group()) + + check_param(model, torch_model, pg) + + model.eval() + torch_model.eval() + + set_seed(pg.dp_local_rank()) + for i, (input_ids, label) in enumerate(train_dataloader): + if i > 2: + break + input_ids_colo = ColoTensor.from_torch_tensor(input_ids, ColoTensorSpec(pg)) + zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids_colo) + torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids) + assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2) + + zero_optim.step() + torch_optim.step() + check_param(model, torch_model, pg) + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + if world_size == 4: + run_gpt(tp_init_spec_func=init_megatron_spec) + else: + run_gpt(tp_init_spec_func=init_1d_col_spec) + run_gpt(tp_init_spec_func=init_1d_row_spec) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_gpt(world_size): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_gpt(4) diff --git a/tests/test_trainer/test_pipeline/test_p2p.py b/tests/test_trainer/test_pipeline/test_p2p.py new file mode 100644 index 0000000000000000000000000000000000000000..cb7a193d2bfa3f0ebcb1c8f58ff097161503993f --- /dev/null +++ b/tests/test_trainer/test_pipeline/test_p2p.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import pytest +import torch +import torch.distributed as dist + +from colossalai.communication import ( + recv_backward, + recv_forward, + recv_obj_meta, + send_backward, + send_backward_recv_forward, + send_forward, + send_forward_recv_backward, + send_obj_meta, +) +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.logging import get_dist_logger +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device + +BATCH_SIZE = 4 +SEQ_LENGTH = 2 +HIDDEN_SIZE = 16 + +CONFIG = dict(parallel=dict(pipeline=dict(size=4), tensor=dict(size=1, mode=None)), seed=1024) + + +def check_equal(A, B): + return torch.allclose(A, B, rtol=1e-5, atol=1e-3) + + +def check_forward(output_tensor, rank, logger): + dist.barrier() + if gpc.is_first_rank(ParallelMode.PIPELINE): + tensor = output_tensor.clone() + else: + tensor = recv_forward(output_tensor.shape) + logger.info('Rank {} received forward. Correct tensor: {}'.format(rank, check_equal(tensor, output_tensor))) + if not gpc.is_last_rank(ParallelMode.PIPELINE): + send_forward(tensor) + logger.info('Rank {} sent forward.'.format(rank)) + + +def check_backward(output_grad, rank, logger): + dist.barrier() + if gpc.is_last_rank(ParallelMode.PIPELINE): + grad = output_grad.clone() + else: + grad = recv_backward(output_grad.shape) + logger.info('Rank {} received backward. Correct grad: {}'.format(rank, check_equal(grad, output_grad))) + if not gpc.is_first_rank(ParallelMode.PIPELINE): + send_backward(grad) + logger.info('Rank {} sent backward.'.format(rank)) + + +def check_forward_backward(output_tensor, output_grad, rank, logger): + dist.barrier() + if not gpc.is_first_rank(ParallelMode.PIPELINE): + tensor = send_backward_recv_forward(output_grad, output_tensor.shape) + logger.info('Rank {} sent backward received forward. Correct tensor: {}'.format( + rank, check_equal(tensor, output_tensor))) + if not gpc.is_last_rank(ParallelMode.PIPELINE): + grad = send_forward_recv_backward(output_tensor, output_grad.shape) + logger.info('Rank {} sent forward received backward. Correct grad: {}'.format( + rank, check_equal(grad, output_grad))) + + +def check_comm(size, rank, prev_rank, next_rank, logger): + dtype = torch.float32 + device = get_current_device() + tensor_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + grad_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + tensor = torch.randn(tensor_shape, dtype=dtype, device=device) + dist.all_reduce(tensor) + grad = torch.randn(grad_shape, dtype=dtype, device=device) + dist.all_reduce(grad) + check_forward(tensor, rank, logger) + check_backward(grad, rank, logger) + check_forward_backward(tensor, grad, rank, logger) + + +def run_check(rank, world_size, port): + launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + logger = get_dist_logger() + rank = gpc.get_global_rank() + prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) + next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) + logger.info('Rank {0}: prev rank {1}, next rank {2}'.format(rank, prev_rank, next_rank)) + logger.info('Distributed environment is initialzied.') + + check_comm(world_size, rank, prev_rank, next_rank, logger) + gpc.destroy() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_p2p(): + world_size = 4 + spawn(run_check, world_size) + + +if __name__ == '__main__': + test_p2p() diff --git a/tests/test_trainer/test_pipeline/test_pipeline_schedule.py b/tests/test_trainer/test_pipeline/test_pipeline_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..6d7bf6b3d89f54b6704096cb0394c56db51aee5f --- /dev/null +++ b/tests/test_trainer/test_pipeline/test_pipeline_schedule.py @@ -0,0 +1,87 @@ +# referenced from Megatron and used to testify communication + +import os +from pathlib import Path + +import pytest +import torch +import torch.nn as nn +from torchvision import transforms +from torchvision.datasets import CIFAR10 +from torchvision.models import resnet18 + +import colossalai +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import get_dataloader, print_rank_0 + +BATCH_SIZE = 8 + +CONFIG = dict(NUM_MICRO_BATCHES=2, parallel=dict(pipeline=dict(size=2), tensor=dict(size=1, mode=None))) + + +def run_schedule(rank, world_size, port): + launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + # build model + model = resnet18(num_classes=10) + + if gpc.get_local_rank(ParallelMode.PIPELINE) == 0: + model = nn.Sequential(model.conv1, model.bn1, model.relu, model.maxpool, model.layer1, model.layer2) + elif gpc.get_local_rank(ParallelMode.PIPELINE) == 1: + + class Flatten(nn.Module): + + def forward(self, x): + return torch.flatten(x, 1) + + model = nn.Sequential(model.layer3, model.layer4, model.avgpool, Flatten(), model.fc) + + print_rank_0('model is created') + + train_dataset = CIFAR10(root=Path(os.environ['DATA']), + download=True, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]), + ])) + + train_dataloader = get_dataloader( + dataset=train_dataset, + shuffle=True, + add_sampler=True, + batch_size=BATCH_SIZE, + pin_memory=True, + ) + + # build criterion + criterion = torch.nn.CrossEntropyLoss() + + # optimizer + optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0) + + # initialize + engine, train_dataloader, _, _ = colossalai.initialize(model, optimizer, criterion, train_dataloader) + + # build pipeline schedule + schedule = engine.schedule + + # run schedule + data_iter = iter(train_dataloader) + schedule.forward_backward_step(engine, data_iter) + + gpc.destroy() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_pipeline_schedule(): + world_size = 2 + spawn(run_schedule, world_size) + + +if __name__ == '__main__': + test_pipeline_schedule() diff --git a/tests/test_trainer/test_trainer_with_non_pipe_schedule.py b/tests/test_trainer/test_trainer_with_non_pipe_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..753f82222f9d81679f4a720d258abc1659bc8fc3 --- /dev/null +++ b/tests/test_trainer/test_trainer_with_non_pipe_schedule.py @@ -0,0 +1,59 @@ +import pytest +import torch + +import colossalai +from colossalai.amp.amp_type import AMP_TYPE +from colossalai.logging import get_dist_logger +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.trainer import Trainer +from colossalai.utils import MultiTimer +from tests.components_to_test.registry import non_distributed_component_funcs + +BATCH_SIZE = 4 +IMG_SIZE = 32 +NUM_EPOCHS = 200 + +CONFIG = dict(fp16=dict(mode=AMP_TYPE.TORCH)) + + +@parameterize('model_name', ['repeated_computed_layers', 'resnet18', 'nested_model']) +def run_trainer(model_name): + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + model = model_builder() + optimizer = optimizer_class(model.parameters(), lr=1e-3) + engine, train_dataloader, *_ = colossalai.initialize(model=model, + optimizer=optimizer, + criterion=criterion, + train_dataloader=train_dataloader) + + logger = get_dist_logger() + logger.info("engine is built", ranks=[0]) + + timer = MultiTimer() + trainer = Trainer(engine=engine, logger=logger, timer=timer) + logger.info("trainer is built", ranks=[0]) + + logger.info("start training", ranks=[0]) + trainer.fit(train_dataloader=train_dataloader, + test_dataloader=test_dataloader, + epochs=NUM_EPOCHS, + max_steps=3, + display_progress=True, + test_interval=5) + torch.cuda.empty_cache() + + +def run_dist(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_trainer_no_pipeline(): + world_size = 4 + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_trainer_no_pipeline() diff --git a/tests/test_trainer/test_trainer_with_pipe_schedule.py b/tests/test_trainer/test_trainer_with_pipe_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..bb63d51a0b656183c20526bf0389727219baf2b1 --- /dev/null +++ b/tests/test_trainer/test_trainer_with_pipe_schedule.py @@ -0,0 +1,96 @@ +import os +from pathlib import Path + +import pytest +import torch +import torch.nn as nn +from torch.optim import Adam +from torchvision import transforms +from torchvision.datasets import CIFAR10 +from torchvision.models import resnet18 + +import colossalai +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.trainer import Trainer +from colossalai.utils import MultiTimer, get_dataloader + +BATCH_SIZE = 4 +IMG_SIZE = 32 +NUM_EPOCHS = 200 + +CONFIG = dict( + NUM_MICRO_BATCHES=2, + parallel=dict(pipeline=2), +) + + +def run_trainer_with_pipeline(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + # build model + model = resnet18(num_classes=10) + + if gpc.get_local_rank(ParallelMode.PIPELINE) == 0: + model = nn.Sequential(model.conv1, model.bn1, model.relu, model.maxpool, model.layer1, model.layer2) + elif gpc.get_local_rank(ParallelMode.PIPELINE) == 1: + + class Flatten(nn.Module): + + def forward(self, x): + return torch.flatten(x, 1) + + model = nn.Sequential(model.layer3, model.layer4, model.avgpool, Flatten(), model.fc) + + # build dataloaders + train_dataset = CIFAR10(root=Path(os.environ['DATA']), + download=True, + transform=transforms.Compose([ + transforms.Resize(size=(IMG_SIZE, IMG_SIZE)), + transforms.ToTensor(), + transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) + ])) + + train_dataloader = get_dataloader(dataset=train_dataset, + shuffle=True, + batch_size=BATCH_SIZE, + pin_memory=True, + drop_last=True) + + # build optimizer + optimizer = Adam(model.parameters(), lr=0.001) + criterion = nn.CrossEntropyLoss() + + engine, train_dataloader, *args = colossalai.initialize(model=model, + optimizer=optimizer, + criterion=criterion, + train_dataloader=train_dataloader) + + logger = get_dist_logger() + logger.info("engine is built", ranks=[0]) + timer = MultiTimer() + trainer = Trainer(engine=engine, logger=logger, timer=timer) + logger.info("trainer is built", ranks=[0]) + + logger.info("start training", ranks=[0]) + + trainer.fit(train_dataloader=train_dataloader, + epochs=NUM_EPOCHS, + max_steps=3, + display_progress=True, + test_interval=5) + gpc.destroy() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_trainer_with_pipeline(): + world_size = 4 + spawn(run_trainer_with_pipeline, world_size) + + +if __name__ == '__main__': + test_trainer_with_pipeline() diff --git a/tests/test_utils/test_activation_checkpointing.py b/tests/test_utils/test_activation_checkpointing.py new file mode 100644 index 0000000000000000000000000000000000000000..59a8acd4b21022fae583e6737e75f47fb27c651e --- /dev/null +++ b/tests/test_utils/test_activation_checkpointing.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import pytest +import torch +import torch.nn.functional as F + +from colossalai.context.parallel_mode import ParallelMode +from colossalai.context.random import add_seed, reset_seeds, seed, set_mode +from colossalai.testing import clear_cache_before_run, parameterize +from colossalai.utils.activation_checkpoint import checkpoint + + +def forward(x, weight): + out = torch.matmul(x, weight) + with seed(ParallelMode.DATA): + out_ = F.dropout(out, p=0.4, training=True) + return out_ + + +def forward_inplace_ckpt(x, weight, cpu_offload=False): + out = torch.matmul(x, weight) + bn = torch.nn.BatchNorm1d(4, affine=False) + bn = bn.to(device="cuda") + out = bn(out) + + def ckpt0(x): + return F.relu(x, inplace=True) + + out = checkpoint(ckpt0, cpu_offload, out, use_reentrant=False) + return out + + +def forward_inplace(x, weight): + out = torch.matmul(x, weight) + bn = torch.nn.BatchNorm1d(4, affine=False) + bn = bn.to(device="cuda") + out = bn(out) + out = F.relu(out, inplace=True) + return out + + +@pytest.mark.gpu +@clear_cache_before_run() +@parameterize("use_reentrant", [True, False]) +@parameterize("cpu_offload", [True, False]) +def test_activation_checkpointing(cpu_offload, use_reentrant): + + # as seed manager is singleton + # if we don't reset seeds here, + # other tests might affect this test + reset_seeds() + + # We put initilization here to avoid change cuda rng state below + inputs = torch.rand(2, 2, requires_grad=True, device='cuda') + weight = torch.rand(2, 4, requires_grad=True, device='cuda') + + # Get a copy of input tensors + inputs_ = torch.empty(2, 2, requires_grad=True, device='cuda') + inputs_.data.copy_(inputs.data) + weight_ = torch.empty(2, 4, requires_grad=True, device='cuda') + weight_.data.copy_(weight.data) + + add_seed(ParallelMode.GLOBAL, 1024) + add_seed(ParallelMode.DATA, 1026) + set_mode(ParallelMode.GLOBAL) + global_cuda_rng_state = torch.cuda.get_rng_state() + set_mode(ParallelMode.DATA) + data_parallel_cuda_rng_state = torch.cuda.get_rng_state() + set_mode(ParallelMode.GLOBAL) + + out = forward(inputs, weight) + loss = out.sum() + loss.backward() + + # Recover cuda rng states + set_mode(ParallelMode.GLOBAL) + torch.cuda.set_rng_state(global_cuda_rng_state) + set_mode(ParallelMode.DATA) + torch.cuda.set_rng_state(data_parallel_cuda_rng_state) + set_mode(ParallelMode.GLOBAL) + + out = checkpoint(forward, cpu_offload, inputs_, weight_, use_reentrant=use_reentrant) + loss = out.sum() + loss.backward() + + assert torch.all(inputs.grad == inputs_.grad), 'Gradient of the input does not match' + torch.cuda.empty_cache() + + # Extra test for use_reentrant=False + if use_reentrant == False: + # Recover cuda rng states + set_mode(ParallelMode.GLOBAL) + torch.cuda.set_rng_state(global_cuda_rng_state) + set_mode(ParallelMode.DATA) + torch.cuda.set_rng_state(data_parallel_cuda_rng_state) + set_mode(ParallelMode.GLOBAL) + + out = forward_inplace(inputs, weight) + loss = out.sum() + loss.backward() + + # Recover cuda rng states + set_mode(ParallelMode.GLOBAL) + torch.cuda.set_rng_state(global_cuda_rng_state) + set_mode(ParallelMode.DATA) + torch.cuda.set_rng_state(data_parallel_cuda_rng_state) + set_mode(ParallelMode.GLOBAL) + + out = forward_inplace_ckpt(inputs_, weight_, cpu_offload=cpu_offload) + loss = out.sum() + loss.backward() + + assert torch.all(inputs.grad == inputs_.grad), 'Gradient of the input does not match' + torch.cuda.empty_cache() + + # as seed manager is singleton + # if we don't reset seeds here, + # other tests will fail if running together with this test + # as other tests can't overwrite the seed set by this test + reset_seeds() + + +if __name__ == "__main__": + test_activation_checkpointing(False, False) diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_1d.py b/tests/test_utils/test_checkpoint/test_checkpoint_1d.py new file mode 100644 index 0000000000000000000000000000000000000000..335be61359ed2f802d6f1544e232dd9fd3c36dd7 --- /dev/null +++ b/tests/test_utils/test_checkpoint/test_checkpoint_1d.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import pprint + +import pytest +import torch +import torch.nn as nn + +import colossalai.nn as col_nn +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn +from colossalai.utils import is_using_pp +from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint + + +def build_pipeline(model): + from colossalai.pipeline.utils import partition_uniform + + pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + depth = len(model) + start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0] + layers = [] + for i in range(depth): + if start <= i < end: + layers.append(model[i]) + else: + layers.append(nn.Identity()) + return nn.Sequential(*tuple(layers)) + + +def check_equal(A, B): + assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) + + +def check_checkpoint_1d(rank, world_size, port): + config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="1d")),) + + disable_existing_loggers() + launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + + m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4)) + sd1 = m1.state_dict() + if gpc.get_global_rank() == 0: + print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n") + save_checkpoint("test.pt", 0, m1) + + m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4)) + if is_using_pp(): + m2 = build_pipeline(m2) + + load_checkpoint("test.pt", m2) + sd2 = m2.state_dict() + if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0: + sd2 = gather_pipeline_parallel_state_dict(sd2) + print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n") + + if gpc.get_global_rank() == 0: + for k, v in sd1.items(): + assert k in sd2 + check_equal(v, sd2[k].to(torch.device("cpu"))) + + +@pytest.mark.dist +@pytest.mark.skip("takes too long") +@skip_if_not_enough_gpus(min_gpus=8) +@rerun_if_address_is_in_use() +def test_checkpoint_1d(): + spawn(check_checkpoint_1d, 8) + + +if __name__ == "__main__": + test_checkpoint_1d() diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_2d.py b/tests/test_utils/test_checkpoint/test_checkpoint_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..175d9ef6ceb9170e6bf0687e0bb8e88ffdfcb5f8 --- /dev/null +++ b/tests/test_utils/test_checkpoint/test_checkpoint_2d.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import pprint + +import pytest +import torch +import torch.nn as nn + +import colossalai.nn as col_nn +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn +from colossalai.utils import is_using_pp +from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint + + +def build_pipeline(model): + from colossalai.pipeline.utils import partition_uniform + + pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + depth = len(model) + start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0] + layers = [] + for i in range(depth): + if start <= i < end: + layers.append(model[i]) + else: + layers.append(nn.Identity()) + return nn.Sequential(*tuple(layers)) + + +def check_equal(A, B): + assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) + + +def check_checkpoint_2d(rank, world_size, port): + config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="2d")),) + + disable_existing_loggers() + launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + + m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4)) + sd1 = m1.state_dict() + if gpc.get_global_rank() == 0: + print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n") + save_checkpoint("test.pt", 0, m1) + + m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4)) + if is_using_pp(): + m2 = build_pipeline(m2) + + load_checkpoint("test.pt", m2) + sd2 = m2.state_dict() + if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0: + sd2 = gather_pipeline_parallel_state_dict(sd2) + print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n") + + if gpc.get_global_rank() == 0: + for k, v in sd1.items(): + assert k in sd2 + check_equal(v, sd2[k].to(torch.device("cpu"))) + + +@pytest.mark.dist +@pytest.mark.skip("takes too long") +@skip_if_not_enough_gpus(min_gpus=8) +@rerun_if_address_is_in_use() +def test_checkpoint_2d(): + spawn(check_checkpoint_2d, 8) + + +if __name__ == "__main__": + test_checkpoint_2d() diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py b/tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py new file mode 100644 index 0000000000000000000000000000000000000000..33cb3a65d184f5ae14627010bc91de9d967c12df --- /dev/null +++ b/tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import pprint + +import pytest +import torch +import torch.nn as nn + +import colossalai.nn as col_nn +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn +from colossalai.utils import is_using_pp +from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint + + +def build_pipeline(model): + from colossalai.pipeline.utils import partition_uniform + + pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + depth = len(model) + start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0] + layers = [] + for i in range(depth): + if start <= i < end: + layers.append(model[i]) + else: + layers.append(nn.Identity()) + return nn.Sequential(*tuple(layers)) + + +def check_equal(A, B): + assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) + + +def check_checkpoint_2p5d(rank, world_size, port): + config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, depth=1, mode="2.5d")),) + + disable_existing_loggers() + launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + + m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4)) + sd1 = m1.state_dict() + if gpc.get_global_rank() == 0: + print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n") + save_checkpoint("test.pt", 0, m1) + + m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4)) + if is_using_pp(): + m2 = build_pipeline(m2) + + load_checkpoint("test.pt", m2) + sd2 = m2.state_dict() + if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0: + sd2 = gather_pipeline_parallel_state_dict(sd2) + print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n") + + if gpc.get_global_rank() == 0: + for k, v in sd1.items(): + assert k in sd2 + check_equal(v, sd2[k].to(torch.device("cpu"))) + + +@pytest.mark.dist +@pytest.mark.skip("takes too long") +@skip_if_not_enough_gpus(min_gpus=8) +@rerun_if_address_is_in_use() +def test_checkpoint_2p5d(): + spawn(check_checkpoint_2p5d, 8) + + +if __name__ == "__main__": + test_checkpoint_2p5d() diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_3d.py b/tests/test_utils/test_checkpoint/test_checkpoint_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..73ac2dd5fe1837ae9ee4f1a7b95e43fc6c51b482 --- /dev/null +++ b/tests/test_utils/test_checkpoint/test_checkpoint_3d.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import pprint + +import pytest +import torch +import torch.nn as nn + +import colossalai.nn as col_nn +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn +from colossalai.utils import is_using_pp +from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint + + +def build_pipeline(model): + from colossalai.pipeline.utils import partition_uniform + + pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + depth = len(model) + start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0] + layers = [] + for i in range(depth): + if start <= i < end: + layers.append(model[i]) + else: + layers.append(nn.Identity()) + return nn.Sequential(*tuple(layers)) + + +def check_equal(A, B): + assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) + + +def check_checkpoint_3d(rank, world_size, port): + config = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=8, mode="3d")),) + + disable_existing_loggers() + launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + + m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4)) + sd1 = m1.state_dict() + if gpc.get_global_rank() == 0: + print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n") + save_checkpoint("test.pt", 0, m1) + + m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4)) + if is_using_pp(): + m2 = build_pipeline(m2) + + load_checkpoint("test.pt", m2) + sd2 = m2.state_dict() + if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0: + sd2 = gather_pipeline_parallel_state_dict(sd2) + print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n") + + if gpc.get_global_rank() == 0: + for k, v in sd1.items(): + assert k in sd2 + check_equal(v, sd2[k].to(torch.device("cpu"))) + + +@pytest.mark.dist +@pytest.mark.skip("takes too long") +@skip_if_not_enough_gpus(min_gpus=8) +@rerun_if_address_is_in_use() +def test_checkpoint_3d(): + spawn(check_checkpoint_3d, 8) + + +if __name__ == "__main__": + test_checkpoint_3d() diff --git a/tests/test_utils/test_checkpoint_io/test_build_checkpoints.py b/tests/test_utils/test_checkpoint_io/test_build_checkpoints.py new file mode 100644 index 0000000000000000000000000000000000000000..6d89fb90c574e9b06de571760b59441e75bf2b33 --- /dev/null +++ b/tests/test_utils/test_checkpoint_io/test_build_checkpoints.py @@ -0,0 +1,120 @@ +import torch +import torch.nn as nn +from colossalai.utils.checkpoint_io.meta import ParamDistMeta +from colossalai.utils.checkpoint_io.utils import build_checkpoints +from torch.optim import Adam + + +class DummyModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.fc = nn.Linear(20, 1) + + +def test_global_model(): + model = DummyModel() + model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model) + assert len(model_checkpoints) == 1 + assert len(optimizer_checkpoints) == 0 + assert meta['dist_meta'] is None + orig_state_dict = model.state_dict() + global_state_dict = model_checkpoints[0] + assert set(orig_state_dict.keys()) == set(global_state_dict.keys()) + for k, v in orig_state_dict.items(): + assert torch.equal(v, global_state_dict[k]) + + +def test_global_model_shard(): + model = DummyModel() + model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(80, model) + assert len(model_checkpoints) == 2 + assert len(optimizer_checkpoints) == 0 + assert meta['dist_meta'] is None + orig_state_dict = model.state_dict() + assert set(orig_state_dict.keys()) == set(model_checkpoints[0].keys()) | set(model_checkpoints[1].keys()) + assert len(set(model_checkpoints[0].keys()) & set(model_checkpoints[1].keys())) == 0 + for k, v in orig_state_dict.items(): + for state_dict in model_checkpoints: + if k in state_dict: + assert torch.equal(v, state_dict[k]) + + +def test_global_optimizer(): + model = DummyModel() + for p in model.parameters(): + p.grad = torch.rand_like(p) + optimizer = Adam(model.parameters(), lr=1e-3) + optimizer.step() + model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer) + assert len(optimizer_checkpoints) == 1 + assert meta['param_to_os'] == {'fc.weight': 0, 'fc.bias': 1} + for state in meta['paired_os'].values(): + for k, is_paired in state.items(): + if k == 'step': + assert not is_paired + else: + assert is_paired + orig_state_dict = optimizer.state_dict() + state_dict = optimizer_checkpoints[0] + for k, orig_state in orig_state_dict['state'].items(): + state = state_dict['state'][k] + for v1, v2 in zip(orig_state.values(), state.values()): + if isinstance(v2, torch.Tensor): + assert torch.equal(v1, v2) + else: + assert v2 == v2 + assert orig_state_dict['param_groups'] == state_dict['param_groups'] + + +def test_global_optimizer_shard(): + model = DummyModel() + for p in model.parameters(): + p.grad = torch.rand_like(p) + optimizer = Adam(model.parameters(), lr=1e-3) + optimizer.step() + model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(80, model, optimizer) + assert len(optimizer_checkpoints) == 2 + assert 'param_groups' in optimizer_checkpoints[0] and 'param_groups' not in optimizer_checkpoints[1] + orig_state_dict = optimizer.state_dict() + assert set(orig_state_dict['state'].keys()) == set(optimizer_checkpoints[0]['state'].keys()) | set( + optimizer_checkpoints[1]['state'].keys()) + assert len(set(optimizer_checkpoints[0]['state'].keys()) & set(optimizer_checkpoints[1]['state'].keys())) == 0 + for k, orig_state in orig_state_dict['state'].items(): + state = optimizer_checkpoints[0]['state'][k] if k in optimizer_checkpoints[0][ + 'state'] else optimizer_checkpoints[1]['state'][k] + for v1, v2 in zip(orig_state.values(), state.values()): + if isinstance(v2, torch.Tensor): + assert torch.equal(v1, v2) + else: + assert v1 == v2 + + assert orig_state_dict['param_groups'] == optimizer_checkpoints[0]['param_groups'] + + +def test_dist_model_optimizer(): + model = DummyModel() + for p in model.parameters(): + p.grad = torch.rand_like(p) + optimizer = Adam(model.parameters(), lr=1e-3) + optimizer.step() + dist_meta = {'fc.weight': ParamDistMeta(0, 2, 0, 1), 'fc.bias': ParamDistMeta(1, 2, 0, 1)} + model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer, dist_meta=dist_meta) + assert dist_meta == meta['dist_meta'] + assert len(model_checkpoints) == 1 + assert len(optimizer_checkpoints) == 1 + assert 'fc.weight' in model_checkpoints[0] and 'fc.bias' in model_checkpoints[0] + assert 0 in optimizer_checkpoints[0]['state'] and 1 in optimizer_checkpoints[0]['state'] + dist_meta = {'fc.weight': ParamDistMeta(1, 2, 0, 1), 'fc.bias': ParamDistMeta(1, 2, 0, 1)} + model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer, dist_meta=dist_meta) + assert dist_meta == meta['dist_meta'] + assert len(model_checkpoints) == 1 + assert len(optimizer_checkpoints) == 1 + + +if __name__ == '__main__': + test_global_model() + test_global_model_shard() + test_global_optimizer() + test_global_optimizer_shard() + test_dist_model_optimizer() diff --git a/tests/test_utils/test_checkpoint_io/test_load.py b/tests/test_utils/test_checkpoint_io/test_load.py new file mode 100644 index 0000000000000000000000000000000000000000..b1a741515728c2e6f697bdbd8f7e017924388b90 --- /dev/null +++ b/tests/test_utils/test_checkpoint_io/test_load.py @@ -0,0 +1,186 @@ +from copy import deepcopy +from functools import partial +from tempfile import TemporaryDirectory +from typing import Dict + +import pytest +import torch +import torch.distributed as dist +import torch.nn as nn +from torch import Tensor +from torch.nn import Module +from torch.optim import Adam, Optimizer + +import colossalai +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils.checkpoint_io.io import load, save +from colossalai.utils.checkpoint_io.meta import ParamDistMeta, ParamRedistMeta, RankRedistMeta, RedistMeta + + +def check_model_state_dict(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> None: + assert set(a.keys()) == set(b.keys()) + for k, v in a.items(): + assert torch.equal(v, b[k]) + + +def check_optim_state_dict(a: dict, b: dict, ignore_param_gruops: bool = False) -> None: + assert set(a['state'].keys()) == set(b['state'].keys()) + for k, state in a['state'].items(): + b_state = b['state'][k] + for v1, v2 in zip(state.values(), b_state.values()): + if isinstance(v1, Tensor): + assert torch.equal(v1, v2) + else: + assert v1 == v2 + if not ignore_param_gruops: + assert a['param_groups'] == b['param_groups'] + + +class DummyModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.fc = nn.Linear(20, 1) + + +def prepare_model_optim(shard: bool = False, zero: bool = False): + model = DummyModel() + if shard: + model.fc.weight.data = model.fc.weight.chunk(2, 1)[dist.get_rank() % 2] + if zero: + dp_rank = dist.get_rank() // 2 + model.fc.weight.data = model.fc.weight.reshape(-1).split([3, model.fc.weight.size(1) - 3], 0)[dp_rank] + if dp_rank != 0: + model.fc.bias.data = torch.empty(0, dtype=model.fc.bias.dtype) + for p in model.parameters(): + p.grad = torch.rand_like(p) + optimizer = Adam(model.parameters(), lr=1e-3) + optimizer.step() + return model, optimizer + + +def reset_model_optim(model: Module, optimizer: Optimizer, scalar: float = 0.0): + with torch.no_grad(): + for p in model.parameters(): + p.fill_(scalar) + for state in optimizer.state.values(): + for v in state.values(): + if isinstance(v, Tensor): + v.fill_(scalar) + + +def get_dist_metas(nprocs: int, zero: bool = False): + dp_world_size = nprocs // 2 + dist_metas = [] + for rank in range(nprocs): + if zero: + dist_metas.append({ + 'fc.weight': + ParamDistMeta(rank // 2, + dp_world_size, + rank % 2, + 2, + tp_shard_dims=[1], + tp_num_parts=[2], + zero_numel=10, + zero_orig_shape=[1, 10]), + 'fc.bias': + ParamDistMeta(rank // 2, dp_world_size, 0, 1, zero_numel=1, zero_orig_shape=[1]) + }) + else: + dist_metas.append({ + 'fc.weight': ParamDistMeta(rank // 2, dp_world_size, rank % 2, 2, tp_shard_dims=[1], tp_num_parts=[2]), + 'fc.bias': ParamDistMeta(rank // 2, dp_world_size, 0, 1) + }) + return dist_metas + + +def get_redist_meta(nprocs: int): + dp_world_size = nprocs // 2 + rank_meta = { + 'fc.weight': {rank: RankRedistMeta(rank // 2, rank % 2, 0) for rank in range(nprocs)}, + 'fc.bias': {rank: RankRedistMeta(rank // 2, 0, 0) for rank in range(nprocs)} + } + param_meta = { + 'fc.weight': ParamRedistMeta(dp_world_size, 2, tp_shard_dims=[1], tp_num_parts=[2]), + 'fc.bias': ParamRedistMeta(dp_world_size, 1) + } + return RedistMeta(rank_meta, [], param_meta) + + +@pytest.mark.parametrize('max_shard_size_gb', [80 / 1024**3, 0]) +def test_save_global_load_global(max_shard_size_gb: float): + model, optimizer = prepare_model_optim() + with TemporaryDirectory() as dir_name: + save(dir_name, model, optimizer, max_shard_size_gb=max_shard_size_gb) + new_model, new_optimizer = prepare_model_optim() + load(dir_name, new_model, new_optimizer, max_shard_size_gb=max_shard_size_gb) + check_model_state_dict(model.state_dict(), new_model.state_dict()) + check_optim_state_dict(optimizer.state_dict(), new_optimizer.state_dict()) + + +def run_dist(rank, world_size, port, test_fn): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + test_fn() + + +def launch_dist(fn, world_size: int): + spawn(run_dist, world_size, test_fn=fn) + + +def save_dist(dir_name: str, zero: bool): + model, optmizer = prepare_model_optim(shard=True, zero=zero) + reset_model_optim(model, optmizer) + world_size = dist.get_world_size() + rank = dist.get_rank() + save(dir_name, model, optmizer, dist_meta=get_dist_metas(world_size, zero)[rank]) + + +def load_and_check_dist(dir_name: str): + world_size = dist.get_world_size() + model, optmizer = prepare_model_optim(shard=True) + reset_model_optim(model, optmizer) + model_state_dict = deepcopy(model.state_dict()) + optimizer_state_dict = deepcopy(optmizer.state_dict()) + reset_model_optim(model, optmizer, 1) + load(dir_name, model, optmizer, get_redist_meta(world_size), get_dist_metas(world_size)) + check_model_state_dict(model_state_dict, model.state_dict()) + check_optim_state_dict(optimizer_state_dict, optmizer.state_dict()) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_save_global_load_dist(): + model, optimizer = prepare_model_optim() + reset_model_optim(model, optimizer) + with TemporaryDirectory() as dir_name: + save(dir_name, model, optimizer) + fn = partial(load_and_check_dist, dir_name) + launch_dist(fn, 4) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_save_dist_load_dist(): + with TemporaryDirectory() as dir_name: + # save tp + dp + fn = partial(save_dist, dir_name, False) + launch_dist(fn, 2) + # load tp + dp + fn = partial(load_and_check_dist, dir_name) + launch_dist(fn, 2) + with TemporaryDirectory() as dir_name: + # save tp + zero + fn = partial(save_dist, dir_name, True) + launch_dist(fn, 4) + # load tp + dp + fn = partial(load_and_check_dist, dir_name) + launch_dist(fn, 2) + launch_dist(fn, 4) + + +if __name__ == '__main__': + test_save_global_load_global(80 / 1024**3) + test_save_global_load_global(0) + test_save_global_load_dist() + test_save_dist_load_dist() diff --git a/tests/test_utils/test_checkpoint_io/test_merge.py b/tests/test_utils/test_checkpoint_io/test_merge.py new file mode 100644 index 0000000000000000000000000000000000000000..255c74adf0a2c9c1111a60c04930a100f187c227 --- /dev/null +++ b/tests/test_utils/test_checkpoint_io/test_merge.py @@ -0,0 +1,126 @@ +import os +from functools import partial +from tempfile import TemporaryDirectory + +import pytest +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.optim import Adam + +import colossalai +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME +from colossalai.utils.checkpoint_io.io import merge, save +from colossalai.utils.checkpoint_io.meta import ParamDistMeta + + +class DummyModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.fc = nn.Linear(20, 1) + + +def prepare_model_optim(shard: bool = False, zero: bool = False): + model = DummyModel() + if shard: + model.fc.weight.data = model.fc.weight.chunk(2, 1)[dist.get_rank() % 2] + if zero: + dp_rank = dist.get_rank() // 2 + model.fc.weight.data = model.fc.weight.reshape(-1).split([3, model.fc.weight.size(1) - 3], 0)[dp_rank] + if dp_rank != 0: + model.fc.bias.data = torch.empty(0, dtype=model.fc.bias.dtype) + for p in model.parameters(): + p.grad = torch.ones_like(p) + optimizer = Adam(model.parameters(), lr=1e-3) + optimizer.step() + return model, optimizer + + +def test_merge_global(): + model, optimizer = prepare_model_optim() + with TemporaryDirectory() as dir_name: + save(dir_name, model, optimizer) + with TemporaryDirectory() as output_dir: + merge(dir_name, output_dir) + assert len(os.listdir(output_dir)) == 0 + with TemporaryDirectory() as dir_name: + save(dir_name, model, optimizer, max_shard_size_gb=80 / 1024**3) + with TemporaryDirectory() as output_dir: + merge(dir_name, output_dir) + assert len(os.listdir(output_dir)) == 0 + + +def run_dist(rank, world_size, port, test_fn): + colossalai.launch(config={'parallel': { + 'tensor': { + 'mode': '1d', + 'size': 2 + } + }}, + rank=rank, + world_size=world_size, + host='localhost', + port=port, + backend='nccl') + test_fn() + + +def run_save_dist(dir_name: str, zero: bool): + model, optmizer = prepare_model_optim(shard=True, zero=zero) + rank = dist.get_rank() + dp_world_size = dist.get_world_size() // 2 + if not zero: + dist_metas = { + 'fc.weight': ParamDistMeta(rank // 2, dp_world_size, rank % 2, 2, tp_shard_dims=[1], tp_num_parts=[2]), + 'fc.bias': ParamDistMeta(rank // 2, dp_world_size, 0, 1) + } + else: + dist_metas = { + 'fc.weight': + ParamDistMeta(rank // 2, + dp_world_size, + rank % 2, + 2, + tp_shard_dims=[1], + tp_num_parts=[2], + zero_numel=10, + zero_orig_shape=[1, 10]), + 'fc.bias': + ParamDistMeta(rank // 2, dp_world_size, 0, 1, zero_numel=1, zero_orig_shape=[1]) + } + save(dir_name, model, optmizer, dist_meta=dist_metas) + + +@pytest.mark.dist +@pytest.mark.parametrize("zero", [False, True]) +@rerun_if_address_is_in_use() +def test_merge_tp_dp(zero: bool): + with TemporaryDirectory() as dir_name: + fn = partial(run_save_dist, dir_name, zero) + world_size = 4 + spawn(run_dist, world_size, test_fn=fn) + with TemporaryDirectory() as output_dir: + merge(dir_name, output_dir) + assert len(os.listdir(output_dir)) == 5 + global_meta = torch.load(os.path.join(output_dir, GLOBAL_META_FILE_NAME)) + assert len(global_meta['meta']) == 1 + meta = torch.load(os.path.join(output_dir, global_meta['meta'][0])) + assert meta['dist_meta'] is None + assert len(meta['params']) == 2 + assert len(meta['model']) == 1 and len(meta['optimizer']) == 1 + model_state_dict = torch.load(os.path.join(output_dir, meta['model'][0])) + assert len(model_state_dict) == 2 + assert model_state_dict['fc.weight'].size(1) == 20 + optimizer_state_dict = torch.load(os.path.join(output_dir, meta['optimizer'][0])) + assert len(optimizer_state_dict['state']) == 2 + assert 'param_groups' in optimizer_state_dict and 'state' in optimizer_state_dict + assert optimizer_state_dict['state'][0]['exp_avg'].size(1) == 20 + assert optimizer_state_dict['state'][0]['exp_avg_sq'].size(1) == 20 + + +if __name__ == '__main__': + test_merge_global() + test_merge_tp_dp(False) + test_merge_tp_dp(True) diff --git a/tests/test_utils/test_checkpoint_io/test_merge_param.py b/tests/test_utils/test_checkpoint_io/test_merge_param.py new file mode 100644 index 0000000000000000000000000000000000000000..5da2ae4fe1f8f0e0797d2b733b4fe4ecf1778a67 --- /dev/null +++ b/tests/test_utils/test_checkpoint_io/test_merge_param.py @@ -0,0 +1,101 @@ +import torch +from colossalai.utils.checkpoint_io.meta import ParamDistMeta +from colossalai.utils.checkpoint_io.distributed import unflatten_zero_param, gather_tp_param, merge_param + + +def test_unflatten_zero_param_even() -> None: + dist_metas = [ParamDistMeta(i, 4, 0, 1, zero_numel=16, zero_orig_shape=[4, 4]) for i in range(4)] + orig_tensor = torch.rand(4, 4) + tensors = list(orig_tensor.reshape(-1).chunk(4)) + unflattened_tensor = unflatten_zero_param(tensors, dist_metas) + assert torch.equal(orig_tensor, unflattened_tensor) + merged_tensor = merge_param(tensors, dist_metas) + assert torch.equal(orig_tensor, merged_tensor) + + +def test_unflatten_zero_param_uneven() -> None: + dist_metas = [ParamDistMeta(i, 4, 0, 1, zero_numel=16, zero_orig_shape=[4, 4]) for i in range(1, 3)] + orig_tensor = torch.rand(4, 4) + tensors = list(orig_tensor.reshape(-1).split([13, 3])) + unflattened_tensor = unflatten_zero_param(tensors, dist_metas) + assert torch.equal(orig_tensor, unflattened_tensor) + merged_tensor = merge_param(tensors, dist_metas) + assert torch.equal(orig_tensor, merged_tensor) + + +def test_gather_tp_param_1d_row() -> None: + dist_metas = [ParamDistMeta(0, 1, i, 4, tp_shard_dims=[0], tp_num_parts=[4]) for i in range(4)] + orig_tensor = torch.rand(4, 4) + tensors = [t.contiguous() for t in orig_tensor.chunk(4, 0)] + gathered_tensor = gather_tp_param(tensors, dist_metas) + assert torch.equal(orig_tensor, gathered_tensor) + merged_tensor = merge_param(tensors, dist_metas) + assert torch.equal(orig_tensor, merged_tensor) + + +def test_gather_tp_param_1d_col() -> None: + dist_metas = [ParamDistMeta(0, 1, i, 4, tp_shard_dims=[1], tp_num_parts=[4]) for i in range(4)] + orig_tensor = torch.rand(4, 4) + tensors = [t.contiguous() for t in orig_tensor.chunk(4, 1)] + gathered_tensor = gather_tp_param(tensors, dist_metas) + assert torch.equal(orig_tensor, gathered_tensor) + merged_tensor = merge_param(tensors, dist_metas) + assert torch.equal(orig_tensor, merged_tensor) + + +def test_gather_tp_param_2d() -> None: + dist_metas = [ParamDistMeta(0, 1, i, 6, tp_shard_dims=[0, 1], tp_num_parts=[2, 3]) for i in range(6)] + orig_tensor = torch.rand(4, 6) + tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)] + gathered_tensor = gather_tp_param(tensors, dist_metas) + assert torch.equal(orig_tensor, gathered_tensor) + merged_tensor = merge_param(tensors, dist_metas) + assert torch.equal(orig_tensor, merged_tensor) + + +def test_gather_tp_param_2d_reverse() -> None: + dist_metas = [ParamDistMeta(0, 1, i, 6, tp_shard_dims=[1, 0], tp_num_parts=[3, 2]) for i in range(6)] + orig_tensor = torch.rand(4, 6) + tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)] + gathered_tensor = gather_tp_param(tensors, dist_metas) + assert torch.equal(orig_tensor, gathered_tensor) + merged_tensor = merge_param(tensors, dist_metas) + assert torch.equal(orig_tensor, merged_tensor) + + +def test_merge_param_hybrid() -> None: + dist_metas = [ + ParamDistMeta(i % 2, + 2, + i // 2, + 6, + tp_shard_dims=[1, 0], + tp_num_parts=[3, 2], + zero_numel=4, + zero_orig_shape=[2, 2]) for i in range(12) + ] + orig_tensor = torch.rand(4, 6) + tensors = [ + chunk for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1) + for chunk in t.contiguous().reshape(-1).split([1, 3]) + ] + merged_tensor = merge_param(tensors, dist_metas) + assert torch.equal(orig_tensor, merged_tensor) + + +def test_merge_param_dummy() -> None: + dist_metas = [ParamDistMeta(0, 1, 0, 1)] + orig_tensor = torch.rand(4, 6) + merged_tensor = merge_param([orig_tensor], dist_metas) + assert torch.equal(orig_tensor, merged_tensor) + + +if __name__ == '__main__': + test_unflatten_zero_param_even() + test_unflatten_zero_param_uneven() + test_gather_tp_param_1d_row() + test_gather_tp_param_1d_col() + test_gather_tp_param_2d() + test_gather_tp_param_2d_reverse() + test_merge_param_hybrid() + test_merge_param_dummy() diff --git a/tests/test_utils/test_checkpoint_io/test_redist.py b/tests/test_utils/test_checkpoint_io/test_redist.py new file mode 100644 index 0000000000000000000000000000000000000000..144715bdfcca3db119eb96deb16091b10c537598 --- /dev/null +++ b/tests/test_utils/test_checkpoint_io/test_redist.py @@ -0,0 +1,152 @@ +import os +from functools import partial +from tempfile import TemporaryDirectory + +import pytest +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.optim import Adam + +import colossalai +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME +from colossalai.utils.checkpoint_io.io import redist, save +from colossalai.utils.checkpoint_io.meta import ( + ParamDistMeta, + ParamRedistMeta, + PipelineRedistMeta, + RankRedistMeta, + RedistMeta, +) + + +class DummyModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.fc = nn.Linear(20, 1) + + +def prepare_model_optim(shard: bool = False, zero: bool = False): + model = DummyModel() + if shard: + model.fc.weight.data = model.fc.weight.chunk(2, 1)[dist.get_rank() % 2] + if zero: + dp_rank = dist.get_rank() // 2 + model.fc.weight.data = model.fc.weight.reshape(-1).split([3, model.fc.weight.size(1) - 3], 0)[dp_rank] + if dp_rank != 0: + model.fc.bias.data = torch.empty(0, dtype=model.fc.bias.dtype) + for p in model.parameters(): + p.grad = torch.ones_like(p) + optimizer = Adam(model.parameters(), lr=1e-3) + optimizer.step() + return model, optimizer + + +def get_dist_metas(nprocs: int, zero: bool = False): + dp_world_size = nprocs // 2 + dist_metas = [] + for rank in range(nprocs): + if zero: + dist_metas.append({ + 'fc.weight': + ParamDistMeta(rank // 2, + dp_world_size, + rank % 2, + 2, + tp_shard_dims=[1], + tp_num_parts=[2], + zero_numel=10, + zero_orig_shape=[1, 10]), + 'fc.bias': + ParamDistMeta(rank // 2, dp_world_size, 0, 1, zero_numel=1, zero_orig_shape=[1]) + }) + else: + dist_metas.append({ + 'fc.weight': ParamDistMeta(rank // 2, dp_world_size, rank % 2, 2, tp_shard_dims=[1], tp_num_parts=[2]), + 'fc.bias': ParamDistMeta(rank // 2, dp_world_size, 0, 1) + }) + return dist_metas + + +def get_redist_meta(nprocs: int): + dp_world_size = nprocs // 2 + rank_meta = { + 'fc.weight': {rank: RankRedistMeta(rank // 2, rank % 2, 0) for rank in range(nprocs)}, + 'fc.bias': {rank: RankRedistMeta(rank // 2, 0, 0) for rank in range(nprocs)} + } + param_meta = { + 'fc.weight': ParamRedistMeta(dp_world_size, 2, tp_shard_dims=[1], tp_num_parts=[2]), + 'fc.bias': ParamRedistMeta(dp_world_size, 1) + } + return RedistMeta(rank_meta, [], param_meta) + + +def check_checkpoint_shape(dir_name: str): + global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME)) + for meta_name in global_meta['meta']: + meta = torch.load(os.path.join(dir_name, meta_name)) + assert meta['dist_meta'] is not None + assert len(meta['params']) == 2 + assert len(meta['model']) == 1 and len(meta['optimizer']) == 1 + model_state_dict = torch.load(os.path.join(dir_name, meta['model'][0])) + assert len(model_state_dict) == 2 + assert model_state_dict['fc.weight'].size(1) == 10 + optimizer_state_dict = torch.load(os.path.join(dir_name, meta['optimizer'][0])) + assert len(optimizer_state_dict['state']) == 2 + assert 'param_groups' in optimizer_state_dict and 'state' in optimizer_state_dict + assert optimizer_state_dict['state'][0]['exp_avg'].size(1) == 10 + assert optimizer_state_dict['state'][0]['exp_avg_sq'].size(1) == 10 + + +def test_global_to_dist(): + model, optimizer = prepare_model_optim() + with TemporaryDirectory() as dir_name: + save(dir_name, model, optimizer) + with TemporaryDirectory() as output_dir: + redist(dir_name, output_dir, get_redist_meta(4), get_dist_metas(4)) + check_checkpoint_shape(output_dir) + + +def run_dist(rank, world_size, port, test_fn): + colossalai.launch(config={'parallel': { + 'tensor': { + 'mode': '1d', + 'size': 2 + } + }}, + rank=rank, + world_size=world_size, + host='localhost', + port=port, + backend='nccl') + test_fn() + + +def run_save_dist(dir_name: str, zero: bool): + model, optmizer = prepare_model_optim(shard=True, zero=zero) + rank = dist.get_rank() + save(dir_name, model, optmizer, dist_meta=get_dist_metas(4, zero)[rank]) + + +@pytest.mark.dist +@pytest.mark.parametrize("zero", [False, True]) +@rerun_if_address_is_in_use() +def test_dist_to_dist(zero: bool): + with TemporaryDirectory() as dir_name: + fn = partial(run_save_dist, dir_name, zero) + world_size = 4 + spawn(run_dist, world_size, test_fn=fn) + with TemporaryDirectory() as output_dir: + redist(dir_name, output_dir, get_redist_meta(4), get_dist_metas(4)) + if not zero: + assert len(os.listdir(output_dir)) == 0 + else: + check_checkpoint_shape(output_dir) + + +if __name__ == '__main__': + test_global_to_dist() + test_dist_to_dist(False) + test_dist_to_dist(True) diff --git a/tests/test_utils/test_checkpoint_io/test_save.py b/tests/test_utils/test_checkpoint_io/test_save.py new file mode 100644 index 0000000000000000000000000000000000000000..e35e566f6ff83c91ce5f8a367e478e3f7eecc579 --- /dev/null +++ b/tests/test_utils/test_checkpoint_io/test_save.py @@ -0,0 +1,149 @@ +import os +from functools import partial +from tempfile import TemporaryDirectory +from typing import Dict + +import pytest +import torch +import torch.distributed as dist +import torch.nn as nn +from torch import Tensor +from torch.optim import Adam + +import colossalai +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils.checkpoint_io.constant import ( + GLOBAL_META_FILE_NAME, + META_CKPT_FILE_NAME, + MODEL_CKPT_FILE_NAME, + OTHER_CKPT_FILE_NAME, +) +from colossalai.utils.checkpoint_io.io import save +from colossalai.utils.checkpoint_io.meta import ParamDistMeta + + +def check_model_state_dict(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> None: + assert set(a.keys()) == set(b.keys()) + for k, v in a.items(): + assert torch.equal(v, b[k]) + + +def check_optim_state_dict(a: dict, b: dict, ignore_param_gruops: bool = False) -> None: + assert set(a['state'].keys()) == set(b['state'].keys()) + for k, state in a['state'].items(): + b_state = b['state'][k] + for v1, v2 in zip(state.values(), b_state.values()): + if isinstance(v1, Tensor): + assert torch.equal(v1, v2) + else: + assert v1 == v2 + if not ignore_param_gruops: + assert a['param_groups'] == b['param_groups'] + + +class DummyModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.fc = nn.Linear(20, 1) + + +def prepare_model_optim(): + model = DummyModel() + for p in model.parameters(): + p.grad = torch.ones_like(p) + optimizer = Adam(model.parameters(), lr=1e-3) + optimizer.step() + return model, optimizer + + +def test_overwrite(): + model = DummyModel() + with TemporaryDirectory() as dir_name: + with open(os.path.join(dir_name, MODEL_CKPT_FILE_NAME.replace('.bin', '-shard0.bin')), 'a') as f: + pass + with pytest.raises(RuntimeError, match=r'Save error: Checkpoint ".+" exists\. \(overwrite = False\)'): + save(dir_name, model) + + +def test_save_global(): + model, optimizer = prepare_model_optim() + with TemporaryDirectory() as dir_name: + save(dir_name, model, optimizer) + assert len(os.listdir(dir_name)) == 5 + global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME)) + assert len(global_meta['meta']) == 1 and global_meta['meta'][0] == META_CKPT_FILE_NAME + meta = torch.load(os.path.join(dir_name, META_CKPT_FILE_NAME)) + assert len(meta['model']) == 1 + assert len(meta['optimizer']) == 1 + model_state_dict = torch.load(os.path.join(dir_name, meta['model'][0])) + check_model_state_dict(model.state_dict(), model_state_dict) + optimizer_state_dict = torch.load(os.path.join(dir_name, meta['optimizer'][0])) + check_optim_state_dict(optimizer.state_dict(), optimizer_state_dict) + other_state_dict = torch.load(os.path.join(dir_name, OTHER_CKPT_FILE_NAME)) + assert len(other_state_dict) == 0 + + +def test_save_global_shard(): + model, optimizer = prepare_model_optim() + with TemporaryDirectory() as dir_name: + save(dir_name, model, optimizer, max_shard_size_gb=80 / 1024**3) + assert len(os.listdir(dir_name)) == 7 + meta = torch.load(os.path.join(dir_name, META_CKPT_FILE_NAME)) + assert len(meta['model']) == 2 and len(meta['optimizer']) == 2 + model_state_dicts = [torch.load(os.path.join(dir_name, name)) for name in meta['model']] + assert len(set(model_state_dicts[0].keys()) & set(model_state_dicts[1].keys())) == 0 + check_model_state_dict(model.state_dict(), {**model_state_dicts[0], **model_state_dicts[1]}) + optimizer_state_dicts = [torch.load(os.path.join(dir_name, name)) for name in meta['optimizer']] + assert len(set(optimizer_state_dicts[0]['state'].keys()) & set(optimizer_state_dicts[1]['state'].keys())) == 0 + assert 'param_groups' in optimizer_state_dicts[0] and 'param_groups' not in optimizer_state_dicts[1] + check_optim_state_dict( + optimizer.state_dict(), { + 'state': { + **optimizer_state_dicts[0]['state'], + **optimizer_state_dicts[1]['state'] + }, + 'param_groups': optimizer_state_dicts[0]['param_groups'] + }) + + +def run_dist(rank, world_size, port, test_fn): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + test_fn() + + +def run_save_dist(dir_name): + model, optmizer = prepare_model_optim() + dist_metas = { + 'fc.weight': ParamDistMeta(dist.get_rank(), dist.get_world_size(), 0, 1), + 'fc.bias': ParamDistMeta(dist.get_rank(), dist.get_world_size(), 0, 1) + } + save(dir_name, model, optmizer, dist_meta=dist_metas) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_save_dist(): + with TemporaryDirectory() as dir_name: + fn = partial(run_save_dist, dir_name) + world_size = 2 + spawn(run_dist, world_size, test_fn=fn) + assert len(os.listdir(dir_name)) == 8 + global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME)) + assert len(global_meta['meta']) == 2 + for rank, meta_name in enumerate(global_meta['meta']): + meta = torch.load(os.path.join(dir_name, meta_name)) + assert meta.get('dist_meta', None) is not None + assert len(meta['model']) == 1 and len(meta['optimizer']) == 1 + model_state_dict = torch.load(os.path.join(dir_name, meta['model'][0])) + assert len(model_state_dict) == 2 + optimizer_state_dict = torch.load(os.path.join(dir_name, meta['optimizer'][0])) + assert len(optimizer_state_dict['state']) == 2 + assert 'param_groups' in optimizer_state_dict + + +if __name__ == '__main__': + test_overwrite() + test_save_global() + test_save_global_shard() + test_save_dist() diff --git a/tests/test_utils/test_checkpoint_io/test_unmerge_param.py b/tests/test_utils/test_checkpoint_io/test_unmerge_param.py new file mode 100644 index 0000000000000000000000000000000000000000..8b83caa12359ff8c404e65f96701a6c390f9792d --- /dev/null +++ b/tests/test_utils/test_checkpoint_io/test_unmerge_param.py @@ -0,0 +1,137 @@ +import torch +from colossalai.utils.checkpoint_io.meta import ParamRedistMeta +from colossalai.utils.checkpoint_io.distributed import flatten_zero_param, split_tp_param, unmerge_param + + +def test_flatten_zero_param_even() -> None: + redist_meta = ParamRedistMeta(4, 1, zero_start_dp_rank=0, zero_offsets=[0, 4, 8, 12]) + orig_tensor = torch.rand(4, 4) + tensors = list(orig_tensor.reshape(-1).chunk(4)) + flat_tensors = flatten_zero_param(orig_tensor, redist_meta) + assert len(tensors) == len(flat_tensors) + for t, st in zip(tensors, flat_tensors): + assert torch.equal(t, st) + unmerged_tensors = unmerge_param(orig_tensor, redist_meta) + assert len(unmerged_tensors) == 1 + unmerged_tensors = unmerged_tensors[0] + assert len(tensors) == len(unmerged_tensors) + for t, tl in zip(tensors, unmerged_tensors): + assert torch.equal(t, tl) + + +def test_flatten_zero_param_uneven() -> None: + redist_meta = ParamRedistMeta(4, 1, zero_start_dp_rank=1, zero_offsets=[0, 13]) + orig_tensor = torch.rand(4, 4) + tensors = list(orig_tensor.reshape(-1).split([13, 3])) + flat_tensors = flatten_zero_param(orig_tensor, redist_meta) + assert flat_tensors[0].size(0) == 0 and flat_tensors[-1].size(0) == 0 + flat_tensors = flat_tensors[1:-1] + assert len(tensors) == len(flat_tensors) + for t, st in zip(tensors, flat_tensors): + assert torch.equal(t, st) + unmerged_tensors = unmerge_param(orig_tensor, redist_meta) + assert len(unmerged_tensors) == 1 + unmerged_tensors = unmerged_tensors[0] + assert unmerged_tensors[0].size(0) == 0 and unmerged_tensors[-1].size(0) == 0 + unmerged_tensors = unmerged_tensors[1:-1] + assert len(tensors) == len(unmerged_tensors) + for t, tl in zip(tensors, unmerged_tensors): + assert torch.equal(t, tl) + + +def test_split_tp_param_1d_row() -> None: + redist_meta = ParamRedistMeta(1, 4, tp_shard_dims=[0], tp_num_parts=[4]) + orig_tensor = torch.rand(4, 4) + tensors = [t.contiguous() for t in orig_tensor.chunk(4, 0)] + split_tensors = split_tp_param(orig_tensor, redist_meta) + assert len(tensors) == len(split_tensors) + for t, st in zip(tensors, split_tensors): + assert torch.equal(t, st) + unmerged_tensors = unmerge_param(orig_tensor, redist_meta) + assert len(tensors) == len(unmerged_tensors) + for t, tl in zip(tensors, unmerged_tensors): + assert len(tl) == 1 + assert torch.equal(t, tl[0]) + + +def test_split_tp_param_1d_col() -> None: + redist_meta = ParamRedistMeta(1, 4, tp_shard_dims=[1], tp_num_parts=[4]) + orig_tensor = torch.rand(4, 4) + tensors = [t.contiguous() for t in orig_tensor.chunk(4, 1)] + split_tensors = split_tp_param(orig_tensor, redist_meta) + assert len(tensors) == len(split_tensors) + for t, st in zip(tensors, split_tensors): + assert torch.equal(t, st) + unmerged_tensors = unmerge_param(orig_tensor, redist_meta) + assert len(tensors) == len(unmerged_tensors) + for t, tl in zip(tensors, unmerged_tensors): + assert len(tl) == 1 + assert torch.equal(t, tl[0]) + + +def test_split_tp_param_2d() -> None: + redist_meta = ParamRedistMeta(1, 6, tp_shard_dims=[0, 1], tp_num_parts=[2, 3]) + orig_tensor = torch.rand(4, 6) + tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)] + split_tensors = split_tp_param(orig_tensor, redist_meta) + assert len(tensors) == len(split_tensors) + for t, st in zip(tensors, split_tensors): + assert torch.equal(t, st) + unmerged_tensors = unmerge_param(orig_tensor, redist_meta) + assert len(tensors) == len(unmerged_tensors) + for t, tl in zip(tensors, unmerged_tensors): + assert len(tl) == 1 + assert torch.equal(t, tl[0]) + + +def test_split_tp_param_2d_reverse() -> None: + redist_meta = ParamRedistMeta(1, 6, tp_shard_dims=[1, 0], tp_num_parts=[3, 2]) + orig_tensor = torch.rand(4, 6) + tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)] + split_tensors = split_tp_param(orig_tensor, redist_meta) + assert len(tensors) == len(split_tensors) + for t, st in zip(tensors, split_tensors): + assert torch.equal(t, st) + unmerged_tensors = unmerge_param(orig_tensor, redist_meta) + assert len(tensors) == len(unmerged_tensors) + for t, tl in zip(tensors, unmerged_tensors): + assert len(tl) == 1 + assert torch.equal(t, tl[0]) + + +def test_unmerge_param_hybrid() -> None: + redist_meta = ParamRedistMeta(2, + 6, + tp_shard_dims=[1, 0], + tp_num_parts=[3, 2], + zero_start_dp_rank=0, + zero_offsets=[0, 1]) + orig_tensor = torch.rand(4, 6) + tensors = [ + chunk for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1) + for chunk in t.contiguous().reshape(-1).split([1, 3]) + ] + unmerged_tensors = unmerge_param(orig_tensor, redist_meta) + assert len(unmerged_tensors) == 6 and len(unmerged_tensors[0]) == 2 + for tp_rank in range(6): + for dp_rank in range(2): + assert torch.equal(tensors[tp_rank * 2 + dp_rank], unmerged_tensors[tp_rank][dp_rank]) + + +def test_unmerge_param_dummy() -> None: + redist_meta = ParamRedistMeta(1, 1) + orig_tensor = torch.rand(4, 6) + unmerged_tensors = unmerge_param(orig_tensor, redist_meta) + assert len(unmerged_tensors) == 1 and len(unmerged_tensors[0]) == 1 + assert torch.equal(orig_tensor, unmerged_tensors[0][0]) + + +if __name__ == '__main__': + test_flatten_zero_param_even() + test_flatten_zero_param_uneven() + test_split_tp_param_1d_row() + test_split_tp_param_1d_col() + test_split_tp_param_2d() + test_split_tp_param_2d_reverse() + test_unmerge_param_hybrid() + test_unmerge_param_dummy() diff --git a/tests/test_utils/test_colo_checkpoint.py b/tests/test_utils/test_colo_checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..89760a5456e774362204d09d049bcdd603beaae5 --- /dev/null +++ b/tests/test_utils/test_colo_checkpoint.py @@ -0,0 +1,206 @@ +import os +import shutil +from copy import deepcopy + +import pytest +import torch +import torch.distributed as dist +from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR + +import colossalai +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.nn.optimizer import ColossalaiOptimizer +from colossalai.tensor import ColoTensor, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils.checkpoint import load_checkpoint, save_checkpoint +from colossalai.utils.cuda import get_current_device +from colossalai.zero import ColoInitContext +from tests.components_to_test.registry import non_distributed_component_funcs + + +def init_1d_row_linear(weight: ColoTensor, pg: ProcessGroup): + spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + weight.set_process_group(pg) + weight.set_tensor_spec(*spec) + + +def init_1d_col_linear(weight, pg): + spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + weight.set_process_group(pg) + weight.set_tensor_spec(*spec) + + +def init_1d_row_embedding(weight, pg): + spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + weight.set_process_group(pg) + weight.set_tensor_spec(*spec) + + +def init_1d_col_embedding(weight, pg): + spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + weight.set_process_group(pg) + weight.set_tensor_spec(*spec) + + +def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup): + spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + for name, p in model.named_parameters(): + if not isinstance(p, ColoTensor): + continue + if 'embed' in name and 'weight' in name: + init_1d_col_embedding(p, pg) + if 'proj1' in name and ('weight' in name or 'bias' in name): + init_1d_col_linear(p, pg) + if 'proj2' in name and 'weight' in name: + init_1d_row_linear(p, pg) + if 'classifier' in name and ('weight' in name or 'bias' in name): + init_1d_col_linear(p, pg) + + +def check_param_equal(model, torch_model): + for (n, p), (tn, tp) in zip(model.named_parameters(), torch_model.named_parameters()): + assert torch.all(p.data == tp.data), "{} went wrong.\n {} vs {}\n{}".format(n, p, tp, p.shape) + + +def remove(path): + """ param could either be relative or absolute. """ + if os.path.isfile(path) or os.path.islink(path): + os.remove(path) + elif os.path.isdir(path): + shutil.rmtree(path) + else: + raise ValueError("file {} is not a file or dir.".format(path)) + + +def compare_optims(optim1, optim2): + state1 = optim1.state_dict()['state'] + state2 = optim2.state_dict()['state'] + for k, p1 in state1.items(): + if k not in state2: + continue + p2 = state2[k] + for n, t1 in p1.items(): + if n not in p2: + continue + t2 = p2[n] + if isinstance(t1, ColoTensor): + assert isinstance(t2, ColoTensor) + assert torch.allclose(t1, t2, rtol=0, atol=0) + + +def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg): + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + + # set_seed(1) + with ColoInitContext(device=get_current_device()): + model = model_builder(checkpoint=True) + + if use_mp_reload: + if 'bert' == model_name: + for name, p in model.named_parameters(): + if not isinstance(p, ColoTensor): + continue + # num_class = type_vocab_size = 2 | (8, 2) + if 'classifier' in name and 'weight' in name: + init_1d_row_linear(p, pg) + # num_class = vocab_size = 30524 | (30524, 8) + elif 'word_embeddings' in name and 'weight' in name: + init_1d_row_embedding(p, pg) + # num_class = seq_len = 512 | (512, 8) + elif 'position_embeddings' in name and 'weight' in name: + init_1d_row_embedding(p, pg) + # num_class = type_vocab_size = 2 | (2, 8) + elif 'token_type_embeddings' in name and 'weight' in name: + init_1d_col_embedding(p, pg) + elif p.process_group.tp_world_size() == 1: + p.set_process_group(pg) + elif "simple_net" == model_name: + init_spec_func(model, pg) + + model_reload = deepcopy(model) + model = model.cuda() + model.eval() + + model_reload = model_reload.cuda() + model_reload.eval() + + opt_class = torch.optim.Adam + colo_optimizer = ColossalaiOptimizer(opt_class(model.parameters(), lr=0.1)) + colo_optimizer_reload = ColossalaiOptimizer(opt_class(model_reload.parameters(), lr=0.1)) + + for i, (data, label) in enumerate(train_dataloader): + + # Zero grad + colo_optimizer.zero_grad() + colo_optimizer_reload.zero_grad() + + data = data.to(get_current_device()) + label = label.to(get_current_device()) + + dist.broadcast(data, pg.tp_rank_list()[0], pg.tp_process_group()) + dist.broadcast(label, pg.tp_rank_list()[0], pg.tp_process_group()) + + # Bcast rank0 data to all processes + if criterion: + output = model(data) + output_reload = model_reload(data) + loss = criterion(output, label) + loss_reload = criterion(output_reload, label) + else: + loss = model(data, label) + loss_reload = model_reload(data, label) + + loss.backward() + loss_reload.backward() + + colo_optimizer.step() + colo_optimizer_reload.step() + + if i > 2: + break + + if not os.path.isdir('./checkpoint') and rank == 0: + os.mkdir('./checkpoint') + dist.barrier() + + save_checkpoint('./checkpoint', 0, model, colo_optimizer, None) + load_checkpoint('./checkpoint', 0, model_reload, colo_optimizer_reload, None) + + check_param_equal(model, model_reload) + compare_optims(colo_optimizer, colo_optimizer_reload) + + if rank == 0: + remove('./checkpoint') + dist.barrier() + + +def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + pg = ProcessGroup(tp_degree=world_size) + + # the data loader of BERT is in DDP mode, causing the input data is not replicated in the TP context + for model_name in ['bert']: + _run_checkpoint(model_name, + init_1d_row_for_linear_weight_spec, + use_ddp, + use_mp_reload, + test_scheduler=test_scheduler, + pg=pg) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2]) +@pytest.mark.parametrize('use_ddp', [False]) +@pytest.mark.parametrize('use_mp_reload', [True, False]) +# @pytest.mark.parametrize('test_scheduler', ['colossalai_cosine_warmup', 'torch_cosine', 'torch_lambda']) +@rerun_if_address_is_in_use() +def test_checkpoint(world_size, use_ddp, use_mp_reload, test_scheduler=None): + spawn(run_dist, world_size, use_ddp=use_ddp, use_mp_reload=use_mp_reload, test_scheduler=test_scheduler) + + +if __name__ == '__main__': + test_checkpoint(2, use_ddp=False, use_mp_reload=True, test_scheduler="torch_cosine") diff --git a/tests/test_utils/test_commons.py b/tests/test_utils/test_commons.py new file mode 100644 index 0000000000000000000000000000000000000000..2633d7da21aa3e71c3764b83515901eaff41c24a --- /dev/null +++ b/tests/test_utils/test_commons.py @@ -0,0 +1,41 @@ +import torch + +import colossalai +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline +from colossalai.zero.legacy.sharded_param import ShardedTensor + + +def run_tensor_move(rank, world_size, port): + colossalai.launch(config={}, rank=0, world_size=world_size, host='localhost', port=port, backend='nccl') + + src_t = torch.ones(2, 3).cuda() + tgt_t = torch.zeros(2, 3) + + colo_model_data_tensor_move(src_t, tgt_t) + assert (torch.sum(tgt_t) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0" + + src_t = torch.ones(2, 3) + tgt_t = torch.zeros(2, 3).cuda().half() + colo_model_data_tensor_move(src_t, tgt_t) + # the src_t has been removed + assert (src_t.numel() == 0) + assert (torch.sum(tgt_t) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0" + + src_t = ShardedTensor(torch.ones(2, 3)) + tgt_t = ShardedTensor(torch.zeros(2, 3).cuda().half()) + colo_model_data_tensor_move(src_t, tgt_t) + assert (torch.sum(tgt_t.payload) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0" + + assert (tgt_t.device.type == 'cuda') + colo_model_data_tensor_move_inline(tgt_t, torch.device('cpu')) + assert (tgt_t.device.type == 'cpu') + + +@rerun_if_address_is_in_use() +def test_tensor_move(): + spawn(run_tensor_move, 1) + + +if __name__ == '__main__': + test_tensor_move() diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..7a28b0157384f2a664fe777d4743dd41462a3a16 --- /dev/null +++ b/tests/test_utils/test_flash_attention.py @@ -0,0 +1,113 @@ +import random + +import pytest +import torch +from einops import rearrange + +from colossalai.kernel.cuda_native.flash_attention import HAS_MEM_EFF_ATTN +from colossalai.testing import clear_cache_before_run, parameterize + +if HAS_MEM_EFF_ATTN: + from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention + + +def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale): + M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + for z in range(Z): + for h in range(H): + p[:, :, M == 0] = float("-inf") + p = torch.softmax(p.float(), dim=-1).half() + ref_out = torch.matmul(p, v) + return ref_out + + +@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") +@clear_cache_before_run() +@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)]) +def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16): + D = H * D_HEAD + + c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda") + attn = ColoAttention(D, H, dropout=0.1) + + x = torch.randn((B, S, D), dtype=dtype, device="cuda") + + qkv = c_attn(x) + q, k, v = rearrange(qkv, 'b s (n h d) -> n b s h d', n=3, h=H) + y = attn(q, k, v, attn_mask_type=AttnMaskType.causal) + + assert list(y.shape) == [B, S, D] + + dy = torch.rand_like(y) + y.backward(dy) + + +@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") +@clear_cache_before_run() +@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)]) +def test_attention_bert(B, S, H, D_HEAD, dtype=torch.float16): + D = H * D_HEAD + + c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda") + attn = ColoAttention(D, H, dropout=0.1) + + x = torch.randn((B, S, D), dtype=dtype, device="cuda") + # attention mask of shape [B, S] with zero padding to max length S + mask = [torch.ones(S - i, dtype=dtype, device="cuda") for i in range(B)] + mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True) + + qkv = c_attn(x) + q, k, v = rearrange(qkv, 'b s (n h d) -> b s n h d', n=3, h=H).unbind(dim=2) + y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.padding) + + assert list(y.shape) == [B, S, D] + + dy = torch.rand_like(y) + y.backward(dy) + + +@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") +@clear_cache_before_run() +@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)]) +def test_attention_no_mask(B, S, H, D_HEAD, dtype=torch.float16): + D = H * D_HEAD + + c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda") + attn = ColoAttention(D, H, dropout=0.1) + + x = torch.randn((B, S, D), dtype=dtype, device="cuda") + qkv = c_attn(x) + q, k, v = rearrange(qkv, 'b s (n h d) -> b s n h d', n=3, h=H).unbind(dim=2) + y = attn(q, k, v) + + assert list(y.shape) == [B, S, D] + + dy = torch.rand_like(y) + y.backward(dy) + + +@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") +@clear_cache_before_run() +@parameterize('B, S, T, H, D_HEAD', [(6, 24, 8, 4, 16)]) +def test_cross_attention(B, S, T, H, D_HEAD, dtype=torch.float16): + D = H * D_HEAD + + q_attn = torch.nn.Linear(D, D, dtype=dtype, device="cuda") + kv_attn = torch.nn.Linear(D, 2 * D, dtype=dtype, device="cuda") + + attn = ColoAttention(D, H, dropout=0.1) + + src = torch.randn((B, S, D), dtype=dtype, device="cuda") + tgt = torch.randn((B, T, D), dtype=dtype, device="cuda") + + q = q_attn(tgt) + kv = kv_attn(src) + q = rearrange(q, 'b s (h d) -> b s h d', h=H) + k, v = rearrange(kv, 'b s (n h d) -> b s n h d', n=2, h=H).unbind(dim=2) + y = attn(q, k, v, attn_mask_type=AttnMaskType.causal) + + assert list(y.shape) == [B, T, D] + + dy = torch.rand_like(y) + y.backward(dy) diff --git a/tests/test_utils/test_lazy_init/test_distribute.py b/tests/test_utils/test_lazy_init/test_distribute.py new file mode 100644 index 0000000000000000000000000000000000000000..2c15ca84efaad42a112207e0e5127c5c513e6985 --- /dev/null +++ b/tests/test_utils/test_lazy_init/test_distribute.py @@ -0,0 +1,109 @@ +from typing import Optional + +import pytest +import torch +import torch.nn as nn + +import colossalai +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.d_tensor.layout import Layout +from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils.common import print_rank_0 + +try: + from colossalai.utils.model.experimental import LazyInitContext, LazyTensor, _MyTensor +except: + pass +from tests.kit.model_zoo import model_zoo + +# from utils import assert_dist_model_equal, set_seed + + +def find_shard_dim(shape: torch.Size) -> Optional[int]: + for dim, size in enumerate(shape): + if size % 2 == 0: + return dim + + +def make_layout(device_mesh: DeviceMesh, original_tensor: torch.Tensor) -> Layout: + shard_dim = find_shard_dim(original_tensor.shape) + dim_partition_dict = {shard_dim: [0]} if shard_dim is not None else {} + target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict=dim_partition_dict) + layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=target_sharding_spec, + entire_shape=original_tensor.shape) + return layout + + +def _get_current_name(prefix: str, name: str) -> str: + return f'{prefix}.{name}'.lstrip('.') + + +def generate_layout_dict(model: nn.Module, device_mesh: DeviceMesh) -> dict: + layout_dict = {} + + @torch.no_grad() + def generate_recursively(module: nn.Module, prefix: str = ''): + # recursively initialize the module + for name, mod in module.named_children(): + generate_recursively(mod, prefix=_get_current_name(prefix, name)) + + # initialize tensors directly attached to the current module + for name, param in module.named_parameters(recurse=False): + if isinstance(param, LazyTensor): + layout = make_layout(device_mesh, param) + layout_dict[_get_current_name(prefix, name)] = layout + + for name, buf in module.named_buffers(recurse=False): + if isinstance(buf, LazyTensor): + layout = make_layout(device_mesh, buf) + layout_dict[_get_current_name(prefix, name)] = layout + + generate_recursively(model) + + return layout_dict + + +@parameterize('subset', ['torchvision', 'diffusers', 'timm', 'transformers', 'torchaudio', 'deepfm', 'dlrm']) +def run_dist_lazy_init(subset, seed: int = 42): + sub_model_zoo = model_zoo.get_sub_registry(subset) + device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True) + # FIXME(ver217): uncomment this line + # _MyTensor._pre_op_fn = lambda *args: set_seed(seed) + # LazyTensor._pre_op_fn = lambda *args: set_seed(seed) + + for name, entry in sub_model_zoo.items(): + # TODO(ver217): lazy init does not support weight norm, skip these models + if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base'): + continue + print_rank_0(name) + model_fn, data_gen_fn, output_transform_fn, model_attr = entry + ctx = LazyInitContext(tensor_cls=_MyTensor) + with ctx: + model = model_fn() + ctx = LazyInitContext() + with ctx: + deferred_model = model_fn() + layout_dict = generate_layout_dict(deferred_model, device_mesh) + ctx.distribute(deferred_model, layout_dict, verbose=True) + # FIXME(ver217): uncomment this line + # assert_dist_model_equal(model, deferred_model, layout_dict) + + +def run_dist(rank, world_size, port) -> None: + colossalai.launch({}, rank=rank, world_size=world_size, host='localhost', port=port) + run_dist_lazy_init() + + +# FIXME(ver217): temporarily skip this test since torch 1.11 does not fully support meta tensor +@pytest.mark.skip +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_dist_lazy_init(): + spawn(run_dist, 4) + + +if __name__ == '__main__': + test_dist_lazy_init() diff --git a/tests/test_utils/test_lazy_init/test_models.py b/tests/test_utils/test_lazy_init/test_models.py new file mode 100644 index 0000000000000000000000000000000000000000..9faddecbaca4dfa57b191b3ed7e351a847f235f3 --- /dev/null +++ b/tests/test_utils/test_lazy_init/test_models.py @@ -0,0 +1,23 @@ +import pytest + +from tests.kit.model_zoo import model_zoo + +# FIXME(ver217): uncomment this line +# from utils import check_lazy_init + + +# FIXME(ver217): temporarily skip this test since torch 1.11 does not fully support meta tensor +@pytest.mark.skip +@pytest.mark.parametrize('subset', ['torchvision', 'diffusers', 'timm', 'transformers', 'torchaudio', 'deepfm', 'dlrm']) +def test_torchvision_models_lazy_init(subset): + sub_model_zoo = model_zoo.get_sub_registry(subset) + for name, entry in sub_model_zoo.items(): + # TODO(ver217): lazy init does not support weight norm, skip these models + if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base'): + continue + # FIXME(ver217): uncomment this line + # check_lazy_init(entry, verbose=True) + + +if __name__ == '__main__': + test_torchvision_models_lazy_init('torchvision') diff --git a/tests/test_utils/test_lazy_init/utils.py b/tests/test_utils/test_lazy_init/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a8aeb4c8930c396bd2e3d9c72497f9d4e426a464 --- /dev/null +++ b/tests/test_utils/test_lazy_init/utils.py @@ -0,0 +1,85 @@ +import random +from typing import Any, Callable, Optional, Tuple + +import numpy as np +import torch + +from colossalai.tensor.d_tensor.layout_converter import to_global +from colossalai.utils.model.experimental import LazyInitContext, LazyTensor, _MyTensor +from tests.kit.model_zoo.registry import ModelAttribute + +# model_fn, data_gen_fn, output_transform_fn, model_attr +TestingEntry = Tuple[Callable[[], torch.nn.Module], Callable[[], dict], Callable[[], dict], Optional[ModelAttribute]] + + +def set_seed(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + +def assert_model_eqaual(m1: torch.nn.Module, m2: torch.nn.Module) -> None: + s1 = m1.state_dict() + s2 = m2.state_dict() + + assert len(s1) == len(s2), f'len {len(s1)} vs {len(s2)}' + + for (n1, t1), (n2, t2) in zip(s1.items(), s2.items()): + assert n1 == n2 + assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}' + + +def assert_forward_equal(m1: torch.nn.Module, m2: torch.nn.Module, data_gen_fn: Callable[[], dict], + output_transform_fn: Callable[[Any], dict]) -> None: + data = data_gen_fn() + + m1.eval() + m2.eval() + # run forward + with torch.no_grad(): + outputs1 = m1(**data) + outputs2 = m2(**data) + + # compare output + transformed_out1 = output_transform_fn(outputs1) + transformed_out2 = output_transform_fn(outputs2) + + assert len(transformed_out1) == len(transformed_out2) + + for key, out1 in transformed_out1.items(): + out2 = transformed_out2[key] + assert torch.allclose(out1, out2, atol=1e-5), \ + f'{m1.__class__.__name__} has inconsistent outputs, {out1} vs {out2}' + + +def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False, check_forward: bool = False) -> None: + model_fn, data_gen_fn, output_transform_fn, model_attr = entry + _MyTensor._pre_op_fn = lambda *args: set_seed(seed) + LazyTensor._pre_op_fn = lambda *args: set_seed(seed) + ctx = LazyInitContext(tensor_cls=_MyTensor) + with ctx: + model = model_fn() + ctx = LazyInitContext() + with ctx: + deferred_model = model_fn() + deferred_model = ctx.materialize(deferred_model, verbose=verbose) + assert_model_eqaual(model, deferred_model) + if check_forward: + assert_forward_equal(model, deferred_model, data_gen_fn, output_transform_fn) + if verbose: + print(f'{model.__class__.__name__} pass') + + +def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, layout_dict: dict) -> None: + state = model.state_dict() + distributed_state = distributed_model.state_dict() + + assert len(state) == len(distributed_state), f'len {len(state)} vs {len(distributed_state)}' + + for (n1, t1), (n2, t2) in zip(state.items(), distributed_state.items()): + assert n1 == n2 + t1 = t1.cuda() + t2 = t2.cuda() + if n2 in layout_dict: + t2 = to_global(t2, layout_dict[n2]) + assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}' diff --git a/tests/test_utils/test_lazy_init_ctx.py b/tests/test_utils/test_lazy_init_ctx.py new file mode 100644 index 0000000000000000000000000000000000000000..97efb3367490e772f351939bec7949fd86ad4da3 --- /dev/null +++ b/tests/test_utils/test_lazy_init_ctx.py @@ -0,0 +1,51 @@ +import torch +from colossalai.utils.model.lazy_init_context import LazyInitContext +from torchvision.models import resnet34 +import random +import numpy as np + +MANUAL_SEED = 0 +random.seed(MANUAL_SEED) +np.random.seed(MANUAL_SEED) +torch.manual_seed(MANUAL_SEED) + + +def test_lazy_init_with_meta(): + ctx = LazyInitContext(to_meta=True) + with ctx: + model = resnet34(num_classes=10) + + for param in model.parameters(): + assert param.is_meta + for buffer in model.buffers(): + assert buffer.is_meta + + ctx.lazy_init_parameters(model) + + for name, param in model.named_parameters(): + assert not param.is_meta, name + + for buffer in model.buffers(): + assert not buffer.is_meta + + +def test_lazy_init_without_meta(): + ctx = LazyInitContext(to_meta=False) + with ctx: + model = resnet34(num_classes=10) + + for param in model.parameters(): + assert not param.is_meta + for buffer in model.buffers(): + assert not buffer.is_meta + + conv1_weight_before_init = model.conv1.weight.clone() + ctx.lazy_init_parameters(model) + conv1_weight_after_init = model.conv1.weight.clone() + + assert not torch.allclose(conv1_weight_after_init, conv1_weight_before_init) + + +if __name__ == '__main__': + test_lazy_init_with_meta() + test_lazy_init_without_meta() diff --git a/tests/test_utils/test_memory.py b/tests/test_utils/test_memory.py new file mode 100644 index 0000000000000000000000000000000000000000..c88c2f8ec3c5abbba033abb682fd788435212e72 --- /dev/null +++ b/tests/test_utils/test_memory.py @@ -0,0 +1,28 @@ +import pytest + +import colossalai +from colossalai.testing import spawn +from colossalai.utils.cuda import get_current_device +from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction + + +def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity(): + frac1 = colo_device_memory_capacity(get_current_device()) + colo_set_process_memory_fraction(0.5) + frac2 = colo_device_memory_capacity(get_current_device()) + assert frac2 * 2 == frac1 + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity() + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [3, 4]) +def test_memory_utils(world_size): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_memory_utils(world_size=2) diff --git a/tests/test_utils/test_norm_gradient_clipping.py b/tests/test_utils/test_norm_gradient_clipping.py new file mode 100644 index 0000000000000000000000000000000000000000..c0d678026c5fe184264ba03b9a7cdb0dbc5b1a06 --- /dev/null +++ b/tests/test_utils/test_norm_gradient_clipping.py @@ -0,0 +1,77 @@ +import pytest +import torch +from torch.nn.parameter import Parameter +from torch.nn.utils import clip_grad_norm_ + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.tensor import ColoTensorSpec, ProcessGroup, distspec +from colossalai.tensor.colo_parameter import ColoParameter +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device +from colossalai.utils.common import clip_grad_norm + + +def close(num: float, other: float, rtol: float = 1e-5, atol: float = 1e-8): + return abs(num - other) <= atol + rtol * other + + +def shard_param(p: ColoParameter) -> None: + pg = p.get_process_group() + p._redistribute(distspec.ShardSpec([0], [pg.tp_world_size()])) + p.grad = p.grad.chunk(pg.tp_world_size(), 0)[pg.tp_local_rank()].clone().detach() + + +def check_grad_equal(p: Parameter, colo_p: ColoParameter) -> None: + pg = colo_p.get_process_group() + if p.shape != colo_p.shape: + grad = p.grad.chunk(pg.tp_world_size(), 0)[pg.tp_local_rank()] + else: + grad = p.grad + assert torch.allclose(grad, colo_p.grad), f'diff: {torch.abs(grad - colo_p.grad)}' + + +@parameterize('dtype', [torch.float]) +@parameterize('device', ['mixed', 'cuda', 'cpu']) +@parameterize('norm_type', [2.0, 3.0, float('inf')]) +def run_grad_clip_norm(world_size: int, dtype: torch.dtype, device: str, norm_type: float): + print(f'{world_size}, {dtype}, {device}, {norm_type}') + cuda_device = get_current_device() + devices = [cuda_device] * 4 + if device == 'cpu': + devices = [torch.device('cpu')] * 4 + elif device == 'mixed': + devices = [cuda_device] * 2 + [torch.device('cpu')] * 2 + pg = ProcessGroup(tp_degree=world_size) + params = [Parameter(torch.empty(4, 4, dtype=dtype, device=devices[i])) for i in range(4)] + colo_params = [ + ColoParameter(torch.empty(4, 4, dtype=dtype, device=devices[i]), spec=ColoTensorSpec(pg)) for i in range(4) + ] + for p, colo_p in zip(params, colo_params): + grad = torch.rand_like(p) + p.grad = grad + colo_p.grad = grad.clone().detach() + shard_param(colo_params[0]) + shard_param(colo_params[2]) + torch_norm = clip_grad_norm_(params, 1.0, norm_type=norm_type) + colo_norm = clip_grad_norm(colo_params, 1.0, norm_type=norm_type) + assert close(torch_norm, colo_norm), f'diff: {abs(torch_norm-colo_norm)}' + for p, colo_p in zip(params, colo_params): + check_grad_equal(p, colo_p) + + +def run_dist(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_grad_clip_norm(world_size=world_size) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2]) +@rerun_if_address_is_in_use() +def test_zero_clip_grad(world_size: int): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_zero_clip_grad(2) diff --git a/tests/test_utils/test_zero_gradient_clippling.py b/tests/test_utils/test_zero_gradient_clippling.py new file mode 100644 index 0000000000000000000000000000000000000000..e99cf388e929df3f34616ced50c4da851205c369 --- /dev/null +++ b/tests/test_utils/test_zero_gradient_clippling.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from functools import partial + +import pytest +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_norm_ + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import checkpoint, clip_grad_norm_fp32 +from colossalai.zero.legacy.shard_utils.tensor_shard_strategy import TensorShardStrategy +from colossalai.zero.legacy.sharded_model.sharded_model_v2 import ShardedModelV2 + + +def checkpoint_wrapper(module, enable=True): + if enable: + module.forward = partial(checkpoint, module.forward, False) + return module + + +class Net(nn.Module): + + def __init__(self, checkpoint=False) -> None: + super().__init__() + self.fc1 = nn.Linear(5, 5) + self.fc2 = nn.Linear(5, 5) + self.fc3 = nn.Linear(5, 1) + if checkpoint: + self.fc1 = checkpoint_wrapper(self.fc1) + self.layers = [self.fc1, self.fc2, self.fc1, self.fc2, self.fc3] + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + + +def run_step(model, optimizer, x, enable_autocast=False, norm_type=2.0): + model.train() + optimizer.zero_grad() + with torch.cuda.amp.autocast(enabled=enable_autocast): + y = model(x) + loss = y.sum() + loss = loss.float() + loss.backward() + clip_grad(model, norm_type) + optimizer.step() + + +def clip_grad(model, norm_type): + if isinstance(model, DDP): + clip_grad_norm_(model.parameters(), max_norm=1.0, norm_type=norm_type) + else: + clip_grad_norm_fp32(model.parameters(), max_norm=1.0, norm_type=norm_type) + + +def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool: + if loose: + return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3) + return torch.allclose(tensor_a, tensor_b) + + +def check_grads(model, zero_model, loose=False): + rank = dist.get_rank() + for p, zero_p in zip(model.parameters(), zero_model.parameters()): + zero_grad = zero_p.grad.clone().to(p.device) + chunks = torch.flatten(p.grad).chunk(4) + if rank >= len(chunks): + continue + grad = chunks[rank] + if zero_p.zero_shard_padding > 0: + zero_grad = zero_grad[:-zero_p.zero_shard_padding] + assert grad.dtype == zero_grad.dtype + assert allclose(grad, zero_grad, loose=loose) + + +def check_params(model, zero_model, loose=False): + rank = dist.get_rank() + for p, zero_p in zip(model.parameters(), zero_model.parameters()): + zero_shard_padding = zero_p.zero_shard_padding + zero_p = zero_p.clone().to(p.device) + chunks = torch.flatten(p).chunk(4) + if rank >= len(chunks): + continue + p = chunks[rank] + if zero_shard_padding > 0: + zero_p = zero_p[:-zero_shard_padding] + assert p.dtype == zero_p.dtype + assert allclose(p, zero_p, loose=loose) + + +def run_dist(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_zero_clip_grad(): + world_size = 4 + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_zero_clip_grad() diff --git a/tests/test_zero/test_gemini/test_chunk_mgrv2.py b/tests/test_zero/test_gemini/test_chunk_mgrv2.py new file mode 100644 index 0000000000000000000000000000000000000000..7ea063877b5c483017c66c52ab5cafceff56d0fd --- /dev/null +++ b/tests/test_zero/test_gemini/test_chunk_mgrv2.py @@ -0,0 +1,67 @@ +import pytest +import torch + +import colossalai +from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.zero.gemini.chunk import ChunkManager +from tests.test_tensor.common_utils import debug_print + +CUDA_MEM_0 = {False: 512, True: 1024} +CUDA_MEM_1 = {False: 0, True: 1024} +CPU_MEM = {True: {True: 0, False: 0}, False: {True: 512, False: 0}} + + +@parameterize('keep_gathered', [True, False]) +@parameterize('pin_memory', [True, False]) +def exam_chunk_memory(keep_gathered, pin_memory): + pg = ProcessGroup() + + debug_print([0], "keep_gathered: {}, pin_memory: {}".format(keep_gathered, pin_memory)) + + params = [ColoTensor(torch.rand(8, 8), spec=ColoTensorSpec(pg)) for _ in range(3)] + config = {2: dict(chunk_size=128, keep_gathered=keep_gathered)} + + chunk_manager = ChunkManager(config) + assert chunk_manager.total_mem['cpu'] == 0 + assert chunk_manager.total_mem['cuda'] == 0 + + for p in params: + chunk_manager.register_tensor(p, 'param', 2, pin_memory=pin_memory) + chunk_manager.close_all_groups() + assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory] + assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[keep_gathered] + + chunks = chunk_manager.get_chunks(params) + + for chunk in chunks: + chunk_manager.access_chunk(chunk) + assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory] + assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[True] + + for chunk in chunks: + chunk_manager.release_chunk(chunk) + + assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory] + assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[keep_gathered] + + for chunk in chunks: + chunk_manager.move_chunk(chunk, torch.device('cpu')) + assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][True] + assert chunk_manager.total_mem['cuda'] == CUDA_MEM_1[keep_gathered] + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_chunk_memory() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [2]) +@rerun_if_address_is_in_use() +def test_chunk_manager(world_size): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_chunk_manager(2) diff --git a/tests/test_zero/test_gemini/test_chunkv2.py b/tests/test_zero/test_gemini/test_chunkv2.py new file mode 100644 index 0000000000000000000000000000000000000000..16764aa6b0b1b1234eb62e74e9c5f2fae4994a22 --- /dev/null +++ b/tests/test_zero/test_gemini/test_chunkv2.py @@ -0,0 +1,121 @@ +import pytest +import torch +import torch.distributed as dist + +import colossalai +from colossalai.tensor import ColoParameter +from colossalai.tensor import ProcessGroup as ColoProcessGroup +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device +from colossalai.zero.gemini import TensorState +from colossalai.zero.gemini.chunk import Chunk + + +def dist_sum(x): + temp = torch.tensor([x], device=get_current_device()) + dist.all_reduce(temp) + return temp.item() + + +def add_param(param_list, param_cp_list, *args, **kwargs): + param = ColoParameter(torch.randn(*args, **kwargs)) + param_list.append(param) + param_cp_list.append(param.clone()) + + +def check_euqal(param, param_cp): + if param.device != param_cp.device: + temp = param.data.to(param_cp.device) + else: + temp = param.data + return torch.equal(temp, param_cp.data) + + +@parameterize('init_device', [None, torch.device('cpu')]) +@parameterize('keep_gathered', [True, False]) +@parameterize('pin_memory', [True, False]) +def exam_chunk_basic(init_device, keep_gathered, pin_memory): + world_size = torch.distributed.get_world_size() + pg = ColoProcessGroup() + my_chunk = Chunk(chunk_size=1024, + process_group=pg, + dtype=torch.float32, + init_device=init_device, + cpu_shard_init=True, + keep_gathered=keep_gathered, + pin_memory=pin_memory) + + param_list = [] + param_cp_list = [] + + add_param(param_list, param_cp_list, 8, 8, 8, device='cuda') + add_param(param_list, param_cp_list, 4, 4) + add_param(param_list, param_cp_list, 4, 8, 2, device='cuda') + add_param(param_list, param_cp_list, 1, 1, 5) + + for param in param_list: + my_chunk.append_tensor(param) + assert my_chunk.utilized_size == 597 + for param, param_cp in zip(param_list, param_cp_list): + check_euqal(param, param_cp) + my_chunk.close_chunk() + + if keep_gathered is False: + assert my_chunk.cpu_shard.size(0) == 1024 // world_size + assert my_chunk.device_type == 'cpu' + assert my_chunk.can_move + my_chunk.shard_move(get_current_device()) + else: + assert my_chunk.cuda_global_chunk.size(0) == 1024 + assert my_chunk.device_type == 'cuda' + assert not my_chunk.can_move + + assert dist_sum(my_chunk.valid_end) == my_chunk.utilized_size + flag = my_chunk.has_inf_or_nan + assert not flag, "has_inf_or_nan is {}".format(flag) + + my_chunk.access_chunk() + assert my_chunk.device_type == 'cuda' + for param, param_cp in zip(param_list, param_cp_list): + check_euqal(param, param_cp) + + assert my_chunk.tensor_state_cnter[TensorState.HOLD] == 4 + my_chunk.tensor_trans_state(param_list[0], TensorState.COMPUTE) + assert my_chunk.tensor_state_cnter[TensorState.HOLD] == 3 + assert my_chunk.tensor_state_cnter[TensorState.COMPUTE] == 1 + assert not my_chunk.can_release + + for param in param_list: + my_chunk.tensor_trans_state(param, TensorState.COMPUTE) + my_chunk.tensor_trans_state(param, TensorState.HOLD_AFTER_BWD) + my_chunk.tensor_trans_state(param, TensorState.READY_FOR_REDUCE) + + assert my_chunk.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == 4 + assert my_chunk.can_reduce + my_chunk.reduce() + assert my_chunk.tensor_state_cnter[TensorState.HOLD] == 4 + + if keep_gathered is False: + assert my_chunk.cuda_shard.size(0) == 1024 // world_size + assert my_chunk.device_type == 'cuda' + assert my_chunk.can_move + else: + assert my_chunk.cuda_global_chunk.size(0) == 1024 + assert my_chunk.device_type == 'cuda' + assert not my_chunk.can_move + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_chunk_basic() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2, 4]) +@rerun_if_address_is_in_use() +def test_chunk_function(world_size): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_chunk_function(4) diff --git a/tests/test_zero/test_gemini/test_fwd_bwd.py b/tests/test_zero/test_gemini/test_fwd_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..f2cbb7fb77d600a8b7918897764afaa46a1f674d --- /dev/null +++ b/tests/test_zero/test_gemini/test_fwd_bwd.py @@ -0,0 +1,161 @@ +import pytest +import torch +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.testing import assert_close + +import colossalai +from colossalai.amp import convert_to_apex_amp +from colossalai.nn.optimizer import HybridAdam +from colossalai.tensor import ProcessGroup +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils.cuda import get_current_device +from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer +from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.zero.gemini.gemini_mgr import GeminiManager +from tests.components_to_test import run_fwd, run_fwd_bwd +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_tensor.common_utils import set_seed + + +def check_grad(model: ZeroDDP, torch_model: torch.nn.Module): + chunk_manager = model.chunk_manager + param_list = [p for p in model.parameters()] + chunk_list = chunk_manager.get_chunks(param_list) + for chunk in chunk_list: + chunk_manager.access_chunk(chunk) + + for (p0, p1) in zip(model.parameters(), torch_model.parameters()): + assert_close(p0, p1.grad, rtol=1e-3, atol=5e-5) + + +@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) +@parameterize('keep_gather', [False, True]) +@parameterize('model_name', ['gpt2', 'bert', 'albert']) +@parameterize('use_grad_checkpoint', [False, True]) +def exam_gpt_fwd_bwd( + placement_policy, + keep_gather, + model_name: str, + use_grad_checkpoint: bool = False, +): + init_device = get_current_device() + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + set_seed(42) + with ColoInitContext(device=init_device): + model = model_builder(use_grad_checkpoint) + + set_seed(42) + torch_model = model_builder(use_grad_checkpoint).cuda() + for torch_p, p in zip(torch_model.parameters(), model.parameters()): + torch_p.data.copy_(p.data) + + world_size = torch.distributed.get_world_size() + config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict[world_size]['chunk_size'] = 5000 + config_dict[world_size]['keep_gathered'] = keep_gather + chunk_manager = ChunkManager(config_dict) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + model = ZeroDDP(model, gemini_manager, pin_memory=True) + optimizer = HybridAdam(model.parameters(), lr=1e-3) + zero_optim = ZeroOptimizer(optimizer, model, initial_scale=1) + + pg = ProcessGroup() + amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1) + torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) + torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) + torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group()) + + set_seed(pg.dp_local_rank()) + for i, (input_ids, label) in enumerate(train_dataloader): + # you can only test a single fwd + bwd. + # after bwd param is grad for Gemini, due to the chunk reuse optimization. + if i > 0: + break + input_ids, label = input_ids.cuda(), label.cuda() + + torch_optim.zero_grad() + zero_optim.zero_grad() + + # set random seed is same as torch_model.eval() + set_seed(42) + torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) + set_seed(42) + loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) + + assert torch.equal(torch_loss, loss) + + check_grad(model, torch_model) + + +@parameterize('placement_policy', ['cuda', 'cpu']) +@parameterize('keep_gather', [False, True]) +@parameterize('model_name', ['gpt2', 'bert', 'albert']) +@parameterize('scatter_after_inference', [False, True]) +def exam_gpt_inference( + placement_policy, + keep_gather, + model_name: str, + scatter_after_inference: bool = False, +): + init_device = get_current_device() + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + set_seed(42) + with ColoInitContext(device=init_device): + model = model_builder() + + set_seed(42) + torch_model = model_builder().cuda() + for torch_p, p in zip(torch_model.parameters(), model.parameters()): + torch_p.data.copy_(p.data) + + world_size = torch.distributed.get_world_size() + config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict[world_size]['chunk_size'] = 5000 + config_dict[world_size]['keep_gathered'] = keep_gather + chunk_manager = ChunkManager(config_dict) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + model = ZeroDDP(model, gemini_manager, pin_memory=True, scatter_after_inference=scatter_after_inference) + + pg = ProcessGroup() + amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1) + torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) + torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) + torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group()) + + set_seed(pg.dp_local_rank()) + model.eval() + torch_model.eval() + for i, (input_ids, label) in enumerate(train_dataloader): + # you can only test a single fwd + bwd. + # after bwd param is grad for Gemini, due to the chunk reuse optimization. + if i > 0: + break + with torch.no_grad(): + input_ids, label = input_ids.cuda(), label.cuda() + + torch_loss = run_fwd(torch_model, input_ids, label, criterion) + loss = run_fwd(model, input_ids, label, criterion) + + assert torch.equal(torch_loss, loss) + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_gpt_fwd_bwd() + exam_gpt_inference() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_gpt(world_size): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_gpt(4) diff --git a/tests/test_zero/test_gemini/test_gemini_use_rmt.py b/tests/test_zero/test_gemini/test_gemini_use_rmt.py new file mode 100644 index 0000000000000000000000000000000000000000..dd580976d8eafd36c83601fb397242cf52f60d09 --- /dev/null +++ b/tests/test_zero/test_gemini/test_gemini_use_rmt.py @@ -0,0 +1,101 @@ +import pytest +import torch + +import colossalai +from colossalai.tensor import ProcessGroup +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.zero import ColoInitContext, ZeroDDP +from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.zero.gemini.gemini_mgr import GeminiManager +from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer +from tests.components_to_test import run_fwd_bwd +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_tensor.common_utils import set_seed + +# run gemini use the runtime memory tracer + + +@parameterize('placement_policy', ['auto']) +@parameterize('keep_gather', [False]) +@parameterize('model_name', ['repeated_computed_layers', 'bert', 'albert', 'gpt2']) +@parameterize('use_grad_checkpoint', [False, True]) +def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_checkpoint: bool = False): + set_seed(42) + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + with ColoInitContext(device='cpu'): + model = model_builder(use_grad_checkpoint) + + print(f'model_name {model_name}') + runtime_mem_tracer = RuntimeMemTracer(model) + for i, (input_ids, label) in enumerate(train_dataloader): + if i > 0: + break + input_ids, label = input_ids.cuda(), label.cuda() + + # mem tracing + if i == 0: + run_fwd_bwd(runtime_mem_tracer, input_ids, label, criterion, runtime_mem_tracer) + memstats = runtime_mem_tracer.memstats() + runtime_tracer_non_model_data = runtime_mem_tracer._memstats._non_model_data_cuda_list + print('runtime tracer non model data points: ', len(runtime_tracer_non_model_data)) + print('runtime tracer: ', runtime_tracer_non_model_data) + print([memstats.param_used_step(p) for p in model.parameters()]) + + if model_name == 'repeated_computed_layers': + for idx, p in enumerate(model.parameters()): + step_list = memstats.param_used_step(p) + if idx < 4: + assert len(step_list) == 4 + + if model_name == 'repeated_computed_layers': + for idx, p in enumerate(model.parameters()): + step_list = memstats.param_used_step(p) + if idx < 4: + assert len(step_list) == 4 + + world_size = torch.distributed.get_world_size() + config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict[world_size]['chunk_size'] = 5000 + config_dict[world_size]['keep_gathered'] = keep_gather + chunk_manager = ChunkManager(config_dict) + gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats) + model = ZeroDDP(model, gemini_manager, pin_memory=True) + + pg = ProcessGroup() + set_seed(pg.dp_local_rank()) + for i, (input_ids, label) in enumerate(train_dataloader): + # you can only test a single fwd + bwd. + # after bwd param is grad for Gemini, due to the chunk reuse optimization. + # print(f'iteration {i}') + if i > 4: + break + input_ids, label = input_ids.cuda(), label.cuda() + + set_seed(42) + loss = run_fwd_bwd(model, input_ids, label, criterion, model) + + gemini_non_model_data = gemini_manager._mem_stats_collector._memstats.non_model_data_list('cuda') + + # print('gemini non model data:', gemini_non_model_data) + + assert len(gemini_non_model_data) == len(runtime_tracer_non_model_data), \ + f'model_name {model_name} {len(gemini_non_model_data)} vs {len(runtime_tracer_non_model_data)}' + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_gemini_use_rmt() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_gemini_use_rmt(world_size): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_gemini_use_rmt(1) diff --git a/tests/test_zero/test_gemini/test_get_torch_model.py b/tests/test_zero/test_gemini/test_get_torch_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b3e3b2b22fc3709b3cf928384b46f7b2928ef4ff --- /dev/null +++ b/tests/test_zero/test_gemini/test_get_torch_model.py @@ -0,0 +1,52 @@ +import pytest +import torch + +import colossalai +from colossalai.tensor import ColoParameter +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils.cuda import get_current_device +from colossalai.zero import ColoInitContext, GeminiDDP +from colossalai.zero.gemini.utils import get_static_torch_model +from tests.components_to_test.registry import non_distributed_component_funcs + + +@parameterize('model_name', ['hanging_param_model', 'resnet18', 'gpt2']) +def run_convert_torch_module(model_name: str): + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, _, _, _, _ = get_components_func() + + with ColoInitContext(device=torch.device("cpu")): + model = model_builder(checkpoint=False) + model = GeminiDDP(model, device=get_current_device(), placement_policy='auto', pin_memory=True) + pytorch_model = get_static_torch_model(model, only_rank_0=False) + + for n, p in pytorch_model.named_parameters(): + assert type(p) == torch.nn.Parameter, f"type error: {n} is a {type(p)}" + + # get the static model should not change the original model + for n, p in model.named_parameters(): + assert isinstance(p, ColoParameter) + + for (pn, pm), (cn, cm) in zip(pytorch_model.named_modules(), model.named_modules()): + assert pn == cn + assert id(pm) != id(cm) + for pp, cp in zip(pm.parameters(recurse=False), cm.parameters(recurse=False)): + assert id(pp) != id(cp) + assert pp.shape == cp.shape + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_convert_torch_module() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_convert_torch_module(world_size): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_convert_torch_module(2) diff --git a/tests/test_zero/test_gemini/test_grad_clip.py b/tests/test_zero/test_gemini/test_grad_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..38b6e474ea986a74f451b30ddb4b45c94f6b12dc --- /dev/null +++ b/tests/test_zero/test_gemini/test_grad_clip.py @@ -0,0 +1,107 @@ +import pytest +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.testing import assert_close + +import colossalai +from colossalai.amp import convert_to_apex_amp +from colossalai.nn.optimizer import HybridAdam +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils.cuda import get_current_device +from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer +from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.zero.gemini.gemini_mgr import GeminiManager +from tests.components_to_test import run_fwd_bwd +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_tensor.common_utils import set_seed + + +def check_param(model: ZeroDDP, torch_model: torch.nn.Module): + zero_dict = model.state_dict(only_rank_0=False) + torch_dict = torch_model.state_dict() + + for key, value in torch_dict.items(): + # key is 'module.model.PARAMETER', so we truncate it + key = key[7:] + assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) + temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) + # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value))) + assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3) + + +@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) +@parameterize('model_name', ['gpt2']) +def exam_grad_clipping(placement_policy, model_name: str): + set_seed(1912) + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + torch_model = model_builder().cuda() + amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=32) + torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) + torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) + torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) + + init_dev = get_current_device() + with ColoInitContext(device=init_dev): + model = model_builder() + + for torch_p, p in zip(torch_model.parameters(), model.parameters()): + p.data.copy_(torch_p.data) + + world_size = torch.distributed.get_world_size() + config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict[world_size]['chunk_size'] = 5000 + config_dict[world_size]['keep_gathered'] = False + if placement_policy != 'cuda': + init_device = torch.device('cpu') + else: + init_device = None + chunk_manager = ChunkManager(config_dict, init_device=init_device) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + model = ZeroDDP(model, gemini_manager, pin_memory=True) + + optimizer = HybridAdam(model.parameters(), lr=1e-3) + zero_optim = ZeroOptimizer(optimizer, model, initial_scale=32, clipping_norm=1.0) + + model.train() + torch_model.train() + + set_seed(dist.get_rank() * 3 + 128) + for i, (data, label) in enumerate(train_dataloader): + if i > 2: + break + data = data.cuda() + label = label.cuda() + + zero_optim.zero_grad() + torch_optim.zero_grad() + + torch_loss = run_fwd_bwd(torch_model, data, label, criterion, torch_optim) + loss = run_fwd_bwd(model, data, label, criterion, zero_optim) + assert_close(torch_loss, loss) + + import apex.amp as apex_amp + torch.nn.utils.clip_grad_norm_(apex_amp.master_params(torch_optim), 1.0) + torch_optim.step() + zero_optim.step() + + check_param(model, torch_model) + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_grad_clipping() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2]) +@rerun_if_address_is_in_use() +def test_grad_clip(world_size): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_grad_clip(2) diff --git a/tests/test_zero/test_gemini/test_inference.py b/tests/test_zero/test_gemini/test_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..790a0611c9ddab198b18bf6c5276521efabb399e --- /dev/null +++ b/tests/test_zero/test_gemini/test_inference.py @@ -0,0 +1,132 @@ +from typing import Callable + +import pytest +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.testing import assert_close + +import colossalai +from colossalai.amp import convert_to_apex_amp +from colossalai.nn.optimizer import HybridAdam +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils.cuda import get_current_device +from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer, post_process_colo_init_ctx, zero_model_wrapper +from colossalai.zero.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration +from colossalai.zero.gemini.gemini_mgr import GeminiManager +from tests.components_to_test import run_fwd_bwd +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_tensor.common_utils import debug_print, set_seed + + +def check_param(model: ZeroDDP, torch_model: torch.nn.Module): + zero_dict = model.state_dict(only_rank_0=False) + torch_dict = torch_model.state_dict() + + for key, value in torch_dict.items(): + # key is 'module.model.PARAMETER', so we truncate it + key = key[7:] + assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) + temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) + # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value))) + assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3) + + +def multi_chunk_init(model: torch.nn.Module, placement_policy: str): + world_size = dist.get_world_size() + config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict[world_size]['chunk_size'] = 5000 + config_dict[world_size]['keep_gathered'] = False + if placement_policy != 'cuda': + init_device = torch.device('cpu') + else: + init_device = None + chunk_manager = ChunkManager(config_dict, init_device=init_device) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + model = ZeroDDP(model, gemini_manager, pin_memory=True) + return model + + +def single_chunk_init(model: torch.nn.Module, placement_policy: str): + gemini_config = dict( + device=get_current_device(), + placement_policy=placement_policy, + pin_memory=True, + ) + model = zero_model_wrapper(model=model, zero_stage=3, gemini_config=gemini_config) + return model + + +@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) +@parameterize('model_name', ['gpt2']) +@parameterize('model_init_func', [single_chunk_init, multi_chunk_init]) +def exam_inference(placement_policy: str, model_name: str, model_init_func: Callable): + set_seed(19360226) + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + torch_model = model_builder().cuda() + amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=128) + torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) + torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) + torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) + + init_dev = get_current_device() + with ColoInitContext(device=init_dev): + model = model_builder() + + for torch_p, p in zip(torch_model.parameters(), model.parameters()): + p.data.copy_(torch_p.data) + + model = model_init_func(model, placement_policy) + optimizer = HybridAdam(model.parameters(), lr=1e-3) + zero_optim = ZeroOptimizer(optimizer, model, initial_scale=128) + + model.eval() + torch_model.eval() + + set_seed(dist.get_rank() * 3 + 128) + train_dataloader = iter(train_dataloader) + + def train_iter(): + input_ids, label = next(train_dataloader) + input_ids, label = input_ids.cuda(), label.cuda() + zero_optim.zero_grad() + torch_optim.zero_grad() + torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) + loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) + assert_close(torch_loss, loss) + zero_optim.step() + torch_optim.step() + check_param(model, torch_model) + + def inference_iter(): + input_ids, label = next(train_dataloader) + input_ids, label = input_ids.cuda(), label.cuda() + with torch.no_grad(): + torch_output = torch_model(input_ids) + torch_loss = criterion(torch_output.float(), label) + zero_output = model(input_ids) + zero_loss = criterion(zero_output.float(), label) + assert_close(torch_loss, zero_loss) + + train_iter() + inference_iter() + train_iter() + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_inference() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_inference(world_size): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_inference(1) diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py new file mode 100644 index 0000000000000000000000000000000000000000..8ce20c16e8f965824c7a3c63e4be6a5ed8f851a7 --- /dev/null +++ b/tests/test_zero/test_gemini/test_optim.py @@ -0,0 +1,159 @@ +import pytest +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.testing import assert_close + +import colossalai +from colossalai.amp import convert_to_apex_amp +from colossalai.nn.optimizer import HybridAdam +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils.cuda import get_current_device +from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer, post_process_colo_init_ctx +from colossalai.zero.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration +from colossalai.zero.gemini.gemini_mgr import GeminiManager +from tests.components_to_test import run_fwd_bwd +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_tensor.common_utils import debug_print, set_seed + +# this model is large enough to slice to chunks +TEST_MODELS = ['gpt2'] +# these models are too small, all parameters in these models are compacted into one chunk +EXAMPLE_MODELS = ['albert', 'beit', 'bert', 'hanging_param_model', 'nested_model', 'repeated_computed_layers'] + + +def check_param(model: ZeroDDP, torch_model: torch.nn.Module): + zero_dict = model.state_dict(only_rank_0=False) + torch_dict = torch_model.state_dict() + + for key, value in torch_dict.items(): + # key is 'module.model.PARAMETER', so we truncate it + key = key[7:] + assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) + temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) + # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value))) + assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3) + + +@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) +@parameterize('model_name', TEST_MODELS) +def exam_model_step(placement_policy, model_name: str): + set_seed(42) + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + torch_model = model_builder().cuda() + amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=128) + torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) + torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) + torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) + + init_dev = get_current_device() + with ColoInitContext(device=init_dev): + model = model_builder() + + for torch_p, p in zip(torch_model.parameters(), model.parameters()): + p.data.copy_(torch_p.data) + + world_size = torch.distributed.get_world_size() + config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict[world_size]['chunk_size'] = 5000 + config_dict[world_size]['keep_gathered'] = False + if placement_policy != 'cuda': + init_device = torch.device('cpu') + else: + init_device = None + chunk_manager = ChunkManager(config_dict, init_device=init_device) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + model = ZeroDDP(model, gemini_manager, pin_memory=True) + + optimizer = HybridAdam(model.parameters(), lr=1e-3) + zero_optim = ZeroOptimizer(optimizer, model, initial_scale=128) + + model.eval() + torch_model.eval() + + set_seed(dist.get_rank() * 3 + 128) + for i, (input_ids, label) in enumerate(train_dataloader): + if i > 2: + break + input_ids, label = input_ids.cuda(), label.cuda() + zero_optim.zero_grad() + torch_optim.zero_grad() + + torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) + loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) + assert_close(torch_loss, loss) + + zero_optim.step() + torch_optim.step() + + check_param(model, torch_model) + + +@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) +@parameterize('model_name', EXAMPLE_MODELS) +def exam_tiny_example(placement_policy, model_name: str): + set_seed(2008) + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + torch_model = model_builder().cuda() + amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=2) + torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) + torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) + torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) + + init_dev = get_current_device() + with ColoInitContext(device=init_dev): + model = model_builder() + + for torch_p, p in zip(torch_model.parameters(), model.parameters()): + p.data.copy_(torch_p.data) + + chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_mb=1) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + model = ZeroDDP(model, gemini_manager, pin_memory=True) + optimizer = HybridAdam(model.parameters(), lr=1e-3) + zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2) + + model.eval() + torch_model.eval() + + set_seed(dist.get_rank() * 3 + 128) + for i, (input_ids, label) in enumerate(train_dataloader): + if i > 2: + break + + input_ids = input_ids.cuda() + label = label.cuda() + + zero_optim.zero_grad() + torch_optim.zero_grad() + + torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) + loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) + assert_close(torch_loss, loss, rtol=1.5e-6, atol=2e-5) # atol should be 2e-5 for torch lower than 1.12 + + zero_optim.step() + torch_optim.step() + + check_param(model, torch_model) + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_model_step() + exam_tiny_example() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_optim(world_size): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_optim(1) diff --git a/tests/test_zero/test_gemini/test_runtime_mem_tracer.py b/tests/test_zero/test_gemini/test_runtime_mem_tracer.py new file mode 100644 index 0000000000000000000000000000000000000000..0e6f283aa5d23ed611d75f0209e3ef4298adb0fe --- /dev/null +++ b/tests/test_zero/test_gemini/test_runtime_mem_tracer.py @@ -0,0 +1,54 @@ +from copy import deepcopy + +import numpy as np +import torch + +from colossalai.testing import clear_cache_before_run +from colossalai.zero import ColoInitContext +from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer +from tests.components_to_test import run_fwd_bwd +from tests.components_to_test.registry import non_distributed_component_funcs + + +@clear_cache_before_run() +def test_runtime_mem_tracer(): + test_models = ['gpt2', 'bert', 'simple_net', 'repeated_computed_layers', 'nested_model', 'albert'] + + for model_name in test_models: + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, _, _, criterion = get_components_func() + + with ColoInitContext(device='cpu'): + model = model_builder(checkpoint=False) + + model_bk = deepcopy(model) + runtime_mem_tracer = RuntimeMemTracer(model) + + for i, (data, label) in enumerate(train_dataloader): + if i > 1: + break + data = data.cuda() + label = label.cuda() + + run_fwd_bwd(runtime_mem_tracer, data, label, criterion, optimizer=runtime_mem_tracer) + + for p1, p2 in zip(model_bk.parameters(), model.parameters()): + torch.allclose(p1.to(torch.half), p2) + + non_model_data_list = runtime_mem_tracer._memstats.non_model_data_list('cuda') + cuda_non_model_data_list = np.array(non_model_data_list) / 1024**2 + print("cuda_non_model_data_list", len(cuda_non_model_data_list)) + print(non_model_data_list) + + cnt1 = 0 + for p in runtime_mem_tracer.parameters_in_runtime_order(): + cnt1 += 1 + cnt2 = 0 + for p in model.parameters(): + cnt2 += 1 + assert cnt2 == cnt1, f'visited param number {cnt1} vs real param number {cnt2}' + del model + + +if __name__ == '__main__': + test_runtime_mem_tracer() diff --git a/tests/test_zero/test_gemini/test_search.py b/tests/test_zero/test_gemini/test_search.py new file mode 100644 index 0000000000000000000000000000000000000000..35b3b93ade0c223100c660c7612a97e8591bd6ff --- /dev/null +++ b/tests/test_zero/test_gemini/test_search.py @@ -0,0 +1,118 @@ +import pytest +import torch + +import colossalai +from colossalai.tensor import ComputePattern, ComputeSpec, ProcessGroup, ShardSpec +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device +from colossalai.zero import ColoInitContext +from colossalai.zero.gemini.chunk import init_chunk_manager, search_chunk_configuration +from tests.components_to_test.registry import non_distributed_component_funcs + + +def init_1d_row_spec(model, pg: ProcessGroup): + tensor_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + for n, p in model.named_parameters(): + if 'weight' in n and 'ln' not in n: + p.set_process_group(pg) + p.set_tensor_spec(*tensor_spec) + + +def exam_search_chunk_size(): + world_size = torch.distributed.get_world_size() + pg_tp = ProcessGroup(tp_degree=world_size) + + get_components_func = non_distributed_component_funcs.get_callable('gpt2') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + # make sure torch_model and model has the same parameter values + with ColoInitContext(device=get_current_device()): + model = model_builder() + init_1d_row_spec(model, pg_tp) + config_dict, *_ = search_chunk_configuration(model, + search_range_mb=1, + search_interval_byte=16, + min_chunk_size_mb=0, + filter_exlarge_params=True) + + for key in config_dict: + chunk_size = config_dict[key]['chunk_size'] + if world_size == 1: + assert chunk_size == 31616 + else: + assert chunk_size == 1024 + + +def exam_search_strict_ddp(): + world_size = torch.distributed.get_world_size() + default_shard_pg = ProcessGroup(tp_degree=world_size) + default_shard_spec = ShardSpec([-1], [world_size]) + + get_components_func = non_distributed_component_funcs.get_callable('gpt2') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + # get the chunk configuration over replicated models + with ColoInitContext(device=get_current_device()): + ddp_model = model_builder() + re_dict, re_total, re_wasted = search_chunk_configuration(ddp_model, + search_range_mb=1, + search_interval_byte=16, + min_chunk_size_mb=0, + filter_exlarge_params=True, + strict_ddp_flag=False) + # get the chunk configuration over sharded ddp models + with ColoInitContext(device=get_current_device(), default_pg=default_shard_pg, + default_dist_spec=default_shard_spec): + sharded_ddp_model = model_builder() + sh_dict, sh_total, sh_wasted = search_chunk_configuration(sharded_ddp_model, + search_range_mb=1, + search_interval_byte=16, + min_chunk_size_mb=0, + filter_exlarge_params=True, + strict_ddp_flag=True) + assert re_dict == sh_dict + for key in re_dict: + assert re_dict[key] == sh_dict[key] + + assert re_total == sh_total + assert re_wasted == sh_wasted + + +def exam_chunk_manager(): + world_size = torch.distributed.get_world_size() + default_shard_pg = ProcessGroup(tp_degree=world_size) + default_shard_spec = ShardSpec([-1], [world_size]) + + get_components_func = non_distributed_component_funcs.get_callable('gpt2') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + with ColoInitContext(device=get_current_device(), default_pg=default_shard_pg, + default_dist_spec=default_shard_spec): + sharded_ddp_model = model_builder() + chunk_manager = init_chunk_manager(sharded_ddp_model, + get_current_device(), + hidden_dim=16, + search_range_mb=1, + min_chunk_size_mb=0, + filter_exlarge_params=True, + strict_ddp_flag=True) + config_dict = chunk_manager.dp_degree_chunk_size_dict + assert len(config_dict) == 1 + assert config_dict[world_size] == 31616 + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_search_chunk_size() + exam_search_strict_ddp() + exam_chunk_manager() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_search(world_size): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_search(4) diff --git a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py new file mode 100644 index 0000000000000000000000000000000000000000..66e05f3ed1ecb7b6a0d58ff3c29dad6c6a477ef8 --- /dev/null +++ b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py @@ -0,0 +1,107 @@ +import pytest +import torch +from torch.testing import assert_close + +import colossalai +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils.cuda import get_current_device +from colossalai.zero import ColoInitContext, ZeroDDP +from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.zero.gemini.gemini_mgr import GeminiManager +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_tensor.common_utils import debug_print, set_seed + + +def ignore_the_first_parameter(model: torch.nn.Module): + for name, param in model.named_parameters(): + print(f"parameter `{name}` is set ignored") + ZeroDDP.set_params_to_ignore([param]) + return + + +@parameterize('placement_policy', ['cuda', 'cpu', 'auto']) +@parameterize('keep_gathered', [True, False]) +@parameterize('model_name', ['gpt2', 'bert']) +def exam_state_dict(placement_policy, keep_gathered, model_name: str): + set_seed(431) + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + with ColoInitContext(device=get_current_device()): + model = model_builder() + + torch_model = model_builder() + for torch_p, p in zip(torch_model.parameters(), model.parameters()): + torch_p.data.copy_(p.data) + + world_size = torch.distributed.get_world_size() + config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict[world_size]['chunk_size'] = 5000 + config_dict[world_size]['keep_gathered'] = keep_gathered + chunk_manager = ChunkManager(config_dict) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + model = ZeroDDP(model, gemini_manager, pin_memory=True) + model.train() + + zero_dict = model.state_dict(only_rank_0=False) + torch_dict = torch_model.state_dict() + + for key, value in torch_dict.items(): + assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) + temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) + assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5) + + +@parameterize('placement_policy', ['cuda', 'cpu', 'auto']) +@parameterize('keep_gathered', [True, False]) +@parameterize('model_name', ['gpt2', 'bert']) +def exam_load_state_dict(placement_policy, keep_gathered, model_name: str): + set_seed(431) + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + with ColoInitContext(device=get_current_device()): + model = model_builder() + + set_seed(451) + torch_model = model_builder() # get a different model + + world_size = torch.distributed.get_world_size() + config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict[world_size]['chunk_size'] = 5000 + config_dict[world_size]['keep_gathered'] = keep_gathered + + if placement_policy != 'cuda': + init_device = torch.device('cpu') + else: + init_device = None + chunk_manager = ChunkManager(config_dict, init_device=init_device) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + model = ZeroDDP(model, gemini_manager, pin_memory=True) + + torch_dict = torch_model.state_dict() + model.load_state_dict(torch_dict, strict=False) + zero_dict = model.state_dict(only_rank_0=False) + + for key, value in torch_dict.items(): + assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) + temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) + assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5) + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_state_dict() + exam_load_state_dict() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_zero_ddp(world_size): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_zero_ddp(1) diff --git a/tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py b/tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py new file mode 100644 index 0000000000000000000000000000000000000000..96c26a1de4df58111c943cd3e11d2a5ce673ab85 --- /dev/null +++ b/tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py @@ -0,0 +1,56 @@ +import pytest +import torch +from torch.testing import assert_close + +import colossalai +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils.cuda import get_current_device +from colossalai.zero import ColoInitContext, ZeroDDP +from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.zero.gemini.gemini_mgr import GeminiManager +from tests.components_to_test.registry import non_distributed_component_funcs + + +@parameterize('placement_policy', ['cuda', 'cpu']) +@parameterize('model_name', ['gpt2', 'bert']) +def exam_state_dict(placement_policy, model_name: str): + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + with ColoInitContext(device=get_current_device()): + model = model_builder() + + model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2 + + config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + chunk_manager = ChunkManager(config_dict) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + model = ZeroDDP(model, gemini_manager) + model.train() + + zero_dict = model.state_dict(only_rank_0=False) + accumulated_keys = set() + # ensure number of shards > 1 + for shard in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False): + for key, value in shard.items(): + assert key not in accumulated_keys, f"key `{key}` is duplicated." + accumulated_keys.add(key) + assert key in zero_dict, f"{key} not in ZeRO dictionary." + assert torch.equal(value, zero_dict[key]), f"{key} not equal." + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_state_dict() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_zero_ddp_state_dict_shard(world_size): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_zero_ddp_state_dict_shard(1) diff --git a/tests/test_zero/test_gemini/test_zerooptim_state_dict.py b/tests/test_zero/test_gemini/test_zerooptim_state_dict.py new file mode 100644 index 0000000000000000000000000000000000000000..a8af176c5b3dc9b1880fd6ce5afc8584114b3b50 --- /dev/null +++ b/tests/test_zero/test_gemini/test_zerooptim_state_dict.py @@ -0,0 +1,86 @@ +import pytest +import torch +import torch.distributed as dist + +import colossalai +from colossalai.nn.optimizer import HybridAdam +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils.cuda import get_current_device +from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer +from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.zero.gemini.gemini_mgr import GeminiManager +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_tensor.common_utils import debug_print, set_seed + + +@parameterize('placement_policy', ['cuda', 'cpu', 'auto']) +@parameterize('keep_gathered', [True, False]) +def exam_zero_optim_state_dict(placement_policy, keep_gathered): + set_seed(431) + get_components_func = non_distributed_component_funcs.get_callable('gpt2') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + with ColoInitContext(device=get_current_device()): + model = model_builder() + + set_seed(451) + torch_model = model_builder() # get a different model + + world_size = torch.distributed.get_world_size() + config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict[world_size]['chunk_size'] = 5000 + config_dict[world_size]['keep_gathered'] = keep_gathered + + if placement_policy != 'cuda': + init_device = torch.device('cpu') + else: + init_device = None + chunk_manager = ChunkManager(config_dict, init_device=init_device) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + model = ZeroDDP(model, gemini_manager, pin_memory=True) + + optimizer = HybridAdam(model.parameters()) + optim = ZeroOptimizer(optimizer, model, initial_scale=32) # initialize the link between chunk16 and chunk32 + + set_seed(dist.get_rank() * 3 + 128) + model.train() + for i, (input_ids, label) in enumerate(train_dataloader): + if i > 0: + break + optim.zero_grad() + logits = model(input_ids) + logits = logits.float() + loss = criterion(logits, input_ids) + optim.backward(loss) + optim.step() + + optim_state_dict = optim.state_dict() + optim.load_state_dict(optim_state_dict) + new_state = optim.state_dict()['state'] + org_state = optim_state_dict['state'] + + for k, v in org_state.items(): + w = new_state[k] + for n, m in v.items(): + if isinstance(m, torch.Tensor): + o = w[n] + assert torch.equal(m, o) + else: + assert m == w[n] + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_zero_optim_state_dict() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_zero_optim(world_size): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_zero_optim(1) diff --git a/tests/test_zero/test_legacy/common.py b/tests/test_zero/test_legacy/common.py new file mode 100644 index 0000000000000000000000000000000000000000..2c3d122c79af9c959d91c4c4dcf0e5f984656ed1 --- /dev/null +++ b/tests/test_zero/test_legacy/common.py @@ -0,0 +1,140 @@ +from functools import partial + +import torch +import torch.distributed as dist + +from colossalai.logging import get_dist_logger +from colossalai.utils import checkpoint +from colossalai.zero.legacy.shard_utils import TensorShardStrategy +from colossalai.zero.legacy.sharded_model import ShardedModelV2 + +LOGGER = get_dist_logger('zero_test') + +MP_PARALLEL_CONFIG = dict(fp16=dict(mode=None,), parallel=dict(pipeline=dict(size=1), tensor=dict(size=2, mode=None))) + +_ZERO_MODEL_CONFIG = dict(reduce_scatter_bucket_size_mb=25, + fp32_reduce_scatter=False, + tensor_placement_policy='cuda', + gradient_predivide_factor=1.0, + shard_strategy=TensorShardStrategy(), + reuse_fp16_shard=False) + +_ZERO_OPTIMIZER_CONFIG = dict(initial_scale=2**5, + min_scale=1, + growth_factor=2, + backoff_factor=0.5, + growth_interval=1000, + hysteresis=2, + max_scale=2**32) + +ZERO_PARALLEL_CONFIG = dict(fp16=dict(mode=None,), + zero=dict( + model_config=_ZERO_MODEL_CONFIG, + optimizer_config=_ZERO_OPTIMIZER_CONFIG, + ), + parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None))) + +CONFIG = dict(fp16=dict(mode=None,), + zero=dict(level=3, + verbose=False, + offload_optimizer_config=dict(device='cpu', pin_memory=True, buffer_count=5, fast_init=False), + offload_param_config=dict(device='cpu', + pin_memory=True, + buffer_count=5, + buffer_size=1e8, + max_in_cpu=1e9)), + parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None))) + + +def run_fwd_bwd(model, data, label, criterion, enable_autocast=False): + model.train() + with torch.cuda.amp.autocast(enabled=enable_autocast): + if criterion: + y = model(data) + loss = criterion(y, label) + else: + loss = model(data, label) + loss = loss.float() + if isinstance(model, ShardedModelV2): + model.backward(loss) + else: + loss.backward() + + +def checkpoint_wrapper(module, enable=True): + if enable: + module.forward = partial(checkpoint, module.forward) + return module + + +def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool: + if loose: + return torch.allclose(tensor_a, tensor_b, atol=1e-2, rtol=1e-3) + return torch.allclose(tensor_a, tensor_b) + + +def check_grads(model, zero_model, loose=False): + for p, zero_p in zip(model.parameters(), zero_model.parameters()): + zero_grad = zero_p.grad.clone().to(p.device) + grad = p.grad.float() + assert grad.dtype == zero_grad.dtype + assert allclose(grad, zero_grad, loose=loose) + + +def check_params(model, zero_model, loose=False): + for p, zero_p in zip(model.parameters(), zero_model.parameters()): + zero_p = zero_p.clone().to(p.device) + # assert p.dtype == zero_p.dtype + assert allclose(p.float(), zero_p.float(), loose=loose), f"diff {p.float() - zero_p.float()}" + + +def check_grads_padding(model, zero_model, loose=False): + rank = dist.get_rank() + for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()): + # zero_grad = zero_p.grad.clone().to(p.device) + if zero_p.colo_attr.is_replicated: + zero_grad = zero_p.colo_attr.grad_payload.clone().to(p.device) + chunks = torch.flatten(p.grad).chunk(dist.get_world_size()) + if rank >= len(chunks): + continue + grad = chunks[rank].float() + if zero_grad.size(0) > grad.size(0): + zero_grad = zero_grad[:grad.size(0)] + else: + zero_grad = zero_p.colo_attr.grad_payload + grad = p.grad.to(zero_grad.dtype) + + assert grad.dtype == zero_grad.dtype + assert allclose(grad, zero_grad, loose=loose), f'diff: {grad - zero_grad}' + + +def check_params_padding(model, zero_model, loose=False): + rank = dist.get_rank() + for p, zero_p in zip(model.parameters(), zero_model.parameters()): + zero_p = zero_p.clone().to(p.device) + chunks = torch.flatten(p).chunk(dist.get_world_size()) + if rank >= len(chunks): + continue + p = chunks[rank] + if zero_p.size(0) > p.size(0): + zero_p = zero_p[:p.size(0)] + assert p.dtype == zero_p.dtype + assert allclose(p, zero_p, loose=loose) + + +def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=False): + rank = dist.get_rank() + for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()): + if zero_p.colo_attr.param_is_sharded: + zero_p = zero_p.colo_attr.data_payload.to(p.device).float() + chunks = torch.flatten(p).chunk(dist.get_world_size()) + if rank >= len(chunks): + continue + p = chunks[rank].float() + if zero_p.size(0) > p.size(0): + zero_p = zero_p[:p.size(0)] + else: + zero_p = zero_p.colo_attr.data_payload.to(p.device) + + assert p.dtype == zero_p.dtype, "Parameter `{}`:\n{} vs {}".format(name, p.dtype, zero_p.dtype) + assert allclose(p, zero_p, loose=loose), f'{p} vs {zero_p}' diff --git a/tests/test_zero/test_legacy/test_found_inf.py b/tests/test_zero/test_legacy/test_found_inf.py new file mode 100644 index 0000000000000000000000000000000000000000..e90158e0a43b65ae756a44d45e5203d82a5cc292 --- /dev/null +++ b/tests/test_zero/test_legacy/test_found_inf.py @@ -0,0 +1,67 @@ +import pytest +import torch +from common import CONFIG +from test_sharded_optim_v2 import _run_step + +import colossalai +from colossalai.nn.optimizer import HybridAdam +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils.cuda import get_current_device +from colossalai.zero.legacy.init_ctx import ZeroInitContext +from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy +from colossalai.zero.legacy.sharded_model import ShardedModelV2 +from colossalai.zero.legacy.sharded_optim import ShardedOptimizerV2 +from colossalai.zero.low_level._utils import has_inf_or_nan +from tests.components_to_test.registry import non_distributed_component_funcs + + +@parameterize("cpu_offload", [True, False]) +@parameterize("shard_strategy_class", [BucketTensorShardStrategy]) +@parameterize("gpu_margin_mem_ratio", [0.0, 0.7]) +def _run_test_found_inf(cpu_offload, shard_strategy_class, gpu_margin_mem_ratio): + test_models = ['repeated_computed_layers'] + shard_strategy = shard_strategy_class() + + for model_name in test_models: + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() + + with ZeroInitContext(target_device=torch.device(f'cpu:0') if cpu_offload else get_current_device(), + shard_strategy=shard_strategy, + shard_param=True): + zero_model = model_builder(checkpoint=True) + zero_model = ShardedModelV2( + zero_model, + shard_strategy, + tensor_placement_policy='cpu' if cpu_offload else 'cuda', + reuse_fp16_shard=True, + ) + + sharded_optim = HybridAdam(zero_model.parameters(), lr=1e-3) + sharded_optim = ShardedOptimizerV2(zero_model, sharded_optim, gpu_margin_mem_ratio=gpu_margin_mem_ratio) + + for i, (data, label) in enumerate(train_dataloader): + if i > 1: + break + assert zero_model.overflow_counter == 0 + data, label = data.cuda(), label.cuda() + _run_step(zero_model, sharded_optim, data, label, criterion, False) + for param in zero_model.parameters(): + assert not has_inf_or_nan(param.colo_attr.data_payload) + + +def _run_dist(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + _run_test_found_inf() + + +# use_cpuadam = True can be used with cpu_offload = False +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [1, 2]) +@rerun_if_address_is_in_use() +def test_found_inf(world_size): + spawn(_run_dist, world_size) + + +if __name__ == '__main__': + test_found_inf(world_size=2) diff --git a/tests/test_zero/test_legacy/test_gemini_manager.py b/tests/test_zero/test_legacy/test_gemini_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..0e956f7cc6178790492abfb91bcdcb2eb5946823 --- /dev/null +++ b/tests/test_zero/test_legacy/test_gemini_manager.py @@ -0,0 +1,75 @@ +import pytest +import torch + +from colossalai.testing import clear_cache_before_run +from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState + + +@pytest.mark.dist +@clear_cache_before_run() +def test_gemini_manager(): + # reset the manager, in case that there exists memory information left + manager = StatefulTensor.GST_MGR + manager.reset() + + # occupation 8 + st1 = StatefulTensor(torch.empty(2, 2, dtype=torch.float16, device='cuda')) + # occupation 60 + st2 = StatefulTensor(torch.empty(3, 5, dtype=torch.float32, device='cpu')) + + # occupation 28 + t1 = torch.empty(7, device='cuda') + # occupation 12 + t2 = torch.empty(3, device='cpu') + st3 = StatefulTensor(t1, TensorState.HOLD_AFTER_FWD) + st4 = StatefulTensor(None, TensorState.FREE) + + assert manager.total_number == 4 + assert manager.total_mem['cpu'] == 60 + assert manager.total_mem['cuda'] == 36 + assert manager.state_mem['cpu'][TensorState.HOLD] == 60 + assert manager.state_mem['cuda'][TensorState.HOLD] == 8 + assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 28 + + st4.payload_reset(t2) + st3.payload_reset(t2) + + assert manager.total_number == 4 + assert manager.total_mem['cpu'] == 84 + assert manager.total_mem['cuda'] == 8 + assert manager.state_mem['cpu'][TensorState.HOLD] == 72 + assert manager.state_mem['cuda'][TensorState.HOLD] == 8 + assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 12 + assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 0 + + st1.move_to(torch.device('cpu')) + st2.move_to(torch.device('cpu')) + st3.move_to(torch.device('cuda', 0)) + + assert manager.total_number == 4 + assert manager.total_mem['cpu'] == 80 + assert manager.total_mem['cuda'] == 12 + assert manager.state_mem['cpu'][TensorState.HOLD] == 80 + assert manager.state_mem['cuda'][TensorState.HOLD] == 0 + assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0 + assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12 + + st1.trans_state(TensorState.COMPUTE) + st2.trans_state(TensorState.COMPUTE) + st2.trans_state(TensorState.HOLD_AFTER_BWD) + + assert manager.total_number == 4 + assert manager.total_mem['cpu'] == 80 + assert manager.total_mem['cuda'] == 12 + assert manager.state_mem['cpu'][TensorState.HOLD] == 12 + assert manager.state_mem['cuda'][TensorState.HOLD] == 0 + assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0 + assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12 + assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_BWD] == 60 + assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_BWD] == 0 + assert manager.state_mem['cpu'][TensorState.COMPUTE] == 8 + assert manager.state_mem['cuda'][TensorState.COMPUTE] == 0 + + +if __name__ == '__main__': + test_gemini_manager() diff --git a/tests/test_zero/test_legacy/test_init_context.py b/tests/test_zero/test_legacy/test_init_context.py new file mode 100644 index 0000000000000000000000000000000000000000..84493827193eef6b47db051fe1fe55c1a10098b1 --- /dev/null +++ b/tests/test_zero/test_legacy/test_init_context.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import pytest +import torch +from common import CONFIG + +import colossalai +from colossalai.logging import get_dist_logger +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils.cuda import get_current_device +from colossalai.utils.memory import colo_device_memory_used +from colossalai.zero.gemini.memory_tracer.utils import colo_model_mem_usage +from colossalai.zero.legacy.init_ctx import ZeroInitContext +from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy +from tests.components_to_test.registry import non_distributed_component_funcs + + +@parameterize("init_device_type", ['cpu', 'cuda']) +@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) +def run_model_test(init_device_type, shard_strategy_class): + logger = get_dist_logger("test_zero_init") + + for name, get_components_func in non_distributed_component_funcs._registry.items(): + # because the ZeroInitContext automatically turns parameters to fp16 + # and the beit model use tensor.erfinv_() function to initialize weights + # tensor.erfinv_() doesn't support Half in CPU, we omit the beit model + if name == 'beit': + continue + model_builder, _, _, _, _ = get_components_func() + if init_device_type == 'cuda': + init_device = get_current_device() + elif init_device_type == 'cpu': + init_device = torch.device("cpu") + else: + continue + + model_numel_tensor = torch.zeros(1, dtype=torch.int) + with ZeroInitContext(target_device=init_device, + shard_strategy=shard_strategy_class(), + shard_param=True, + model_numel_tensor=model_numel_tensor): + model = model_builder(checkpoint=True) + + for param in model.parameters(): + assert hasattr(param, 'colo_attr') + assert param.colo_attr.sharded_data_tensor.dtype == torch.half + assert param.colo_attr.sharded_data_tensor.is_sharded + assert param.colo_attr.data_payload.device.type == init_device.type, \ + f'{param.colo_attr.data_payload.device.type} vs. {init_device.type}' + + cuda_mem_use, _ = colo_model_mem_usage(model) + model_data_cuda_mem_MB = cuda_mem_use / 1e6 + logger.info(f"Existing ZeRO Context.\nModel Data CUDA Memory {model_data_cuda_mem_MB} MB", ranks=[0]) + sys_cuda_mem_MB = colo_device_memory_used(get_current_device()) / 1e6 + logger.info(f"System CUDA Memory Usage {sys_cuda_mem_MB} MB", ranks=[0]) + logger.info(f"Model Number Parameter {model_numel_tensor.numpy()[0]/1e6} M", ranks=[0]) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_model_test() + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [1, 4]) +@rerun_if_address_is_in_use() +def test_zero_init_context(world_size): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_zero_init_context(1) diff --git a/tests/test_zero/test_legacy/test_param_op.py b/tests/test_zero/test_legacy/test_param_op.py new file mode 100644 index 0000000000000000000000000000000000000000..b91371b98922d25c69af15c6ba645966a1a2ac88 --- /dev/null +++ b/tests/test_zero/test_legacy/test_param_op.py @@ -0,0 +1,82 @@ +import copy + +import torch + +from colossalai.testing import clear_cache_before_run +from colossalai.zero.legacy.gemini.paramhooks import BaseParamHookMgr +from tests.components_to_test.registry import non_distributed_component_funcs + + +def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool: + if loose: + return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3) + return torch.allclose(tensor_a, tensor_b) + + +def run_model(model, inputs, label, criterion, use_param_hook=False): + if use_param_hook: + + class HooKWrapper: + + def __init__(self) -> None: + self.hook_triggered_times = 0 + + def wrapper_func(self): + + def hook(param, grad) -> torch.Tensor or None: + self.hook_triggered_times += 1 + return grad + + return hook + + hookwrapper = HooKWrapper() + param_list = [p for p in model.parameters()] + hook_mgr = BaseParamHookMgr(param_list) + hook_mgr.register_backward_hooks(hookwrapper.wrapper_func()) + + model.zero_grad(set_to_none=True) + + with torch.cuda.amp.autocast(): + if criterion: + y = model(inputs) + loss = criterion(y, label) + else: + loss = model(inputs, label) + loss = loss.float() + loss.backward() + + if use_param_hook: + hook_mgr.remove_hooks() + return hookwrapper.hook_triggered_times + + +@clear_cache_before_run() +def test_base_param_hook(): + test_models = ['repeated_computed_layers', 'resnet18', 'hanging_param_model', 'inline_op_model'] + # test_models = ['bert'] + + for model_name in test_models: + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, _, _, criterion = get_components_func() + + torch.manual_seed(0) + model = model_builder(checkpoint=True).cuda() + model.train() + + for i, (inputs, label) in enumerate(train_dataloader): + if i > 0: + break + model_copy = copy.deepcopy(model) + + run_model(model, inputs.cuda(), label.cuda(), criterion, False) + ret2 = run_model(model_copy, inputs.cuda(), label.cuda(), criterion, True) + + # Make sure param hook has only be fired once in case of parameter sharing + assert ret2 == len(list(model.parameters())) + + for p, p_copy in zip(model.parameters(), model_copy.parameters()): + assert allclose(p.grad, p_copy.grad), f"{p.grad} vs {p_copy.grad}" + + +if __name__ == '__main__': + test_base_param_hook() diff --git a/tests/test_zero/test_legacy/test_shard_model_v2.py b/tests/test_zero/test_legacy/test_shard_model_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..93d624aa2bbd66c17a2636370687358c3c4fe25e --- /dev/null +++ b/tests/test_zero/test_legacy/test_shard_model_v2.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import pytest +import torch +from common import CONFIG, check_grads_padding, run_fwd_bwd +from torch.nn.parallel import DistributedDataParallel as DDP + +import colossalai +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.zero.legacy.init_ctx import ZeroInitContext +from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy +from colossalai.zero.legacy.sharded_model import ShardedModelV2 +from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp16 +from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy +from tests.components_to_test.registry import non_distributed_component_funcs + + +@parameterize("enable_autocast", [True]) +@parameterize("shard_strategy_class", [BucketTensorShardStrategy]) +def run_model_test(enable_autocast, shard_strategy_class): + test_models = ['repeated_computed_layers', 'resnet18', 'bert', 'hanging_param_model'] + shard_strategy = shard_strategy_class() + for model_name in test_models: + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, _, _, criterion = get_components_func() + + with ZeroInitContext(target_device=torch.device('cuda', torch.cuda.current_device()), + shard_strategy=shard_strategy, + shard_param=True): + zero_model = model_builder(checkpoint=True) + zero_model = ShardedModelV2(zero_model, shard_strategy) + + model = model_builder(checkpoint=True).half() + col_model_deepcopy(zero_model, model) + model = model.cuda() + + model = DDP(model, device_ids=[torch.cuda.current_device()]) + + for i, (data, label) in enumerate(train_dataloader): + if i > 5: + break + + data, label = cast_tensor_to_fp16(data).cuda(), label.cuda() + run_fwd_bwd(model, data, label, criterion, enable_autocast) + run_fwd_bwd(zero_model, data, label, criterion, enable_autocast) + + check_grads_padding(model, zero_model, loose=True) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_model_test() + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [1, 2]) +@rerun_if_address_is_in_use() +def test_shard_model_v2(world_size): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_shard_model_v2(world_size=2) diff --git a/tests/test_zero/test_legacy/test_shard_param.py b/tests/test_zero/test_legacy/test_shard_param.py new file mode 100644 index 0000000000000000000000000000000000000000..4ba43edceb5d1cd9ad2413ac2c0d771084d6056d --- /dev/null +++ b/tests/test_zero/test_legacy/test_shard_param.py @@ -0,0 +1,91 @@ +from copy import deepcopy + +import pytest +import torch +from common import CONFIG, allclose + +import colossalai +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor +from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy +from colossalai.zero.legacy.sharded_param import ShardedTensor +from colossalai.zero.legacy.sharded_param.sharded_param import ShardedParamV2 + + +@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) +def run_shard_tensor_with_strategy(shard_strategy_class, world_size): + t = ShardedTensor(tensor=torch.randn(world_size * 2, 3)) + assert list(t.origin_shape) == [world_size * 2, 3] + assert list(t.shape) == [world_size * 2, 3] + + shard_strategy = shard_strategy_class() + + # test shard strategy + shard_strategy.shard([t]) + assert list(t.shape) == [6], f"{list(t.shape)} vs 6" + shard_strategy.gather([t]) + assert list(t.shape) == [world_size * 2, 3], f"{list(t.shape)} vs {[world_size * 2, 3]}" + + +def _run_shard_tensor(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_shard_tensor_with_strategy(world_size=world_size) + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [1, 2]) +@rerun_if_address_is_in_use() +def test_shard_tensor(world_size): + spawn(_run_shard_tensor, world_size) + + +def _run_shard_param_v2(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + param = torch.nn.Parameter(torch.randn(2, 3)) + param_ref = deepcopy(param) + sparam = ShardedParamV2(param=param) + + allclose(sparam.data_payload, param_ref.data) + + # Test get memory usage + sparam.saved_grad = StatefulTensor(torch.randn(2, 3)) + cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() + assert cpu_mem_use == 2 * 3 * 4 * 2, f"cpu_mem_use: {cpu_mem_use}" + + sparam.set_data_none() + assert (param.data.numel() == 0) + cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() + # 4 is size of dummy tensor of param.data + assert cpu_mem_use == 2 * 3 * 4 * 2 + + sparam.saved_grad = StatefulTensor(torch.randn(2, 3)) + sparam.set_data_none() + cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() + assert cpu_mem_use == 2 * 3 * 4 * 2 + assert cuda_mem_use == 0 + + # append a grad to torch param + param.data = sparam.data_payload + param.grad = torch.randn(2, 3) + cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() + assert cpu_mem_use == 2 * 3 * 4 * 2 + 2 * 3 * 4, f"cpu_mem_use {cpu_mem_use}" + assert cuda_mem_use == 0 + + # reuse torch grad for sparam + sparam.saved_grad = StatefulTensor(param.grad) + cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() + assert cpu_mem_use == 2 * 3 * 4 * 2 + assert cuda_mem_use == 0 + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [1, 2]) +@rerun_if_address_is_in_use() +def test_shard_param_v2(world_size): + spawn(_run_shard_param_v2, world_size) + + +if __name__ == '__main__': + # test_shard_tensor(2) + test_shard_param_v2(2) diff --git a/tests/test_zero/test_legacy/test_sharded_optim_state_dict.py b/tests/test_zero/test_legacy/test_sharded_optim_state_dict.py new file mode 100644 index 0000000000000000000000000000000000000000..1ca144662722df86800431547d9855ef7f52e4b4 --- /dev/null +++ b/tests/test_zero/test_legacy/test_sharded_optim_state_dict.py @@ -0,0 +1,89 @@ +import pytest +import torch + +import colossalai +from colossalai.nn.optimizer import HybridAdam +from colossalai.tensor import ProcessGroup +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils.cuda import get_current_device +from colossalai.zero.legacy.init_ctx import ZeroInitContext +from colossalai.zero.legacy.shard_utils import TensorShardStrategy +from colossalai.zero.legacy.sharded_model import ShardedModelV2 +from colossalai.zero.legacy.sharded_optim import ShardedOptimizerV2 +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_tensor.common_utils import set_seed + + +def init_zero(model_builder, placement_policy): + device = get_current_device() if placement_policy == 'cuda' else torch.device('cpu') + shard_strategy = TensorShardStrategy() + with ZeroInitContext(target_device=device, shard_strategy=shard_strategy, shard_param=True): + model = model_builder() + model = ShardedModelV2( + model, + shard_strategy, + tensor_placement_policy=placement_policy, + reuse_fp16_shard=True, + ) + optim = HybridAdam(model.parameters(), lr=1e-3) + optim = ShardedOptimizerV2(model, optim, initial_scale=32) + return model, optim + + +def run_step(model, optim, criterion, data, label): + optim.zero_grad() + logits = model(data) + loss = criterion(logits, label) + optim.backward(loss) + optim.step() + + +def check_state_dict_eq(state_dict, other): + for p, state in state_dict['state'].items(): + other_state = other['state'][p] + for k, v in state.items(): + if isinstance(v, torch.Tensor): + assert torch.allclose(v, other_state[k], atol=1e-3), f'{v} vs {other_state[k]}' + else: + assert v == other_state[k] + + +@parameterize('placement_policy', ['cuda', 'cpu']) +def run_nested_model(placement_policy): + get_components_func = non_distributed_component_funcs.get_callable('simple_net') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + set_seed(42) + model, optim = init_zero(model_builder, placement_policy) + set_seed(42) + model_copy, optim_copy = init_zero(model_builder, placement_policy) + + model.train() + model_copy.train() + pg = ProcessGroup() + set_seed(pg.dp_local_rank()) + data_iter = iter(train_dataloader) + + data, label = map(lambda x: x.cuda(), next(data_iter)) + run_step(model, optim, criterion, data, label) + optim_copy.load_state_dict(optim.state_dict()) + check_state_dict_eq(optim.state_dict(), optim_copy.state_dict()) + + data, label = map(lambda x: x.cuda(), next(data_iter)) + run_step(model_copy, optim_copy, criterion, data, label) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_nested_model() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2]) +@rerun_if_address_is_in_use() +def test_sharded_optim_state_dist(world_size): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_sharded_optim_state_dist(2) diff --git a/tests/test_zero/test_legacy/test_sharded_optim_v2.py b/tests/test_zero/test_legacy/test_sharded_optim_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..c6f77995ebcd7820795d6a93e49195105a79bffd --- /dev/null +++ b/tests/test_zero/test_legacy/test_sharded_optim_v2.py @@ -0,0 +1,110 @@ +import pytest +import torch +import torch.distributed as dist +from common import CONFIG, check_sharded_model_params +from torch.nn.parallel import DistributedDataParallel as DDP + +import colossalai +from colossalai.amp import convert_to_apex_amp +from colossalai.nn.optimizer import CPUAdam +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils.cuda import get_current_device +from colossalai.zero.legacy.init_ctx import ZeroInitContext +from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy +from colossalai.zero.legacy.sharded_model import ShardedModelV2 +from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy +from colossalai.zero.legacy.sharded_optim import ShardedOptimizerV2 +from colossalai.zero.low_level._utils import has_inf_or_nan +from tests.components_to_test.registry import non_distributed_component_funcs + + +def _run_step(model, optimizer, data, label, criterion, enable_autocast=False): + model.train() + optimizer.zero_grad() + with torch.cuda.amp.autocast(enabled=enable_autocast): + if criterion: + y = model(data) + loss = criterion(y, label) + else: + loss = model(data, label) + + loss = loss.float() + if isinstance(model, ShardedModelV2): + optimizer.backward(loss) + else: + loss.backward() + optimizer.step() + + +@parameterize("cpu_offload", [True, False]) +@parameterize("use_cpuadam", [True, False]) +@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) +@parameterize("gpu_margin_mem_ratio", [0.0, 0.7]) +def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, gpu_margin_mem_ratio): + test_models = ['repeated_computed_layers', 'resnet18', 'bert', 'hanging_param_model'] + shard_strategy = shard_strategy_class() + + if use_cpuadam and cpu_offload is False: + return + if gpu_margin_mem_ratio > 0.0 and not (cpu_offload and use_cpuadam): + return + + for model_name in test_models: + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() + + with ZeroInitContext(target_device=torch.device(f'cpu:0') if cpu_offload else get_current_device(), + shard_strategy=shard_strategy, + shard_param=True): + zero_model = model_builder(checkpoint=True) + zero_model = ShardedModelV2( + zero_model, + shard_strategy, + tensor_placement_policy='cpu' if cpu_offload else 'auto', + reuse_fp16_shard=use_cpuadam, + ) + + model = model_builder(checkpoint=True).half() + col_model_deepcopy(zero_model, model) + model = model.cuda().float() + + if use_cpuadam: + optimizer_class = CPUAdam + optim = optimizer_class(model.parameters(), lr=1e-3) + sharded_optim = optimizer_class(zero_model.parameters(), lr=1e-3) + sharded_optim = ShardedOptimizerV2(zero_model, + sharded_optim, + initial_scale=2**5, + gpu_margin_mem_ratio=gpu_margin_mem_ratio) + + amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False) + apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config) + if dist.get_world_size() > 1: + apex_model = DDP(apex_model, device_ids=[torch.cuda.current_device()]) + + for i, (data, label) in enumerate(train_dataloader): + if i > 5: + break + data, label = data.cuda(), label.cuda() + _run_step(apex_model, apex_optimizer, data, label, criterion, False) + _run_step(zero_model, sharded_optim, data, label, criterion, False) + check_sharded_model_params(model, zero_model, loose=True, reuse_fp16_shard=use_cpuadam) + for param in model.parameters(): + assert not has_inf_or_nan(param) + + +def _run_dist(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + _run_test_sharded_optim_v2() + + +# use_cpuadam = True can be used with cpu_offload = False +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [1, 2]) +@rerun_if_address_is_in_use() +def test_sharded_optim_v2(world_size): + spawn(_run_dist, world_size) + + +if __name__ == '__main__': + test_sharded_optim_v2(world_size=2) diff --git a/tests/test_zero/test_legacy/test_sharded_optim_with_sync_bn.py b/tests/test_zero/test_legacy/test_sharded_optim_with_sync_bn.py new file mode 100644 index 0000000000000000000000000000000000000000..61d850d06080fd444e365fe5d4e5b91d284f42d8 --- /dev/null +++ b/tests/test_zero/test_legacy/test_sharded_optim_with_sync_bn.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import pytest +import torch +import torch.distributed as dist +from torchvision.models import resnet50 + +import colossalai +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.zero.legacy.init_ctx import ZeroInitContext +from colossalai.zero.legacy.shard_utils import TensorShardStrategy + + +def run_dist(rank, world_size, port): + # this test only runs on resnet18 + # as this model has sync batch normalization + # need to configure cudnn deterministic so that + # randomness of convolution layers will be disabled + zero_config = dict(model_config=dict(shard_strategy=TensorShardStrategy())) + colossalai.launch(config=dict(zero=zero_config, cudnn_determinstic=True, cudnn_benchmark=False), + rank=rank, + world_size=world_size, + host='localhost', + port=port, + backend='nccl') + + with ZeroInitContext(target_device=torch.cuda.current_device(), + shard_strategy=gpc.config.zero.model_config.shard_strategy, + shard_param=True): + model = resnet50() + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + criterion = torch.nn.CrossEntropyLoss() + + engine, *args = colossalai.initialize(model, optimizer, criterion) + + # train for dummy iterations + engine.train() + for _ in range(2): + data = torch.rand(4, 3, 128, 128).cuda().half() + label = torch.randint(0, 10, size=(4,)).cuda() + engine.zero_grad() + out = engine(data) + loss = engine.criterion(out, label) + engine.backward(loss) + engine.step() + + # test + # need to make sure the batch norm stats are synchronized + # so that given the same input, the model will produce the same + # output on different ranks + engine.eval() + data = torch.rand(4, 3, 128, 128).cuda().half() + dist.broadcast(data, src=0, group=gpc.get_group(ParallelMode.DATA)) + + # predict + out = engine(data) + + # test if results are equal + tensor_list = [torch.empty_like(out) for _ in range(world_size - 1)] + tensor_list.insert(rank, out) + dist.all_gather(tensor_list=tensor_list, tensor=out, group=gpc.get_group(ParallelMode.DATA)) + + assert torch.all(tensor_list[0] == tensor_list[1]), \ + 'expected the output from different ranks to be the same, but got different values' + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_sharded_optim_with_sync_bn(): + """ + This test is to make sure that buffers are synchronized between ranks + when using ZeRO. An example of module buffer is the running stats of + BatchNormalization layer, i.e. mean and var. + + If the buffers are not synchronized, the model will produce different + output even though the input and parameters are the same. This is not + wanted if we are doing predictions. + + """ + spawn(run_dist, 2) + + +if __name__ == '__main__': + test_sharded_optim_with_sync_bn() diff --git a/tests/test_zero/test_legacy/test_state_dict.py b/tests/test_zero/test_legacy/test_state_dict.py new file mode 100644 index 0000000000000000000000000000000000000000..5f76fff3e5c372eeb7b62280c73b04082d289901 --- /dev/null +++ b/tests/test_zero/test_legacy/test_state_dict.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from functools import partial + +import pytest +import torch +from common import CONFIG + +import colossalai +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.zero.legacy.init_ctx import ZeroInitContext +from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy +from colossalai.zero.legacy.sharded_model import ShardedModelV2 +from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy +from tests.components_to_test.registry import non_distributed_component_funcs + + +@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) +def run_zero_state_dict(shard_strategy_class): + test_models = ['repeated_computed_layers', 'resnet18'] + shard_strategy = shard_strategy_class() + for model_name in test_models: + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, test_dataloader, optimizer, criterion = get_components_func() + + with ZeroInitContext(target_device=torch.device('cuda', torch.cuda.current_device()), + shard_strategy=shard_strategy, + shard_param=True): + zero_model = model_builder(checkpoint=True) + zero_model = ShardedModelV2(zero_model, shard_strategy) + + model = model_builder(checkpoint=True).half() + col_model_deepcopy(zero_model, model) + model = model.cuda() + + zero_state_dict = zero_model.state_dict() + for key, val in model.state_dict().items(): + assert torch.equal(val, zero_state_dict[key].to(val.device)) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_zero_state_dict() + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [1, 2]) +@rerun_if_address_is_in_use() +def test_zero_state_dict(world_size): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_zero_state_dict(2) diff --git a/tests/test_zero/test_legacy/test_tensor_utils.py b/tests/test_zero/test_legacy/test_tensor_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..238bc3fe1a98084df12381ae48a4296c9618c08b --- /dev/null +++ b/tests/test_zero/test_legacy/test_tensor_utils.py @@ -0,0 +1,94 @@ +import pytest +import torch + +import colossalai +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils.cuda import get_current_device +from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor +from colossalai.zero.legacy.gemini.tensor_utils import ( + colo_model_data_move_to_cpu, + colo_model_data_tensor_move, + colo_model_data_tensor_move_inline, + colo_model_tensor_clone, + colo_tensor_mem_usage, +) + + +def _run_colo_tensor_mem_usage(): + for i in range(1): + if i == 1: + t1 = StatefulTensor(torch.randn(2, 2)) + t2 = StatefulTensor(torch.randn(4, 4)) + c1, g1 = colo_tensor_mem_usage(t1) + c2, g2 = colo_tensor_mem_usage(t2) + assert c1 * 4 == c2 + assert g1 * 4 == g2 + else: + t1 = torch.randn(2, 2) + t2 = torch.randn(4, 4) + c1, g1 = colo_tensor_mem_usage(t1) + c2, g2 = colo_tensor_mem_usage(t2) + assert c1 * 4 == c2 + assert g1 * 4 == g2 + + +def _run_colo_model_data_tensor_move_inline(): + for t in [StatefulTensor(torch.randn(2, 3)), torch.randn(2, 3)]: + colo_model_data_tensor_move_inline(t, get_current_device()) + assert t.device == get_current_device() + + +def _run_colo_model_data_tensor_move(): + for t in [(StatefulTensor(torch.ones(2, 3)), StatefulTensor(torch.zeros(2, 3).to(get_current_device()))), + (torch.ones(2, 3), torch.zeros(2, 3).to(get_current_device()))]: + cpu_t, cuda_t = t + colo_model_data_tensor_move(cpu_t, cuda_t) + assert cuda_t.device == get_current_device() + + +def _run_colo_model_data_move_to_cpu(): + for t in [StatefulTensor(torch.randn(2, 2)), torch.randn(4, 4)]: + colo_model_data_move_to_cpu(t) + assert t.device == torch.device("cpu") + + +def _run_colo_model_tensor_clone(): + for t in [ + StatefulTensor(torch.randn(2, 2).cuda(torch.cuda.current_device())), + torch.randn(4, 4).cuda(torch.cuda.current_device()) + ]: + if issubclass(type(t), StatefulTensor): + assert t.payload.device == get_current_device() + else: + assert t.device == get_current_device() + p = colo_model_tensor_clone(t, get_current_device()) + assert p.device == get_current_device() + for i in range(2): + for j in range(2): + if issubclass(type(t), StatefulTensor): + assert t.payload.device == p.device + assert t.payload[i][j] == p[i][j] + else: + assert t.device == p.device + assert t[i][j] == p[i][j] + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + _run_colo_tensor_mem_usage() + _run_colo_model_data_tensor_move_inline() + _run_colo_model_data_tensor_move() + _run_colo_model_data_move_to_cpu() + _run_colo_model_tensor_clone() + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [2, 4]) +@rerun_if_address_is_in_use() +def test_zero_tensor_utils(world_size): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_zero_tensor_utils(world_size=2) diff --git a/tests/test_zero/test_legacy/test_zero_engine.py b/tests/test_zero/test_legacy/test_zero_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..dc8847ce56ab97c188684c1cf154f0d37628f5d5 --- /dev/null +++ b/tests/test_zero/test_legacy/test_zero_engine.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import pytest +import torch +import torch.distributed as dist +from common import MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params, check_sharded_model_params +from torch.nn.parallel import DistributedDataParallel as DDP + +import colossalai +from colossalai.core import global_context as gpc +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.zero.legacy.init_ctx import ZeroInitContext +from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy +from colossalai.zero.low_level._utils import has_inf_or_nan +from tests.components_to_test.registry import non_distributed_component_funcs + + +def run_dist(rank, world_size, port, parallel_config): + colossalai.launch(config=parallel_config, + rank=rank, + world_size=world_size, + host='localhost', + port=port, + backend='nccl') + + test_models = ['repeated_computed_layers', 'resnet18', 'bert'] + for model_name in test_models: + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() + with ZeroInitContext(target_device=torch.cuda.current_device(), + shard_strategy=gpc.config.zero.model_config.shard_strategy, + shard_param=True): + colo_model = model_builder(checkpoint=True) + + colo_optimizer = optimizer_class(colo_model.parameters(), lr=1e-3) + engine, train_dataloader, _, _ = colossalai.initialize(colo_model, + optimizer=colo_optimizer, + criterion=criterion, + train_dataloader=train_dataloader) + torch_model = model_builder(checkpoint=True).half() + col_model_deepcopy(engine.model, torch_model) + torch_model = torch_model.cuda().float() + + engine.train() + torch_optimizer = optimizer_class(torch_model.parameters(), lr=1e-3) + + if dist.get_world_size() > 1: + torch_model = DDP(torch_model, device_ids=[torch.cuda.current_device()]) + + i = 0 + for data, label in train_dataloader: + if i > 4: + break + + data, label = data.cuda(), label.cuda() + + engine.zero_grad() + torch_optimizer.zero_grad() + + if criterion: + output = engine(data) + loss = engine.criterion(output, label) + + torch_output = torch_model(data) + torch_loss = engine.criterion(torch_output, label) + else: + loss = engine(data, label) + torch_loss = torch_model(data, label) + + engine.backward(loss) + engine.step() + + torch_loss.backward() + + for param in torch_model.parameters(): + if param.grad is not None: + assert not has_inf_or_nan(param.grad) + + torch_optimizer.step() + i += 1 + + if parallel_config == MP_PARALLEL_CONFIG: + check_params(torch_model, colo_model, loose=True) + elif parallel_config == ZERO_PARALLEL_CONFIG: + check_sharded_model_params(torch_model, colo_model, loose=True) + + +# FIXME: enable this test in next PR +@pytest.mark.skip +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [2, 4]) +@rerun_if_address_is_in_use() +def test_mp_engine(world_size): + spawn(run_dist, world_size, parallel_config=MP_PARALLEL_CONFIG) + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [1, 2]) +@rerun_if_address_is_in_use() +def test_zero_engine(world_size): + spawn(run_dist, world_size, parallel_config=ZERO_PARALLEL_CONFIG) + + +if __name__ == '__main__': + test_zero_engine(world_size=4) diff --git a/tests/test_zero/test_low_level/test_grad_acc.py b/tests/test_zero/test_low_level/test_grad_acc.py new file mode 100644 index 0000000000000000000000000000000000000000..2ae1f3a99d79b3d3657617a98cdafad7a10cb683 --- /dev/null +++ b/tests/test_zero/test_low_level/test_grad_acc.py @@ -0,0 +1,163 @@ +import copy + +import pytest +import torch +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.testing import assert_close + +import colossalai +from colossalai.testing import spawn +from colossalai.testing.random import seed_all +from colossalai.zero import LowLevelZeroOptimizer + + +class MlpModel(nn.Module): + + def __init__(self): + super(MlpModel, self).__init__() + self.linear1 = nn.Linear(128, 256) + self.linear2 = nn.Linear(256, 512) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +def exam_zero_1_2_grad_acc(): + local_rank = torch.distributed.get_rank() + seed_all(2009) + + # create model + zero1_model = MlpModel().cuda() + zero2_model = copy.deepcopy(zero1_model) + # create optimizer + zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1) + zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1) + zero1_optimizer = LowLevelZeroOptimizer(zero1_optimizer, + overlap_communication=True, + initial_scale=32, + clip_grad_norm=1.0, + verbose=True) + zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer, + overlap_communication=True, + partition_grad=True, + initial_scale=32, + clip_grad_norm=1.0) + # create data + seed_all(2021 + local_rank) + input_data1 = torch.randn(32, 128).cuda() + input_data2 = torch.randn(32, 128).cuda() + + def fwd_bwd_func(number, cur_data): + # zero-dp forward + zero1_output = zero1_model(cur_data) + zero2_output = zero2_model(cur_data) + assert torch.equal(zero1_output, zero2_output) + + # zero-dp backward + zero1_optimizer.backward(zero1_output.sum().float(), sync_grad=False) + zero2_optimizer.backward(zero2_output.sum().float(), sync_grad=False) + + for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()): + if z2p.grad is not None: + # print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad))) + assert torch.equal(z1p.grad, z2p.grad) + + zero1_optimizer._sync_grad() + zero2_optimizer._sync_grad() + + fwd_bwd_func(0, input_data1) + fwd_bwd_func(1, input_data2) + + # step + zero1_optimizer.step() + zero2_optimizer.step() + + # check updated param + for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()): + assert torch.equal(z1p.data, z2p.data) + + +def exam_zero_1_grad_acc(): + local_rank = torch.distributed.get_rank() + grad_scale = 32 + seed_all(2008) + + # create models + zero_model = MlpModel() + torch_model = copy.deepcopy(zero_model) + + seed_all(2008) + zero_model = zero_model.cuda() + torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0) + + # create optimizer + zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1) + + # we only test stage 1 here + # in `check_sharded_param_consistency.py`, we will test whether + # level 1 and 2 will produce exactly the same results + zero_optimizer = LowLevelZeroOptimizer(zero_optimizer, + overlap_communication=False, + initial_scale=grad_scale, + reduce_bucket_size=262144, + clip_grad_norm=1.0) + + torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1) + + # create data + seed_all(2022 + local_rank) + input_data1 = torch.randn(32, 128).cuda() + input_data2 = torch.randn(32, 128).cuda() + + def fwd_bwd_func(number, cur_data, check_flag): + # zero-dp forward + zero_output = zero_model(cur_data) + + # torch-ddp forward + torch_output = torch_model(cur_data) + assert torch.equal(zero_output, torch_output) + + # zero-dp backward + zero_optimizer.backward(zero_output.sum().float(), sync_grad=False) + # torch-ddp backward + torch_output.sum().backward() + + if check_flag: + # check grad + for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): + unscale_grad = z1p.grad / grad_scale + # print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad))) + assert torch.equal(p.grad, unscale_grad) + + zero_optimizer._sync_grad() + + fwd_bwd_func(0, input_data1, True) + fwd_bwd_func(1, input_data2, False) + + zero_optimizer.step() + torch.nn.utils.clip_grad_norm_(torch_model.parameters(), 1.0) + torch_optimizer.step() + + # check updated param + for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): + # print(n, p.shape, torch.max(p.data), torch.max(z1p.data), torch.max(torch.abs(p.data - z1p.data))) + assert_close(p.data, z1p.data) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + + exam_zero_1_grad_acc() + exam_zero_1_2_grad_acc() + + +@pytest.mark.dist +def test_grad_accumulation(): + spawn(run_dist, 2) + + +if __name__ == '__main__': + test_grad_accumulation() diff --git a/tests/test_zero/test_low_level/test_zero1_2.py b/tests/test_zero/test_low_level/test_zero1_2.py new file mode 100644 index 0000000000000000000000000000000000000000..4086af9d896e83ba51953b195fd23f15e301628a --- /dev/null +++ b/tests/test_zero/test_low_level/test_zero1_2.py @@ -0,0 +1,183 @@ +import copy + +import pytest +import torch +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.testing import assert_close + +import colossalai +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.testing.random import seed_all +from colossalai.zero import LowLevelZeroOptimizer + + +class MlpModel(nn.Module): + + def __init__(self): + super(MlpModel, self).__init__() + self.linear1 = nn.Linear(128, 256) + self.linear2 = nn.Linear(256, 512) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +def half_close(a, b, loose=False): + rtol = None + atol = None + if loose: + rtol = 5e-2 + atol = 5e-4 + + a = a.detach().half() + b = b.detach().half() + + assert_close(a, b, rtol=rtol, atol=atol) + + +def exam_zero_1_2(): + """ + In this test, we want to test whether zero stage 1 and 2 + deliver the same numerical results despite different communication + pattern + + we use these prefixes to differentiate the zero stage + oss: partition optimizer states + pg: partition gradients and optimizer states + + """ + local_rank = torch.distributed.get_rank() + seed_all(2001) + + # create model + zero1_model = MlpModel().cuda() + zero2_model = copy.deepcopy(zero1_model) + + # create optimizer + zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1) + zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1) + zero1_optimizer = LowLevelZeroOptimizer(zero1_optimizer, + overlap_communication=True, + initial_scale=128, + verbose=True) + zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer, + overlap_communication=True, + partition_grad=True, + initial_scale=128) + # create data + seed_all(2001 + local_rank) + input_data = torch.randn(32, 128).cuda() + + zero1_output = zero1_model(input_data) + zero2_output = zero2_model(input_data) + assert torch.equal(zero1_output, zero2_output) + + # zero-dp backward + zero1_optimizer.backward(zero1_output.mean().float(), sync_grad=False) + zero2_optimizer.backward(zero2_output.mean().float(), sync_grad=False) + + for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()): + if z2p.grad is not None: + # print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad))) + assert torch.equal(z1p.grad, z2p.grad) + + zero1_optimizer._sync_grad() + zero2_optimizer._sync_grad() + + # step + zero1_optimizer.step() + zero2_optimizer.step() + + # check updated param + for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()): + assert torch.equal(z1p.data, z2p.data) + + +def exam_zero_1_torch_ddp(): + """ + In this test, two pairs of model and optimizers are created. + 1. zero: use sharded optimizer and fp16 parameters + 2. torch: use torch DDP and fp32 parameters + + We feed these two sets of models with the same input and check if the + differences in model output and updated parameters are within tolerance. + """ + local_rank = torch.distributed.get_rank() + seed_all(1453) + + # create models + zero_model = MlpModel() + torch_model = copy.deepcopy(zero_model) + + zero_model = zero_model.cuda().half() + torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0) + torch_model = torch_model.cuda() + + # for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): + # half_close(p.data, z1p.data) + + # create optimizer + zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1) + + # we only test stage 1 here + # in `check_sharded_param_consistency.py`, we will test whether + # level 1 and 2 will produce exactly the same results + zero_optimizer = LowLevelZeroOptimizer(zero_optimizer, + overlap_communication=True, + initial_scale=1, + reduce_bucket_size=262144) + + torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) + + seed_all(1453 + local_rank) + # create + input_data = torch.rand(32, 128).cuda() + + # zero-dp forward + zero_output = zero_model(input_data.half()) + + # torch-ddp forward + torch_output = torch_model(input_data) + half_close(zero_output, torch_output, loose=True) + + # zero-dp backward + zero_optimizer.backward(zero_output.mean().float(), sync_grad=False) + + # torch-ddp backward + torch_output.mean().backward() + + # check grad + for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): + half_close(p.grad, z1p.grad, loose=True) + + # zero-dp step + zero_optimizer._sync_grad() + zero_optimizer.step() + + # torch ddp step + torch_optimizer.step() + + # check updated param + for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): + # print(n, torch.max(torch.abs(p.data - z1p.data))) + half_close(p.data, z1p.data, loose=True) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + + exam_zero_1_torch_ddp() + exam_zero_1_2() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_zero_1_2(): + spawn(run_dist, 2) + + +if __name__ == '__main__': + test_zero_1_2() diff --git a/tests/test_zero/test_low_level/test_zero_init.py b/tests/test_zero/test_low_level/test_zero_init.py new file mode 100644 index 0000000000000000000000000000000000000000..aeeaff5b5cb92a08552fc2bd9b65cd046c25a9e5 --- /dev/null +++ b/tests/test_zero/test_low_level/test_zero_init.py @@ -0,0 +1,56 @@ +import pytest +import torch +import torch.distributed as dist +import torch.nn as nn + +import colossalai +from colossalai.tensor import ProcessGroup +from colossalai.testing import spawn +from colossalai.utils import get_current_device +from colossalai.zero import ColoInitContext, LowLevelZeroOptimizer + + +class MlpModel(nn.Module): + + def __init__(self): + super(MlpModel, self).__init__() + self.linear1 = nn.Linear(128, 256) + self.linear2 = nn.Linear(256, 512) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +def exam_zero_init(): + dp_2_tp_2_pg = ProcessGroup(dp_degree=2, tp_degree=2) + model1 = MlpModel().cuda() + with ColoInitContext(device=get_current_device(), default_pg=dp_2_tp_2_pg): + model2 = MlpModel() + optimizer1 = LowLevelZeroOptimizer(torch.optim.Adam(model1.parameters(), lr=1)) + optimizer2 = LowLevelZeroOptimizer(torch.optim.Adam(model2.parameters(), lr=1)) + + assert optimizer1._local_rank == optimizer2._local_rank + assert optimizer1._world_size == optimizer2._world_size + assert optimizer1._dp_global_ranks == optimizer2._dp_global_ranks + + mp_group1 = optimizer1._mp_torch_group + mp_group2 = optimizer2._mp_torch_group + assert dist.get_world_size(mp_group1) == dist.get_world_size(mp_group2) + assert dist.get_rank(mp_group1) == dist.get_rank(mp_group2) + + +def run_dist(rank, world_size, port): + config_dict = dict(parallel=dict(data=2, tensor=dict(size=2, mode='1d'))) + colossalai.launch(config=config_dict, rank=rank, world_size=world_size, port=port, host='localhost') + exam_zero_init() + + +@pytest.mark.dist +def test_zero_init(): + spawn(run_dist, 4) + + +if __name__ == '__main__': + test_zero_init() diff --git a/tests/test_zero/test_low_level/test_zero_tp.py b/tests/test_zero/test_low_level/test_zero_tp.py new file mode 100644 index 0000000000000000000000000000000000000000..f0804f4bb5ba51b5990b91d1b43f3dd365c93b7a --- /dev/null +++ b/tests/test_zero/test_low_level/test_zero_tp.py @@ -0,0 +1,93 @@ +import pytest +import torch +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.testing import assert_close + +import colossalai +from colossalai.tensor import ProcessGroup +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device +from colossalai.zero import ColoInitContext, LowLevelZeroOptimizer +from tests.test_tensor.common_utils import set_seed, split_param_col_tp1d, split_param_row_tp1d, tensor_shard_equal + + +def strict_shard_equal(tensor, shard, tp_pg, rtol=1e-3, atol=1e-4): + return tensor_shard_equal(tensor, shard, tp_pg.tp_local_rank(), tp_pg.tp_world_size(), rtol, atol) + + +class MlpModel(nn.Module): + + def __init__(self): + super(MlpModel, self).__init__() + self.linear1 = nn.Linear(32, 128) + self.act = nn.GELU() + self.linear2 = nn.Linear(128, 32) + + def forward(self, x): + y = self.linear1(x) + y = self.act(y) + y = self.linear2(y) + return x + y + + +@parameterize("overlap_flag", [False, True]) +@parameterize("partition_flag", [False, True]) +def exam_zero_with_tp(overlap_flag, partition_flag): + set_seed(233010) + tp_pg = ProcessGroup(tp_degree=2) + + with ColoInitContext(device=get_current_device(), default_pg=tp_pg): + hybrid_model = MlpModel() + torch_model = MlpModel().cuda() + for pt, ph in zip(torch_model.parameters(), hybrid_model.parameters()): + pt.data.copy_(ph.data) + + for name, param in hybrid_model.named_parameters(): + if 'linear1' in name: + split_param_row_tp1d(param, tp_pg) + param.compute_spec.set_output_replicate(False) + if 'linear2.weight' in name: + split_param_col_tp1d(param, tp_pg) + + torch_model = DDP(torch_model, device_ids=[tp_pg.rank()], process_group=tp_pg.dp_process_group()) + torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-2) # set to 1e-2 for torch-1.11 + hybrid_optim = torch.optim.Adam(hybrid_model.parameters(), lr=1e-2) + hybrid_optim = LowLevelZeroOptimizer(hybrid_optim, + initial_scale=2, + clip_grad_norm=1.0, + overlap_communication=overlap_flag, + partition_grad=partition_flag) + + dp_local_rank = tp_pg.dp_local_rank() + set_seed(255 + dp_local_rank) + + data = torch.randn(8, 32, device=get_current_device()) + torch_loss = torch_model(data).sum() + hybrid_loss = hybrid_model(data).sum() + assert_close(torch_loss, hybrid_loss) + + torch_loss.backward() + torch.nn.utils.clip_grad_norm_(torch_model.parameters(), 1.0) + hybrid_optim.backward(hybrid_loss) + + torch_optim.step() + hybrid_optim.step() + + for (name, pt), ph in zip(torch_model.named_parameters(), hybrid_model.parameters()): + assert strict_shard_equal(pt.data, ph.data, tp_pg) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') + exam_zero_with_tp() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_zero_with_tp(): + spawn(run_dist, 4) + + +if __name__ == '__main__': + test_zero_with_tp() diff --git a/version.txt b/version.txt new file mode 100644 index 0000000000000000000000000000000000000000..a45be4627678330112f94a9b48a5e821ed846104 --- /dev/null +++ b/version.txt @@ -0,0 +1 @@ +0.2.8